| """ |
| ARModel — standard decoder-only (GPT-2 style) Transformer for next-token prediction. |
| |
| Baseline to compare against SAD / Block-AR diffusion at matched scale: |
| - Same hidden_size / n_blocks / n_heads / seq_len as SADModel |
| - Same RoPE (reused from dit_components) |
| - Standard pre-LN blocks with causal self-attention + GELU MLP |
| - Untied token embedding / output head, matching Block-AR parameterization |
| - No adaLN / no timestep conditioning / no DiT modulation |
| |
| Inference: |
| forward(input_ids) — full-sequence forward, used by training / eval / the |
| first (prompt) step of generation. |
| forward_cached(input_ids, past_kv_list=None) — returns |
| (logits, new_kv_list). Used for left-to-right generation |
| with an incrementally grown KV cache. |
| |
| KV cache layout: list of length n_blocks; each entry is (k, v) with shape |
| [B, H, S_cache, head_dim]. |
| Max total length is `max_seq_len` (RoPE is precomputed for that length). |
| """ |
|
|
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from .dit_components import Rotary, apply_rotary_pos_emb |
|
|
|
|
| KVPair = Tuple[torch.Tensor, torch.Tensor] |
|
|
|
|
| class ARBlock(nn.Module): |
| """Pre-LN causal self-attention + MLP, no conditioning.""" |
|
|
| def __init__(self, dim: int, n_heads: int, mlp_ratio: int = 4, dropout: float = 0.0): |
| super().__init__() |
| assert dim % n_heads == 0 |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.dropout = dropout |
|
|
| self.norm1 = nn.LayerNorm(dim) |
| self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.attn_out = nn.Linear(dim, dim, bias=False) |
|
|
| self.norm2 = nn.LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_ratio * dim, bias=True), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(mlp_ratio * dim, dim, bias=True), |
| ) |
|
|
| def _qkv(self, x: torch.Tensor, rotary_cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| h = self.norm1(x) |
| qkv = self.attn_qkv(h) |
| qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads) |
| cos, sin = rotary_cos_sin |
| qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) |
| q = qkv[:, :, 0].transpose(1, 2) |
| k = qkv[:, :, 1].transpose(1, 2) |
| v = qkv[:, :, 2].transpose(1, 2) |
| return q, k, v |
|
|
| def forward(self, x: torch.Tensor, rotary_cos_sin) -> torch.Tensor: |
| """Uncached path (training / full-sequence eval).""" |
| q, k, v = self._qkv(x, rotary_cos_sin) |
| attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| attn = rearrange(attn, "b h s d -> b s (h d)") |
| x = x + F.dropout(self.attn_out(attn), p=self.dropout, training=self.training) |
| x = x + F.dropout(self.mlp(self.norm2(x)), p=self.dropout, training=self.training) |
| return x |
|
|
| def forward_cached( |
| self, |
| x: torch.Tensor, |
| rotary_cos_sin, |
| past_kv: Optional[KVPair] = None, |
| ) -> Tuple[torch.Tensor, KVPair]: |
| """ |
| Cached path (generation). |
| |
| past_kv: optional (k_cache, v_cache) each [B, H, S_cache, D]. |
| Returns (out [B, S_new, d], new_kv = (k_all, v_all)). |
| |
| With S_cache == 0 (first call), acts like an is_causal=True forward. |
| With S_cache > 0, expects S_new == 1 (single-step append) — the new |
| query at the last position attends to all S_cache + 1 tokens, no mask. |
| """ |
| q, k_new, v_new = self._qkv(x, rotary_cos_sin) |
|
|
| if past_kv is None or past_kv[0].size(2) == 0: |
| k = k_new |
| v = v_new |
| attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| else: |
| pk, pv = past_kv |
| k = torch.cat([pk, k_new], dim=2) |
| v = torch.cat([pv, v_new], dim=2) |
| |
| |
| assert q.size(2) == 1, ( |
| f"forward_cached with non-empty cache expects S_new == 1, got {q.size(2)}" |
| ) |
| attn = F.scaled_dot_product_attention(q, k, v, is_causal=False) |
|
|
| attn = rearrange(attn, "b h s d -> b s (h d)") |
| x = x + self.attn_out(attn) |
| x = x + self.mlp(self.norm2(x)) |
| return x, (k, v) |
|
|
|
|
| class ARModel(nn.Module): |
| """ |
| GPT-2-style decoder-only Transformer with RoPE. |
| |
| Args: |
| vocab_size: V |
| hidden_size: d |
| n_blocks: number of transformer blocks |
| n_heads: number of attention heads |
| max_seq_len: max supported sequence length (for RoPE precompute) |
| dropout: dropout inside blocks (0.0 by default, matching SAD) |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| hidden_size: int = 768, |
| n_blocks: int = 12, |
| n_heads: int = 12, |
| max_seq_len: int = 512, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.n_blocks = n_blocks |
| self.n_heads = n_heads |
| self.max_seq_len = max_seq_len |
|
|
| |
| self.rounded_vocab_size = vocab_size + (128 - vocab_size % 128) % 128 |
|
|
| |
| self.tok_embed = nn.Embedding(self.rounded_vocab_size, hidden_size) |
| nn.init.normal_(self.tok_embed.weight, std=0.02) |
|
|
| self.rotary_emb = Rotary(hidden_size // n_heads, max_seq_len=max_seq_len) |
|
|
| self.blocks = nn.ModuleList([ |
| ARBlock(hidden_size, n_heads, dropout=dropout) for _ in range(n_blocks) |
| ]) |
| self.norm_final = nn.LayerNorm(hidden_size) |
|
|
| |
| self.lm_head = nn.Linear(hidden_size, self.rounded_vocab_size, bias=False) |
| nn.init.normal_(self.lm_head.weight, std=0.02) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Full-sequence forward. Training / full-sequence eval path. |
| |
| Args: |
| input_ids: [B, S] int64 |
| Returns: |
| logits: [B, S, V] (sliced to true vocab, not rounded) |
| """ |
| x = self.tok_embed(input_ids) |
| rotary_cos_sin = self.rotary_emb(x) |
|
|
| for block in self.blocks: |
| x = block(x, rotary_cos_sin) |
|
|
| x = self.norm_final(x) |
| logits = self.lm_head(x)[..., :self.vocab_size] |
| return logits |
|
|
| def forward_cached( |
| self, |
| input_ids: torch.Tensor, |
| past_kv_list: Optional[List[KVPair]] = None, |
| ) -> Tuple[torch.Tensor, List[KVPair]]: |
| """ |
| Cached forward for generation. |
| |
| Args: |
| input_ids: [B, S_new] |
| past_kv_list: None (first call) or list of length n_blocks; |
| each entry is (k, v) of shape [B, H, S_cache, D]. |
| Returns: |
| logits: [B, S_new, V] |
| new_kv_list: list of length n_blocks with updated (k, v) of |
| shape [B, H, S_cache + S_new, D]. |
| """ |
| B, S_new = input_ids.shape |
| device = input_ids.device |
|
|
| S_cache = 0 if past_kv_list is None else past_kv_list[0][0].size(2) |
| total_len = S_cache + S_new |
| assert total_len <= self.max_seq_len, ( |
| f"cache+new ({S_cache}+{S_new}={total_len}) exceeds " |
| f"max_seq_len={self.max_seq_len}" |
| ) |
|
|
| x = self.tok_embed(input_ids) |
|
|
| |
| position_ids = torch.arange(S_cache, S_cache + S_new, device=device) |
| rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids) |
|
|
| new_kv_list: List[KVPair] = [] |
| for i, block in enumerate(self.blocks): |
| pkv = None if past_kv_list is None else past_kv_list[i] |
| x, new_kv = block.forward_cached(x, rotary_cos_sin, past_kv=pkv) |
| new_kv_list.append(new_kv) |
|
|
| x = self.norm_final(x) |
| logits = self.lm_head(x)[..., :self.vocab_size] |
| return logits, new_kv_list |
|
|