"""Qwen3.5-aligned text refiner. Mirrors the per-layer tensor shapes of ``Qwen3_5TextModel`` so that ``transplant_qwen_text_weights.py`` can load real Qwen3.5 weights into our own modules. The mirror is intentionally minimal and architecture-faithful where it can be (RMSNorm, SwiGLU MLP, GQA + rotary), and approximate where Qwen3.5 uses an exotic op (Gated DeltaNet). Layers that mirror DeltaNet keep the input/post norms and MLP weights (which transplant 1:1) but replace the linear-attention mixing with an identity pass — letting the 6 standard ``self_attn`` layers carry the cross-token mixing. This module shares activation singletons (``SHARED_SILU``) and keeps weight names aligned with Qwen's ``layers.{i}.{...}`` paths so transplant is a direct key map. It is dim-agnostic at construction time; defaults match Qwen3.5-0.8B exactly. """ from __future__ import annotations import math from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Shared activation singletons. Importing modules can grab these instead of # instantiating their own; all share the same nn.Module instance so the # adapter has one canonical SiLU rather than thirty. # --------------------------------------------------------------------------- SHARED_SILU = nn.SiLU() SHARED_GELU = nn.GELU() SHARED_SIGMOID = nn.Sigmoid() # --------------------------------------------------------------------------- # Primitives # --------------------------------------------------------------------------- class QwenRMSNorm(nn.Module): """RMSNorm matching Qwen3.5: weight only, no bias, eps default 1e-6.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x_f = x.float() x_f = x_f * torch.rsqrt(x_f.pow(2).mean(-1, keepdim=True) + self.eps) return (x_f * self.weight.float()).to(dtype) def _build_inv_freq(rope_dim: int, base: float, device, dtype) -> torch.Tensor: half = rope_dim // 2 return 1.0 / (base ** (torch.arange(0, half, device=device, dtype=dtype) / half)) def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply rotary to the first ``cos.shape[-1] * 2`` dims of head_dim.""" rope_dim = cos.shape[-1] * 2 x_rope, x_pass = x[..., :rope_dim], x[..., rope_dim:] x1, x2 = x_rope.chunk(2, dim=-1) rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) return torch.cat([rotated, x_pass], dim=-1) class QwenSwiGLU(nn.Module): """Mirrors Qwen3.5 ``mlp`` layer: gate_proj, up_proj, down_proj (no bias).""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) self.act = SHARED_SILU def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) class QwenGatedGQA(nn.Module): """Mirrors Qwen3.5 ``self_attn``: GQA + rotary + per-head q/k norm + a halved ``o_proj`` input (Qwen3.5 splits q into attn/gate halves). Shapes for Qwen3.5-0.8B exactly: q_proj: (q_heads*head_dim, hidden) = (4096, 1024) k_proj: (kv_heads*head_dim, hidden) = ( 512, 1024) v_proj: (kv_heads*head_dim, hidden) = ( 512, 1024) o_proj: (hidden, (q_heads//2)*head_dim) = (1024, 2048) q_norm: (head_dim,) = (256,) k_norm: (head_dim,) = (256,) """ def __init__( self, hidden_size: int = 1024, num_q_heads: int = 16, num_kv_heads: int = 2, head_dim: int = 256, rope_dim: int = 64, rope_base: float = 1_000_000.0, ): super().__init__() assert num_q_heads % 2 == 0, "q heads must be even for Qwen3.5 gated split" assert num_q_heads % num_kv_heads == 0, "q heads must be a multiple of kv heads" self.hidden_size = hidden_size self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.num_attn_heads = num_q_heads // 2 # half routed through attention self.head_dim = head_dim self.rope_dim = rope_dim self.rope_base = rope_base self.kv_repeat = self.num_attn_heads // num_kv_heads q_dim = num_q_heads * head_dim kv_dim = num_kv_heads * head_dim attn_out = self.num_attn_heads * head_dim self.q_proj = nn.Linear(hidden_size, q_dim, bias=False) self.k_proj = nn.Linear(hidden_size, kv_dim, bias=False) self.v_proj = nn.Linear(hidden_size, kv_dim, bias=False) self.o_proj = nn.Linear(attn_out, hidden_size, bias=False) self.q_norm = QwenRMSNorm(head_dim) self.k_norm = QwenRMSNorm(head_dim) def _rotary(self, seq_len: int, device, dtype) -> Tuple[torch.Tensor, torch.Tensor]: inv_freq = _build_inv_freq(self.rope_dim, self.rope_base, device, dtype) pos = torch.arange(seq_len, device=device, dtype=dtype) freqs = torch.einsum("i,j->ij", pos, inv_freq) # (T, rope_dim/2) cos, sin = freqs.cos(), freqs.sin() return cos[None, None, :, :], sin[None, None, :, :] # broadcast over (B, H, T, D/2) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: bsz, seq_len, _ = x.shape q = self.q_proj(x) # (B, T, num_q_heads * hd) k = self.k_proj(x) # (B, T, num_kv * hd) v = self.v_proj(x) # Split Qwen3.5 "gated" q into attn and gate halves q = q.view(bsz, seq_len, self.num_q_heads, self.head_dim) q_attn, q_gate = q[:, :, : self.num_attn_heads, :], q[:, :, self.num_attn_heads :, :] q_attn = self.q_norm(q_attn) k = self.k_norm(k.view(bsz, seq_len, self.num_kv_heads, self.head_dim)) v = v.view(bsz, seq_len, self.num_kv_heads, self.head_dim) # (B, H, T, D) for attention q_attn = q_attn.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) cos, sin = self._rotary(seq_len, x.device, q_attn.dtype) q_attn = _apply_rotary(q_attn, cos, sin) k = _apply_rotary(k, cos, sin) # Expand kv heads to match attn heads (GQA) if self.kv_repeat > 1: k = k.repeat_interleave(self.kv_repeat, dim=1) v = v.repeat_interleave(self.kv_repeat, dim=1) scale = 1.0 / math.sqrt(self.head_dim) attn_scores = torch.matmul(q_attn, k.transpose(-2, -1)) * scale if attention_mask is not None: # 1 = keep, 0 = mask -> additive -inf on masked KEYS, broadcast across heads key_mask = attention_mask[:, None, None, :].to(attn_scores.dtype) attn_scores = attn_scores.masked_fill(key_mask == 0, float("-inf")) attn = attn_scores.softmax(dim=-1) out = torch.matmul(attn, v) # (B, H, T, D) out = out.transpose(1, 2).reshape(bsz, seq_len, self.num_attn_heads * self.head_dim) # Apply gate signal (gate halves * SiLU as in Qwen3.5 gated attention) q_gate = SHARED_SILU(q_gate).reshape(bsz, seq_len, self.num_attn_heads * self.head_dim) out = out * q_gate return self.o_proj(out) class QwenAlignedBlock(nn.Module): """Mirrors a Qwen3.5 transformer block. ``layer_kind="attention"`` mirrors the 6 standard ``self_attn`` layers (3, 7, 11, 15, 19, 23). ``layer_kind="deltanet"`` mirrors the 18 ``linear_attn`` layers structurally but uses identity for the mix-token op so we do not depend on flash-linear-attention. Both kinds keep the Qwen-shaped MLP + norms so weight transplant is 1:1 for those tensors. """ def __init__( self, hidden_size: int, intermediate_size: int, layer_kind: str = "attention", num_q_heads: int = 16, num_kv_heads: int = 2, head_dim: int = 256, rope_dim: int = 64, rope_base: float = 1_000_000.0, ): super().__init__() if layer_kind not in {"attention", "deltanet"}: raise ValueError(f"unknown layer_kind: {layer_kind}") self.layer_kind = layer_kind self.input_layernorm = QwenRMSNorm(hidden_size) self.post_attention_layernorm = QwenRMSNorm(hidden_size) if layer_kind == "attention": self.self_attn = QwenGatedGQA( hidden_size=hidden_size, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, rope_dim=rope_dim, rope_base=rope_base, ) else: self.self_attn = None self.mlp = QwenSwiGLU(hidden_size, intermediate_size) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if self.self_attn is not None: h = self.input_layernorm(x) x = x + self.self_attn(h, attention_mask=attention_mask) # deltanet layers contribute only their MLP (mix happens at the # attention layers; this gives a real residual transformer signal) h = self.post_attention_layernorm(x) x = x + self.mlp(h) return x class QwenAlignedTextRefiner(nn.Module): """Stack of Qwen-aligned blocks. Designed to sit on top of a host text encoder's hidden states and produce a Qwen-conditioned representation at the same hidden dim. The block layout mirrors Qwen3.5-0.8B: ``num_layers=24`` with attention at every 4th position (indices 3, 7, 11, 15, 19, 23), but is configurable so smaller refiners can be transplanted from a Qwen subset. Outputs are projected to ``out_dim`` (defaults to hidden_size) via a final ``norm`` + ``proj`` so the refiner can plug into any downstream conditioning bridge. """ DEFAULT_ATTENTION_INDICES = (3, 7, 11, 15, 19, 23) def __init__( self, hidden_size: int = 1024, intermediate_size: int = 3584, num_layers: int = 24, attention_indices: Optional[Tuple[int, ...]] = None, num_q_heads: int = 16, num_kv_heads: int = 2, head_dim: int = 256, rope_dim: int = 64, rope_base: float = 1_000_000.0, out_dim: Optional[int] = None, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_layers = num_layers self.attention_indices = tuple( self.DEFAULT_ATTENTION_INDICES if attention_indices is None else attention_indices ) self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.rope_dim = rope_dim self.rope_base = rope_base attention_set = set(self.attention_indices) self.layers = nn.ModuleList( [ QwenAlignedBlock( hidden_size=hidden_size, intermediate_size=intermediate_size, layer_kind="attention" if i in attention_set else "deltanet", num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, rope_dim=rope_dim, rope_base=rope_base, ) for i in range(num_layers) ] ) self.norm = QwenRMSNorm(hidden_size) target_dim = hidden_size if out_dim is None else int(out_dim) self.out_dim = target_dim if target_dim == hidden_size: self.proj = nn.Identity() else: self.proj = nn.Linear(hidden_size, target_dim, bias=False) self.gate = nn.Parameter(torch.zeros(())) # learned residual gate, init 0 (identity) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if hidden_states.shape[-1] != self.hidden_size: raise ValueError( f"QwenAlignedTextRefiner expected dim {self.hidden_size}, got {hidden_states.shape[-1]}" ) residual = hidden_states h = hidden_states for layer in self.layers: h = layer(h, attention_mask=attention_mask) h = self.norm(h) h = self.proj(h) if isinstance(self.proj, nn.Identity): # gate=0 init: refiner starts as identity; training learns to mix it in return residual + torch.tanh(self.gate) * (h - residual) # When projecting to a new dim, residual is not addable — return h directly. # gate is still a learnable scalar so downstream training can dampen this path. return h * (1.0 + torch.tanh(self.gate)) def get_qwen_state_dict_map(self) -> List[Tuple[str, str]]: """Return list of (qwen_key, our_key) pairs for transplant. Only includes tensors whose shape matches between Qwen3.5 and us.""" pairs: List[Tuple[str, str]] = [] for i in range(self.num_layers): ours = f"layers.{i}" qwen = f"layers.{i}" pairs.append((f"{qwen}.input_layernorm.weight", f"{ours}.input_layernorm.weight")) pairs.append((f"{qwen}.post_attention_layernorm.weight", f"{ours}.post_attention_layernorm.weight")) pairs.append((f"{qwen}.mlp.gate_proj.weight", f"{ours}.mlp.gate_proj.weight")) pairs.append((f"{qwen}.mlp.up_proj.weight", f"{ours}.mlp.up_proj.weight")) pairs.append((f"{qwen}.mlp.down_proj.weight", f"{ours}.mlp.down_proj.weight")) if i in set(self.attention_indices): pairs.append((f"{qwen}.self_attn.q_proj.weight", f"{ours}.self_attn.q_proj.weight")) pairs.append((f"{qwen}.self_attn.k_proj.weight", f"{ours}.self_attn.k_proj.weight")) pairs.append((f"{qwen}.self_attn.v_proj.weight", f"{ours}.self_attn.v_proj.weight")) pairs.append((f"{qwen}.self_attn.o_proj.weight", f"{ours}.self_attn.o_proj.weight")) pairs.append((f"{qwen}.self_attn.q_norm.weight", f"{ours}.self_attn.q_norm.weight")) pairs.append((f"{qwen}.self_attn.k_norm.weight", f"{ours}.self_attn.k_norm.weight")) pairs.append(("norm.weight", "norm.weight")) return pairs