WJAD / src /wjad /modules /gate_attention.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""GateSelfAttention / GateCrossAttention(基于 PyTorch SDPA)。
与 Design.md 一致:
- Q 经 Linear + Sigmoid 生成 D 维门控参数;
- 注意力得到的多头 V 合并后与门控逐元素相乘,再做 out_proj;
- 门控网络初始化输出 ≈ 1(bias 设大正值,weight ≈ 0),低 LR 缓慢步进。
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pos_encoding import apply_rope
class _MultiHeadProj(nn.Module):
"""通用的多头 Q/K/V 投影 + reshape。"""
def __init__(
self,
dim_q: int,
dim_kv: int,
num_heads: int,
head_dim: int,
q_bias: bool = True,
kv_bias: bool = True,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner = num_heads * head_dim
self.q_proj = nn.Linear(dim_q, inner, bias=q_bias)
self.k_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
self.v_proj = nn.Linear(dim_kv, inner, bias=kv_bias)
def project_q(self, x: torch.Tensor) -> torch.Tensor:
b, n, _ = x.shape
return self.q_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
def project_kv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
b, n, _ = x.shape
k = self.k_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
return k, v
class _GateModule(nn.Module):
"""门控生成器:输入 Q 来源张量,输出 [B,N,D] 门控值,初始 ≈ 1。
bias 初始化为 ``init_bias``(默认 5.0 → sigmoid≈0.993),weight 初始化为 0。
这样初始状态等价于普通注意力,门控随训练缓慢偏离 1。
"""
def __init__(self, dim: int, init_bias: float = 5.0) -> None:
super().__init__()
self.proj = nn.Linear(dim, dim)
nn.init.zeros_(self.proj.weight)
nn.init.constant_(self.proj.bias, init_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(self.proj(x))
class GateSelfAttention(nn.Module):
"""门控自注意力,使用 PyTorch SDPA。
支持仅对视觉 token 应用 3D RoPE:通过 ``visual_slice`` 指定切片。
"""
def __init__(
self,
dim: int,
num_heads: int,
dropout: float = 0.0,
gate_init_bias: float = 5.0,
q_bias: bool = True,
kv_bias: bool = True,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim 必须能被 num_heads 整除"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.dropout_p = dropout
self.proj = _MultiHeadProj(dim, dim, num_heads, self.head_dim, q_bias, kv_bias)
self.gate = _GateModule(dim, init_bias=gate_init_bias)
self.out_proj = nn.Linear(dim, dim, bias=True)
def forward(
self,
x: torch.Tensor,
rope_cos: Optional[torch.Tensor] = None,
rope_sin: Optional[torch.Tensor] = None,
visual_slice: Optional[tuple[int, int]] = None,
) -> torch.Tensor:
"""
参数
----
x : [B, N, D]
rope_cos, rope_sin : [B, N_v, H, head_dim/2] 或 None
visual_slice : (start, end),指定视觉 token 在序列中的范围。
非视觉 token 切片 Q/K 不做 RoPE。
"""
b, n, _ = x.shape
q = self.proj.project_q(x) # [B, H, N, Dh]
k, v = self.proj.project_kv(x)
# 仅对视觉切片应用 RoPE
if rope_cos is not None and visual_slice is not None:
s, e = visual_slice
q_v = q[:, :, s:e, :]
k_v = k[:, :, s:e, :]
q_v, k_v = apply_rope(q_v, k_v, rope_cos, rope_sin)
q = torch.cat([q[:, :, :s, :], q_v, q[:, :, e:, :]], dim=2)
k = torch.cat([k[:, :, :s, :], k_v, k[:, :, e:, :]], dim=2)
# SDPA:[B, H, N, Dh]
attn = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=False,
)
# 多头合并
attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim)
# 门控 ⊗ 多头合并后的 V,再 out_proj
gate = self.gate(x) # 用 Q 的源(即 x)生成门控
out = self.out_proj(attn * gate)
return out
class GateCrossAttention(nn.Module):
"""门控交叉注意力,Q 来自 query token,K/V 来自 context(如 DINOv3 patch 特征)。"""
def __init__(
self,
dim_q: int,
dim_kv: int,
num_heads: int,
dropout: float = 0.0,
gate_init_bias: float = 5.0,
q_bias: bool = True,
kv_bias: bool = True,
) -> None:
super().__init__()
assert dim_q % num_heads == 0, "dim_q 必须能被 num_heads 整除"
self.dim_q = dim_q
self.num_heads = num_heads
self.head_dim = dim_q // num_heads
self.dropout_p = dropout
self.proj = _MultiHeadProj(dim_q, dim_kv, num_heads, self.head_dim, q_bias, kv_bias)
self.gate = _GateModule(dim_q, init_bias=gate_init_bias)
self.out_proj = nn.Linear(dim_q, dim_q, bias=True)
def forward(self, q_in: torch.Tensor, kv_in: torch.Tensor) -> torch.Tensor:
b, n, _ = q_in.shape
q = self.proj.project_q(q_in)
k, v = self.proj.project_kv(kv_in)
attn = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=False,
)
attn = attn.transpose(1, 2).contiguous().view(b, n, self.dim_q)
gate = self.gate(q_in)
return self.out_proj(attn * gate)