""" Language Diffusion Transformer (DiT for text). Public open-source version: - pure PyTorch only - state_dict key layout kept compatible with the internal model """ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch._dynamo import disable as dynamo_disable class FallbackRMSNorm(nn.Module): """Minimal RMSNorm implementation used by the public model.""" def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x_float = x.float() norm = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps) out = x_float * norm return (out.to(dtype=x_dtype) * self.weight).to(dtype=x_dtype) def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def _rope_cos_sin( seqlen: int, dim: int, theta: float, *, device: torch.device, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: inv_freq = 1.0 / ( theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) t = torch.arange(seqlen, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype)[None, :, None, :] sin = emb.sin().to(dtype=dtype)[None, :, None, :] return cos, sin class TokenEmbedding(nn.Module): """Token embedding (untied from output).""" def __init__(self, vocab_size: int, dim: int): super().__init__() self.embed = nn.Embedding(vocab_size, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.embed(x) class TimestepEmbedding(nn.Module): """Sinusoidal timestep embedding -> conditioning vector.""" def __init__(self, cond_dim: int, freq_dim: int = 256): super().__init__() self.freq_dim = freq_dim self.mlp = nn.Sequential( nn.Linear(freq_dim, cond_dim, bias=True), nn.SiLU(), nn.Linear(cond_dim, cond_dim, bias=True), ) def forward(self, t: torch.Tensor) -> torch.Tensor: if t.ndim == 2 and t.shape[1] == 1: t = t.squeeze(-1) half = self.freq_dim // 2 freqs = torch.exp( -math.log(10000) * torch.arange(half, device=t.device, dtype=torch.float32) / half ) args = t[:, None].to(dtype=torch.float32) * freqs[None] embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embed = embed.to(dtype=self.mlp[0].weight.dtype) return self.mlp(embed) class RotaryEmbedding(nn.Module): """Pure PyTorch RoPE implementation.""" def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0): super().__init__() self.dim = int(dim) self.theta = float(theta) self.max_seq_len = max_seq_len def apply_bshd(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: cos, sin = _rope_cos_sin( q.shape[1], q.shape[-1], self.theta, device=q.device, dtype=q.dtype, ) q = (q * cos) + (_rotate_half(q) * sin) k = (k * cos) + (_rotate_half(k) * sin) return q, k class Attention(nn.Module): """ Multi-head attention with expanded attention dimension. hidden_size -> attn_dim for Q,K,V -> hidden_size """ def __init__( self, hidden_size: int, attn_dim: int, num_heads: int, head_dim: int = 128, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.attn_dim = attn_dim self.attn_drop = float(attn_drop) self.qkv = nn.Linear(hidden_size, attn_dim * 3, bias=False) self.proj = nn.Linear(attn_dim, hidden_size, bias=False) self.proj_drop = nn.Dropout(proj_drop) @dynamo_disable def forward( self, x: torch.Tensor, rope: Optional[RotaryEmbedding] = None, pack: Optional[object] = None, ) -> torch.Tensor: if pack is not None: raise RuntimeError("Packed attention is not included in the public torch-only model.") bsz, seqlen, _ = x.shape qkv = self.qkv(x).reshape(bsz, seqlen, 3, self.num_heads, self.head_dim) q, k, v = qkv.unbind(dim=2) if rope is not None: q, k = rope.apply_bshd(q, k) qh = q.permute(0, 2, 1, 3).contiguous() kh = k.permute(0, 2, 1, 3).contiguous() vh = v.permute(0, 2, 1, 3).contiguous() out = F.scaled_dot_product_attention( qh, kh, vh, dropout_p=self.attn_drop if self.training else 0.0, is_causal=False, ) x = out.permute(0, 2, 1, 3).contiguous().reshape(bsz, seqlen, self.attn_dim) x = self.proj(x) x = self.proj_drop(x) return x class SwiGLU(nn.Module): """SwiGLU FFN with custom intermediate dimension.""" def __init__(self, hidden_size: int, ffn_dim: int, dropout: float = 0.0): super().__init__() self.w12 = nn.Linear(hidden_size, 2 * ffn_dim, bias=False) self.w3 = nn.Linear(ffn_dim, hidden_size, bias=False) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = self.w12(x).chunk(2, dim=-1) return self.w3(self.drop(F.silu(x1) * x2)) def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """AdaLN modulation.""" return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class DiTBlock(nn.Module): """ DiT transformer block with AdaLN-Zero modulation. Pre-norm: RMSNorm -> modulate -> Attention/FFN -> gate """ def __init__( self, hidden_size: int, attn_dim: int, ffn_dim: int, num_heads: int, head_dim: int = 128, cond_dim: int = 256, attn_drop: float = 0.0, drop: float = 0.0, ): super().__init__() self.norm1 = FallbackRMSNorm(hidden_size, eps=1e-6) self.attn = Attention(hidden_size, attn_dim, num_heads, head_dim, attn_drop, drop) self.norm2 = FallbackRMSNorm(hidden_size, eps=1e-6) self.mlp = SwiGLU(hidden_size, ffn_dim, dropout=drop) self.adaLN = nn.Sequential( nn.SiLU(), nn.Linear(cond_dim, 6 * hidden_size, bias=True), ) def forward( self, x: torch.Tensor, c: torch.Tensor, rope: Optional[RotaryEmbedding] = None, pack: Optional[object] = None, ) -> torch.Tensor: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN(c).chunk(6, dim=-1) x = x + gate_msa.unsqueeze(1) * self.attn( modulate(self.norm1(x), shift_msa, scale_msa), rope, pack=pack, ) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class FinalAdaLN(nn.Module): """Final AdaLN block that produces prelogits.""" def __init__(self, hidden_size: int, cond_dim: int): super().__init__() self.norm = FallbackRMSNorm(hidden_size, eps=1e-6) self.adaLN = nn.Sequential( nn.SiLU(), nn.Linear(cond_dim, 2 * hidden_size, bias=True), ) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN(c).chunk(2, dim=-1) return modulate(self.norm(x), shift, scale) class LangDiT(nn.Module): """Language Diffusion Transformer.""" def __init__( self, vocab_size: int = 64512, hidden_size: int = 2048, attn_dim: int = 3072, ffn_dim: int = 7168, depth: int = 48, num_heads: int = 24, head_dim: int = 128, max_seq_len: int = 4096, timestep_freq_dim: int = 256, rope_theta: float = 10000.0, cond_dim: int = 256, dropout: float = 0.0, attn_dropout: float = 0.0, ): super().__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size self.token_embed = TokenEmbedding(vocab_size, hidden_size) self.time_embed = TimestepEmbedding(cond_dim, freq_dim=timestep_freq_dim) self.rope = RotaryEmbedding(head_dim, max_seq_len, theta=rope_theta) self.blocks = nn.ModuleList( [ DiTBlock( hidden_size=hidden_size, attn_dim=attn_dim, ffn_dim=ffn_dim, num_heads=num_heads, head_dim=head_dim, cond_dim=cond_dim, attn_drop=attn_dropout, drop=dropout, ) for _ in range(depth) ] ) self.final_ada = FinalAdaLN(hidden_size, cond_dim) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) self._init_weights() def _init_weights(self): def init_fn(module: nn.Module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=0.02) self.apply(init_fn) for block in self.blocks: nn.init.zeros_(block.adaLN[-1].weight) nn.init.zeros_(block.adaLN[-1].bias) nn.init.zeros_(self.final_ada.adaLN[-1].weight) nn.init.zeros_(self.final_ada.adaLN[-1].bias) def forward_hidden( self, input_ids: torch.Tensor, timesteps: torch.Tensor, *, pack: Optional[object] = None, ) -> torch.Tensor: if pack is not None: raise RuntimeError("Packed attention is not included in the public torch-only model.") x = self.token_embed(input_ids) c = self.time_embed(timesteps) for block in self.blocks: x = block(x, c, self.rope, pack=None) return self.final_ada(x, c) def logits_from_hidden(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.lm_head(hidden_states) def forward( self, input_ids: torch.Tensor, timesteps: torch.Tensor, *, pack: Optional[object] = None, return_hidden: bool = False, ) -> torch.Tensor: hidden = self.forward_hidden(input_ids, timesteps, pack=pack) if return_hidden: return hidden return self.logits_from_hidden(hidden) def create_model(config: dict) -> LangDiT: """Create model from config dict.""" model_cfg = config["model"] return LangDiT( vocab_size=model_cfg["vocab_size"], hidden_size=model_cfg["hidden_size"], attn_dim=model_cfg["attn_dim"], ffn_dim=model_cfg["ffn_dim"], depth=model_cfg["depth"], num_heads=model_cfg["num_heads"], head_dim=model_cfg["head_dim"], max_seq_len=model_cfg["max_seq_len"], timestep_freq_dim=model_cfg.get("timestep_freq_dim", 256), rope_theta=model_cfg.get("rope_theta", 10000.0), cond_dim=model_cfg["cond_dim"], dropout=model_cfg.get("dropout", 0.0), attn_dropout=model_cfg.get("attn_dropout", 0.0), )