bmeyer2025 commited on
Commit
feccb58
Β·
verified Β·
1 Parent(s): 94d17bb

Upload src/model_modern.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model_modern.py +166 -0
src/model_modern.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modernized GPT model.
3
+
4
+ Same architecture as model.py but with all four swaps applied:
5
+ 1. RMSNorm (replaces LayerNorm everywhere)
6
+ 2. SwiGLU (replaces ReLU FFN)
7
+ 3. RoPE (replaces learned positional embeddings)
8
+ 4. KV Cache (for fast inference generation)
9
+
10
+ The positional embedding table is removed entirely β€” position is encoded
11
+ via RoPE rotations directly in each attention head.
12
+
13
+ BUG FIX (2026-03-29): RoPE positions were wrong during KV cache generation.
14
+ When generating token-by-token with use_cache=True, we were computing RoPE
15
+ for position 0 every time instead of the actual position. This made every
16
+ generated token think it was at position 0 β†’ garbage output. Fixed by
17
+ tracking _cache_pos and passing position offset to forward().
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from modernize import ModernBlock, RMSNorm, precompute_rope_freqs
25
+
26
+
27
+ class ModernGPT(nn.Module):
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_size: int,
32
+ n_embd: int = 384,
33
+ n_heads: int = 6,
34
+ n_layer: int = 6,
35
+ block_size: int = 256,
36
+ dropout: float = 0.2,
37
+ ):
38
+ super().__init__()
39
+ self.block_size = block_size
40
+ self.n_heads = n_heads
41
+ self.head_size = n_embd // n_heads
42
+
43
+ # Token embedding only β€” no positional embedding table (RoPE handles position)
44
+ self.token_emb = nn.Embedding(vocab_size, n_embd)
45
+
46
+ self.blocks = nn.ModuleList([
47
+ ModernBlock(n_embd=n_embd, n_heads=n_heads, block_size=block_size, dropout=dropout)
48
+ for _ in range(n_layer)
49
+ ])
50
+ self.ln_f = RMSNorm(n_embd)
51
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
52
+
53
+ # Weight tying
54
+ self.lm_head.weight = self.token_emb.weight
55
+
56
+ # Track position for KV cache generation
57
+ self._cache_pos = 0
58
+
59
+ self._init_weights()
60
+
61
+ def _init_weights(self):
62
+ for module in self.modules():
63
+ if isinstance(module, nn.Linear):
64
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
65
+ if module.bias is not None:
66
+ nn.init.zeros_(module.bias)
67
+ elif isinstance(module, nn.Embedding):
68
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
69
+
70
+ def clear_kv_cache(self):
71
+ self._cache_pos = 0
72
+ for block in self.blocks:
73
+ block.clear_cache()
74
+
75
+ def forward(
76
+ self,
77
+ idx: torch.Tensor,
78
+ targets: torch.Tensor | None = None,
79
+ use_cache: bool = False,
80
+ ):
81
+ B, T = idx.shape
82
+ assert T <= self.block_size
83
+
84
+ # Precompute RoPE frequencies.
85
+ # During KV cache generation, we need frequencies for the ACTUAL
86
+ # positions (cache_pos .. cache_pos + T), not always 0..T.
87
+ # We precompute for max length and slice to the right range.
88
+ max_pos = self._cache_pos + T
89
+ cos_full, sin_full = precompute_rope_freqs(self.head_size, max_pos, idx.device)
90
+ # Slice to just the positions we need
91
+ cos = cos_full[self._cache_pos : max_pos] # (T, head_size//2)
92
+ sin = sin_full[self._cache_pos : max_pos]
93
+
94
+ if use_cache:
95
+ self._cache_pos += T
96
+
97
+ x = self.token_emb(idx) # (B, T, n_embd)
98
+
99
+ for block in self.blocks:
100
+ x = block(x, cos, sin, use_cache=use_cache)
101
+
102
+ x = self.ln_f(x)
103
+ logits = self.lm_head(x) # (B, T, vocab_size)
104
+
105
+ loss = None
106
+ if targets is not None:
107
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
108
+
109
+ return logits, loss
110
+
111
+ @torch.no_grad()
112
+ def generate(
113
+ self,
114
+ idx: torch.Tensor,
115
+ max_new_tokens: int,
116
+ temperature: float = 1.0,
117
+ top_k: int | None = None,
118
+ ) -> torch.Tensor:
119
+ """Generate tokens using KV cache for fast inference."""
120
+ self.eval()
121
+ self.clear_kv_cache()
122
+
123
+ # Process the prompt all at once to fill the cache
124
+ if idx.shape[1] > 1:
125
+ _, _ = self(idx, use_cache=True)
126
+
127
+ for _ in range(max_new_tokens):
128
+ # Only pass the last token β€” KV cache has the rest
129
+ # RoPE now correctly uses position = cache_pos (not 0!)
130
+ idx_last = idx[:, -1:]
131
+ logits, _ = self(idx_last, use_cache=True)
132
+ logits = logits[:, -1, :] / temperature
133
+
134
+ if top_k is not None:
135
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
136
+ logits[logits < v[:, [-1]]] = float("-inf")
137
+
138
+ probs = F.softmax(logits, dim=-1)
139
+ next_id = torch.multinomial(probs, num_samples=1)
140
+ idx = torch.cat([idx, next_id], dim=1)
141
+
142
+ self.clear_kv_cache()
143
+ return idx
144
+
145
+
146
+ # ── Sanity check ──────────────────────────────────────────────────────────────
147
+ if __name__ == "__main__":
148
+ import time
149
+ from tokenizer import DEVICE, VOCAB_SIZE, BLOCK_SIZE
150
+
151
+ model = ModernGPT(vocab_size=VOCAB_SIZE, block_size=BLOCK_SIZE).to(DEVICE)
152
+
153
+ n_params = sum(p.numel() for p in model.parameters())
154
+ print(f"ModernGPT parameters : {n_params:,} (~{n_params/1e6:.1f}M)")
155
+
156
+ # Forward pass
157
+ x = torch.zeros((2, 8), dtype=torch.long, device=DEVICE)
158
+ logits, loss = model(x, x)
159
+ print(f"Logits shape : {logits.shape}")
160
+ print(f"Loss (untrained) : {loss.item():.4f}")
161
+
162
+ # Confirm no positional embedding table
163
+ has_pos_emb = hasattr(model, "pos_emb")
164
+ print(f"Has pos_emb table : {has_pos_emb} (expected False β€” using RoPE)")
165
+
166
+ print("\nModernGPT OK.")