File size: 8,506 Bytes
922bb4b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | """
ARModel — standard decoder-only (GPT-2 style) Transformer for next-token prediction.
Baseline to compare against SAD / Block-AR diffusion at matched scale:
- Same hidden_size / n_blocks / n_heads / seq_len as SADModel
- Same RoPE (reused from dit_components)
- Standard pre-LN blocks with causal self-attention + GELU MLP
- Untied token embedding / output head, matching Block-AR parameterization
- No adaLN / no timestep conditioning / no DiT modulation
Inference:
forward(input_ids) — full-sequence forward, used by training / eval / the
first (prompt) step of generation.
forward_cached(input_ids, past_kv_list=None) — returns
(logits, new_kv_list). Used for left-to-right generation
with an incrementally grown KV cache.
KV cache layout: list of length n_blocks; each entry is (k, v) with shape
[B, H, S_cache, head_dim].
Max total length is `max_seq_len` (RoPE is precomputed for that length).
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .dit_components import Rotary, apply_rotary_pos_emb
KVPair = Tuple[torch.Tensor, torch.Tensor]
class ARBlock(nn.Module):
"""Pre-LN causal self-attention + MLP, no conditioning."""
def __init__(self, dim: int, n_heads: int, mlp_ratio: int = 4, dropout: float = 0.0):
super().__init__()
assert dim % n_heads == 0
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.dropout = dropout
self.norm1 = nn.LayerNorm(dim)
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_ratio * dim, dim, bias=True),
)
def _qkv(self, x: torch.Tensor, rotary_cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
h = self.norm1(x)
qkv = self.attn_qkv(h)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
q = qkv[:, :, 0].transpose(1, 2) # [B, H, S, D]
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
return q, k, v
def forward(self, x: torch.Tensor, rotary_cos_sin) -> torch.Tensor:
"""Uncached path (training / full-sequence eval)."""
q, k, v = self._qkv(x, rotary_cos_sin)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attn = rearrange(attn, "b h s d -> b s (h d)")
x = x + F.dropout(self.attn_out(attn), p=self.dropout, training=self.training)
x = x + F.dropout(self.mlp(self.norm2(x)), p=self.dropout, training=self.training)
return x
def forward_cached(
self,
x: torch.Tensor,
rotary_cos_sin,
past_kv: Optional[KVPair] = None,
) -> Tuple[torch.Tensor, KVPair]:
"""
Cached path (generation).
past_kv: optional (k_cache, v_cache) each [B, H, S_cache, D].
Returns (out [B, S_new, d], new_kv = (k_all, v_all)).
With S_cache == 0 (first call), acts like an is_causal=True forward.
With S_cache > 0, expects S_new == 1 (single-step append) — the new
query at the last position attends to all S_cache + 1 tokens, no mask.
"""
q, k_new, v_new = self._qkv(x, rotary_cos_sin)
if past_kv is None or past_kv[0].size(2) == 0:
k = k_new
v = v_new
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
pk, pv = past_kv
k = torch.cat([pk, k_new], dim=2)
v = torch.cat([pv, v_new], dim=2)
# Single-step append: new query is the most recent position → full
# visibility over [0 .. S_cache] is correct (no causal mask needed).
assert q.size(2) == 1, (
f"forward_cached with non-empty cache expects S_new == 1, got {q.size(2)}"
)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=False)
attn = rearrange(attn, "b h s d -> b s (h d)")
x = x + self.attn_out(attn)
x = x + self.mlp(self.norm2(x))
return x, (k, v)
class ARModel(nn.Module):
"""
GPT-2-style decoder-only Transformer with RoPE.
Args:
vocab_size: V
hidden_size: d
n_blocks: number of transformer blocks
n_heads: number of attention heads
max_seq_len: max supported sequence length (for RoPE precompute)
dropout: dropout inside blocks (0.0 by default, matching SAD)
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 768,
n_blocks: int = 12,
n_heads: int = 12,
max_seq_len: int = 512,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.n_blocks = n_blocks
self.n_heads = n_heads
self.max_seq_len = max_seq_len
# Round vocab to multiple of 128 (same trick as SADModel for tensor-core friendliness)
self.rounded_vocab_size = vocab_size + (128 - vocab_size % 128) % 128
# Input embedding
self.tok_embed = nn.Embedding(self.rounded_vocab_size, hidden_size)
nn.init.normal_(self.tok_embed.weight, std=0.02)
self.rotary_emb = Rotary(hidden_size // n_heads, max_seq_len=max_seq_len)
self.blocks = nn.ModuleList([
ARBlock(hidden_size, n_heads, dropout=dropout) for _ in range(n_blocks)
])
self.norm_final = nn.LayerNorm(hidden_size)
# Decoupled output head to match the Block-AR baseline parameterization.
self.lm_head = nn.Linear(hidden_size, self.rounded_vocab_size, bias=False)
nn.init.normal_(self.lm_head.weight, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, # accepted for API symmetry; unused
) -> torch.Tensor:
"""
Full-sequence forward. Training / full-sequence eval path.
Args:
input_ids: [B, S] int64
Returns:
logits: [B, S, V] (sliced to true vocab, not rounded)
"""
x = self.tok_embed(input_ids) # [B, S, d]
rotary_cos_sin = self.rotary_emb(x)
for block in self.blocks:
x = block(x, rotary_cos_sin)
x = self.norm_final(x)
logits = self.lm_head(x)[..., :self.vocab_size]
return logits
def forward_cached(
self,
input_ids: torch.Tensor,
past_kv_list: Optional[List[KVPair]] = None,
) -> Tuple[torch.Tensor, List[KVPair]]:
"""
Cached forward for generation.
Args:
input_ids: [B, S_new]
past_kv_list: None (first call) or list of length n_blocks;
each entry is (k, v) of shape [B, H, S_cache, D].
Returns:
logits: [B, S_new, V]
new_kv_list: list of length n_blocks with updated (k, v) of
shape [B, H, S_cache + S_new, D].
"""
B, S_new = input_ids.shape
device = input_ids.device
S_cache = 0 if past_kv_list is None else past_kv_list[0][0].size(2)
total_len = S_cache + S_new
assert total_len <= self.max_seq_len, (
f"cache+new ({S_cache}+{S_new}={total_len}) exceeds "
f"max_seq_len={self.max_seq_len}"
)
x = self.tok_embed(input_ids) # [B, S_new, d]
# RoPE positions for the new tokens: [S_cache .. S_cache + S_new - 1]
position_ids = torch.arange(S_cache, S_cache + S_new, device=device)
rotary_cos_sin = self.rotary_emb(x, position_ids=position_ids)
new_kv_list: List[KVPair] = []
for i, block in enumerate(self.blocks):
pkv = None if past_kv_list is None else past_kv_list[i]
x, new_kv = block.forward_cached(x, rotary_cos_sin, past_kv=pkv)
new_kv_list.append(new_kv)
x = self.norm_final(x)
logits = self.lm_head(x)[..., :self.vocab_size]
return logits, new_kv_list
|