sad / src /models /ar_model.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
8.51 kB
"""
ARModel — standard decoder-only (GPT-2 style) Transformer for next-token prediction.
Baseline to compare against SAD / Block-AR diffusion at matched scale:
- Same hidden_size / n_blocks / n_heads / seq_len as SADModel
- Same RoPE (reused from dit_components)
- Standard pre-LN blocks with causal self-attention + GELU MLP
- Untied token embedding / output head, matching Block-AR parameterization
- No adaLN / no timestep conditioning / no DiT modulation
Inference:
forward(input_ids) — full-sequence forward, used by training / eval / the
first (prompt) step of generation.
forward_cached(input_ids, past_kv_list=None) — returns
(logits, new_kv_list). Used for left-to-right generation
with an incrementally grown KV cache.
KV cache layout: list of length n_blocks; each entry is (k, v) with shape
[B, H, S_cache, head_dim].
Max total length is `max_seq_len` (RoPE is precomputed for that length).
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .dit_components import Rotary, apply_rotary_pos_emb
KVPair = Tuple[torch.Tensor, torch.Tensor]
class ARBlock(nn.Module):
"""Pre-LN causal self-attention + MLP, no conditioning."""
def __init__(self, dim: int, n_heads: int, mlp_ratio: int = 4, dropout: float = 0.0):
super().__init__()
assert dim % n_heads == 0
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.dropout = dropout
self.norm1 = nn.LayerNorm(dim)
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_ratio * dim, dim, bias=True),
)
def _qkv(self, x: torch.Tensor, rotary_cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
h = self.norm1(x)
qkv = self.attn_qkv(h)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
q = qkv[:, :, 0].transpose(1, 2) # [B, H, S, D]
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
return q, k, v
def forward(self, x: torch.Tensor, rotary_cos_sin) -> torch.Tensor:
"""Uncached path (training / full-sequence eval)."""
q, k, v = self._qkv(x, rotary_cos_sin)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attn = rearrange(attn, "b h s d -> b s (h d)")
x = x + F.dropout(self.attn_out(attn), p=self.dropout, training=self.training)
x = x + F.dropout(self.mlp(self.norm2(x)), p=self.dropout, training=self.training)
return x
def forward_cached(
self,
x: torch.Tensor,
rotary_cos_sin,
past_kv: Optional[KVPair] = None,
) -> Tuple[torch.Tensor, KVPair]:
"""
Cached path (generation).
past_kv: optional (k_cache, v_cache) each [B, H, S_cache, D].
Returns (out [B, S_new, d], new_kv = (k_all, v_all)).
With S_cache == 0 (first call), acts like an is_causal=True forward.
With S_cache > 0, expects S_new == 1 (single-step append) — the new
query at the last position attends to all S_cache + 1 tokens, no mask.
"""
q, k_new, v_new = self._qkv(x, rotary_cos_sin)
if past_kv is None or past_kv[0].size(2) == 0:
k = k_new
v = v_new
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
pk, pv = past_kv
k = torch.cat([pk, k_new], dim=2)
v = torch.cat([pv, v_new], dim=2)
# Single-step append: new query is the most recent position → full
# visibility over [0 .. S_cache] is correct (no causal mask needed).
assert q.size(2) == 1, (
f"forward_cached with non-empty cache expects S_new == 1, got {q.size(2)}"
)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=False)
attn = rearrange(attn, "b h s d -> b s (h d)")
x = x + self.attn_out(attn)
x = x + self.mlp(self.norm2(x))
return x, (k, v)
class ARModel(nn.Module):
"""
GPT-2-style decoder-only Transformer with RoPE.
Args:
vocab_size: V
hidden_size: d
n_blocks: number of transformer blocks
n_heads: number of attention heads
max_seq_len: max supported sequence length (for RoPE precompute)
dropout: dropout inside blocks (0.0 by default, matching SAD)
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
n_blocks: int = 12,
n_heads: int = 12,
max_seq_len: int = 512,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.n_blocks = n_blocks
self.n_heads = n_heads
self.max_seq_len = max_seq_len
# Round vocab to multiple of 128 (same trick as SADModel for tensor-core friendliness)
self.rounded_vocab_size = vocab_size + (128 - vocab_size % 128) % 128
# Input embedding
self.tok_embed = nn.Embedding(self.rounded_vocab_size, hidden_size)
nn.init.normal_(self.tok_embed.weight, std=0.02)
self.rotary_emb = Rotary(hidden_size // n_heads, max_seq_len=max_seq_len)
self.blocks = nn.ModuleList([
ARBlock(hidden_size, n_heads, dropout=dropout) for _ in range(n_blocks)
])
self.norm_final = nn.LayerNorm(hidden_size)
# Decoupled output head to match the Block-AR baseline parameterization.
self.lm_head = nn.Linear(hidden_size, self.rounded_vocab_size, bias=False)
nn.init.normal_(self.lm_head.weight, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, # accepted for API symmetry; unused
) -> torch.Tensor:
"""
Full-sequence forward. Training / full-sequence eval path.
Args:
input_ids: [B, S] int64
Returns:
logits: [B, S, V] (sliced to true vocab, not rounded)
"""
x = self.tok_embed(input_ids) # [B, S, d]
rotary_cos_sin = self.rotary_emb(x)
for block in self.blocks:
x = block(x, rotary_cos_sin)
x = self.norm_final(x)
logits = self.lm_head(x)[..., :self.vocab_size]
return logits
def forward_cached(
self,
input_ids: torch.Tensor,
past_kv_list: Optional[List[KVPair]] = None,
) -> Tuple[torch.Tensor, List[KVPair]]:
"""
Cached forward for generation.
Args:
input_ids: [B, S_new]
past_kv_list: None (first call) or list of length n_blocks;
each entry is (k, v) of shape [B, H, S_cache, D].
Returns:
logits: [B, S_new, V]
new_kv_list: list of length n_blocks with updated (k, v) of
shape [B, H, S_cache + S_new, D].
"""
B, S_new = input_ids.shape
device = input_ids.device
S_cache = 0 if past_kv_list is None else past_kv_list[0][0].size(2)
total_len = S_cache + S_new
assert total_len <= self.max_seq_len, (
f"cache+new ({S_cache}+{S_new}={total_len}) exceeds "
f"max_seq_len={self.max_seq_len}"
)
x = self.tok_embed(input_ids) # [B, S_new, d]
# RoPE positions for the new tokens: [S_cache .. S_cache + S_new - 1]
position_ids = torch.arange(S_cache, S_cache + S_new, device=device)
rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids)
new_kv_list: List[KVPair] = []
for i, block in enumerate(self.blocks):
pkv = None if past_kv_list is None else past_kv_list[i]
x, new_kv = block.forward_cached(x, rotary_cos_sin, past_kv=pkv)
new_kv_list.append(new_kv)
x = self.norm_final(x)
logits = self.lm_head(x)[..., :self.vocab_size]
return logits, new_kv_list