| | """ |
| | Geometric Attention - 标准 Softmax 版本 |
| | 基于论文 "The Neural Data Router" (Csordás et al., 2022) |
| | """ |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from typing import Optional |
| |
|
| |
|
| | def geometric_attention_std( |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | *, |
| | head_first: bool = False, |
| | seq_start: Optional[torch.Tensor] = None, |
| | sm_scale: Optional[float] = None, |
| | normalize: bool = True, |
| | ) -> torch.Tensor: |
| | """ |
| | 标准 Softmax 版本的 Geometric Attention |
| | |
| | Args: |
| | q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first |
| | k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first |
| | v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first |
| | head_first: 是否head维度在前 |
| | seq_start: 序列起始位置 [B] |
| | sm_scale: scaling factor,默认 1/sqrt(D) |
| | normalize: 是否归一化attention weights |
| | |
| | Returns: |
| | output: [B, T, H, D] or [B, H, T, D] if head_first |
| | """ |
| | |
| | |
| | if not head_first: |
| | q = rearrange(q, "b t h d -> b h t d") |
| | k = rearrange(k, "b t h d -> b h t d") |
| | v = rearrange(v, "b t h d -> b h t d") |
| | |
| | B, H, T_q, D = q.shape |
| | T_k = k.shape[2] |
| | |
| | if sm_scale is None: |
| | sm_scale = 1.0 / math.sqrt(D) |
| | |
| | |
| | logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale |
| | |
| | |
| | |
| | if T_q == T_k: |
| | diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device) |
| | logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf')) |
| | |
| | |
| | if seq_start is not None: |
| | seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] |
| | logits = logits.masked_fill(seq_mask, float('-inf')) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | attn_weights = geometric_weighting(logits, normalize=normalize) |
| | |
| | |
| | out = torch.matmul(attn_weights.to(v.dtype), v) |
| | |
| | if not head_first: |
| | out = rearrange(out, "b h t d -> b t h d") |
| | |
| | return out |
| |
|
| |
|
| | def geometric_weighting( |
| | logits: torch.Tensor, |
| | normalize: bool = True, |
| | ) -> torch.Tensor: |
| | """ |
| | 计算geometric attention weights |
| | |
| | 实现论文中的 Equation 7: |
| | A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j |
| | |
| | Args: |
| | logits: [B, H, T_q, T_k] attention logits |
| | normalize: 是否归一化 |
| | |
| | Returns: |
| | weights: [B, H, T_q, T_k] attention weights |
| | """ |
| | B, H, T_q, T_k = logits.shape |
| | |
| | |
| | P = torch.sigmoid(logits) |
| | |
| | |
| | log_P = torch.log(P + 1e-10) |
| | log_one_minus_P = torch.log(1.0 - P + 1e-10) |
| | |
| | |
| | |
| | |
| | |
| | log_decay_left = log_one_minus_P.cumsum(dim=-1) |
| | |
| | |
| | |
| | weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1)) |
| | |
| | |
| | |
| | weights_first = P[:, :, :, :1] |
| | weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1) |
| | |
| | |
| | if normalize: |
| | weights = F.normalize(weights, p=1, dim=-1) |
| | |
| | |
| | weights = torch.nan_to_num(weights, 0.0) |
| | |
| | return weights |
| |
|
| |
|
| | def geometric_weighting_full( |
| | logits: torch.Tensor, |
| | normalize: bool = True, |
| | ) -> torch.Tensor: |
| | """ |
| | 完整版geometric weighting(更慢但更准确) |
| | |
| | 仅在需要最高精度时使用,训练时建议用上面的简化版 |
| | """ |
| | B, H, T_q, T_k = logits.shape |
| | device = logits.device |
| | |
| | P = torch.sigmoid(logits) |
| | log_P = torch.log(P + 1e-10) |
| | log_one_minus_P = torch.log(1.0 - P + 1e-10) |
| | |
| | |
| | weights = torch.zeros_like(P) |
| | |
| | |
| | for i in range(T_q): |
| | for j in range(T_k): |
| | |
| | if i < j: |
| | |
| | closer_positions = range(i + 1, j) |
| | elif i > j: |
| | |
| | closer_positions = range(j + 1, i) |
| | else: |
| | |
| | continue |
| | |
| | |
| | log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0 |
| | |
| | |
| | weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod) |
| | |
| | if normalize: |
| | weights = F.normalize(weights, p=1, dim=-1) |
| | |
| | weights = torch.nan_to_num(weights, 0.0) |
| | |
| | return weights |