epoyraz commited on
Commit
d47aee7
·
verified ·
1 Parent(s): ab0b8a1

Upgrade to modded-nanoGPT + Muon checkpoint (val 2.65 -> 2.45)

Browse files
Files changed (4) hide show
  1. README.md +33 -37
  2. config.json +5 -2
  3. model.py +64 -18
  4. tinystories-25m.pt +2 -2
README.md CHANGED
@@ -12,7 +12,8 @@ tags:
12
  - pytorch
13
  - rope
14
  - gqa
15
- - swiglu
 
16
  - multi-token-prediction
17
  pipeline_tag: text-generation
18
  ---
@@ -21,34 +22,40 @@ pipeline_tag: text-generation
21
 
22
  A small (~19.2M parameter) decoder-only GPT trained **from scratch** on
23
  [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). It writes
24
- simple, coherent children's stories and is meant as a compact, hackable reference
25
- for modern LLM architecture techniques — small enough to train end-to-end in a few
26
- minutes on a consumer GPU (RTX 2060 Super, 8 GB).
 
 
 
 
 
27
 
28
  ## Sample output
29
 
30
- > **Once upon a time,** there was a little girl named Lily. She loved to play with
31
- > her dolls and sing songs. One day, she went to the park to play with her friends.
32
- > She saw a boy playing with a toy car and asked why he played too much...
 
33
 
34
- > **Lily and Tom went to the park and** played on the swings. They had a lot of fun.
35
- > They played with their toys and had a lot of fun. They also learned to be good and
36
- > not judge others. They were happy.
37
 
38
  ## Architecture
39
 
40
- A LLaMA-style decoder-only transformer with several modern techniques wired in:
41
 
42
  | Component | Choice |
43
  |---|---|
44
  | Layers / heads / dim | 8 layers, 6 heads, `n_embd` 384 |
45
  | Context length | 256 tokens |
46
  | Vocabulary | 16,384 (ByteLevel BPE) |
47
- | Position encoding | **RoPE** (rotary embeddings) |
48
- | Attention | **Grouped-Query Attention** (2 KV heads) |
49
- | MLP | **SwiGLU** |
50
  | Normalization | **RMSNorm** |
51
- | Extra heads | **Multi-Token Prediction** (2 auxiliary heads) for sample efficiency |
 
52
  | Weight tying | token embedding ↔ output head (and MTP heads) |
53
 
54
  ## Training
@@ -57,20 +64,17 @@ A LLaMA-style decoder-only transformer with several modern techniques wired in:
57
  |---|---|
58
  | Dataset | TinyStories (~2.1M stories) |
59
  | Steps | 3,000 |
60
- | Batch | 32 × 256 tokens |
61
- | Optimizer | AdamW, cosine schedule, 200-step warmup, peak LR 6e-4 |
62
- | Precision | fp16 mixed precision |
63
- | Hardware | 1× RTX 2060 Super (8 GB), ~7 minutes |
64
- | Throughput | ~57K tokens/sec |
65
- | Final loss | 2.62 (combined next-token + MTP auxiliary) |
66
- | Validation loss | 2.65 |
67
-
68
- This is a lightly trained demo checkpoint; longer training lowers loss further.
69
 
70
  ## Usage
71
 
72
- This is a **custom architecture**, so you need `model.py` from this repo (it's small
73
- and dependency-light). Download it next to your script, then:
74
 
75
  ```python
76
  import torch
@@ -105,22 +109,14 @@ print(tok.decode(out[0].tolist()))
105
 
106
  ## Limitations
107
 
108
- - Trained only on TinyStories — vocabulary and style are limited to simple
109
- children's-story English. It is not a general-purpose assistant.
110
- - Small and lightly trained: it repeats phrases and occasionally drifts or
111
- contradicts itself (e.g. swapping character names).
112
  - 256-token context.
113
 
114
- ## Source
115
-
116
- Trained with the "train a language model from scratch" project — a from-scratch GPT
117
- with independently configurable modern techniques (RoPE, GQA, SwiGLU, RMSNorm, MTP,
118
- mHC, BitNet, TurboQuant) plus Muon/AdamW optimizers and speculative decoding.
119
-
120
  ## References
121
 
122
  - [TinyStories](https://arxiv.org/abs/2305.07759)
123
  - [RoFormer / RoPE](https://arxiv.org/abs/2104.09864)
124
  - [GQA](https://arxiv.org/abs/2305.13245)
125
- - [GLU Variants / SwiGLU](https://arxiv.org/abs/2002.05202)
126
  - [DeepSeek-V3 (MTP)](https://arxiv.org/abs/2412.19437)
 
 
12
  - pytorch
13
  - rope
14
  - gqa
15
+ - qk-norm
16
+ - muon
17
  - multi-token-prediction
18
  pipeline_tag: text-generation
19
  ---
 
22
 
23
  A small (~19.2M parameter) decoder-only GPT trained **from scratch** on
24
  [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories). It writes
25
+ simple, coherent children's stories and is a compact, hackable reference for modern
26
+ LLM architecture + optimization techniques — trained end-to-end in a few minutes on a
27
+ single consumer GPU (RTX 2060 Super, 8 GB).
28
+
29
+ This checkpoint uses the **modded-nanoGPT-style recipe**: trained with the **Muon**
30
+ optimizer and **QK-Norm + squared-ReLU MLP + logit soft-capping**, which improved
31
+ validation loss from 2.65 to **2.45** versus a plain AdamW/SwiGLU baseline at the same
32
+ 3,000 steps.
33
 
34
  ## Sample output
35
 
36
+ > **Once upon a time,** there was a little girl named Lily. She loved to play outside
37
+ > and explore the world around her. One day, she found a long piece of cardboard on the
38
+ > floor. It was a big, white box with a bow on it. She picked it up and opened it. Inside
39
+ > the box, she found a toy car...
40
 
41
+ > **Lily and Tom went to the park and** saw a man with a big hat and a big smile. He was
42
+ > very nice... "Sure, you can play with us," Lily said. They played tag and hide and seek.
 
43
 
44
  ## Architecture
45
 
46
+ A LLaMA-/modded-nanoGPT-style decoder-only transformer:
47
 
48
  | Component | Choice |
49
  |---|---|
50
  | Layers / heads / dim | 8 layers, 6 heads, `n_embd` 384 |
51
  | Context length | 256 tokens |
52
  | Vocabulary | 16,384 (ByteLevel BPE) |
53
+ | Position encoding | **RoPE** |
54
+ | Attention | **Grouped-Query Attention** (2 KV heads) + **QK-Norm** |
55
+ | MLP | **squared-ReLU** (ungated) |
56
  | Normalization | **RMSNorm** |
57
+ | Logits | **soft-capped** at 15 (`cap·tanh(logits/cap)`) |
58
+ | Extra heads | **Multi-Token Prediction** (2 auxiliary heads) |
59
  | Weight tying | token embedding ↔ output head (and MTP heads) |
60
 
61
  ## Training
 
64
  |---|---|
65
  | Dataset | TinyStories (~2.1M stories) |
66
  | Steps | 3,000 |
67
+ | Batch | 40 × 256 tokens |
68
+ | Optimizer | **Muon** (2D weights) + AdamW (embeddings/norms), peak LR 3e-3, cosine schedule |
69
+ | Precision | fp16 mixed precision, `torch.compile` |
70
+ | Hardware | 1× RTX 2060 Super (8 GB), ~11 minutes (~47K tokens/sec) |
71
+ | Train loss | 2.47 (combined next-token + MTP auxiliary) |
72
+ | **Validation loss** | **2.45** (perplexity 11.5) |
 
 
 
73
 
74
  ## Usage
75
 
76
+ This is a **custom architecture**, so you need `model.py` from this repo (small,
77
+ dependency-light). Download it next to your script, then:
78
 
79
  ```python
80
  import torch
 
109
 
110
  ## Limitations
111
 
112
+ - Trained only on TinyStories — simple children's-story English, not a general assistant.
113
+ - Small and lightly trained: occasional repetition, name swaps, or drift.
 
 
114
  - 256-token context.
115
 
 
 
 
 
 
 
116
  ## References
117
 
118
  - [TinyStories](https://arxiv.org/abs/2305.07759)
119
  - [RoFormer / RoPE](https://arxiv.org/abs/2104.09864)
120
  - [GQA](https://arxiv.org/abs/2305.13245)
 
121
  - [DeepSeek-V3 (MTP)](https://arxiv.org/abs/2412.19437)
122
+ - [Muon optimizer](https://kellerjordan.github.io/posts/muon/) · [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt)
config.json CHANGED
@@ -6,10 +6,13 @@
6
  "n_layer": 8,
7
  "use_rope": true,
8
  "n_kv_head": 2,
9
- "use_swiglu": true,
10
  "use_rmsnorm": true,
11
  "use_mtp": true,
12
  "mtp_heads": 2,
13
  "mtp_weight": 0.1,
14
- "tie_mtp_lm_head": true
 
 
 
15
  }
 
6
  "n_layer": 8,
7
  "use_rope": true,
8
  "n_kv_head": 2,
9
+ "use_swiglu": false,
10
  "use_rmsnorm": true,
11
  "use_mtp": true,
12
  "mtp_heads": 2,
13
  "mtp_weight": 0.1,
14
+ "tie_mtp_lm_head": true,
15
+ "use_relu2": true,
16
+ "use_qk_norm": true,
17
+ "logit_cap": 15.0
18
  }
model.py CHANGED
@@ -5,6 +5,13 @@ import math
5
  from torch.utils.checkpoint import checkpoint
6
 
7
 
 
 
 
 
 
 
 
8
  # --- mHC: Manifold-Constrained Hyper-Connections ---
9
 
10
  def sinkhorn(log_alpha, n_iters=5):
@@ -264,28 +271,26 @@ class MTPHead(nn.Module):
264
  self.future_idx = future_idx
265
  n_embd = config["n_embd"]
266
  vocab_size = config["vocab_size"]
 
267
  self.proj = nn.Linear(n_embd, n_embd)
268
  self.ln = nn.LayerNorm(n_embd)
269
  self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
270
 
271
  def forward(self, hidden, targets=None):
 
 
 
272
  if targets is not None:
273
  shift = self.future_idx
274
- if targets.size(1) <= shift:
275
- return None, None
276
- # Only the first T-shift positions have a future target, so project
277
- # just those instead of the full sequence (saves a vocab matmul slice).
278
- h = self.ln(self.proj(hidden[:, :-shift]))
279
- logits = self.lm_head(h)
280
- targets_shifted = targets[:, shift:]
281
- loss = F.cross_entropy(
282
- logits.reshape(-1, logits.size(-1)),
283
- targets_shifted.reshape(-1),
284
- ignore_index=-1,
285
- )
286
- return logits, loss
287
- h = self.ln(self.proj(hidden))
288
- return self.lm_head(h), None
289
 
290
 
291
  # --- RoPE: Rotary Position Embeddings ---
@@ -339,6 +344,23 @@ class SwiGLU(nn.Module):
339
  return self.down(F.silu(self.gate(x)) * self.up(x))
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  # --- Core model ---
343
 
344
  def make_norm(n_embd, use_rmsnorm=False):
@@ -359,6 +381,7 @@ class CausalSelfAttention(nn.Module):
359
  raise ValueError(f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})")
360
  self.head_dim = self.n_embd // self.n_head
361
  self.use_rope = config.get("use_rope", False)
 
362
  use_bitnet = config.get("use_bitnet", False)
363
  use_fast_bitnet = config.get("use_fast_bitnet", False)
364
 
@@ -367,6 +390,11 @@ class CausalSelfAttention(nn.Module):
367
  self.v_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
368
  self.proj = make_linear(self.n_embd, self.n_embd, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
369
 
 
 
 
 
 
370
  if self.use_rope:
371
  self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.get("block_size", 512))
372
 
@@ -376,6 +404,10 @@ class CausalSelfAttention(nn.Module):
376
  k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
377
  v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
378
 
 
 
 
 
379
  if self.use_rope:
380
  cos, sin = self.rope(pos_offset + T)
381
  cos, sin = cos[pos_offset:pos_offset + T], sin[pos_offset:pos_offset + T]
@@ -416,7 +448,9 @@ class Block(nn.Module):
416
  self.ln1 = make_norm(config["n_embd"], use_rmsnorm)
417
  self.attn = CausalSelfAttention(config)
418
  self.ln2 = make_norm(config["n_embd"], use_rmsnorm)
419
- if config.get("use_swiglu", False):
 
 
420
  self.mlp = SwiGLU(config)
421
  else:
422
  self.mlp = MLP(config)
@@ -452,6 +486,7 @@ class GPT(nn.Module):
452
  self.use_turboquant = config.get("use_turboquant", False)
453
  self.turboquant_bits = config.get("turboquant_bits", 4)
454
  self.use_activation_checkpointing = config.get("use_activation_checkpointing", False)
 
455
  use_rmsnorm = config.get("use_rmsnorm", False)
456
 
457
  self.tok_emb = nn.Embedding(config["vocab_size"], config["n_embd"])
@@ -513,7 +548,7 @@ class GPT(nn.Module):
513
 
514
  def forward(self, idx, targets=None, return_hidden=False):
515
  hidden = self._compute_hidden(idx)
516
- logits = self.lm_head(hidden)
517
  loss = None
518
  if targets is not None:
519
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
@@ -536,7 +571,7 @@ class GPT(nn.Module):
536
  for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
537
  x = block(x, kv_cache=cache, pos_offset=pos_offset)
538
  hidden = self.ln_f(x)
539
- logits = self.lm_head(hidden)
540
  if return_hidden:
541
  return logits, hidden
542
  return logits
@@ -916,6 +951,16 @@ FAST_2060_MTP_FBITNET_CONFIG = {
916
  "use_fast_bitnet": True,
917
  }
918
 
 
 
 
 
 
 
 
 
 
 
919
  FAST_2060_MTP_TURBO_CONFIG = {
920
  **FAST_2060_MTP_CONFIG,
921
  "use_turboquant": True,
@@ -957,6 +1002,7 @@ CONFIGS = {
957
  "fast_2060": FAST_2060_CONFIG,
958
  "fast_2060_mtp": FAST_2060_MTP_CONFIG,
959
  "fast_2060_mtp_fbitnet": FAST_2060_MTP_FBITNET_CONFIG,
 
960
  "fast_2060_mtp_turbo": FAST_2060_MTP_TURBO_CONFIG,
961
  "tiny_fast": TINY_FAST_CONFIG,
962
  "low_memory_2060": LOW_MEMORY_2060_CONFIG,
 
5
  from torch.utils.checkpoint import checkpoint
6
 
7
 
8
+ def soft_cap(logits, cap):
9
+ """Gemma2/modded-nanoGPT logit soft-capping: cap * tanh(logits / cap). No-op if cap falsy."""
10
+ if cap:
11
+ return cap * torch.tanh(logits / cap)
12
+ return logits
13
+
14
+
15
  # --- mHC: Manifold-Constrained Hyper-Connections ---
16
 
17
  def sinkhorn(log_alpha, n_iters=5):
 
271
  self.future_idx = future_idx
272
  n_embd = config["n_embd"]
273
  vocab_size = config["vocab_size"]
274
+ self.logit_cap = config.get("logit_cap", 0)
275
  self.proj = nn.Linear(n_embd, n_embd)
276
  self.ln = nn.LayerNorm(n_embd)
277
  self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
278
 
279
  def forward(self, hidden, targets=None):
280
+ h = self.ln(self.proj(hidden))
281
+ logits = soft_cap(self.lm_head(h), self.logit_cap)
282
+ loss = None
283
  if targets is not None:
284
  shift = self.future_idx
285
+ if targets.size(1) > shift:
286
+ logits_shifted = logits[:, :-shift].contiguous()
287
+ targets_shifted = targets[:, shift:].contiguous()
288
+ loss = F.cross_entropy(
289
+ logits_shifted.view(-1, logits_shifted.size(-1)),
290
+ targets_shifted.view(-1),
291
+ ignore_index=-1,
292
+ )
293
+ return logits, loss
 
 
 
 
 
 
294
 
295
 
296
  # --- RoPE: Rotary Position Embeddings ---
 
344
  return self.down(F.silu(self.gate(x)) * self.up(x))
345
 
346
 
347
+ class ReLU2MLP(nn.Module):
348
+ """Ungated MLP with squared-ReLU activation (modded-nanoGPT). Simpler and a bit
349
+ faster than SwiGLU; competitive quality at small scale."""
350
+
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ n_embd = config["n_embd"]
354
+ hidden = 4 * n_embd
355
+ use_bitnet = config.get("use_bitnet", False)
356
+ use_fast_bitnet = config.get("use_fast_bitnet", False)
357
+ self.fc = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
358
+ self.proj = make_linear(hidden, n_embd, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
359
+
360
+ def forward(self, x):
361
+ return self.proj(F.relu(self.fc(x)).square())
362
+
363
+
364
  # --- Core model ---
365
 
366
  def make_norm(n_embd, use_rmsnorm=False):
 
381
  raise ValueError(f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})")
382
  self.head_dim = self.n_embd // self.n_head
383
  self.use_rope = config.get("use_rope", False)
384
+ self.use_qk_norm = config.get("use_qk_norm", False)
385
  use_bitnet = config.get("use_bitnet", False)
386
  use_fast_bitnet = config.get("use_fast_bitnet", False)
387
 
 
390
  self.v_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
391
  self.proj = make_linear(self.n_embd, self.n_embd, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
392
 
393
+ # QK-Norm (modded-nanoGPT): RMSNorm Q and K over the head dim before attention.
394
+ if self.use_qk_norm:
395
+ self.q_norm = nn.RMSNorm(self.head_dim)
396
+ self.k_norm = nn.RMSNorm(self.head_dim)
397
+
398
  if self.use_rope:
399
  self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.get("block_size", 512))
400
 
 
404
  k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
405
  v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
406
 
407
+ if self.use_qk_norm:
408
+ q = self.q_norm(q)
409
+ k = self.k_norm(k)
410
+
411
  if self.use_rope:
412
  cos, sin = self.rope(pos_offset + T)
413
  cos, sin = cos[pos_offset:pos_offset + T], sin[pos_offset:pos_offset + T]
 
448
  self.ln1 = make_norm(config["n_embd"], use_rmsnorm)
449
  self.attn = CausalSelfAttention(config)
450
  self.ln2 = make_norm(config["n_embd"], use_rmsnorm)
451
+ if config.get("use_relu2", False):
452
+ self.mlp = ReLU2MLP(config)
453
+ elif config.get("use_swiglu", False):
454
  self.mlp = SwiGLU(config)
455
  else:
456
  self.mlp = MLP(config)
 
486
  self.use_turboquant = config.get("use_turboquant", False)
487
  self.turboquant_bits = config.get("turboquant_bits", 4)
488
  self.use_activation_checkpointing = config.get("use_activation_checkpointing", False)
489
+ self.logit_cap = config.get("logit_cap", 0)
490
  use_rmsnorm = config.get("use_rmsnorm", False)
491
 
492
  self.tok_emb = nn.Embedding(config["vocab_size"], config["n_embd"])
 
548
 
549
  def forward(self, idx, targets=None, return_hidden=False):
550
  hidden = self._compute_hidden(idx)
551
+ logits = soft_cap(self.lm_head(hidden), self.logit_cap)
552
  loss = None
553
  if targets is not None:
554
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
 
571
  for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
572
  x = block(x, kv_cache=cache, pos_offset=pos_offset)
573
  hidden = self.ln_f(x)
574
+ logits = soft_cap(self.lm_head(hidden), self.logit_cap)
575
  if return_hidden:
576
  return logits, hidden
577
  return logits
 
951
  "use_fast_bitnet": True,
952
  }
953
 
954
+ # modded-nanoGPT-style recipe. QK-Norm helps under any optimizer; ReLU2 and
955
+ # logit_cap only pay off paired with Muon's higher LR. Train with --optimizer muon.
956
+ FAST_2060_MODDED_CONFIG = {
957
+ **FAST_2060_MTP_CONFIG,
958
+ "use_swiglu": False, # superseded by ReLU2 below
959
+ "use_relu2": True,
960
+ "use_qk_norm": True,
961
+ "logit_cap": 15.0,
962
+ }
963
+
964
  FAST_2060_MTP_TURBO_CONFIG = {
965
  **FAST_2060_MTP_CONFIG,
966
  "use_turboquant": True,
 
1002
  "fast_2060": FAST_2060_CONFIG,
1003
  "fast_2060_mtp": FAST_2060_MTP_CONFIG,
1004
  "fast_2060_mtp_fbitnet": FAST_2060_MTP_FBITNET_CONFIG,
1005
+ "fast_2060_modded": FAST_2060_MODDED_CONFIG,
1006
  "fast_2060_mtp_turbo": FAST_2060_MTP_TURBO_CONFIG,
1007
  "tiny_fast": TINY_FAST_CONFIG,
1008
  "low_memory_2060": LOW_MEMORY_2060_CONFIG,
tinystories-25m.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f08fa57d4360cd654e407322bce66695018c5b9b673df8be5f8c9f5631fe3103
3
- size 76793291
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69375d07a06ef3b325f3189b23b0caf21a7983fc1e87316b0f5651c579331af3
3
+ size 76800459