| """GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。 |
| |
| 与 Design.md 一致: |
| - Q 经 Linear + Sigmoid 生成 D 维门控参数; |
| - 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj; |
| - 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .pos_encoding import apply_rope |
|
|
|
|
| class _MultiHeadProj(nn.Module): |
| """通用的多头 Q/K/V 投影 + reshape。""" |
|
|
| def __init__( |
| self, |
| dim_q: int, |
| dim_kv: int, |
| num_heads: int, |
| head_dim: int, |
| q_bias: bool = True, |
| kv_bias: bool = True, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| inner = num_heads * head_dim |
| self.q_proj = nn.Linear(dim_q, inner, bias=q_bias) |
| self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias) |
| self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias) |
|
|
| def project_q(self, x: torch.Tensor) -> torch.Tensor: |
| b, n, _ = x.shape |
| return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| b, n, _ = x.shape |
| k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) |
| return k, v |
|
|
|
|
| class _GateModule(nn.Module): |
| """门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。 |
| |
| bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。 |
| 这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。 |
| """ |
|
|
| def __init__(self, dim: int, init_bias: float = 5.0) -> None: |
| super().__init__() |
| self.proj = nn.Linear(dim, dim) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.constant_(self.proj.bias, init_bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sigmoid(self.proj(x)) |
|
|
|
|
| class GateSelfAttention(nn.Module): |
| """门控自注意力,使用 PyTorch SDPA。 |
| |
| 支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。 |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| dropout: float = 0.0, |
| gate_init_bias: float = 5.0, |
| q_bias: bool = True, |
| kv_bias: bool = True, |
| ) -> None: |
| super().__init__() |
| assert dim % num_heads == 0, "dim 必须能被 num_heads 整除" |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = 1.0 / math.sqrt(self.head_dim) |
| self.dropout_p = dropout |
|
|
| self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias) |
| self.gate = _GateModule(dim, init_bias=gate_init_bias) |
| self.out_proj = nn.Linear(dim, dim, bias=True) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope_cos: Optional[torch.Tensor] = None, |
| rope_sin: Optional[torch.Tensor] = None, |
| visual_slice: Optional[tuple[int, int]] = None, |
| ) -> torch.Tensor: |
| """ |
| 参数 |
| ---- |
| x : [B, N, D] |
| rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None |
| visual_slice : (start, end),指定视觉 token 在序列中的范围。 |
| 非视觉 token 切片 Q/K 不做 RoPE。 |
| """ |
| b, n, _ = x.shape |
| q = self.proj.project_q(x) |
| k, v = self.proj.project_kv(x) |
|
|
| |
| if rope_cos is not None and visual_slice is not None: |
| s, e = visual_slice |
| q_v = q[:, :, s:e, :] |
| k_v = k[:, :, s:e, :] |
| q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin) |
| q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2) |
| k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2) |
|
|
| |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=self.dropout_p if self.training else 0.0, |
| is_causal=False, |
| ) |
| |
| attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim) |
|
|
| |
| gate = self.gate(x) |
| out = self.out_proj(attn * gate) |
| return out |
|
|
|
|
| class GateCrossAttention(nn.Module): |
| """门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。""" |
|
|
| def __init__( |
| self, |
| dim_q: int, |
| dim_kv: int, |
| num_heads: int, |
| dropout: float = 0.0, |
| gate_init_bias: float = 5.0, |
| q_bias: bool = True, |
| kv_bias: bool = True, |
| ) -> None: |
| super().__init__() |
| assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除" |
| self.dim_q = dim_q |
| self.num_heads = num_heads |
| self.head_dim = dim_q // num_heads |
| self.dropout_p = dropout |
|
|
| self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias) |
| self.gate = _GateModule(dim_q, init_bias=gate_init_bias) |
| self.out_proj = nn.Linear(dim_q, dim_q, bias=True) |
|
|
| def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor: |
| b, n, _ = q_in.shape |
| q = self.proj.project_q(q_in) |
| k, v = self.proj.project_kv(kv_in) |
|
|
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=self.dropout_p if self.training else 0.0, |
| is_causal=False, |
| ) |
| attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q) |
| gate = self.gate(q_in) |
| return self.out_proj(attn * gate) |
|
|