File size: 5,757 Bytes
feccb58 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
Modernized GPT model.
Same architecture as model.py but with all four swaps applied:
1. RMSNorm (replaces LayerNorm everywhere)
2. SwiGLU (replaces ReLU FFN)
3. RoPE (replaces learned positional embeddings)
4. KV Cache (for fast inference generation)
The positional embedding table is removed entirely β position is encoded
via RoPE rotations directly in each attention head.
BUG FIX (2026-03-29): RoPE positions were wrong during KV cache generation.
When generating token-by-token with use_cache=True, we were computing RoPE
for position 0 every time instead of the actual position. This made every
generated token think it was at position 0 β garbage output. Fixed by
tracking _cache_pos and passing position offset to forward().
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from modernize import ModernBlock, RMSNorm, precompute_rope_freqs
class ModernGPT(nn.Module):
def __init__(
self,
vocab_size: int,
n_embd: int = 384,
n_heads: int = 6,
n_layer: int = 6,
block_size: int = 256,
dropout: float = 0.2,
):
super().__init__()
self.block_size = block_size
self.n_heads = n_heads
self.head_size = n_embd // n_heads
# Token embedding only β no positional embedding table (RoPE handles position)
self.token_emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([
ModernBlock(n_embd=n_embd, n_heads=n_heads, block_size=block_size, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = RMSNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.token_emb.weight
# Track position for KV cache generation
self._cache_pos = 0
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def clear_kv_cache(self):
self._cache_pos = 0
for block in self.blocks:
block.clear_cache()
def forward(
self,
idx: torch.Tensor,
targets: torch.Tensor | None = None,
use_cache: bool = False,
):
B, T = idx.shape
assert T <= self.block_size
# Precompute RoPE frequencies.
# During KV cache generation, we need frequencies for the ACTUAL
# positions (cache_pos .. cache_pos + T), not always 0..T.
# We precompute for max length and slice to the right range.
max_pos = self._cache_pos + T
cos_full, sin_full = precompute_rope_freqs(self.head_size, max_pos, idx.device)
# Slice to just the positions we need
cos = cos_full[self._cache_pos : max_pos] # (T, head_size//2)
sin = sin_full[self._cache_pos : max_pos]
if use_cache:
self._cache_pos += T
x = self.token_emb(idx) # (B, T, n_embd)
for block in self.blocks:
x = block(x, cos, sin, use_cache=use_cache)
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
) -> torch.Tensor:
"""Generate tokens using KV cache for fast inference."""
self.eval()
self.clear_kv_cache()
# Process the prompt all at once to fill the cache
if idx.shape[1] > 1:
_, _ = self(idx, use_cache=True)
for _ in range(max_new_tokens):
# Only pass the last token β KV cache has the rest
# RoPE now correctly uses position = cache_pos (not 0!)
idx_last = idx[:, -1:]
logits, _ = self(idx_last, use_cache=True)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_id], dim=1)
self.clear_kv_cache()
return idx
# ββ Sanity check ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import time
from tokenizer import DEVICE, VOCAB_SIZE, BLOCK_SIZE
model = ModernGPT(vocab_size=VOCAB_SIZE, block_size=BLOCK_SIZE).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f"ModernGPT parameters : {n_params:,} (~{n_params/1e6:.1f}M)")
# Forward pass
x = torch.zeros((2, 8), dtype=torch.long, device=DEVICE)
logits, loss = model(x, x)
print(f"Logits shape : {logits.shape}")
print(f"Loss (untrained) : {loss.item():.4f}")
# Confirm no positional embedding table
has_pos_emb = hasattr(model, "pos_emb")
print(f"Has pos_emb table : {has_pos_emb} (expected False β using RoPE)")
print("\nModernGPT OK.")
|