"""GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。 与 Design.md 一致: - Q 经 Linear + Sigmoid 生成 D 维门控参数; - 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj; - 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。 """ from __future__ import annotations import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .pos_encoding import apply_rope class _MultiHeadProj(nn.Module): """通用的多头 Q/K/V 投影 + reshape。""" def __init__( self, dim_q: int, dim_kv: int, num_heads: int, head_dim: int, q_bias: bool = True, kv_bias: bool = True, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = head_dim inner = num_heads * head_dim self.q_proj = nn.Linear(dim_q, inner, bias=q_bias) self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias) self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias) def project_q(self, x: torch.Tensor) -> torch.Tensor: b, n, _ = x.shape return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: b, n, _ = x.shape k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2) return k, v class _GateModule(nn.Module): """门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。 bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。 这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。 """ def __init__(self, dim: int, init_bias: float = 5.0) -> None: super().__init__() self.proj = nn.Linear(dim, dim) nn.init.zeros_(self.proj.weight) nn.init.constant_(self.proj.bias, init_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sigmoid(self.proj(x)) class GateSelfAttention(nn.Module): """门控自注意力,使用 PyTorch SDPA。 支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。 """ def __init__( self, dim: int, num_heads: int, dropout: float = 0.0, gate_init_bias: float = 5.0, q_bias: bool = True, kv_bias: bool = True, ) -> None: super().__init__() assert dim % num_heads == 0, "dim 必须能被 num_heads 整除" self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.dropout_p = dropout self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias) self.gate = _GateModule(dim, init_bias=gate_init_bias) self.out_proj = nn.Linear(dim, dim, bias=True) def forward( self, x: torch.Tensor, rope_cos: Optional[torch.Tensor] = None, rope_sin: Optional[torch.Tensor] = None, visual_slice: Optional[tuple[int, int]] = None, ) -> torch.Tensor: """ 参数 ---- x : [B, N, D] rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None visual_slice : (start, end),指定视觉 token 在序列中的范围。 非视觉 token 切片 Q/K 不做 RoPE。 """ b, n, _ = x.shape q = self.proj.project_q(x) # [B, H, N, Dh] k, v = self.proj.project_kv(x) # 仅对视觉切片应用 RoPE if rope_cos is not None and visual_slice is not None: s, e = visual_slice q_v = q[:, :, s:e, :] k_v = k[:, :, s:e, :] q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin) q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2) k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2) # SDPA:[B, H, N, Dh] attn = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout_p if self.training else 0.0, is_causal=False, ) # 多头合并 attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim) # 门控 ⊗ 多头合并后的 V,再 out_proj gate = self.gate(x) # 用 Q 的源(即 x)生成门控 out = self.out_proj(attn * gate) return out class GateCrossAttention(nn.Module): """门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。""" def __init__( self, dim_q: int, dim_kv: int, num_heads: int, dropout: float = 0.0, gate_init_bias: float = 5.0, q_bias: bool = True, kv_bias: bool = True, ) -> None: super().__init__() assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除" self.dim_q = dim_q self.num_heads = num_heads self.head_dim = dim_q // num_heads self.dropout_p = dropout self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias) self.gate = _GateModule(dim_q, init_bias=gate_init_bias) self.out_proj = nn.Linear(dim_q, dim_q, bias=True) def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor: b, n, _ = q_in.shape q = self.proj.project_q(q_in) k, v = self.proj.project_kv(kv_in) attn = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout_p if self.training else 0.0, is_causal=False, ) attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q) gate = self.gate(q_in) return self.out_proj(attn * gate)