Phillnet-2 / ImageGen /model /qwen_aligned_text.py
ayjays132's picture
Upload 478 files
101858b verified
"""Qwen3.5-aligned text refiner.
Mirrors the per-layer tensor shapes of ``Qwen3_5TextModel`` so that
``transplant_qwen_text_weights.py`` can load real Qwen3.5 weights into our
own modules. The mirror is intentionally minimal and architecture-faithful
where it can be (RMSNorm, SwiGLU MLP, GQA + rotary), and approximate where
Qwen3.5 uses an exotic op (Gated DeltaNet). Layers that mirror DeltaNet
keep the input/post norms and MLP weights (which transplant 1:1) but
replace the linear-attention mixing with an identity pass — letting the
6 standard ``self_attn`` layers carry the cross-token mixing.
This module shares activation singletons (``SHARED_SILU``) and keeps
weight names aligned with Qwen's ``layers.{i}.{...}`` paths so transplant
is a direct key map. It is dim-agnostic at construction time; defaults
match Qwen3.5-0.8B exactly.
"""
from __future__ import annotations
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Shared activation singletons. Importing modules can grab these instead of
# instantiating their own; all share the same nn.Module instance so the
# adapter has one canonical SiLU rather than thirty.
# ---------------------------------------------------------------------------
SHARED_SILU = nn.SiLU()
SHARED_GELU = nn.GELU()
SHARED_SIGMOID = nn.Sigmoid()
# ---------------------------------------------------------------------------
# Primitives
# ---------------------------------------------------------------------------
class QwenRMSNorm(nn.Module):
"""RMSNorm matching Qwen3.5: weight only, no bias, eps default 1e-6."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x_f = x.float()
x_f = x_f * torch.rsqrt(x_f.pow(2).mean(-1, keepdim=True) + self.eps)
return (x_f * self.weight.float()).to(dtype)
def _build_inv_freq(rope_dim: int, base: float, device, dtype) -> torch.Tensor:
half = rope_dim // 2
return 1.0 / (base ** (torch.arange(0, half, device=device, dtype=dtype) / half))
def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary to the first ``cos.shape[-1] * 2`` dims of head_dim."""
rope_dim = cos.shape[-1] * 2
x_rope, x_pass = x[..., :rope_dim], x[..., rope_dim:]
x1, x2 = x_rope.chunk(2, dim=-1)
rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return torch.cat([rotated, x_pass], dim=-1)
class QwenSwiGLU(nn.Module):
"""Mirrors Qwen3.5 ``mlp`` layer: gate_proj, up_proj, down_proj (no bias)."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.act = SHARED_SILU
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
class QwenGatedGQA(nn.Module):
"""Mirrors Qwen3.5 ``self_attn``: GQA + rotary + per-head q/k norm + a
halved ``o_proj`` input (Qwen3.5 splits q into attn/gate halves).
Shapes for Qwen3.5-0.8B exactly:
q_proj: (q_heads*head_dim, hidden) = (4096, 1024)
k_proj: (kv_heads*head_dim, hidden) = ( 512, 1024)
v_proj: (kv_heads*head_dim, hidden) = ( 512, 1024)
o_proj: (hidden, (q_heads//2)*head_dim) = (1024, 2048)
q_norm: (head_dim,) = (256,)
k_norm: (head_dim,) = (256,)
"""
def __init__(
self,
hidden_size: int = 1024,
num_q_heads: int = 16,
num_kv_heads: int = 2,
head_dim: int = 256,
rope_dim: int = 64,
rope_base: float = 1_000_000.0,
):
super().__init__()
assert num_q_heads % 2 == 0, "q heads must be even for Qwen3.5 gated split"
assert num_q_heads % num_kv_heads == 0, "q heads must be a multiple of kv heads"
self.hidden_size = hidden_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_attn_heads = num_q_heads // 2 # half routed through attention
self.head_dim = head_dim
self.rope_dim = rope_dim
self.rope_base = rope_base
self.kv_repeat = self.num_attn_heads // num_kv_heads
q_dim = num_q_heads * head_dim
kv_dim = num_kv_heads * head_dim
attn_out = self.num_attn_heads * head_dim
self.q_proj = nn.Linear(hidden_size, q_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, kv_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, kv_dim, bias=False)
self.o_proj = nn.Linear(attn_out, hidden_size, bias=False)
self.q_norm = QwenRMSNorm(head_dim)
self.k_norm = QwenRMSNorm(head_dim)
def _rotary(self, seq_len: int, device, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
inv_freq = _build_inv_freq(self.rope_dim, self.rope_base, device, dtype)
pos = torch.arange(seq_len, device=device, dtype=dtype)
freqs = torch.einsum("i,j->ij", pos, inv_freq) # (T, rope_dim/2)
cos, sin = freqs.cos(), freqs.sin()
return cos[None, None, :, :], sin[None, None, :, :] # broadcast over (B, H, T, D/2)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
bsz, seq_len, _ = x.shape
q = self.q_proj(x) # (B, T, num_q_heads * hd)
k = self.k_proj(x) # (B, T, num_kv * hd)
v = self.v_proj(x)
# Split Qwen3.5 "gated" q into attn and gate halves
q = q.view(bsz, seq_len, self.num_q_heads, self.head_dim)
q_attn, q_gate = q[:, :, : self.num_attn_heads, :], q[:, :, self.num_attn_heads :, :]
q_attn = self.q_norm(q_attn)
k = self.k_norm(k.view(bsz, seq_len, self.num_kv_heads, self.head_dim))
v = v.view(bsz, seq_len, self.num_kv_heads, self.head_dim)
# (B, H, T, D) for attention
q_attn = q_attn.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
cos, sin = self._rotary(seq_len, x.device, q_attn.dtype)
q_attn = _apply_rotary(q_attn, cos, sin)
k = _apply_rotary(k, cos, sin)
# Expand kv heads to match attn heads (GQA)
if self.kv_repeat > 1:
k = k.repeat_interleave(self.kv_repeat, dim=1)
v = v.repeat_interleave(self.kv_repeat, dim=1)
scale = 1.0 / math.sqrt(self.head_dim)
attn_scores = torch.matmul(q_attn, k.transpose(-2, -1)) * scale
if attention_mask is not None:
# 1 = keep, 0 = mask -> additive -inf on masked KEYS, broadcast across heads
key_mask = attention_mask[:, None, None, :].to(attn_scores.dtype)
attn_scores = attn_scores.masked_fill(key_mask == 0, float("-inf"))
attn = attn_scores.softmax(dim=-1)
out = torch.matmul(attn, v) # (B, H, T, D)
out = out.transpose(1, 2).reshape(bsz, seq_len, self.num_attn_heads * self.head_dim)
# Apply gate signal (gate halves * SiLU as in Qwen3.5 gated attention)
q_gate = SHARED_SILU(q_gate).reshape(bsz, seq_len, self.num_attn_heads * self.head_dim)
out = out * q_gate
return self.o_proj(out)
class QwenAlignedBlock(nn.Module):
"""Mirrors a Qwen3.5 transformer block.
``layer_kind="attention"`` mirrors the 6 standard ``self_attn`` layers
(3, 7, 11, 15, 19, 23). ``layer_kind="deltanet"`` mirrors the 18
``linear_attn`` layers structurally but uses identity for the mix-token
op so we do not depend on flash-linear-attention. Both kinds keep the
Qwen-shaped MLP + norms so weight transplant is 1:1 for those tensors.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
layer_kind: str = "attention",
num_q_heads: int = 16,
num_kv_heads: int = 2,
head_dim: int = 256,
rope_dim: int = 64,
rope_base: float = 1_000_000.0,
):
super().__init__()
if layer_kind not in {"attention", "deltanet"}:
raise ValueError(f"unknown layer_kind: {layer_kind}")
self.layer_kind = layer_kind
self.input_layernorm = QwenRMSNorm(hidden_size)
self.post_attention_layernorm = QwenRMSNorm(hidden_size)
if layer_kind == "attention":
self.self_attn = QwenGatedGQA(
hidden_size=hidden_size,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
rope_dim=rope_dim,
rope_base=rope_base,
)
else:
self.self_attn = None
self.mlp = QwenSwiGLU(hidden_size, intermediate_size)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.self_attn is not None:
h = self.input_layernorm(x)
x = x + self.self_attn(h, attention_mask=attention_mask)
# deltanet layers contribute only their MLP (mix happens at the
# attention layers; this gives a real residual transformer signal)
h = self.post_attention_layernorm(x)
x = x + self.mlp(h)
return x
class QwenAlignedTextRefiner(nn.Module):
"""Stack of Qwen-aligned blocks.
Designed to sit on top of a host text encoder's hidden states and
produce a Qwen-conditioned representation at the same hidden dim. The
block layout mirrors Qwen3.5-0.8B: ``num_layers=24`` with attention at
every 4th position (indices 3, 7, 11, 15, 19, 23), but is configurable
so smaller refiners can be transplanted from a Qwen subset.
Outputs are projected to ``out_dim`` (defaults to hidden_size) via a
final ``norm`` + ``proj`` so the refiner can plug into any
downstream conditioning bridge.
"""
DEFAULT_ATTENTION_INDICES = (3, 7, 11, 15, 19, 23)
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 3584,
num_layers: int = 24,
attention_indices: Optional[Tuple[int, ...]] = None,
num_q_heads: int = 16,
num_kv_heads: int = 2,
head_dim: int = 256,
rope_dim: int = 64,
rope_base: float = 1_000_000.0,
out_dim: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_layers = num_layers
self.attention_indices = tuple(
self.DEFAULT_ATTENTION_INDICES if attention_indices is None else attention_indices
)
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.rope_dim = rope_dim
self.rope_base = rope_base
attention_set = set(self.attention_indices)
self.layers = nn.ModuleList(
[
QwenAlignedBlock(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_kind="attention" if i in attention_set else "deltanet",
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
rope_dim=rope_dim,
rope_base=rope_base,
)
for i in range(num_layers)
]
)
self.norm = QwenRMSNorm(hidden_size)
target_dim = hidden_size if out_dim is None else int(out_dim)
self.out_dim = target_dim
if target_dim == hidden_size:
self.proj = nn.Identity()
else:
self.proj = nn.Linear(hidden_size, target_dim, bias=False)
self.gate = nn.Parameter(torch.zeros(())) # learned residual gate, init 0 (identity)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hidden_states.shape[-1] != self.hidden_size:
raise ValueError(
f"QwenAlignedTextRefiner expected dim {self.hidden_size}, got {hidden_states.shape[-1]}"
)
residual = hidden_states
h = hidden_states
for layer in self.layers:
h = layer(h, attention_mask=attention_mask)
h = self.norm(h)
h = self.proj(h)
if isinstance(self.proj, nn.Identity):
# gate=0 init: refiner starts as identity; training learns to mix it in
return residual + torch.tanh(self.gate) * (h - residual)
# When projecting to a new dim, residual is not addable — return h directly.
# gate is still a learnable scalar so downstream training can dampen this path.
return h * (1.0 + torch.tanh(self.gate))
def get_qwen_state_dict_map(self) -> List[Tuple[str, str]]:
"""Return list of (qwen_key, our_key) pairs for transplant. Only
includes tensors whose shape matches between Qwen3.5 and us."""
pairs: List[Tuple[str, str]] = []
for i in range(self.num_layers):
ours = f"layers.{i}"
qwen = f"layers.{i}"
pairs.append((f"{qwen}.input_layernorm.weight", f"{ours}.input_layernorm.weight"))
pairs.append((f"{qwen}.post_attention_layernorm.weight", f"{ours}.post_attention_layernorm.weight"))
pairs.append((f"{qwen}.mlp.gate_proj.weight", f"{ours}.mlp.gate_proj.weight"))
pairs.append((f"{qwen}.mlp.up_proj.weight", f"{ours}.mlp.up_proj.weight"))
pairs.append((f"{qwen}.mlp.down_proj.weight", f"{ours}.mlp.down_proj.weight"))
if i in set(self.attention_indices):
pairs.append((f"{qwen}.self_attn.q_proj.weight", f"{ours}.self_attn.q_proj.weight"))
pairs.append((f"{qwen}.self_attn.k_proj.weight", f"{ours}.self_attn.k_proj.weight"))
pairs.append((f"{qwen}.self_attn.v_proj.weight", f"{ours}.self_attn.v_proj.weight"))
pairs.append((f"{qwen}.self_attn.o_proj.weight", f"{ours}.self_attn.o_proj.weight"))
pairs.append((f"{qwen}.self_attn.q_norm.weight", f"{ours}.self_attn.q_norm.weight"))
pairs.append((f"{qwen}.self_attn.k_norm.weight", f"{ours}.self_attn.k_norm.weight"))
pairs.append(("norm.weight", "norm.weight"))
return pairs