""" ARModel — standard decoder-only (GPT-2 style) Transformer for next-token prediction. Baseline to compare against SAD / Block-AR diffusion at matched scale: - Same hidden_size / n_blocks / n_heads / seq_len as SADModel - Same RoPE (reused from dit_components) - Standard pre-LN blocks with causal self-attention + GELU MLP - Untied token embedding / output head, matching Block-AR parameterization - No adaLN / no timestep conditioning / no DiT modulation Inference: forward(input_ids) — full-sequence forward, used by training / eval / the first (prompt) step of generation. forward_cached(input_ids, past_kv_list=None) — returns (logits, new_kv_list). Used for left-to-right generation with an incrementally grown KV cache. KV cache layout: list of length n_blocks; each entry is (k, v) with shape [B, H, S_cache, head_dim]. Max total length is `max_seq_len` (RoPE is precomputed for that length). """ from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .dit_components import Rotary, apply_rotary_pos_emb KVPair = Tuple[torch.Tensor, torch.Tensor] class ARBlock(nn.Module): """Pre-LN causal self-attention + MLP, no conditioning.""" def __init__(self, dim: int, n_heads: int, mlp_ratio: int = 4, dropout: float = 0.0): super().__init__() assert dim % n_heads == 0 self.n_heads = n_heads self.head_dim = dim // n_heads self.dropout = dropout self.norm1 = nn.LayerNorm(dim) self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_ratio * dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_ratio * dim, dim, bias=True), ) def _qkv(self, x: torch.Tensor, rotary_cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: h = self.norm1(x) qkv = self.attn_qkv(h) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads) cos, sin = rotary_cos_sin qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) q = qkv[:, :, 0].transpose(1, 2) # [B, H, S, D] k = qkv[:, :, 1].transpose(1, 2) v = qkv[:, :, 2].transpose(1, 2) return q, k, v def forward(self, x: torch.Tensor, rotary_cos_sin) -> torch.Tensor: """Uncached path (training / full-sequence eval).""" q, k, v = self._qkv(x, rotary_cos_sin) attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) attn = rearrange(attn, "b h s d -> b s (h d)") x = x + F.dropout(self.attn_out(attn), p=self.dropout, training=self.training) x = x + F.dropout(self.mlp(self.norm2(x)), p=self.dropout, training=self.training) return x def forward_cached( self, x: torch.Tensor, rotary_cos_sin, past_kv: Optional[KVPair] = None, ) -> Tuple[torch.Tensor, KVPair]: """ Cached path (generation). past_kv: optional (k_cache, v_cache) each [B, H, S_cache, D]. Returns (out [B, S_new, d], new_kv = (k_all, v_all)). With S_cache == 0 (first call), acts like an is_causal=True forward. With S_cache > 0, expects S_new == 1 (single-step append) — the new query at the last position attends to all S_cache + 1 tokens, no mask. """ q, k_new, v_new = self._qkv(x, rotary_cos_sin) if past_kv is None or past_kv[0].size(2) == 0: k = k_new v = v_new attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: pk, pv = past_kv k = torch.cat([pk, k_new], dim=2) v = torch.cat([pv, v_new], dim=2) # Single-step append: new query is the most recent position → full # visibility over [0 .. S_cache] is correct (no causal mask needed). assert q.size(2) == 1, ( f"forward_cached with non-empty cache expects S_new == 1, got {q.size(2)}" ) attn = F.scaled_dot_product_attention(q, k, v, is_causal=False) attn = rearrange(attn, "b h s d -> b s (h d)") x = x + self.attn_out(attn) x = x + self.mlp(self.norm2(x)) return x, (k, v) class ARModel(nn.Module): """ GPT-2-style decoder-only Transformer with RoPE. Args: vocab_size: V hidden_size: d n_blocks: number of transformer blocks n_heads: number of attention heads max_seq_len: max supported sequence length (for RoPE precompute) dropout: dropout inside blocks (0.0 by default, matching SAD) """ def __init__( self, vocab_size: int, hidden_size: int = 768, n_blocks: int = 12, n_heads: int = 12, max_seq_len: int = 512, dropout: float = 0.0, ): super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.n_blocks = n_blocks self.n_heads = n_heads self.max_seq_len = max_seq_len # Round vocab to multiple of 128 (same trick as SADModel for tensor-core friendliness) self.rounded_vocab_size = vocab_size + (128 - vocab_size % 128) % 128 # Input embedding self.tok_embed = nn.Embedding(self.rounded_vocab_size, hidden_size) nn.init.normal_(self.tok_embed.weight, std=0.02) self.rotary_emb = Rotary(hidden_size // n_heads, max_seq_len=max_seq_len) self.blocks = nn.ModuleList([ ARBlock(hidden_size, n_heads, dropout=dropout) for _ in range(n_blocks) ]) self.norm_final = nn.LayerNorm(hidden_size) # Decoupled output head to match the Block-AR baseline parameterization. self.lm_head = nn.Linear(hidden_size, self.rounded_vocab_size, bias=False) nn.init.normal_(self.lm_head.weight, std=0.02) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, # accepted for API symmetry; unused ) -> torch.Tensor: """ Full-sequence forward. Training / full-sequence eval path. Args: input_ids: [B, S] int64 Returns: logits: [B, S, V] (sliced to true vocab, not rounded) """ x = self.tok_embed(input_ids) # [B, S, d] rotary_cos_sin = self.rotary_emb(x) for block in self.blocks: x = block(x, rotary_cos_sin) x = self.norm_final(x) logits = self.lm_head(x)[..., :self.vocab_size] return logits def forward_cached( self, input_ids: torch.Tensor, past_kv_list: Optional[List[KVPair]] = None, ) -> Tuple[torch.Tensor, List[KVPair]]: """ Cached forward for generation. Args: input_ids: [B, S_new] past_kv_list: None (first call) or list of length n_blocks; each entry is (k, v) of shape [B, H, S_cache, D]. Returns: logits: [B, S_new, V] new_kv_list: list of length n_blocks with updated (k, v) of shape [B, H, S_cache + S_new, D]. """ B, S_new = input_ids.shape device = input_ids.device S_cache = 0 if past_kv_list is None else past_kv_list[0][0].size(2) total_len = S_cache + S_new assert total_len <= self.max_seq_len, ( f"cache+new ({S_cache}+{S_new}={total_len}) exceeds " f"max_seq_len={self.max_seq_len}" ) x = self.tok_embed(input_ids) # [B, S_new, d] # RoPE positions for the new tokens: [S_cache .. S_cache + S_new - 1] position_ids = torch.arange(S_cache, S_cache + S_new, device=device) rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids) new_kv_list: List[KVPair] = [] for i, block in enumerate(self.blocks): pkv = None if past_kv_list is None else past_kv_list[i] x, new_kv = block.forward_cached(x, rotary_cos_sin, past_kv=pkv) new_kv_list.append(new_kv) x = self.norm_final(x) logits = self.lm_head(x)[..., :self.vocab_size] return logits, new_kv_list