brettleehari commited on
Commit
14c107a
·
verified ·
1 Parent(s): 73815b2

Initial microGPT upload

Browse files
Files changed (5) hide show
  1. README.md +473 -0
  2. ckpt.pt +3 -0
  3. inference.py +114 -0
  4. model.py +152 -0
  5. tokenizer.json +0 -0
README.md ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - text-generation
7
+ - transformer
8
+ - educational
9
+ - tiny-llm
10
+ - from-scratch
11
+ - decoder-only
12
+ - gpt
13
+ datasets:
14
+ - roneneldan/TinyStories
15
+ pipeline_tag: text-generation
16
+ library_name: pytorch
17
+ model-index:
18
+ - name: microgpt
19
+ results:
20
+ - task:
21
+ type: text-generation
22
+ name: Story completion
23
+ dataset:
24
+ name: TinyStories (validation split)
25
+ type: roneneldan/TinyStories
26
+ metrics:
27
+ - type: cross-entropy
28
+ value: 2.25
29
+ name: Validation cross-entropy loss
30
+ - type: perplexity
31
+ value: 9.49
32
+ name: Validation perplexity
33
+ ---
34
+
35
+ # microGPT
36
+
37
+ A **1.35M-parameter decoder-only transformer** trained from scratch on the
38
+ [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset.
39
+ The entire training run took roughly two hours on an Apple Silicon laptop.
40
+ At ~50,000× smaller than GPT-3, it can still produce coherent simple
41
+ children's stories.
42
+
43
+ This is an **educational artifact**, not a production model. Its purpose is
44
+ to make every component of a modern LLM legible, debuggable, and rebuildable
45
+ on consumer hardware.
46
+
47
+ ---
48
+
49
+ ## Quick facts
50
+
51
+ | | |
52
+ |---|---|
53
+ | **Architecture** | Decoder-only transformer (GPT-style) |
54
+ | **Parameters** | 1,345,792 trainable (1.35M) |
55
+ | **File size on disk** | ~5.1 MB (float32) |
56
+ | **Training data** | ~470M tokens of TinyStories |
57
+ | **Training compute** | ~1.5 hours on Apple Silicon (MPS) |
58
+ | **Final val loss** | 2.25 (perplexity 9.49) |
59
+ | **Context window** | 256 tokens |
60
+ | **Tokenizer** | Byte-level BPE, vocab=4096 |
61
+ | **License** | MIT |
62
+
63
+ ---
64
+
65
+ ## Architecture in detail
66
+
67
+ ```
68
+ Input tokens (B, T)
69
+
70
+ ├─► Token Embedding (4096 → 128)
71
+ │ │
72
+ └─► Position Embedding ────┘ ← element-wise sum
73
+
74
+ ▼ (B, T, 128)
75
+ ┌──── Block × 4 ────────────────────────────┐
76
+ │ │
77
+ │ x = LayerNorm(x) │
78
+ │ x = x + CausalSelfAttention(x) ← 4 heads│
79
+ │ x = LayerNorm(x) │
80
+ │ x = x + MLP(x) ← 128→512→128, GELU
81
+ │ │
82
+ └────────────────────────────────────────────┘
83
+
84
+ ▼ (B, T, 128)
85
+ LayerNorm
86
+
87
+
88
+ Linear (128 → 4096) ← weight-tied with token embedding
89
+
90
+ ▼ (B, T, 4096)
91
+ Logits
92
+ ```
93
+
94
+ | Hyperparameter | Value | Notes |
95
+ |---|---|---|
96
+ | `n_layers` | 4 | Stacked transformer blocks |
97
+ | `d_model` | 128 | Hidden dimension |
98
+ | `n_heads` | 4 | Each head is 128/4 = 32 dim |
99
+ | `head_dim` | 32 | Per-head dimensionality |
100
+ | `ffn_dim` | 512 | MLP intermediate width (4×d_model) |
101
+ | `ctx_len` | 256 | Maximum input length in tokens |
102
+ | `vocab_size` | 4,096 | BPE-derived vocabulary |
103
+ | Normalization | LayerNorm | Pre-LN (applied before sublayers) |
104
+ | Position encoding | Learned | Absolute, additive |
105
+ | Activation | GELU | In the MLP |
106
+ | Attention | Multi-head, causal | Implemented via `F.scaled_dot_product_attention` |
107
+ | Embedding tying | Yes | Output projection shares weight with `tok_emb` |
108
+ | Bias on linear layers | No | Following common modern practice |
109
+ | Dropout | 0.1 (training) | 0.0 at inference |
110
+
111
+ ### Parameter breakdown — where the 1.35M live
112
+
113
+ | Component | Shape | Params | % |
114
+ |---|---|---|---|
115
+ | Token embeddings (`tok_emb.weight`) | (4096, 128) | 524,288 | 38.9% |
116
+ | Position embeddings (`pos_emb.weight`) | (256, 128) | 32,768 | 2.4% |
117
+ | 4 × transformer block | — | 788,480 | 58.6% |
118
+ | └─ Per block: `ln1` (γ, β) | (128,) × 2 | 256 | |
119
+ | └─ Per block: `attn.qkv` | (384, 128) | 49,152 | |
120
+ | └─ Per block: `attn.proj` | (128, 128) | 16,384 | |
121
+ | └─ Per block: `ln2` (γ, β) | (128,) × 2 | 256 | |
122
+ | └─ Per block: `mlp.fc1` | (512, 128) | 65,536 | |
123
+ | └─ Per block: `mlp.fc2` | (128, 512) | 65,536 | |
124
+ | Final LayerNorm (`ln_f`) | (128,) × 2 | 256 | 0.02% |
125
+ | Output projection (`head.weight`) | (4096, 128) | 0 | tied |
126
+ | **Total** | | **1,345,792** | |
127
+
128
+ Two observations worth absorbing:
129
+
130
+ - **Embeddings are 41% of total parameters** at this scale. This is typical of small models — the vocab × d_model matrix dominates. As models grow, the transformer blocks become the much larger fraction (frontier models are >90% transformer body, with embeddings a rounding error).
131
+ - **MLPs (`fc1` + `fc2`) account for half of every block's params**: 131,072 of 197,120 = 66%. Recent interpretability research suggests MLPs are where most factual knowledge gets stored. At frontier scale this stays roughly true.
132
+
133
+ ---
134
+
135
+ ## Training
136
+
137
+ ### Data
138
+
139
+ - **Dataset:** [`roneneldan/TinyStories`](https://huggingface.co/datasets/roneneldan/TinyStories) (Eldan & Li, 2023)
140
+ - **Stories:** ~2.1M (train) + ~22K (validation)
141
+ - **Tokens (after BPE):** ~470M (train) + ~5M (validation)
142
+ - **Why TinyStories specifically:** synthetic dataset designed so vocabulary
143
+ and grammar stay within what a 3–4 year-old understands, making coherent
144
+ generation possible at very small model scales. Without this curation, a
145
+ 1.35M-param model on general web text produces gibberish.
146
+
147
+ ### Tokenizer
148
+
149
+ - **Type:** byte-level Byte-Pair Encoding (BPE)
150
+ - **Vocabulary:** 4,096 tokens (including special tokens `<unk>`, `<eos>`)
151
+ - **Trained on:** 50,000 stories from the train split (vocab converges
152
+ quickly; full corpus produces a near-identical tokenizer)
153
+ - **Avg compression:** ~4 characters per token on TinyStories text
154
+
155
+ ### Optimization
156
+
157
+ | Hyperparameter | Value |
158
+ |---|---|
159
+ | Optimizer | AdamW |
160
+ | β₁, β₂ | 0.9, 0.95 |
161
+ | Weight decay | 0.1 |
162
+ | Peak learning rate | 3e-4 |
163
+ | Min learning rate | 3e-5 |
164
+ | Schedule | Linear warmup (200 steps) → cosine decay |
165
+ | Batch size (sequences) | 64 |
166
+ | Sequence length | 256 |
167
+ | Tokens per step | 16,384 |
168
+ | Total steps | 20,000 |
169
+ | Total tokens seen | ~327M |
170
+ | Gradient clipping | 1.0 (global L2 norm) |
171
+ | Random seed | 1337 |
172
+
173
+ ### Hardware & wall-clock
174
+
175
+ | | |
176
+ |---|---|
177
+ | Hardware | Apple M-series laptop (MPS backend) |
178
+ | Precision | float32 |
179
+ | Wall-clock | ~1.5 hours |
180
+ | Peak memory | ~1.5 GB |
181
+ | Disk footprint | ~1 GB tokenized corpus + 5.1 MB checkpoint |
182
+
183
+ ---
184
+
185
+ ## Evaluation
186
+
187
+ ### Held-out validation loss
188
+
189
+ | Step | Val loss | Perplexity |
190
+ |---|---|---|
191
+ | 0 (init) | 8.32 | 4096 |
192
+ | ~17,500 | 2.26 | 9.59 |
193
+ | ~20,000 | **2.25** | **9.49** |
194
+
195
+ For context: a uniform random predictor over the 4,096-token vocabulary
196
+ would have loss `ln(4096) = 8.32` and perplexity 4,096. A perplexity of
197
+ 9.49 means the model has effectively narrowed each prediction down to
198
+ ~9.5 plausible next-token choices, on average.
199
+
200
+ > **Note on missing checkpoint progression.** This release contains only the
201
+ > final checkpoint. The original plan was to capture step-tagged snapshots
202
+ > (steps 200, 1000, 3000, 8000, 20000) to demonstrate capability *emergence*
203
+ > across training — gibberish → words → grammar → narrative. A monitoring
204
+ > script was started after training had already passed the early steps, so
205
+ > only the final checkpoint was preserved. A v0.2 release with multi-checkpoint
206
+ > snapshots and an emergence comparison is planned.
207
+
208
+ ### Capability probes
209
+
210
+ Three diagnostic prompts designed to probe specific capabilities, evaluated
211
+ on the final checkpoint (step 20,000, T=0.8, top-k=40):
212
+
213
+ **Short-range coherence** — *"Lily picked up the red ball and"*
214
+ > ...put it on the floor. She threw the red ball with her hands and pulled
215
+ > it onto the floor. She tried to make it shine, but it made her feel good.
216
+ > "Oh no, no!" she shouted. "Sharing is very dangerous!"
217
+
218
+ Pronoun continuity holds; ball persists as object. ✓
219
+
220
+ **Causal reasoning** — *"Tom was very sad because"*
221
+ > ...he had no one to play. He thought that it would be fun and share his
222
+ > toys with Timmy. His mom took him inside and said, "I love you too, Timmy!"
223
+
224
+ "because" produces a real cause of sadness; emotional consistency holds for
225
+ 2–3 sentences before drifting. ✓ (with caveats)
226
+
227
+ **Narrative composition** — *"The little dragon wanted to fly but"*
228
+ > ...her mommy said no. The bear was very sad that he was gone. He wanted
229
+ > to fly anymore and get lost.
230
+
231
+ Initial obstacle is set up correctly, but the model loses track of which
232
+ character is which (dragon → bear → "he"). ✗
233
+
234
+ This pattern — local coherence ✓, multi-sentence composition partial — is
235
+ expected at this scale. Narrative arc requires planning across many tokens,
236
+ which is one of the last capabilities to emerge in language models even at
237
+ frontier scale.
238
+
239
+ ---
240
+
241
+ ## Intended use
242
+
243
+ **In scope:**
244
+ - Educational reference for the GPT-style transformer architecture
245
+ - Demonstration of end-to-end LLM training on consumer hardware
246
+ - Generating short, simple, TinyStories-style English children's narratives
247
+ - Exploring how sampling parameters (temperature, top-k, top-p) affect output
248
+ - Comparison baseline for tiny-model research
249
+
250
+ **Out of scope:**
251
+ - General-purpose text generation (vocabulary is restricted to TinyStories)
252
+ - Question answering, instruction following, or chat (no SFT or RLHF stage)
253
+ - Anything requiring factual accuracy (no factual grounding)
254
+ - Non-English text (English-only training data)
255
+ - Long-form generation (256-token context window)
256
+
257
+ ---
258
+
259
+ ## Limitations and biases
260
+
261
+ - **Distribution lock-in:** Trained exclusively on synthetic children's
262
+ stories. Generation outside this distribution (e.g., technical text,
263
+ adult themes, dialogue formats) will be incoherent.
264
+ - **No instruction following:** This is a base model — pre-training only.
265
+ It completes text; it does not answer questions or follow instructions.
266
+ - **Hallucination:** No factual grounding. The model has no concept of
267
+ "I don't know" — it produces the most statistically plausible
268
+ continuation, which is often false outside the training distribution.
269
+ - **Context window:** 256 tokens is too short to model long dependencies.
270
+ - **Synthetic data biases:** TinyStories was generated by GPT-3.5/4 with
271
+ prompted constraints, so it inherits some of that generator's stylistic
272
+ patterns and any biases encoded therein.
273
+ - **No safety training:** No RLHF, no Constitutional AI, no content
274
+ filtering. While the training data is innocuous, prompts that
275
+ push toward harmful outputs receive no safeguards.
276
+ - **Memorization vs generalization:** Some completions ("She was very
277
+ happy and they played all day") are likely memorized stylistic
278
+ patterns rather than novel generation.
279
+
280
+ ---
281
+
282
+ ## How to use
283
+
284
+ ### Inference
285
+
286
+ ```python
287
+ from inference import NanoSLMInference
288
+
289
+ slm = NanoSLMInference("ckpt.pt", "tokenizer.json")
290
+
291
+ text = slm.generate(
292
+ "Once upon a time, there was a little",
293
+ max_new_tokens=200,
294
+ temperature=0.8,
295
+ top_k=40,
296
+ )
297
+ print(text)
298
+ ```
299
+
300
+ ### Sampling parameters
301
+
302
+ | Parameter | Effect |
303
+ |---|---|
304
+ | `temperature` | Scales logits before softmax. 0 = greedy (deterministic, often repetitive). 1.0 = no scaling. >1 = more random. Typical: 0.7–1.0. |
305
+ | `top_k` | Keep only the *k* highest-probability tokens. Filters tail-of-distribution garbage. Typical: 40–100. |
306
+ | `top_p` (nucleus) | Keep the smallest set of tokens with cumulative probability ≥ p. Adapts the cutoff to distribution shape. Typical: 0.9–0.95. |
307
+ | `seed` | Sets PyTorch RNG for reproducibility. |
308
+
309
+ ---
310
+
311
+ ## How this model is served
312
+
313
+ A live demo is hosted on [Hugging Face Spaces](https://huggingface.co/spaces/brettleehari/microgpt-demo).
314
+ The serving stack is intentionally minimal:
315
+
316
+ ```
317
+ User browser
318
+ ↓ HTTPS
319
+ HF Spaces (free CPU instance, 2 vCPU / 16 GB RAM)
320
+
321
+ Gradio + FastAPI/uvicorn
322
+
323
+ PyTorch eager-mode forward pass on CPU
324
+
325
+ Autoregressive token generation, one token per pass
326
+ ```
327
+
328
+ Approximate latency for 100 generated tokens: **~3 seconds on Spaces' free
329
+ CPU**, **~0.5 seconds on Apple M-series with MPS**.
330
+
331
+ What this serving setup deliberately does *not* implement (each is a separate
332
+ upgrade and a useful learning exercise):
333
+
334
+ - **KV-caching** — every generation step re-processes all prior tokens.
335
+ A real implementation caches K/V tensors and pays only for the new token.
336
+ - **Continuous batching** — multiple users would queue serially. Production
337
+ servers (vLLM, TGI) batch concurrent requests dynamically.
338
+ - **Quantization** — weights are float32. int8/int4 would shrink memory ~4×.
339
+ - **Compiled graphs** — eager-mode PyTorch leaves performance on the table
340
+ vs `torch.compile()`, ONNX Runtime, or a dedicated engine.
341
+
342
+ For a model this small the overheads don't matter. At any production scale,
343
+ *every one of the above becomes critical to unit economics*.
344
+
345
+ ---
346
+
347
+ ## Comparison with frontier models
348
+
349
+ The architecture is structurally identical to GPT-2/3, Llama, Mistral, and
350
+ Claude. The differences below are evolutionary refinements, not categorical
351
+ changes — the core "decoder-only transformer trained with next-token
352
+ prediction" recipe is the same.
353
+
354
+ | | microGPT (this) | Llama 3 70B |
355
+ |---|---|---|
356
+ | Parameters | 1.35M | 70B (~52,000× larger) |
357
+ | Layers | 4 | 80 |
358
+ | `d_model` | 128 | 8,192 |
359
+ | Heads | 4 (multi-head) | 64 (grouped-query attention) |
360
+ | Context | 256 | 128,000 |
361
+ | Vocab | 4,096 | 128,256 |
362
+ | Position | Learned absolute | Rotary (RoPE) |
363
+ | Activation | GELU | SwiGLU |
364
+ | Normalization | LayerNorm | RMSNorm |
365
+ | Training tokens | ~327M | ~15T (~46,000× more) |
366
+ | Training compute | ~5 kWh laptop | many MW-months on H100 clusters |
367
+
368
+ ---
369
+
370
+ ## Glossary
371
+
372
+ A short reference for the terminology used above. Worth absorbing — these
373
+ terms come up constantly in AI literature and interviews.
374
+
375
+ **Parameter / weight.** A single learnable number stored in the model.
376
+ Updated during training, read during inference. A "1.35M parameter model"
377
+ literally has 1.35M of these numbers.
378
+
379
+ **Embedding.** A learned vector representation of a discrete object (token,
380
+ position). Implemented as a lookup table.
381
+
382
+ **Token.** The atomic unit of text the model operates on. Produced by the
383
+ tokenizer; typically ~4 characters of English per token for byte-level BPE.
384
+
385
+ **Tokenizer.** The deterministic, reversible function that converts strings
386
+ to integer ID sequences and back. Decisions made here (vocab size, BPE
387
+ merges) propagate through the entire model.
388
+
389
+ **BPE (Byte-Pair Encoding).** A subword tokenization algorithm that
390
+ iteratively merges the most frequent adjacent pairs of symbols into new
391
+ vocabulary entries.
392
+
393
+ **Logits.** The raw, unnormalized scores the model outputs — one per
394
+ vocabulary token at each position. Becomes a probability distribution after
395
+ softmax.
396
+
397
+ **Softmax.** Function that converts logits to probabilities by exponentiating
398
+ and normalizing.
399
+
400
+ **Cross-entropy loss.** The training objective: how surprised the model is
401
+ by the correct next token. Equals 0 if the model assigned probability 1 to
402
+ the right answer; equals `ln(vocab_size)` if the model is uniformly
403
+ uninformed.
404
+
405
+ **Perplexity.** `exp(loss)`. The "effective number of choices" the model is
406
+ deciding between. Useful because it has a more intuitive scale than loss.
407
+
408
+ **Decoder-only / autoregressive.** The model only attends to past tokens
409
+ (causal mask), and generates one token at a time conditioned on what it has
410
+ already produced.
411
+
412
+ **Self-attention.** The mechanism by which each position computes a
413
+ weighted combination of all (allowed) other positions, where the weights
414
+ depend on the content at each position.
415
+
416
+ **Multi-head attention.** Self-attention computed in parallel across `n`
417
+ subspaces ("heads"), each with `d_model / n` dimensions. Different heads
418
+ empirically learn to specialize.
419
+
420
+ **KV cache.** At inference time, the Key and Value tensors from previous
421
+ tokens can be cached and reused, avoiding redundant computation. Critical
422
+ for production serving; not implemented in this model.
423
+
424
+ **Pre-LayerNorm.** Applying LayerNorm *before* the attention/MLP sublayers,
425
+ not after. Stabilizes training of deep transformers.
426
+
427
+ **Weight tying.** Sharing parameters between the input embedding matrix and
428
+ the output projection matrix. Saves memory; usually improves quality.
429
+
430
+ **Cosine learning-rate schedule.** Learning rate ramps up linearly during
431
+ warmup, then decays following a cosine curve. Standard for transformer
432
+ training.
433
+
434
+ **Gradient clipping.** Capping the global L2 norm of gradients during
435
+ backpropagation to prevent destabilizing weight updates.
436
+
437
+ **MPS (Metal Performance Shaders).** Apple's GPU acceleration backend for
438
+ PyTorch on M-series chips. The Apple Silicon equivalent of CUDA.
439
+
440
+ **Pre-training.** The stage of training described here: minimize next-token
441
+ prediction loss on a large corpus. Produces a *base model*.
442
+
443
+ **SFT (Supervised Fine-Tuning).** A subsequent training stage on
444
+ `(instruction, ideal response)` pairs. Teaches the model to follow
445
+ instructions. Not done for this model.
446
+
447
+ **RLHF (Reinforcement Learning from Human Feedback).** A further training
448
+ stage using preference data. Aligns model behavior with human preferences.
449
+ Not done for this model.
450
+
451
+ ---
452
+
453
+ ## Citation
454
+
455
+ If this model or its companion code helped you, please cite or link to:
456
+
457
+ ```
458
+ @misc{microgpt,
459
+ author = {Brett Lee Hary},
460
+ title = {microGPT: a 1.35M-parameter transformer trained from scratch on TinyStories},
461
+ year = {2026},
462
+ howpublished = {\url{https://huggingface.co/brettleehari/microgpt}},
463
+ }
464
+ ```
465
+
466
+ ### Acknowledgements
467
+
468
+ - Andrej Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT) — the
469
+ reference implementation that made this approachable.
470
+ - Eldan & Li (2023), [TinyStories: How Small Can Language Models Be and Still Speak Coherent English?](https://arxiv.org/abs/2305.07759) — the dataset and the insight that data quality can substitute for model scale.
471
+ - Vaswani et al. (2017), [Attention Is All You Need](https://arxiv.org/abs/1706.03762) — the original transformer.
472
+ - The Hugging Face `transformers`, `tokenizers`, and `datasets` teams for
473
+ the infrastructure that makes projects like this trivial to share.
ckpt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a503409e144a80c461d97b9462ee76236e663d54499afd6bb39ce1230c68f31
3
+ size 5394041
inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference helper for Nano-SLM.
3
+
4
+ Wraps the model + tokenizer into a clean `generate()` function suitable for
5
+ demos, notebooks, or a Gradio interface.
6
+
7
+ Usage:
8
+ from inference import NanoSLMInference
9
+ slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json")
10
+ text = slm.generate("Once upon a time", max_new_tokens=200, temperature=0.8)
11
+ print(text)
12
+ """
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from tokenizers import Tokenizer
16
+ from model import NanoSLM
17
+
18
+
19
+ # Must match the architecture used during training.
20
+ DEFAULT_CFG = dict(
21
+ vocab_size=4096, d_model=128, n_heads=4, n_layers=4,
22
+ ffn_dim=512, ctx_len=256, dropout=0.0,
23
+ )
24
+
25
+
26
+ class NanoSLMInference:
27
+ def __init__(self, ckpt_path, tokenizer_path, device=None, cfg=None):
28
+ if device is None:
29
+ if torch.backends.mps.is_available():
30
+ device = "mps"
31
+ elif torch.cuda.is_available():
32
+ device = "cuda"
33
+ else:
34
+ device = "cpu"
35
+ self.device = device
36
+
37
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
38
+
39
+ cfg = cfg or DEFAULT_CFG
40
+ self.model = NanoSLM(**cfg)
41
+ ckpt = torch.load(ckpt_path, map_location=device)
42
+ # support both raw state_dicts and {"model": ...} checkpoints
43
+ state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
44
+ self.model.load_state_dict(state)
45
+ self.model.to(device).eval()
46
+ self.ctx_len = cfg["ctx_len"]
47
+
48
+ @torch.no_grad()
49
+ def generate(
50
+ self,
51
+ prompt: str,
52
+ max_new_tokens: int = 200,
53
+ temperature: float = 0.8,
54
+ top_k: int | None = 40,
55
+ top_p: float | None = None,
56
+ seed: int | None = None,
57
+ ) -> str:
58
+ """Generate continuation for a prompt.
59
+
60
+ Args:
61
+ prompt: input text
62
+ max_new_tokens: how many tokens to generate
63
+ temperature: 0 = greedy, 1.0 = no scaling, >1 = more random
64
+ top_k: keep only the k highest-prob tokens (None = no filter)
65
+ top_p: nucleus — keep smallest set with cumulative prob >= p
66
+ seed: for reproducibility
67
+ """
68
+ if seed is not None:
69
+ torch.manual_seed(seed)
70
+
71
+ ids = self.tokenizer.encode(prompt).ids
72
+ x = torch.tensor([ids], dtype=torch.long, device=self.device)
73
+
74
+ for _ in range(max_new_tokens):
75
+ # truncate context if it grows past ctx_len
76
+ x_cond = x[:, -self.ctx_len:]
77
+ logits, _ = self.model(x_cond)
78
+ # we only care about the prediction for the next token
79
+ logits = logits[:, -1, :]
80
+
81
+ if temperature == 0.0:
82
+ # greedy: pick the argmax
83
+ next_tok = logits.argmax(dim=-1, keepdim=True)
84
+ else:
85
+ logits = logits / temperature
86
+
87
+ if top_k is not None:
88
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
89
+ logits[logits < v[:, [-1]]] = -float("inf")
90
+
91
+ if top_p is not None:
92
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
93
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
94
+ # mask tokens past the nucleus
95
+ mask = cum_probs > top_p
96
+ # shift right so we always keep at least one token
97
+ mask[..., 1:] = mask[..., :-1].clone()
98
+ mask[..., 0] = False
99
+ sorted_logits[mask] = -float("inf")
100
+ # unsort back to original vocab order
101
+ logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
102
+
103
+ probs = F.softmax(logits, dim=-1)
104
+ next_tok = torch.multinomial(probs, num_samples=1)
105
+
106
+ x = torch.cat([x, next_tok], dim=1)
107
+
108
+ return self.tokenizer.decode(x[0].tolist())
109
+
110
+
111
+ if __name__ == "__main__":
112
+ # quick self-test
113
+ slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json")
114
+ print(slm.generate("Once upon a time", max_new_tokens=100, temperature=0.8, top_k=40))
model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nano-SLM: a tiny decoder-only transformer (~1M params).
3
+
4
+ Architecture is intentionally minimal so every line is readable.
5
+ Mirrors the standard GPT recipe: token + position embeddings, N stacked
6
+ (causal self-attention -> MLP) blocks with pre-LayerNorm and residuals,
7
+ final LayerNorm, and a tied LM head.
8
+ """
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class CausalSelfAttention(nn.Module):
16
+ """Multi-head causal self-attention. Uses fused QKV and PyTorch's SDPA."""
17
+
18
+ def __init__(self, d_model, n_heads, dropout=0.1):
19
+ super().__init__()
20
+ assert d_model % n_heads == 0
21
+ self.n_heads = n_heads
22
+ self.head_dim = d_model // n_heads
23
+ # one big linear that produces Q, K, V at once
24
+ self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
25
+ self.proj = nn.Linear(d_model, d_model, bias=False)
26
+ self.attn_dropout_p = dropout
27
+ self.resid_dropout = nn.Dropout(dropout)
28
+
29
+ def forward(self, x):
30
+ B, T, C = x.shape
31
+ q, k, v = self.qkv(x).split(C, dim=-1)
32
+ # reshape to (B, n_heads, T, head_dim)
33
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
34
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
35
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
36
+ # Flash/SDPA: causal mask + scaling handled internally
37
+ y = F.scaled_dot_product_attention(
38
+ q, k, v,
39
+ is_causal=True,
40
+ dropout_p=self.attn_dropout_p if self.training else 0.0,
41
+ )
42
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
43
+ return self.resid_dropout(self.proj(y))
44
+
45
+
46
+ class MLP(nn.Module):
47
+ """Position-wise feed-forward (GELU)."""
48
+
49
+ def __init__(self, d_model, ffn_dim, dropout=0.1):
50
+ super().__init__()
51
+ self.fc1 = nn.Linear(d_model, ffn_dim, bias=False)
52
+ self.fc2 = nn.Linear(ffn_dim, d_model, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x):
56
+ return self.dropout(self.fc2(F.gelu(self.fc1(x))))
57
+
58
+
59
+ class Block(nn.Module):
60
+ """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x))."""
61
+
62
+ def __init__(self, d_model, n_heads, ffn_dim, dropout=0.1):
63
+ super().__init__()
64
+ self.ln1 = nn.LayerNorm(d_model)
65
+ self.attn = CausalSelfAttention(d_model, n_heads, dropout)
66
+ self.ln2 = nn.LayerNorm(d_model)
67
+ self.mlp = MLP(d_model, ffn_dim, dropout)
68
+
69
+ def forward(self, x):
70
+ x = x + self.attn(self.ln1(x))
71
+ x = x + self.mlp(self.ln2(x))
72
+ return x
73
+
74
+
75
+ class NanoSLM(nn.Module):
76
+ def __init__(
77
+ self,
78
+ vocab_size=4096,
79
+ d_model=128,
80
+ n_heads=4,
81
+ n_layers=4,
82
+ ffn_dim=512,
83
+ ctx_len=256,
84
+ dropout=0.1,
85
+ ):
86
+ super().__init__()
87
+ self.ctx_len = ctx_len
88
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
89
+ self.pos_emb = nn.Embedding(ctx_len, d_model)
90
+ self.drop = nn.Dropout(dropout)
91
+ self.blocks = nn.ModuleList(
92
+ [Block(d_model, n_heads, ffn_dim, dropout) for _ in range(n_layers)]
93
+ )
94
+ self.ln_f = nn.LayerNorm(d_model)
95
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
96
+ # weight tying: input embedding and output projection share weights.
97
+ # saves a lot of params at small vocab sizes and usually helps quality.
98
+ self.head.weight = self.tok_emb.weight
99
+
100
+ self.apply(self._init_weights)
101
+ # scaled init for residual projections (GPT-2 trick)
102
+ for name, p in self.named_parameters():
103
+ if name.endswith("proj.weight") or name.endswith("fc2.weight"):
104
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layers))
105
+
106
+ def _init_weights(self, m):
107
+ if isinstance(m, nn.Linear):
108
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
109
+ if m.bias is not None:
110
+ nn.init.zeros_(m.bias)
111
+ elif isinstance(m, nn.Embedding):
112
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
113
+
114
+ def num_params(self, non_embedding=False):
115
+ n = sum(p.numel() for p in self.parameters())
116
+ if non_embedding:
117
+ n -= self.tok_emb.weight.numel()
118
+ n -= self.pos_emb.weight.numel()
119
+ return n
120
+
121
+ def forward(self, idx, targets=None):
122
+ B, T = idx.shape
123
+ assert T <= self.ctx_len, f"sequence length {T} > ctx_len {self.ctx_len}"
124
+ pos = torch.arange(T, device=idx.device)
125
+ x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
126
+ for block in self.blocks:
127
+ x = block(x)
128
+ x = self.ln_f(x)
129
+ logits = self.head(x)
130
+ loss = None
131
+ if targets is not None:
132
+ loss = F.cross_entropy(
133
+ logits.view(-1, logits.size(-1)),
134
+ targets.view(-1),
135
+ ignore_index=-100,
136
+ )
137
+ return logits, loss
138
+
139
+ @torch.no_grad()
140
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
141
+ """Autoregressive sampling. Slow on purpose: no KV cache (a great upgrade later)."""
142
+ for _ in range(max_new_tokens):
143
+ idx_cond = idx[:, -self.ctx_len:]
144
+ logits, _ = self(idx_cond)
145
+ logits = logits[:, -1, :] / temperature
146
+ if top_k is not None:
147
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
148
+ logits[logits < v[:, [-1]]] = -float("inf")
149
+ probs = F.softmax(logits, dim=-1)
150
+ next_tok = torch.multinomial(probs, num_samples=1)
151
+ idx = torch.cat([idx, next_tok], dim=1)
152
+ return idx
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff