|
|
""" |
|
|
Geometric Attention - 标准 Softmax 版本 |
|
|
基于论文 "The Neural Data Router" (Csordás et al., 2022) |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
def geometric_attention_std( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
head_first: bool = False, |
|
|
seq_start: Optional[torch.Tensor] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
normalize: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
标准 Softmax 版本的 Geometric Attention |
|
|
|
|
|
Args: |
|
|
q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first |
|
|
k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first |
|
|
v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first |
|
|
head_first: 是否head维度在前 |
|
|
seq_start: 序列起始位置 [B] |
|
|
sm_scale: scaling factor,默认 1/sqrt(D) |
|
|
normalize: 是否归一化attention weights |
|
|
|
|
|
Returns: |
|
|
output: [B, T, H, D] or [B, H, T, D] if head_first |
|
|
""" |
|
|
|
|
|
|
|
|
if not head_first: |
|
|
q = rearrange(q, "b t h d -> b h t d") |
|
|
k = rearrange(k, "b t h d -> b h t d") |
|
|
v = rearrange(v, "b t h d -> b h t d") |
|
|
|
|
|
B, H, T_q, D = q.shape |
|
|
T_k = k.shape[2] |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1.0 / math.sqrt(D) |
|
|
|
|
|
|
|
|
logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
|
|
|
|
|
|
|
|
|
|
|
if T_q == T_k: |
|
|
diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device) |
|
|
logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf')) |
|
|
|
|
|
|
|
|
if seq_start is not None: |
|
|
seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
|
|
logits = logits.masked_fill(seq_mask, float('-inf')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = geometric_weighting(logits, normalize=normalize) |
|
|
|
|
|
|
|
|
out = torch.matmul(attn_weights.to(v.dtype), v) |
|
|
|
|
|
if not head_first: |
|
|
out = rearrange(out, "b h t d -> b t h d") |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def geometric_weighting( |
|
|
logits: torch.Tensor, |
|
|
normalize: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
计算geometric attention weights |
|
|
|
|
|
实现论文中的 Equation 7: |
|
|
A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j |
|
|
|
|
|
Args: |
|
|
logits: [B, H, T_q, T_k] attention logits |
|
|
normalize: 是否归一化 |
|
|
|
|
|
Returns: |
|
|
weights: [B, H, T_q, T_k] attention weights |
|
|
""" |
|
|
B, H, T_q, T_k = logits.shape |
|
|
|
|
|
|
|
|
P = torch.sigmoid(logits) |
|
|
|
|
|
|
|
|
log_P = torch.log(P + 1e-10) |
|
|
log_one_minus_P = torch.log(1.0 - P + 1e-10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_decay_left = log_one_minus_P.cumsum(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1)) |
|
|
|
|
|
|
|
|
|
|
|
weights_first = P[:, :, :, :1] |
|
|
weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1) |
|
|
|
|
|
|
|
|
if normalize: |
|
|
weights = F.normalize(weights, p=1, dim=-1) |
|
|
|
|
|
|
|
|
weights = torch.nan_to_num(weights, 0.0) |
|
|
|
|
|
return weights |
|
|
|
|
|
|
|
|
def geometric_weighting_full( |
|
|
logits: torch.Tensor, |
|
|
normalize: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
完整版geometric weighting(更慢但更准确) |
|
|
|
|
|
仅在需要最高精度时使用,训练时建议用上面的简化版 |
|
|
""" |
|
|
B, H, T_q, T_k = logits.shape |
|
|
device = logits.device |
|
|
|
|
|
P = torch.sigmoid(logits) |
|
|
log_P = torch.log(P + 1e-10) |
|
|
log_one_minus_P = torch.log(1.0 - P + 1e-10) |
|
|
|
|
|
|
|
|
weights = torch.zeros_like(P) |
|
|
|
|
|
|
|
|
for i in range(T_q): |
|
|
for j in range(T_k): |
|
|
|
|
|
if i < j: |
|
|
|
|
|
closer_positions = range(i + 1, j) |
|
|
elif i > j: |
|
|
|
|
|
closer_positions = range(j + 1, i) |
|
|
else: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0 |
|
|
|
|
|
|
|
|
weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod) |
|
|
|
|
|
if normalize: |
|
|
weights = F.normalize(weights, p=1, dim=-1) |
|
|
|
|
|
weights = torch.nan_to_num(weights, 0.0) |
|
|
|
|
|
return weights |