File size: 5,803 Bytes
b86534f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
Geometric Attention - 标准 Softmax 版本
基于论文 "The Neural Data Router" (Csordás et al., 2022)
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from typing import Optional


def geometric_attention_std(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    *,
    head_first: bool = False,
    seq_start: Optional[torch.Tensor] = None,
    sm_scale: Optional[float] = None,
    normalize: bool = True,
) -> torch.Tensor:
    """
    标准 Softmax 版本的 Geometric Attention
    
    Args:
        q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first
        k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first
        v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first
        head_first: 是否head维度在前
        seq_start: 序列起始位置 [B]
        sm_scale: scaling factor,默认 1/sqrt(D)
        normalize: 是否归一化attention weights
    
    Returns:
        output: [B, T, H, D] or [B, H, T, D] if head_first
    """
    
    # Rearrange to head_first format
    if not head_first:
        q = rearrange(q, "b t h d -> b h t d")
        k = rearrange(k, "b t h d -> b h t d")
        v = rearrange(v, "b t h d -> b h t d")
    
    B, H, T_q, D = q.shape
    T_k = k.shape[2]
    
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(D)
    
    # Step 1: 计算 content-based logits
    logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
    # logits: [B, H, T_q, T_k]
    
    # Step 2: Mask diagonal (不允许attend到自己)
    if T_q == T_k:
        diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device)
        logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf'))
    
    # Step 3: 处理 seq_start mask
    if seq_start is not None:
        seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None]
        logits = logits.masked_fill(seq_mask, float('-inf'))
    
    # Step 4: Causal mask (如果需要)
    # 注意:geometric attention论文中没有causal,如果你的任务需要可以取消注释
    # P_SEQ = T_k - T_q
    # causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1)
    # logits = logits.masked_fill(causal_mask[None, None, :, :], float('-inf'))
    
    # Step 5: Geometric weighting (核心算法)
    attn_weights = geometric_weighting(logits, normalize=normalize)
    
    # Step 6: 应用attention到values
    out = torch.matmul(attn_weights.to(v.dtype), v)
    
    if not head_first:
        out = rearrange(out, "b h t d -> b t h d")
    
    return out


def geometric_weighting(
    logits: torch.Tensor,
    normalize: bool = True,
) -> torch.Tensor:
    """
    计算geometric attention weights
    
    实现论文中的 Equation 7:
    A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j
    
    Args:
        logits: [B, H, T_q, T_k] attention logits
        normalize: 是否归一化
    
    Returns:
        weights: [B, H, T_q, T_k] attention weights
    """
    B, H, T_q, T_k = logits.shape
    
    # Step 1: Sigmoid to get matching probabilities
    P = torch.sigmoid(logits)  # [B, H, T_q, T_k]
    
    # Step 2: 使用 log-space 计算(数值稳定)
    log_P = torch.log(P + 1e-10)
    log_one_minus_P = torch.log(1.0 - P + 1e-10)
    
    # Step 3: 简化版本 - 使用cumsum实现几何分布
    # 这是一个高效的近似,避免了显式的循环
    
    # 对于每个位置i,计算其左侧所有位置的log(1-P)累积和
    log_decay_left = log_one_minus_P.cumsum(dim=-1)
    
    # 计算weights(简化版)
    # 完整版本需要根据距离动态选择区间,这里用一个高效近似
    weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1))
    
    # 第一个位置特殊处理(没有左侧元素)
    # 避免inplace操作
    weights_first = P[:, :, :, :1]  # 获取第一列
    weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1)
    
    # Step 4: 归一化(可选)
    if normalize:
        weights = F.normalize(weights, p=1, dim=-1)
    
    # 处理NaN(如果所有位置都是-inf)
    weights = torch.nan_to_num(weights, 0.0)
    
    return weights


def geometric_weighting_full(
    logits: torch.Tensor,
    normalize: bool = True,
) -> torch.Tensor:
    """
    完整版geometric weighting(更慢但更准确)
    
    仅在需要最高精度时使用,训练时建议用上面的简化版
    """
    B, H, T_q, T_k = logits.shape
    device = logits.device
    
    P = torch.sigmoid(logits)
    log_P = torch.log(P + 1e-10)
    log_one_minus_P = torch.log(1.0 - P + 1e-10)
    
    # 初始化weights
    weights = torch.zeros_like(P)
    
    # 对每个(i,j)计算geometric weight
    for i in range(T_q):
        for j in range(T_k):
            # 找出比j更接近i的所有位置k
            if i < j:
                # 向右看:closer positions are [i+1, ..., j-1]
                closer_positions = range(i + 1, j)
            elif i > j:
                # 向左看:closer positions are [j+1, ..., i-1]
                closer_positions = range(j + 1, i)
            else:
                # i == j (对角线),已经在外面mask掉了
                continue
            
            # 计算 ∏(1 - P[i,k]) in log-space
            log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0
            
            # weights[i,j] = P[i,j] * ∏(1 - P[i,k])
            weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod)
    
    if normalize:
        weights = F.normalize(weights, p=1, dim=-1)
    
    weights = torch.nan_to_num(weights, 0.0)
    
    return weights