File size: 5,803 Bytes
b86534f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""
Geometric Attention - 标准 Softmax 版本
基于论文 "The Neural Data Router" (Csordás et al., 2022)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from typing import Optional
def geometric_attention_std(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
head_first: bool = False,
seq_start: Optional[torch.Tensor] = None,
sm_scale: Optional[float] = None,
normalize: bool = True,
) -> torch.Tensor:
"""
标准 Softmax 版本的 Geometric Attention
Args:
q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first
k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first
v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first
head_first: 是否head维度在前
seq_start: 序列起始位置 [B]
sm_scale: scaling factor,默认 1/sqrt(D)
normalize: 是否归一化attention weights
Returns:
output: [B, T, H, D] or [B, H, T, D] if head_first
"""
# Rearrange to head_first format
if not head_first:
q = rearrange(q, "b t h d -> b h t d")
k = rearrange(k, "b t h d -> b h t d")
v = rearrange(v, "b t h d -> b h t d")
B, H, T_q, D = q.shape
T_k = k.shape[2]
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# Step 1: 计算 content-based logits
logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale
# logits: [B, H, T_q, T_k]
# Step 2: Mask diagonal (不允许attend到自己)
if T_q == T_k:
diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device)
logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf'))
# Step 3: 处理 seq_start mask
if seq_start is not None:
seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None]
logits = logits.masked_fill(seq_mask, float('-inf'))
# Step 4: Causal mask (如果需要)
# 注意:geometric attention论文中没有causal,如果你的任务需要可以取消注释
# P_SEQ = T_k - T_q
# causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1)
# logits = logits.masked_fill(causal_mask[None, None, :, :], float('-inf'))
# Step 5: Geometric weighting (核心算法)
attn_weights = geometric_weighting(logits, normalize=normalize)
# Step 6: 应用attention到values
out = torch.matmul(attn_weights.to(v.dtype), v)
if not head_first:
out = rearrange(out, "b h t d -> b t h d")
return out
def geometric_weighting(
logits: torch.Tensor,
normalize: bool = True,
) -> torch.Tensor:
"""
计算geometric attention weights
实现论文中的 Equation 7:
A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j
Args:
logits: [B, H, T_q, T_k] attention logits
normalize: 是否归一化
Returns:
weights: [B, H, T_q, T_k] attention weights
"""
B, H, T_q, T_k = logits.shape
# Step 1: Sigmoid to get matching probabilities
P = torch.sigmoid(logits) # [B, H, T_q, T_k]
# Step 2: 使用 log-space 计算(数值稳定)
log_P = torch.log(P + 1e-10)
log_one_minus_P = torch.log(1.0 - P + 1e-10)
# Step 3: 简化版本 - 使用cumsum实现几何分布
# 这是一个高效的近似,避免了显式的循环
# 对于每个位置i,计算其左侧所有位置的log(1-P)累积和
log_decay_left = log_one_minus_P.cumsum(dim=-1)
# 计算weights(简化版)
# 完整版本需要根据距离动态选择区间,这里用一个高效近似
weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1))
# 第一个位置特殊处理(没有左侧元素)
# 避免inplace操作
weights_first = P[:, :, :, :1] # 获取第一列
weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1)
# Step 4: 归一化(可选)
if normalize:
weights = F.normalize(weights, p=1, dim=-1)
# 处理NaN(如果所有位置都是-inf)
weights = torch.nan_to_num(weights, 0.0)
return weights
def geometric_weighting_full(
logits: torch.Tensor,
normalize: bool = True,
) -> torch.Tensor:
"""
完整版geometric weighting(更慢但更准确)
仅在需要最高精度时使用,训练时建议用上面的简化版
"""
B, H, T_q, T_k = logits.shape
device = logits.device
P = torch.sigmoid(logits)
log_P = torch.log(P + 1e-10)
log_one_minus_P = torch.log(1.0 - P + 1e-10)
# 初始化weights
weights = torch.zeros_like(P)
# 对每个(i,j)计算geometric weight
for i in range(T_q):
for j in range(T_k):
# 找出比j更接近i的所有位置k
if i < j:
# 向右看:closer positions are [i+1, ..., j-1]
closer_positions = range(i + 1, j)
elif i > j:
# 向左看:closer positions are [j+1, ..., i-1]
closer_positions = range(j + 1, i)
else:
# i == j (对角线),已经在外面mask掉了
continue
# 计算 ∏(1 - P[i,k]) in log-space
log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0
# weights[i,j] = P[i,j] * ∏(1 - P[i,k])
weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod)
if normalize:
weights = F.normalize(weights, p=1, dim=-1)
weights = torch.nan_to_num(weights, 0.0)
return weights |