| """ |
| Geometric Attention - 标准 Softmax 版本 |
| 基于论文 "The Neural Data Router" (Csordás et al., 2022) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from typing import Optional |
|
|
|
|
| def geometric_attention_std( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| *, |
| head_first: bool = False, |
| seq_start: Optional[torch.Tensor] = None, |
| sm_scale: Optional[float] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """ |
| 标准 Softmax 版本的 Geometric Attention |
| |
| Args: |
| q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first |
| k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first |
| v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first |
| head_first: 是否head维度在前 |
| seq_start: 序列起始位置 [B] |
| sm_scale: scaling factor,默认 1/sqrt(D) |
| normalize: 是否归一化attention weights |
| |
| Returns: |
| output: [B, T, H, D] or [B, H, T, D] if head_first |
| """ |
| |
| |
| if not head_first: |
| q = rearrange(q, "b t h d -> b h t d") |
| k = rearrange(k, "b t h d -> b h t d") |
| v = rearrange(v, "b t h d -> b h t d") |
| |
| B, H, T_q, D = q.shape |
| T_k = k.shape[2] |
| |
| if sm_scale is None: |
| sm_scale = 1.0 / math.sqrt(D) |
| |
| |
| logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
| |
| |
| |
| if T_q == T_k: |
| diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device) |
| logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf')) |
| |
| |
| if seq_start is not None: |
| seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
| logits = logits.masked_fill(seq_mask, float('-inf')) |
| |
| |
| |
| |
| |
| |
| |
| |
| attn_weights = geometric_weighting(logits, normalize=normalize) |
| |
| |
| out = torch.matmul(attn_weights.to(v.dtype), v) |
| |
| if not head_first: |
| out = rearrange(out, "b h t d -> b t h d") |
| |
| return out |
|
|
|
|
| def geometric_weighting( |
| logits: torch.Tensor, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """ |
| 计算geometric attention weights |
| |
| 实现论文中的 Equation 7: |
| A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j |
| |
| Args: |
| logits: [B, H, T_q, T_k] attention logits |
| normalize: 是否归一化 |
| |
| Returns: |
| weights: [B, H, T_q, T_k] attention weights |
| """ |
| B, H, T_q, T_k = logits.shape |
| |
| |
| P = torch.sigmoid(logits) |
| |
| |
| log_P = torch.log(P + 1e-10) |
| log_one_minus_P = torch.log(1.0 - P + 1e-10) |
| |
| |
| |
| |
| |
| log_decay_left = log_one_minus_P.cumsum(dim=-1) |
| |
| |
| |
| weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1)) |
| |
| |
| |
| weights_first = P[:, :, :, :1] |
| weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1) |
| |
| |
| if normalize: |
| weights = F.normalize(weights, p=1, dim=-1) |
| |
| |
| weights = torch.nan_to_num(weights, 0.0) |
| |
| return weights |
|
|
|
|
| def geometric_weighting_full( |
| logits: torch.Tensor, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """ |
| 完整版geometric weighting(更慢但更准确) |
| |
| 仅在需要最高精度时使用,训练时建议用上面的简化版 |
| """ |
| B, H, T_q, T_k = logits.shape |
| device = logits.device |
| |
| P = torch.sigmoid(logits) |
| log_P = torch.log(P + 1e-10) |
| log_one_minus_P = torch.log(1.0 - P + 1e-10) |
| |
| |
| weights = torch.zeros_like(P) |
| |
| |
| for i in range(T_q): |
| for j in range(T_k): |
| |
| if i < j: |
| |
| closer_positions = range(i + 1, j) |
| elif i > j: |
| |
| closer_positions = range(j + 1, i) |
| else: |
| |
| continue |
| |
| |
| log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0 |
| |
| |
| weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod) |
| |
| if normalize: |
| weights = F.normalize(weights, p=1, dim=-1) |
| |
| weights = torch.nan_to_num(weights, 0.0) |
| |
| return weights |