drixo commited on
Commit
f946798
·
verified ·
1 Parent(s): 9609c66

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +206 -4
  3. checkpoints/README.md +4 -0
  4. checkpoints/meta_005000.json +58 -0
  5. checkpoints/model_005000.pt +3 -0
  6. checkpoints/optim_005000_rank0.pt +3 -0
  7. nanochat/__init__.py +0 -0
  8. nanochat/__pycache__/__init__.cpython-310.pyc +0 -0
  9. nanochat/__pycache__/checkpoint_manager.cpython-310.pyc +0 -0
  10. nanochat/__pycache__/common.cpython-310.pyc +0 -0
  11. nanochat/__pycache__/core_eval.cpython-310.pyc +0 -0
  12. nanochat/__pycache__/dataloader.cpython-310.pyc +0 -0
  13. nanochat/__pycache__/dataset.cpython-310.pyc +0 -0
  14. nanochat/__pycache__/engine.cpython-310.pyc +0 -0
  15. nanochat/__pycache__/execution.cpython-310.pyc +0 -0
  16. nanochat/__pycache__/flash_attention.cpython-310.pyc +0 -0
  17. nanochat/__pycache__/gpt.cpython-310.pyc +0 -0
  18. nanochat/__pycache__/loss_eval.cpython-310.pyc +0 -0
  19. nanochat/__pycache__/optim.cpython-310.pyc +0 -0
  20. nanochat/__pycache__/report.cpython-310.pyc +0 -0
  21. nanochat/__pycache__/tokenizer.cpython-310.pyc +0 -0
  22. nanochat/checkpoint_manager.py +194 -0
  23. nanochat/common.py +278 -0
  24. nanochat/core_eval.py +262 -0
  25. nanochat/dataloader.py +166 -0
  26. nanochat/dataset.py +160 -0
  27. nanochat/engine.py +357 -0
  28. nanochat/execution.py +349 -0
  29. nanochat/flash_attention.py +187 -0
  30. nanochat/fp8.py +266 -0
  31. nanochat/gpt.py +507 -0
  32. nanochat/logo.svg +8 -0
  33. nanochat/loss_eval.py +65 -0
  34. nanochat/optim.py +533 -0
  35. nanochat/report.py +418 -0
  36. nanochat/tokenizer.py +406 -0
  37. nanochat/ui.html +566 -0
  38. pyproject.toml +74 -0
  39. scripts/__pycache__/base_eval.cpython-310.pyc +0 -0
  40. scripts/__pycache__/base_train.cpython-310.pyc +0 -0
  41. scripts/__pycache__/chat_eval.cpython-310.pyc +0 -0
  42. scripts/__pycache__/chat_sft.cpython-310.pyc +0 -0
  43. scripts/__pycache__/tok_eval.cpython-310.pyc +0 -0
  44. scripts/__pycache__/tok_train.cpython-310.pyc +0 -0
  45. scripts/base_eval.py +323 -0
  46. scripts/base_train.py +629 -0
  47. scripts/chat_cli.py +100 -0
  48. scripts/chat_eval.py +251 -0
  49. scripts/chat_rl.py +332 -0
  50. scripts/chat_sft.py +519 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Andrej Karpathy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,4 +1,206 @@
1
- # NanoChat Checkpoint Step 5000
2
- - Partial training checkpoint (training crashed after this step)
3
- - Trained on ARC-Easy + ARC-Challenge dataset
4
- - Loss became NaN afterward, do **not** use for further training without fixing learning rate / FP8 settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nanochat
2
+
3
+ ![nanochat logo](dev/nanochat.png)
4
+ ![scaling laws](dev/scaling_laws_jan26.png)
5
+
6
+ nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $48 (~2 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$15. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way.
7
+
8
+ For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord.
9
+
10
+ ## Time-to-GPT-2 Leaderboard
11
+
12
+ Presently, the main focus of development is on tuning the pretraining stage, which takes the most amount of compute. Inspired by the modded-nanogpt repo and to incentivise progress and community collaboration, nanochat maintains a leaderboard for a "GPT-2 speedrun", which is the wall-clock time required to train a nanochat model to GPT-2 grade capability, as measured by the DCLM CORE score. The [runs/speedrun.sh](runs/speedrun.sh) script always reflects the reference way to train GPT-2 grade model and talk to it. The current leaderboard looks as follows:
13
+
14
+ | # | time | val_bpb | CORE | Description | Date | Commit | Contributors |
15
+ |---|-------------|---------|------|-------------|------|--------|--------------|
16
+ | 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI |
17
+ | 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy |
18
+ | 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy |
19
+ | 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy |
20
+ | 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy |
21
+ | 5 | 1.80 | 0.71808 | 0.2690 | autoresearch [round 1](https://x.com/karpathy/status/2031135152349524125) | Mar 9 2026 | 6ed7d1d | @karpathy |
22
+ | 5 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy |
23
+
24
+ The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48).
25
+
26
+ See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard.
27
+
28
+ ## Getting started
29
+
30
+ ### Reproduce and talk to GPT-2
31
+
32
+ The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
33
+
34
+ ```bash
35
+ bash runs/speedrun.sh
36
+ ```
37
+
38
+ You may wish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
39
+
40
+ ```bash
41
+ python -m scripts.chat_web
42
+ ```
43
+
44
+ And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
45
+
46
+ ---
47
+
48
+ <img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
49
+
50
+ ---
51
+
52
+ A few more notes:
53
+
54
+ - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
55
+ - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
56
+ - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
57
+ - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges.
58
+
59
+ ## Research
60
+
61
+ If you are a researcher and wish to help improve nanochat, two scripts of interest are [runs/scaling_laws.sh](runs/scaling_laws.sh) and [runs/miniseries.sh](runs/miniseries.sh). See [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) for related documentation. For quick experimentation (~5 min pretraining runs) my favorite scale is to train a 12-layer model (GPT-1 sized), e.g. like this:
62
+
63
+ ```
64
+ OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
65
+ --depth=12 \
66
+ --run="d12" \
67
+ --model-tag="d12" \
68
+ --core-metric-every=999999 \
69
+ --sample-every=-1 \
70
+ --save-every=-1 \
71
+ ```
72
+
73
+ This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for:
74
+
75
+ 1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`.
76
+ 2. `core_metric` (the DCLM CORE socre)
77
+ 3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput)
78
+
79
+ See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044).
80
+
81
+ The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth.
82
+
83
+ ## Running on CPU / MPS
84
+
85
+ The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
86
+
87
+ ## Precision / dtype
88
+
89
+ nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware:
90
+
91
+ | Hardware | Default dtype | Why |
92
+ |----------|--------------|-----|
93
+ | CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores |
94
+ | CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) |
95
+ | CPU / MPS | `float32` | No reduced-precision tensor cores |
96
+
97
+ You can override the default with the `NANOCHAT_DTYPE` environment variable:
98
+
99
+ ```bash
100
+ NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32
101
+ NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16
102
+ ```
103
+
104
+ How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision.
105
+
106
+ Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere.
107
+
108
+ ## Guides
109
+
110
+ I've published a number of guides that might contain helpful information, most recent to least recent:
111
+
112
+ - [Feb 1 2026: Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481)
113
+ - [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models.
114
+ - To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
115
+ - To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage.
116
+ - [Oct 13 2025: original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master.
117
+
118
+ ## File structure
119
+
120
+ ```
121
+ .
122
+ ├── LICENSE
123
+ ├── README.md
124
+ ├── dev
125
+ │ ├── gen_synthetic_data.py # Example synthetic data for identity
126
+ │ ├── generate_logo.html
127
+ │ ├── nanochat.png
128
+ │ └── repackage_data_reference.py # Pretraining data shard generation
129
+ ├── nanochat
130
+ │ ├── __init__.py # empty
131
+ │ ├── checkpoint_manager.py # Save/Load model checkpoints
132
+ │ ├── common.py # Misc small utilities, quality of life
133
+ │ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
134
+ │ ├── dataloader.py # Tokenizing Distributed Data Loader
135
+ │ ├── dataset.py # Download/read utils for pretraining data
136
+ │ ├── engine.py # Efficient model inference with KV Cache
137
+ │ ├── execution.py # Allows the LLM to execute Python code as tool
138
+ │ ├── gpt.py # The GPT nn.Module Transformer
139
+ │ ├── logo.svg
140
+ │ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
141
+ │ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
142
+ │ ├── report.py # Utilities for writing the nanochat Report
143
+ │ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
144
+ │ └── ui.html # HTML/CSS/JS for nanochat frontend
145
+ ├── pyproject.toml
146
+ ├── runs
147
+ │ ├── miniseries.sh # Miniseries training script
148
+ │ ├── runcpu.sh # Small example of how to run on CPU/MPS
149
+ │ ├── scaling_laws.sh # Scaling laws experiments
150
+ │ └── speedrun.sh # Train the ~$100 nanochat d20
151
+ ├── scripts
152
+ │ ├── base_eval.py # Base model: CORE score, bits per byte, samples
153
+ │ ├── base_train.py # Base model: train
154
+ │ ├── chat_cli.py # Chat model: talk to over CLI
155
+ │ ├── chat_eval.py # Chat model: eval tasks
156
+ │ ├── chat_rl.py # Chat model: reinforcement learning
157
+ │ ├── chat_sft.py # Chat model: train SFT
158
+ │ ├── chat_web.py # Chat model: talk to over WebUI
159
+ │ ├── tok_eval.py # Tokenizer: evaluate compression rate
160
+ │ └── tok_train.py # Tokenizer: train it
161
+ ├── tasks
162
+ │ ├── arc.py # Multiple choice science questions
163
+ │ ├── common.py # TaskMixture | TaskSequence
164
+ │ ├── customjson.py # Make Task from arbitrary jsonl convos
165
+ │ ├── gsm8k.py # 8K Grade School Math questions
166
+ │ ├── humaneval.py # Misnomer; Simple Python coding task
167
+ │ ├── mmlu.py # Multiple choice questions, broad topics
168
+ │ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
169
+ │ └── spellingbee.py # Task teaching model to spell/count letters
170
+ ├── tests
171
+ │ └── test_engine.py
172
+ └── uv.lock
173
+ ```
174
+
175
+ ## Contributing
176
+
177
+ The goal of nanochat is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there are no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a ChatGPT model you can talk to. Currently, the most interesting part personally is speeding up the latency to GPT-2 (i.e. getting a CORE score above 0.256525). Currently this takes ~3 hours, but by improving the pretraining stage we can improve this further.
178
+
179
+ Current AI policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
180
+
181
+ ## Acknowledgements
182
+
183
+ - The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining.
184
+ - nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining.
185
+ - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
186
+ - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
187
+ - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance.
188
+ - Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
189
+
190
+ ## Cite
191
+
192
+ If you find nanochat helpful in your research cite simply as:
193
+
194
+ ```bibtex
195
+ @misc{nanochat,
196
+ author = {Andrej Karpathy},
197
+ title = {nanochat: The best ChatGPT that \$100 can buy},
198
+ year = {2025},
199
+ publisher = {GitHub},
200
+ url = {https://github.com/karpathy/nanochat}
201
+ }
202
+ ```
203
+
204
+ ## License
205
+
206
+ MIT
checkpoints/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # NanoChat Checkpoint Step 5000
2
+ - Partial training checkpoint (training crashed after this step)
3
+ - Trained on ARC-Easy + ARC-Challenge dataset
4
+ - Loss became NaN afterward, do **not** use for further training without fixing learning rate / FP8 settings
checkpoints/meta_005000.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "step": 5000,
3
+ "val_bpb": 3.0855938505925744,
4
+ "model_config": {
5
+ "sequence_len": 512,
6
+ "vocab_size": 32768,
7
+ "n_layer": 6,
8
+ "n_head": 6,
9
+ "n_kv_head": 6,
10
+ "n_embd": 384,
11
+ "window_pattern": "L"
12
+ },
13
+ "user_config": {
14
+ "run": "dummy",
15
+ "device_type": "",
16
+ "fp8": false,
17
+ "fp8_recipe": "tensorwise",
18
+ "depth": 6,
19
+ "aspect_ratio": 64,
20
+ "head_dim": 64,
21
+ "max_seq_len": 512,
22
+ "window_pattern": "L",
23
+ "num_iterations": 5000,
24
+ "target_flops": -1.0,
25
+ "target_param_data_ratio": 10.5,
26
+ "device_batch_size": 32,
27
+ "total_batch_size": 16384,
28
+ "embedding_lr": 0.3,
29
+ "unembedding_lr": 0.008,
30
+ "weight_decay": 0.28,
31
+ "matrix_lr": 0.02,
32
+ "scalar_lr": 0.5,
33
+ "warmup_steps": 40,
34
+ "warmdown_ratio": 0.65,
35
+ "final_lr_frac": 0.05,
36
+ "resume_from_step": -1,
37
+ "eval_every": 100,
38
+ "eval_tokens": 524288,
39
+ "core_metric_every": -1,
40
+ "core_metric_max_per_task": 500,
41
+ "sample_every": 100,
42
+ "save_every": -1,
43
+ "model_tag": null
44
+ },
45
+ "device_batch_size": 32,
46
+ "max_seq_len": 512,
47
+ "total_batch_size": 16384,
48
+ "dataloader_state_dict": {
49
+ "pq_idx": 2,
50
+ "rg_idx": 48,
51
+ "epoch": 1
52
+ },
53
+ "loop_state": {
54
+ "min_val_bpb": 2.689004677760501,
55
+ "smooth_train_loss": 10.054429589944,
56
+ "total_training_time": 16636.2527115345
57
+ }
58
+ }
checkpoints/model_005000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ff116a0583967f5e0b0df01bc329993f75b5ffab8664928432de7ba556fd0e6
3
+ size 294145379
checkpoints/optim_005000_rank0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:782a6b507c9468849298071883982d32a83736aabc1cc892cf414e0c21f0dcc9
3
+ size 545905413
nanochat/__init__.py ADDED
File without changes
nanochat/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
nanochat/__pycache__/checkpoint_manager.cpython-310.pyc ADDED
Binary file (6.96 kB). View file
 
nanochat/__pycache__/common.cpython-310.pyc ADDED
Binary file (9.61 kB). View file
 
nanochat/__pycache__/core_eval.cpython-310.pyc ADDED
Binary file (8.46 kB). View file
 
nanochat/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
nanochat/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
nanochat/__pycache__/engine.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
nanochat/__pycache__/execution.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
nanochat/__pycache__/flash_attention.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
nanochat/__pycache__/gpt.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
nanochat/__pycache__/loss_eval.cpython-310.pyc ADDED
Binary file (2.32 kB). View file
 
nanochat/__pycache__/optim.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
nanochat/__pycache__/report.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
nanochat/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def _patch_missing_config_keys(model_config_kwargs):
24
+ """Add default values for new config keys missing in old checkpoints."""
25
+ # Old models were trained with full context (no sliding window)
26
+ if "window_pattern" not in model_config_kwargs:
27
+ model_config_kwargs["window_pattern"] = "L"
28
+ log0(f"Patching missing window_pattern in model config to 'L'")
29
+
30
+ def _patch_missing_keys(model_data, model_config):
31
+ """Add default values for new parameters that may be missing in old checkpoints."""
32
+ n_layer = model_config.n_layer
33
+ # resid_lambdas defaults to 1.0 (identity scaling)
34
+ if "resid_lambdas" not in model_data:
35
+ model_data["resid_lambdas"] = torch.ones(n_layer)
36
+ log0(f"Patching missing resid_lambdas in model data to 1.0")
37
+ # x0_lambdas defaults to 0.0 (disabled)
38
+ if "x0_lambdas" not in model_data:
39
+ model_data["x0_lambdas"] = torch.zeros(n_layer)
40
+ log0(f"Patching missing x0_lambdas in model data to 0.0")
41
+
42
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
43
+ if rank == 0:
44
+ os.makedirs(checkpoint_dir, exist_ok=True)
45
+ # Save the model state parameters
46
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
47
+ torch.save(model_data, model_path)
48
+ logger.info(f"Saved model parameters to: {model_path}")
49
+ # Save the metadata dict as json
50
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
51
+ with open(meta_path, "w", encoding="utf-8") as f:
52
+ json.dump(meta_data, f, indent=2)
53
+ logger.info(f"Saved metadata to: {meta_path}")
54
+ # Note that optimizer state is sharded across ranks, so each rank must save its own.
55
+ if optimizer_data is not None:
56
+ os.makedirs(checkpoint_dir, exist_ok=True)
57
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
58
+ torch.save(optimizer_data, optimizer_path)
59
+ logger.info(f"Saved optimizer state to: {optimizer_path}")
60
+
61
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
62
+ # Load the model state
63
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
64
+ model_data = torch.load(model_path, map_location=device)
65
+ # Load the optimizer state if requested
66
+ optimizer_data = None
67
+ if load_optimizer:
68
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
69
+ optimizer_data = torch.load(optimizer_path, map_location=device)
70
+ # Load the metadata
71
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
72
+ with open(meta_path, "r", encoding="utf-8") as f:
73
+ meta_data = json.load(f)
74
+ return model_data, optimizer_data, meta_data
75
+
76
+
77
+ def build_model(checkpoint_dir, step, device, phase):
78
+ """
79
+ A bunch of repetitive code to build a model from a given checkpoint.
80
+ Returns:
81
+ - base model - uncompiled, not wrapped in DDP
82
+ - tokenizer
83
+ - meta data saved during base model training
84
+ """
85
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
86
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
87
+ if device.type in {"cpu", "mps"}:
88
+ # Convert bfloat16 tensors to float for CPU inference
89
+ model_data = {
90
+ k: v.float() if v.dtype == torch.bfloat16 else v
91
+ for k, v in model_data.items()
92
+ }
93
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
94
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
95
+ model_config_kwargs = meta_data["model_config"]
96
+ _patch_missing_config_keys(model_config_kwargs)
97
+ log0(f"Building model with config: {model_config_kwargs}")
98
+ model_config = GPTConfig(**model_config_kwargs)
99
+ _patch_missing_keys(model_data, model_config)
100
+ with torch.device("meta"):
101
+ model = GPT(model_config)
102
+ # Load the model state
103
+ model.to_empty(device=device)
104
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
105
+ model.load_state_dict(model_data, strict=True, assign=True)
106
+ # Put the model in the right training phase / mode
107
+ if phase == "eval":
108
+ model.eval()
109
+ else:
110
+ model.train()
111
+ # Load the Tokenizer
112
+ tokenizer = get_tokenizer()
113
+ # Sanity check: compatibility between model and tokenizer
114
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
115
+ return model, tokenizer, meta_data
116
+
117
+
118
+ def find_largest_model(checkpoints_dir):
119
+ # attempt to guess the model tag: take the biggest model available
120
+ model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
121
+ if not model_tags:
122
+ raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
123
+ # 1) normally all model tags are of the form d<number>, try that first:
124
+ candidates = []
125
+ for model_tag in model_tags:
126
+ match = re.match(r"d(\d+)", model_tag)
127
+ if match:
128
+ model_depth = int(match.group(1))
129
+ candidates.append((model_depth, model_tag))
130
+ if candidates:
131
+ candidates.sort(key=lambda x: x[0], reverse=True)
132
+ return candidates[0][1]
133
+ # 2) if that failed, take the most recently updated model:
134
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
135
+ return model_tags[0]
136
+
137
+
138
+ def find_last_step(checkpoint_dir):
139
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
140
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
141
+ if not checkpoint_files:
142
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
143
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
144
+ return last_step
145
+
146
+ # -----------------------------------------------------------------------------
147
+ # convenience functions that take into account nanochat's directory structure
148
+
149
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
150
+ if model_tag is None:
151
+ # guess the model tag by defaulting to the largest model
152
+ model_tag = find_largest_model(checkpoints_dir)
153
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
154
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
155
+ if step is None:
156
+ # guess the step by defaulting to the last step
157
+ step = find_last_step(checkpoint_dir)
158
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
159
+ # build the model
160
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
161
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
162
+ return model, tokenizer, meta_data
163
+
164
+ def load_model(source, *args, **kwargs):
165
+ model_dir = {
166
+ "base": "base_checkpoints",
167
+ "sft": "chatsft_checkpoints",
168
+ "rl": "chatrl_checkpoints",
169
+ }[source]
170
+ base_dir = get_base_dir()
171
+ checkpoints_dir = os.path.join(base_dir, model_dir)
172
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
173
+
174
+ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
175
+ """Load just the optimizer shard for a given rank, without re-loading the model."""
176
+ model_dir = {
177
+ "base": "base_checkpoints",
178
+ "sft": "chatsft_checkpoints",
179
+ "rl": "chatrl_checkpoints",
180
+ }[source]
181
+ base_dir = get_base_dir()
182
+ checkpoints_dir = os.path.join(base_dir, model_dir)
183
+ if model_tag is None:
184
+ model_tag = find_largest_model(checkpoints_dir)
185
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
186
+ if step is None:
187
+ step = find_last_step(checkpoint_dir)
188
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
189
+ if not os.path.exists(optimizer_path):
190
+ log0(f"Optimizer checkpoint not found: {optimizer_path}")
191
+ return None
192
+ log0(f"Loading optimizer state from {optimizer_path}")
193
+ optimizer_data = torch.load(optimizer_path, map_location=device)
194
+ return optimizer_data
nanochat/common.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import urllib.request
9
+ import torch
10
+ import torch.distributed as dist
11
+ from filelock import FileLock
12
+
13
+ # The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
14
+ # Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
15
+ # Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
16
+ _DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
17
+ def _detect_compute_dtype():
18
+ env = os.environ.get("NANOCHAT_DTYPE")
19
+ if env is not None:
20
+ return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
21
+ if torch.cuda.is_available():
22
+ # bf16 requires SM 80+ (Ampere: A100, A10, etc.)
23
+ # Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
24
+ capability = torch.cuda.get_device_capability()
25
+ if capability >= (8, 0):
26
+ return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
27
+ # fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
28
+ # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
29
+ return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
30
+ return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
31
+ COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
32
+
33
+ class ColoredFormatter(logging.Formatter):
34
+ """Custom formatter that adds colors to log messages."""
35
+ # ANSI color codes
36
+ COLORS = {
37
+ 'DEBUG': '\033[36m', # Cyan
38
+ 'INFO': '\033[32m', # Green
39
+ 'WARNING': '\033[33m', # Yellow
40
+ 'ERROR': '\033[31m', # Red
41
+ 'CRITICAL': '\033[35m', # Magenta
42
+ }
43
+ RESET = '\033[0m'
44
+ BOLD = '\033[1m'
45
+ def format(self, record):
46
+ # Add color to the level name
47
+ levelname = record.levelname
48
+ if levelname in self.COLORS:
49
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
50
+ # Format the message
51
+ message = super().format(record)
52
+ # Add color to specific parts of the message
53
+ if levelname == 'INFO':
54
+ # Highlight numbers and percentages
55
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
56
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
57
+ return message
58
+
59
+ def setup_default_logging():
60
+ handler = logging.StreamHandler()
61
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
62
+ logging.basicConfig(
63
+ level=logging.INFO,
64
+ handlers=[handler]
65
+ )
66
+
67
+ setup_default_logging()
68
+ logger = logging.getLogger(__name__)
69
+
70
+ def get_base_dir():
71
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
72
+ if os.environ.get("NANOCHAT_BASE_DIR"):
73
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
74
+ else:
75
+ home_dir = os.path.expanduser("~")
76
+ cache_dir = os.path.join(home_dir, ".cache")
77
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
78
+ os.makedirs(nanochat_dir, exist_ok=True)
79
+ return nanochat_dir
80
+
81
+ def download_file_with_lock(url, filename, postprocess_fn=None):
82
+ """
83
+ Downloads a file from a URL to a local path in the base directory.
84
+ Uses a lock file to prevent concurrent downloads among multiple ranks.
85
+ """
86
+ base_dir = get_base_dir()
87
+ file_path = os.path.join(base_dir, filename)
88
+ lock_path = file_path + ".lock"
89
+
90
+ if os.path.exists(file_path):
91
+ return file_path
92
+
93
+ with FileLock(lock_path):
94
+ # Only a single rank can acquire this lock
95
+ # All other ranks block until it is released
96
+
97
+ # Recheck after acquiring lock
98
+ if os.path.exists(file_path):
99
+ return file_path
100
+
101
+ # Download the content as bytes
102
+ print(f"Downloading {url}...")
103
+ with urllib.request.urlopen(url) as response:
104
+ content = response.read() # bytes
105
+
106
+ # Write to local file
107
+ with open(file_path, 'wb') as f:
108
+ f.write(content)
109
+ print(f"Downloaded to {file_path}")
110
+
111
+ # Run the postprocess function if provided
112
+ if postprocess_fn is not None:
113
+ postprocess_fn(file_path)
114
+
115
+ return file_path
116
+
117
+ def print0(s="",**kwargs):
118
+ ddp_rank = int(os.environ.get('RANK', 0))
119
+ if ddp_rank == 0:
120
+ print(s, **kwargs)
121
+
122
+ def print_banner():
123
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
124
+ banner = """
125
+ █████ █████
126
+ ░░███ ░░███
127
+ ████████ ██████ ██��█████ ██████ ██████ ░███████ ██████ ███████
128
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
129
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
130
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
131
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
132
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
133
+ """
134
+ print0(banner)
135
+
136
+ def is_ddp_requested() -> bool:
137
+ """
138
+ True if launched by torchrun (env present), even before init.
139
+ Used to decide whether we *should* initialize a PG.
140
+ """
141
+ return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
142
+
143
+ def is_ddp_initialized() -> bool:
144
+ """
145
+ True if torch.distributed is available and the process group is initialized.
146
+ Used at cleanup to avoid destroying a non-existent PG.
147
+ """
148
+ return dist.is_available() and dist.is_initialized()
149
+
150
+ def get_dist_info():
151
+ if is_ddp_requested():
152
+ # We rely on torchrun's env to decide if we SHOULD init.
153
+ # (Initialization itself happens in compute init.)
154
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
155
+ ddp_rank = int(os.environ['RANK'])
156
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
157
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
158
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
159
+ else:
160
+ return False, 0, 0, 1
161
+
162
+ def autodetect_device_type():
163
+ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
164
+ if torch.cuda.is_available():
165
+ device_type = "cuda"
166
+ elif torch.backends.mps.is_available():
167
+ device_type = "mps"
168
+ else:
169
+ device_type = "cpu"
170
+ print0(f"Autodetected device type: {device_type}")
171
+ return device_type
172
+
173
+ def compute_init(device_type="cuda"): # cuda|cpu|mps
174
+ """Basic initialization that we keep doing over and over, so make common."""
175
+
176
+ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
177
+ if device_type == "cuda":
178
+ assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
179
+ if device_type == "mps":
180
+ assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
181
+
182
+ # Reproducibility
183
+ # Note that we set the global seeds here, but most of the code uses explicit rng objects.
184
+ # The only place where global rng might be used is nn.Module initialization of the model weights.
185
+ torch.manual_seed(42)
186
+ if device_type == "cuda":
187
+ torch.cuda.manual_seed(42)
188
+ # skipping full reproducibility for now, possibly investigate slowdown later
189
+ # torch.use_deterministic_algorithms(True)
190
+
191
+ # Precision
192
+ if device_type == "cuda":
193
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
194
+
195
+ # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
196
+ is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
197
+ if is_ddp_requested and device_type == "cuda":
198
+ device = torch.device("cuda", ddp_local_rank)
199
+ torch.cuda.set_device(device) # make "cuda" default to this device
200
+ dist.init_process_group(backend="nccl", device_id=device)
201
+ dist.barrier()
202
+ else:
203
+ device = torch.device(device_type) # mps|cpu
204
+
205
+ if ddp_rank == 0:
206
+ logger.info(f"Distributed world size: {ddp_world_size}")
207
+
208
+ return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
209
+
210
+ def compute_cleanup():
211
+ """Companion function to compute_init, to clean things up before script exit"""
212
+ if is_ddp_initialized():
213
+ dist.destroy_process_group()
214
+
215
+ class DummyWandb:
216
+ """Useful if we wish to not use wandb but have all the same signatures"""
217
+ def __init__(self):
218
+ pass
219
+ def log(self, *args, **kwargs):
220
+ pass
221
+ def finish(self):
222
+ pass
223
+
224
+ # hardcoded BF16 peak flops for various GPUs
225
+ # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
226
+ # and PR: https://github.com/karpathy/nanochat/pull/147
227
+ def get_peak_flops(device_name: str) -> float:
228
+ name = device_name.lower()
229
+
230
+ # Table order matters: more specific patterns first.
231
+ _PEAK_FLOPS_TABLE = (
232
+ # NVIDIA Blackwell
233
+ (["gb200"], 2.5e15),
234
+ (["grace blackwell"], 2.5e15),
235
+ (["b200"], 2.25e15),
236
+ (["b100"], 1.8e15),
237
+ # NVIDIA Hopper
238
+ (["h200", "nvl"], 836e12),
239
+ (["h200", "pcie"], 836e12),
240
+ (["h200"], 989e12),
241
+ (["h100", "nvl"], 835e12),
242
+ (["h100", "pcie"], 756e12),
243
+ (["h100"], 989e12),
244
+ (["h800", "nvl"], 989e12),
245
+ (["h800"], 756e12),
246
+ # NVIDIA Ampere data center
247
+ (["a100"], 312e12),
248
+ (["a800"], 312e12),
249
+ (["a40"], 149.7e12),
250
+ (["a30"], 165e12),
251
+ # NVIDIA Ada data center
252
+ (["l40s"], 362e12),
253
+ (["l40-s"], 362e12),
254
+ (["l40 s"], 362e12),
255
+ (["l4"], 121e12),
256
+ # AMD CDNA accelerators
257
+ (["mi355"], 2.5e15),
258
+ (["mi325"], 1.3074e15),
259
+ (["mi300x"], 1.3074e15),
260
+ (["mi300a"], 980.6e12),
261
+ (["mi250x"], 383e12),
262
+ (["mi250"], 362.1e12),
263
+ # Consumer RTX
264
+ (["5090"], 209.5e12),
265
+ (["4090"], 165.2e12),
266
+ (["3090"], 71e12),
267
+ )
268
+ for patterns, flops in _PEAK_FLOPS_TABLE:
269
+ if all(p in name for p in patterns):
270
+ return flops
271
+ if "data center gpu max 1550" in name:
272
+ # Ponte Vecchio (PVC) - dynamic based on compute units
273
+ max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
274
+ return 512 * max_comp_units * 1300 * 10**6
275
+
276
+ # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
277
+ logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
278
+ return float('inf')
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, num_fewshot)
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distributed dataloaders for pretraining.
3
+
4
+ BOS-aligned bestfit:
5
+ - Every row starts with BOS token
6
+ - Documents packed using best-fit algorithm to minimize cropping
7
+ - When no document fits remaining space, crops a document to fill exactly
8
+ - 100% utilization (no padding), ~35% tokens cropped at T=2048
9
+
10
+ Compared to the original tokenizing_distributed_data_loader:
11
+ BOS-aligned loses ~35% of tokens to cropping, but ensures that
12
+ there are fewer "confusing" tokens in the train/val batches as every token can
13
+ now attend back to the BOS token and sees the full context of the document.
14
+
15
+ Fallback to the original if you have very limited data AND long documents:
16
+ https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
17
+ """
18
+
19
+ import torch
20
+ import pyarrow.parquet as pq
21
+
22
+ from nanochat.common import get_dist_info
23
+ from nanochat.dataset import list_parquet_files
24
+
25
+ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
26
+ """
27
+ Infinite iterator over document batches (list of text strings) from parquet files.
28
+
29
+ Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
30
+ where text_batch is a list of document strings, indices track position for resumption,
31
+ and epoch counts how many times we've cycled through the dataset (starts at 1).
32
+ """
33
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
34
+
35
+ warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy
36
+ parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy)
37
+ assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
38
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
39
+
40
+ resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
41
+ resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
42
+ resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
43
+ first_pass = True
44
+ pq_idx = resume_pq_idx
45
+ epoch = resume_epoch
46
+
47
+ while True: # iterate infinitely (multi-epoch)
48
+ pq_idx = resume_pq_idx if first_pass else 0
49
+ while pq_idx < len(parquet_paths):
50
+ filepath = parquet_paths[pq_idx]
51
+ pf = pq.ParquetFile(filepath)
52
+ # Start from resume point if resuming on same file, otherwise from DDP rank
53
+ if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
54
+ base_idx = resume_rg_idx // ddp_world_size
55
+ base_idx += 1 # advance by 1 so we don't repeat data after resuming
56
+ rg_idx = base_idx * ddp_world_size + ddp_rank
57
+ if rg_idx >= pf.num_row_groups:
58
+ pq_idx += 1
59
+ continue
60
+ resume_rg_idx = None # only do this once
61
+ else:
62
+ rg_idx = ddp_rank
63
+ while rg_idx < pf.num_row_groups:
64
+ rg = pf.read_row_group(rg_idx)
65
+ batch = rg.column('text').to_pylist()
66
+ for i in range(0, len(batch), tokenizer_batch_size):
67
+ yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
68
+ rg_idx += ddp_world_size
69
+ pq_idx += 1
70
+ first_pass = False
71
+ epoch += 1
72
+
73
+
74
+ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
75
+ tokenizer, B, T, split,
76
+ tokenizer_threads=4, tokenizer_batch_size=128,
77
+ device="cuda", resume_state_dict=None,
78
+ buffer_size=1000
79
+ ):
80
+ """
81
+ BOS-aligned dataloader with Best-Fit Cropping.
82
+
83
+ Reduces token waste compared to simple greedy cropping by searching a buffer
84
+ for documents that fit well, while maintaining 100% utilization (no padding).
85
+
86
+ Algorithm for each row:
87
+ 1. From buffered docs, pick the LARGEST doc that fits entirely
88
+ 2. Repeat until no doc fits
89
+ 3. When nothing fits, crop a doc to fill remaining space exactly
90
+
91
+ Key properties:
92
+ - Every row starts with BOS
93
+ - 100% utilization (no padding, every token is trained on)
94
+ - Approximately 35% of all tokens are discarded due to cropping
95
+ """
96
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
97
+
98
+ row_capacity = T + 1
99
+ batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
100
+ bos_token = tokenizer.get_bos_token_id()
101
+ doc_buffer = []
102
+ pq_idx, rg_idx, epoch = 0, 0, 1
103
+
104
+ def refill_buffer():
105
+ nonlocal pq_idx, rg_idx, epoch
106
+ doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
107
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
108
+ for tokens in token_lists:
109
+ doc_buffer.append(tokens)
110
+
111
+ # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
112
+ # This gives us contiguous views and a single HtoD transfer
113
+ use_cuda = device == "cuda"
114
+ row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
115
+ cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
116
+ gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
117
+ cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
118
+ cpu_targets = cpu_buffer[B * T:].view(B, T)
119
+ inputs = gpu_buffer[:B * T].view(B, T)
120
+ targets = gpu_buffer[B * T:].view(B, T)
121
+
122
+ while True:
123
+ for row_idx in range(B):
124
+ pos = 0
125
+ while pos < row_capacity:
126
+ # Ensure buffer has documents
127
+ while len(doc_buffer) < buffer_size:
128
+ refill_buffer()
129
+
130
+ remaining = row_capacity - pos
131
+
132
+ # Find largest doc that fits entirely
133
+ best_idx = -1
134
+ best_len = 0
135
+ for i, doc in enumerate(doc_buffer):
136
+ doc_len = len(doc)
137
+ if doc_len <= remaining and doc_len > best_len:
138
+ best_idx = i
139
+ best_len = doc_len
140
+
141
+ if best_idx >= 0:
142
+ doc = doc_buffer.pop(best_idx)
143
+ doc_len = len(doc)
144
+ row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
145
+ pos += doc_len
146
+ else:
147
+ # No doc fits - crop shortest in buffer to fill remaining and minimize waste
148
+ shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
149
+ doc = doc_buffer.pop(shortest_idx)
150
+ row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
151
+ pos += remaining
152
+
153
+ # Copy to pinned CPU buffer, then single HtoD transfer
154
+ cpu_inputs.copy_(row_buffer[:, :-1])
155
+ cpu_targets.copy_(row_buffer[:, 1:])
156
+
157
+ state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
158
+
159
+ # Single HtoD copy into persistent GPU buffer and yield
160
+ gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
161
+ yield inputs, targets, state_dict
162
+
163
+ def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
164
+ """Helper that omits state_dict from yields."""
165
+ for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
166
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
24
+ MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data_climbmix")
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # These functions are useful utilities to other modules, can/should be imported
31
+
32
+ def list_parquet_files(data_dir=None, warn_on_legacy=False):
33
+ """ Looks into a data dir and returns full paths to all parquet files. """
34
+ data_dir = DATA_DIR if data_dir is None else data_dir
35
+
36
+ # Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B
37
+ # This code will eventually be deleted.
38
+ if not os.path.exists(data_dir):
39
+ if warn_on_legacy:
40
+ print()
41
+ print("=" * 80)
42
+ print(" WARNING: DATASET UPGRADE REQUIRED")
43
+ print("=" * 80)
44
+ print()
45
+ print(f" Could not find: {data_dir}")
46
+ print()
47
+ print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.")
48
+ print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.")
49
+ print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:")
50
+ print()
51
+ print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired")
52
+ print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data")
53
+ print()
54
+ print(" For now, falling back to your old FinewebEdu-100B dataset...")
55
+ print("=" * 80)
56
+ print()
57
+ # attempt a fallback to the legacy data directory
58
+ data_dir = os.path.join(base_dir, "base_data")
59
+
60
+ parquet_files = sorted([
61
+ f for f in os.listdir(data_dir)
62
+ if f.endswith('.parquet') and not f.endswith('.tmp')
63
+ ])
64
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
65
+ return parquet_paths
66
+
67
+ def parquets_iter_batched(split, start=0, step=1):
68
+ """
69
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
70
+ - split can be "train" or "val". the last parquet file will be val.
71
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
72
+ """
73
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
74
+ parquet_paths = list_parquet_files()
75
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
76
+ for filepath in parquet_paths:
77
+ pf = pq.ParquetFile(filepath)
78
+ for rg_idx in range(start, pf.num_row_groups, step):
79
+ rg = pf.read_row_group(rg_idx)
80
+ texts = rg.column('text').to_pylist()
81
+ yield texts
82
+
83
+ # -----------------------------------------------------------------------------
84
+ def download_single_file(index):
85
+ """ Downloads a single file index, with some backoff """
86
+
87
+ # Construct the local filepath for this file and skip if it already exists
88
+ filename = index_to_filename(index)
89
+ filepath = os.path.join(DATA_DIR, filename)
90
+ if os.path.exists(filepath):
91
+ print(f"Skipping {filepath} (already exists)")
92
+ return True
93
+
94
+ # Construct the remote URL for this file
95
+ url = f"{BASE_URL}/{filename}"
96
+ print(f"Downloading {filename}...")
97
+
98
+ # Download with retries
99
+ max_attempts = 5
100
+ for attempt in range(1, max_attempts + 1):
101
+ try:
102
+ response = requests.get(url, stream=True, timeout=30)
103
+ response.raise_for_status()
104
+ # Write to temporary file first
105
+ temp_path = filepath + f".tmp"
106
+ with open(temp_path, 'wb') as f:
107
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
108
+ if chunk:
109
+ f.write(chunk)
110
+ # Move temp file to final location
111
+ os.rename(temp_path, filepath)
112
+ print(f"Successfully downloaded {filename}")
113
+ return True
114
+
115
+ except (requests.RequestException, IOError) as e:
116
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
117
+ # Clean up any partial files
118
+ for path in [filepath + f".tmp", filepath]:
119
+ if os.path.exists(path):
120
+ try:
121
+ os.remove(path)
122
+ except:
123
+ pass
124
+ # Try a few times with exponential backoff: 2^attempt seconds
125
+ if attempt < max_attempts:
126
+ wait_time = 2 ** attempt
127
+ print(f"Waiting {wait_time} seconds before retry...")
128
+ time.sleep(wait_time)
129
+ else:
130
+ print(f"Failed to download {filename} after {max_attempts} attempts")
131
+ return False
132
+
133
+ return False
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
138
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable")
139
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
140
+ args = parser.parse_args()
141
+
142
+ # Prepare the output directory
143
+ os.makedirs(DATA_DIR, exist_ok=True)
144
+
145
+ # The way this works is that the user specifies the number of train shards to download via the -n flag.
146
+ # In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
147
+ num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
148
+ ids_to_download = list(range(num_train_shards))
149
+ ids_to_download.append(MAX_SHARD) # always download the validation shard
150
+
151
+ # Download the shards
152
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
153
+ print(f"Target directory: {DATA_DIR}")
154
+ print()
155
+ with Pool(processes=args.num_workers) as pool:
156
+ results = pool.map(download_single_file, ids_to_download)
157
+
158
+ # Report results
159
+ successful = sum(1 for success in results if success)
160
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/engine.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init, autodetect_device_type
21
+ from nanochat.checkpoint_manager import load_model
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Calculator tool helpers
25
+ @contextmanager
26
+ def timeout(duration, formula):
27
+ def timeout_handler(signum, frame):
28
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
29
+
30
+ signal.signal(signal.SIGALRM, timeout_handler)
31
+ signal.alarm(duration)
32
+ yield
33
+ signal.alarm(0)
34
+
35
+ def eval_with_timeout(formula, max_time=3):
36
+ try:
37
+ with timeout(max_time, formula):
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter("ignore", SyntaxWarning)
40
+ return eval(formula, {"__builtins__": {}}, {})
41
+ except Exception as e:
42
+ signal.alarm(0)
43
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
44
+ return None
45
+
46
+ def use_calculator(expr):
47
+ """
48
+ Evaluate a Python expression safely.
49
+ Supports both math expressions and string operations like .count()
50
+ """
51
+ # Remove commas from numbers
52
+ expr = expr.replace(",", "")
53
+
54
+ # Check if it's a pure math expression (old behavior)
55
+ if all([x in "0123456789*+-/.() " for x in expr]):
56
+ if "**" in expr: # disallow power operator
57
+ return None
58
+ return eval_with_timeout(expr)
59
+
60
+ # Check if it's a string operation we support
61
+ # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
62
+ allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
63
+ if not all([x in allowed_chars for x in expr]):
64
+ return None
65
+
66
+ # Disallow dangerous patterns
67
+ dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
68
+ 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
69
+ 'getattr', 'setattr', 'delattr', 'hasattr']
70
+ expr_lower = expr.lower()
71
+ if any(pattern in expr_lower for pattern in dangerous_patterns):
72
+ return None
73
+
74
+ # Only allow .count() method for now (can expand later)
75
+ if '.count(' not in expr:
76
+ return None
77
+
78
+ # Evaluate with timeout
79
+ return eval_with_timeout(expr)
80
+
81
+ # -----------------------------------------------------------------------------
82
+ class KVCache:
83
+ """
84
+ KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
85
+
86
+ Key differences from FA2-style cache:
87
+ - Tensors are (B, T, H, D) not (B, H, T, D)
88
+ - FA3 updates the cache in-place during flash_attn_with_kvcache
89
+ - Position tracked per batch element via cache_seqlens tensor
90
+ """
91
+
92
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
93
+ self.batch_size = batch_size
94
+ self.max_seq_len = seq_len
95
+ self.n_layers = num_layers
96
+ self.n_heads = num_heads
97
+ self.head_dim = head_dim
98
+ # Pre-allocate cache tensors: (n_layers, B, T, H, D)
99
+ self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
100
+ self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
101
+ # Current sequence length per batch element (FA3 needs int32)
102
+ self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
103
+ # Previous token's normalized embedding for smear (set by model forward pass)
104
+ self.prev_embedding = None
105
+
106
+ def reset(self):
107
+ """Reset cache to empty state."""
108
+ self.cache_seqlens.zero_()
109
+ self.prev_embedding = None
110
+
111
+ def get_pos(self):
112
+ """Get current position (assumes all batch elements at same position)."""
113
+ return self.cache_seqlens[0].item()
114
+
115
+ def get_layer_cache(self, layer_idx):
116
+ """Return (k_cache, v_cache) views for a specific layer."""
117
+ return self.k_cache[layer_idx], self.v_cache[layer_idx]
118
+
119
+ def advance(self, num_tokens):
120
+ """Advance the cache position by num_tokens."""
121
+ self.cache_seqlens += num_tokens
122
+
123
+ def prefill(self, other):
124
+ """
125
+ Copy cached KV from another cache into this one.
126
+ Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
127
+ """
128
+ assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
129
+ assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
130
+ assert self.max_seq_len >= other.max_seq_len
131
+ other_pos = other.get_pos()
132
+ self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
133
+ self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
134
+ self.cache_seqlens.fill_(other_pos)
135
+ # Copy smear state: expand batch=1 prev_embedding to num_samples
136
+ if other.prev_embedding is not None:
137
+ self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
138
+
139
+ # -----------------------------------------------------------------------------
140
+ @torch.inference_mode()
141
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
142
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
143
+ assert temperature >= 0.0, "temperature must be non-negative"
144
+ if temperature == 0.0:
145
+ return torch.argmax(logits, dim=-1, keepdim=True)
146
+ if top_k is not None and top_k > 0:
147
+ k = min(top_k, logits.size(-1))
148
+ vals, idx = torch.topk(logits, k, dim=-1)
149
+ vals = vals / temperature
150
+ probs = F.softmax(vals, dim=-1)
151
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
152
+ return idx.gather(1, choice)
153
+ else:
154
+ logits = logits / temperature
155
+ probs = F.softmax(logits, dim=-1)
156
+ return torch.multinomial(probs, num_samples=1, generator=rng)
157
+
158
+ # -----------------------------------------------------------------------------
159
+
160
+ class RowState:
161
+ # Per-row state tracking during generation
162
+ def __init__(self, current_tokens=None):
163
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
164
+ self.forced_tokens = deque() # Queue of tokens to force inject
165
+ self.in_python_block = False # Whether we are inside a python block
166
+ self.python_expr_tokens = [] # Tokens of the current python expression
167
+ self.completed = False # Whether this row has completed generation
168
+
169
+ class Engine:
170
+
171
+ def __init__(self, model, tokenizer):
172
+ self.model = model
173
+ self.tokenizer = tokenizer # needed for tool use
174
+
175
+ @torch.inference_mode()
176
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
177
+ """Same as generate, but does single prefill and then clones the KV cache."""
178
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
179
+ device = self.model.get_device()
180
+ # NOTE: setting the dtype here and in this way is an ugly hack.
181
+ # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
182
+ # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
183
+ # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
184
+ # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
185
+ # In particular, the KVCache should allocate its tensors lazily
186
+ dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
187
+ rng = torch.Generator(device=device)
188
+ rng.manual_seed(seed)
189
+
190
+ # Get the special tokens we need to coordinate the tool use state machine
191
+ get_special = lambda s: self.tokenizer.encode_special(s)
192
+ python_start = get_special("<|python_start|>")
193
+ python_end = get_special("<|python_end|>")
194
+ output_start = get_special("<|output_start|>")
195
+ output_end = get_special("<|output_end|>")
196
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
197
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
198
+
199
+ # 1) Run a batch 1 prefill of the prompt tokens
200
+ m = self.model.config
201
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
202
+ kv_cache_prefill = KVCache(
203
+ batch_size=1,
204
+ seq_len=len(tokens),
205
+ device=device,
206
+ dtype=dtype,
207
+ **kv_model_kwargs,
208
+ )
209
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
210
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
211
+ logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
212
+
213
+ # 2) Replicate the KV cache for each sample/row
214
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
215
+ kv_cache_decode = KVCache(
216
+ batch_size=num_samples,
217
+ seq_len=kv_length_hint,
218
+ device=device,
219
+ dtype=dtype,
220
+ **kv_model_kwargs,
221
+ )
222
+ kv_cache_decode.prefill(kv_cache_prefill)
223
+ del kv_cache_prefill # no need to keep this memory around
224
+
225
+ # 3) Initialize states for each sample
226
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
227
+
228
+ # 4) Main generation loop
229
+ num_generated = 0
230
+ while True:
231
+ # Stop condition: we've reached max tokens
232
+ if max_tokens is not None and num_generated >= max_tokens:
233
+ break
234
+ # Stop condition: all rows are completed
235
+ if all(state.completed for state in row_states):
236
+ break
237
+
238
+ # Sample the next token for each row
239
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
240
+ sampled_tokens = next_ids[:, 0].tolist()
241
+
242
+ # Process each row: choose the next token, update state, optional tool use
243
+ token_column = [] # contains the next token id along each row
244
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
245
+ for i, state in enumerate(row_states):
246
+ # Select the next token in this row
247
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
248
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
249
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
250
+ token_column.append(next_token)
251
+ # Update the state of this row to include the next token
252
+ state.current_tokens.append(next_token)
253
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
254
+ if next_token == assistant_end or next_token == bos:
255
+ state.completed = True
256
+ # Handle tool logic
257
+ if next_token == python_start:
258
+ state.in_python_block = True
259
+ state.python_expr_tokens = []
260
+ elif next_token == python_end and state.in_python_block:
261
+ state.in_python_block = False
262
+ if state.python_expr_tokens:
263
+ expr = self.tokenizer.decode(state.python_expr_tokens)
264
+ result = use_calculator(expr)
265
+ if result is not None:
266
+ result_tokens = self.tokenizer.encode(str(result))
267
+ state.forced_tokens.append(output_start)
268
+ state.forced_tokens.extend(result_tokens)
269
+ state.forced_tokens.append(output_end)
270
+ state.python_expr_tokens = []
271
+ elif state.in_python_block:
272
+ state.python_expr_tokens.append(next_token)
273
+
274
+ # Yield the token column
275
+ yield token_column, token_masks
276
+ num_generated += 1
277
+
278
+ # Prepare logits for next iteration
279
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
280
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
281
+
282
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
283
+ """
284
+ Non-streaming batch generation that just returns the final token sequences.
285
+ Returns a list of token sequences (list of lists of ints).
286
+ Terminal tokens (assistant_end, bos) are not included in the results.
287
+ """
288
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
289
+ bos = self.tokenizer.get_bos_token_id()
290
+ results = [tokens.copy() for _ in range(num_samples)]
291
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
292
+ completed = [False] * num_samples
293
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
294
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
295
+ if not completed[i]:
296
+ if token == assistant_end or token == bos:
297
+ completed[i] = True
298
+ else:
299
+ results[i].append(token)
300
+ masks[i].append(mask)
301
+ # Stop if all rows are completed
302
+ if all(completed):
303
+ break
304
+ return results, masks
305
+
306
+
307
+ if __name__ == "__main__":
308
+ """
309
+ Quick inline test to make sure that the naive/slow model.generate function
310
+ is equivalent to the faster Engine.generate function here.
311
+ """
312
+ import time
313
+ # init compute
314
+ device_type = autodetect_device_type()
315
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
316
+ # load the model and tokenizer
317
+ model, tokenizer, meta = load_model("base", device, phase="eval")
318
+ bos_token_id = tokenizer.get_bos_token_id()
319
+ # common hyperparameters
320
+ kwargs = dict(max_tokens=64, temperature=0.0)
321
+ # set the starting prompt
322
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
323
+ # generate the reference sequence using the model.generate() function
324
+ generated_tokens = []
325
+ torch.cuda.synchronize()
326
+ t0 = time.time()
327
+ stream = model.generate(prompt_tokens, **kwargs)
328
+ for token in stream:
329
+ generated_tokens.append(token)
330
+ chunk = tokenizer.decode([token])
331
+ print(chunk, end="", flush=True)
332
+ print()
333
+ torch.cuda.synchronize()
334
+ t1 = time.time()
335
+ print(f"Reference time: {t1 - t0:.2f}s")
336
+ reference_ids = generated_tokens
337
+ # generate tokens with Engine
338
+ generated_tokens = []
339
+ engine = Engine(model, tokenizer)
340
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
341
+ torch.cuda.synchronize()
342
+ t0 = time.time()
343
+ for token_column, token_masks in stream:
344
+ token = token_column[0] # only print out the first row
345
+ generated_tokens.append(token)
346
+ chunk = tokenizer.decode([token])
347
+ print(chunk, end="", flush=True)
348
+ print()
349
+ torch.cuda.synchronize()
350
+ t1 = time.time()
351
+ print(f"Engine time: {t1 - t0:.2f}s")
352
+ # compare the two sequences
353
+ for i in range(len(reference_ids)):
354
+ if reference_ids[i] != generated_tokens[i]:
355
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
356
+ break
357
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ finally:
131
+ os.chdir(cwd)
132
+
133
+
134
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
135
+ """
136
+ This disables various destructive functions and prevents the generated code
137
+ from interfering with the test (e.g. fork bomb, killing other processes,
138
+ removing filesystem files, etc.)
139
+
140
+ WARNING
141
+ This function is NOT a security sandbox. Untrusted code, including, model-
142
+ generated code, should not be blindly executed outside of one. See the
143
+ Codex paper for more information about OpenAI's code sandbox, and proceed
144
+ with caution.
145
+ """
146
+
147
+ if platform.uname().system != "Darwin":
148
+ # These resource limit calls seem to fail on macOS (Darwin), skip?
149
+ import resource
150
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
151
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
152
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
153
+
154
+ faulthandler.disable()
155
+
156
+ import builtins
157
+
158
+ builtins.exit = None
159
+ builtins.quit = None
160
+
161
+ import os
162
+
163
+ os.environ["OMP_NUM_THREADS"] = "1"
164
+
165
+ os.kill = None
166
+ os.system = None
167
+ os.putenv = None
168
+ os.remove = None
169
+ os.removedirs = None
170
+ os.rmdir = None
171
+ os.fchdir = None
172
+ os.setuid = None
173
+ os.fork = None
174
+ os.forkpty = None
175
+ os.killpg = None
176
+ os.rename = None
177
+ os.renames = None
178
+ os.truncate = None
179
+ os.replace = None
180
+ os.unlink = None
181
+ os.fchmod = None
182
+ os.fchown = None
183
+ os.chmod = None
184
+ os.chown = None
185
+ os.chroot = None
186
+ os.fchdir = None
187
+ os.lchflags = None
188
+ os.lchmod = None
189
+ os.lchown = None
190
+ os.getcwd = None
191
+ os.chdir = None
192
+
193
+ import shutil
194
+
195
+ shutil.rmtree = None
196
+ shutil.move = None
197
+ shutil.chown = None
198
+
199
+ import subprocess
200
+
201
+ subprocess.Popen = None # type: ignore
202
+
203
+ __builtins__["help"] = None
204
+
205
+ import sys
206
+
207
+ sys.modules["ipdb"] = None
208
+ sys.modules["joblib"] = None
209
+ sys.modules["resource"] = None
210
+ sys.modules["psutil"] = None
211
+ sys.modules["tkinter"] = None
212
+
213
+
214
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
215
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
216
+ with create_tempdir():
217
+
218
+ # These system calls are needed when cleaning up tempdir.
219
+ import os
220
+ import shutil
221
+
222
+ rmtree = shutil.rmtree
223
+ rmdir = os.rmdir
224
+ chdir = os.chdir
225
+ unlink = os.unlink
226
+
227
+ # Disable functionalities that can make destructive changes to the test.
228
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
229
+
230
+ # Default to failure
231
+ result_dict.update({
232
+ "success": False,
233
+ "stdout": "",
234
+ "stderr": "",
235
+ "timeout": False,
236
+ "memory_exceeded": False,
237
+ "error": None,
238
+ })
239
+
240
+ try:
241
+ exec_globals = {}
242
+ with capture_io() as (stdout_capture, stderr_capture):
243
+ with time_limit(timeout):
244
+ # WARNING
245
+ # This program exists to execute untrusted model-generated code. Although
246
+ # it is highly unlikely that model-generated code will do something overtly
247
+ # malicious in response to this test suite, model-generated code may act
248
+ # destructively due to a lack of model capability or alignment.
249
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
250
+ # does not perform destructive actions on their host or network. For more
251
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
252
+ # Once you have read this disclaimer and taken appropriate precautions,
253
+ # uncomment the following line and proceed at your own risk:
254
+ exec(code, exec_globals)
255
+
256
+ result_dict.update({
257
+ "success": True,
258
+ "stdout": stdout_capture.getvalue(),
259
+ "stderr": stderr_capture.getvalue(),
260
+ })
261
+
262
+ except TimeoutException:
263
+ result_dict.update({
264
+ "timeout": True,
265
+ "error": "Execution timed out",
266
+ })
267
+
268
+ except MemoryError as e:
269
+ result_dict.update({
270
+ "memory_exceeded": True,
271
+ "error": f"Memory limit exceeded: {e}",
272
+ })
273
+
274
+ except BaseException as e:
275
+ result_dict.update({
276
+ "error": f"{type(e).__name__}: {e}",
277
+ })
278
+
279
+ # Needed for cleaning up.
280
+ shutil.rmtree = rmtree
281
+ os.rmdir = rmdir
282
+ os.chdir = chdir
283
+ os.unlink = unlink
284
+
285
+
286
+ def execute_code(
287
+ code: str,
288
+ timeout: float = 5.0, # 5 seconds default
289
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
290
+ ) -> ExecutionResult:
291
+ """
292
+ Execute Python code in a sandboxed environment.
293
+
294
+ Args:
295
+ code: Python code to execute as a string
296
+ timeout: Maximum execution time in seconds (default: 5.0)
297
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
298
+
299
+ Returns:
300
+ ExecutionResult with success status, stdout/stderr, and error information
301
+
302
+ Example:
303
+ >>> result = execute_code("print('hello world')")
304
+ >>> result.success
305
+ True
306
+ >>> result.stdout
307
+ 'hello world\\n'
308
+ """
309
+
310
+ manager = multiprocessing.Manager()
311
+ result_dict = manager.dict()
312
+
313
+ p = multiprocessing.Process(
314
+ target=_unsafe_execute,
315
+ args=(code, timeout, maximum_memory_bytes, result_dict)
316
+ )
317
+ p.start()
318
+ p.join(timeout=timeout + 1)
319
+
320
+ if p.is_alive():
321
+ p.kill()
322
+ return ExecutionResult(
323
+ success=False,
324
+ stdout="",
325
+ stderr="",
326
+ error="Execution timed out (process killed)",
327
+ timeout=True,
328
+ memory_exceeded=False,
329
+ )
330
+
331
+ if not result_dict:
332
+ return ExecutionResult(
333
+ success=False,
334
+ stdout="",
335
+ stderr="",
336
+ error="Execution failed (no result returned)",
337
+ timeout=True,
338
+ memory_exceeded=False,
339
+ )
340
+
341
+ return ExecutionResult(
342
+ success=result_dict["success"],
343
+ stdout=result_dict["stdout"],
344
+ stderr=result_dict["stderr"],
345
+ error=result_dict["error"],
346
+ timeout=result_dict["timeout"],
347
+ memory_exceeded=result_dict["memory_exceeded"],
348
+ )
349
+
nanochat/flash_attention.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Flash Attention interface with automatic FA3/SDPA switching.
3
+
4
+ Exports `flash_attn` module that matches the FA3 API exactly, but falls back
5
+ to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
6
+
7
+ Usage (drop-in replacement for FA3):
8
+ from nanochat.flash_attention import flash_attn
9
+
10
+ # Training (no KV cache)
11
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
12
+
13
+ # Inference (with KV cache)
14
+ y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
15
+ """
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # =============================================================================
21
+ # Detection: Try to load FA3 on Hopper+ GPUs
22
+ # =============================================================================
23
+ def _load_flash_attention_3():
24
+ """Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
25
+ if not torch.cuda.is_available():
26
+ return None
27
+ try:
28
+ major, _ = torch.cuda.get_device_capability()
29
+ # FA3 kernels are compiled for Hopper (sm90) only
30
+ # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
31
+ if major != 9:
32
+ return None
33
+ import os
34
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
35
+ from kernels import get_kernel
36
+ return get_kernel('varunneal/flash-attention-3').flash_attn_interface
37
+ except Exception:
38
+ return None
39
+
40
+
41
+ _fa3 = _load_flash_attention_3()
42
+ HAS_FA3 = _fa3 is not None
43
+
44
+ # Override for testing: set to 'fa3', 'sdpa', or None (auto)
45
+ _override_impl = None
46
+
47
+
48
+ def _resolve_use_fa3():
49
+ """Decide once whether to use FA3, based on availability, override, and dtype."""
50
+ if _override_impl == 'fa3':
51
+ assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
52
+ return True
53
+ if _override_impl == 'sdpa':
54
+ return False
55
+ if HAS_FA3:
56
+ # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
57
+ from nanochat.common import COMPUTE_DTYPE
58
+ if COMPUTE_DTYPE == torch.bfloat16:
59
+ return True
60
+ return False
61
+ return False
62
+
63
+ USE_FA3 = _resolve_use_fa3()
64
+
65
+
66
+ # =============================================================================
67
+ # SDPA helpers
68
+ # =============================================================================
69
+ def _sdpa_attention(q, k, v, window_size, enable_gqa):
70
+ """
71
+ SDPA attention with sliding window support.
72
+ q, k, v are (B, H, T, D) format.
73
+ """
74
+ Tq = q.size(2)
75
+ Tk = k.size(2)
76
+ window = window_size[0]
77
+
78
+ # Full context, same length
79
+ if (window < 0 or window >= Tq) and Tq == Tk:
80
+ return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
81
+
82
+ # Single token generation
83
+ if Tq == 1:
84
+ if window >= 0 and window < Tk:
85
+ # window is "left" tokens we need to include (window + 1) keys total
86
+ start = max(0, Tk - (window + 1))
87
+ k = k[:, :, start:, :]
88
+ v = v[:, :, start:, :]
89
+ return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
90
+
91
+ # Need explicit mask for sliding window/chunk inference
92
+ device = q.device
93
+ # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
94
+ row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
95
+ col_idx = torch.arange(Tk, device=device).unsqueeze(0)
96
+ mask = col_idx <= row_idx
97
+
98
+ # sliding window (left)
99
+ if window >= 0 and window < Tk:
100
+ mask = mask & ((row_idx - col_idx) <= window)
101
+
102
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
103
+
104
+ # =============================================================================
105
+ # Public API: Same interface as FA3
106
+ # =============================================================================
107
+ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
108
+ """
109
+ Flash Attention for training (no KV cache).
110
+
111
+ Args:
112
+ q, k, v: Tensors of shape (B, T, H, D)
113
+ causal: Whether to use causal masking
114
+ window_size: (left, right) sliding window. -1 means unlimited.
115
+
116
+ Returns:
117
+ Output tensor of shape (B, T, H, D)
118
+ """
119
+ if USE_FA3:
120
+ return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
121
+
122
+ # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
123
+ q = q.transpose(1, 2)
124
+ k = k.transpose(1, 2)
125
+ v = v.transpose(1, 2)
126
+ enable_gqa = q.size(1) != k.size(1)
127
+ y = _sdpa_attention(q, k, v, window_size, enable_gqa)
128
+ return y.transpose(1, 2) # back to (B, T, H, D)
129
+
130
+
131
+ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
132
+ causal=False, window_size=(-1, -1)):
133
+ """
134
+ Flash Attention with KV cache for inference.
135
+
136
+ FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
137
+
138
+ Args:
139
+ q: Queries, shape (B, T_new, H, D)
140
+ k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
141
+ k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
142
+ cache_seqlens: Current position in cache, shape (B,) int32
143
+ causal: Whether to use causal masking
144
+ window_size: (left, right) sliding window. -1 means unlimited.
145
+
146
+ Returns:
147
+ Output tensor of shape (B, T_new, H, D)
148
+ """
149
+ if USE_FA3:
150
+ return _fa3.flash_attn_with_kvcache(
151
+ q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
152
+ causal=causal, window_size=window_size
153
+ )
154
+
155
+ # SDPA fallback: manually manage KV cache
156
+ B, T_new, H, D = q.shape
157
+ pos = cache_seqlens[0].item() # assume uniform position across batch
158
+
159
+ # Insert new k, v into cache (in-place, matching FA3 behavior)
160
+ if k is not None and v is not None:
161
+ k_cache[:, pos:pos+T_new, :, :] = k
162
+ v_cache[:, pos:pos+T_new, :, :] = v
163
+
164
+ # Get full cache up to current position + new tokens
165
+ end_pos = pos + T_new
166
+ k_full = k_cache[:, :end_pos, :, :]
167
+ v_full = v_cache[:, :end_pos, :, :]
168
+
169
+ # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
170
+ q_sdpa = q.transpose(1, 2)
171
+ k_sdpa = k_full.transpose(1, 2)
172
+ v_sdpa = v_full.transpose(1, 2)
173
+
174
+ enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
175
+ y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
176
+
177
+ return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
178
+
179
+
180
+ # =============================================================================
181
+ # Export: flash_attn module interface (drop-in replacement for FA3)
182
+ # =============================================================================
183
+ from types import SimpleNamespace
184
+ flash_attn = SimpleNamespace(
185
+ flash_attn_func=flash_attn_func,
186
+ flash_attn_with_kvcache=flash_attn_with_kvcache,
187
+ )
nanochat/fp8.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
2
+
3
+ Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
4
+ We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
5
+ generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
6
+ subclass dispatch tables, etc.)
7
+
8
+ How FP8 training works
9
+ ======================
10
+ A standard Linear layer does one matmul in forward and two in backward:
11
+ forward: output = input @ weight.T
12
+ backward: grad_input = grad_output @ weight
13
+ grad_weight= grad_output.T @ input
14
+
15
+ FP8 training wraps each of these three matmuls with:
16
+ 1. Compute scale = FP8_MAX / max(|tensor|) for each operand
17
+ 2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
18
+ 3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
19
+ 4. Dequantize: _scaled_mm handles this internally using the inverse scales
20
+
21
+ The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
22
+ torchao is just orchestration around these primitives. We can call them directly.
23
+
24
+ FP8 dtype choice
25
+ ================
26
+ There are two FP8 formats. We use both, following the standard convention:
27
+ - float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
28
+ Higher precision (more mantissa bits), used for input and weight.
29
+ - float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
30
+ Wider range (more exponent bits), used for gradients which can be large.
31
+
32
+ torch._scaled_mm layout requirements
33
+ =====================================
34
+ The cuBLAS FP8 kernel requires specific memory layouts:
35
+ - First argument (A): must be row-major (contiguous)
36
+ - Second argument (B): must be column-major (B.t().contiguous().t())
37
+ If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
38
+ already column-major — no copy needed. Otherwise we use _to_col_major().
39
+
40
+ How this differs from torchao's approach
41
+ ========================================
42
+ torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
43
+ of torch.Tensor that bundles FP8 data + scale + metadata. It implements
44
+ __torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
45
+ reshape, clone, ...) and handles it in FP8-aware fashion. When you call
46
+ output = input @ weight.T
47
+ the @ operator dispatches to aten.mm, which gets intercepted and routed to
48
+ torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
49
+ a handler for every tensor operation that might touch an FP8 tensor.
50
+
51
+ We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
52
+ full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
53
+ full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
54
+ opaque node rather than trying to trace inside.
55
+
56
+ The trade-off is in how torch.compile sees the two approaches:
57
+ - torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
58
+ sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
59
+ nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
60
+ amax computation with the preceding layer's activation function).
61
+ - ours: compile sees a single opaque call. It can optimize everything around
62
+ the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
63
+
64
+ Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical.
65
+ The difference is only in the "glue" ops (amax, scale, cast) which are tiny
66
+ compared to the matmul. In practice this means our version is slightly faster
67
+ (less compilation overhead, no tensor subclass dispatch cost) but can produce
68
+ subtly different floating-point rounding paths under torch.compile, since Inductor
69
+ generates a different graph. Numerics are bitwise identical in eager mode.
70
+ """
71
+
72
+ import torch
73
+ import torch.nn as nn
74
+
75
+ from nanochat.common import COMPUTE_DTYPE
76
+
77
+ # Avoid division by zero when computing scale from an all-zeros tensor
78
+ EPS = 1e-12
79
+
80
+
81
+ @torch.no_grad()
82
+ def _to_fp8(x, fp8_dtype):
83
+ """Dynamically quantize a tensor to FP8 using tensorwise scaling.
84
+
85
+ "Tensorwise" means one scalar scale for the entire tensor (as opposed to
86
+ "rowwise" which computes a separate scale per row). Tensorwise is faster
87
+ because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
88
+
89
+ Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
90
+ """
91
+ fp8_max = torch.finfo(fp8_dtype).max
92
+ # Compute the max absolute value across the entire tensor
93
+ amax = x.float().abs().max()
94
+ # Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
95
+ # ensure consistent numerics between torch.compile and eager mode.
96
+ # (torchao does the same upcast — without it, compile/eager can diverge)
97
+ scale = fp8_max / amax.double().clamp(min=EPS)
98
+ scale = scale.float()
99
+ # Quantize: scale into FP8 range, saturate (clamp prevents overflow when
100
+ # casting — PyTorch's default is to wrap, not saturate), then cast to FP8
101
+ x_scaled = x.float() * scale
102
+ x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
103
+ x_fp8 = x_clamped.to(fp8_dtype)
104
+ # _scaled_mm expects the *inverse* of our scale (it multiplies by this to
105
+ # convert FP8 values back to the original range during the matmul)
106
+ inv_scale = scale.reciprocal()
107
+ return x_fp8, inv_scale
108
+
109
+
110
+ def _to_col_major(x):
111
+ """Rearrange a 2D tensor's memory to column-major layout.
112
+
113
+ torch._scaled_mm requires its second operand in column-major layout.
114
+ The trick: transpose -> contiguous (forces a copy in transposed order)
115
+ -> transpose back. The result has the same logical shape but column-major
116
+ strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
117
+ """
118
+ return x.t().contiguous().t()
119
+
120
+
121
+ # allow_in_graph tells torch.compile to treat this as an opaque operation —
122
+ # dynamo won't try to decompose it into smaller ops. See the module docstring
123
+ # for how this differs from torchao's tensor subclass approach.
124
+ @torch._dynamo.allow_in_graph
125
+ class _Float8Matmul(torch.autograd.Function):
126
+ """Custom autograd for the three FP8 GEMMs of a Linear layer.
127
+
128
+ The forward quantizes input and weight to FP8 and saves
129
+ the quantized tensors + scales for backward.
130
+ """
131
+
132
+ @staticmethod
133
+ def forward(ctx, input_2d, weight):
134
+ # Quantize both operands to e4m3 (higher precision format)
135
+ input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
136
+ weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
137
+ ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
138
+
139
+ # output = input @ weight.T
140
+ # input_fp8 is [B, K] contiguous = row-major (good for first arg)
141
+ # weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
142
+ # strides (1, K) = column-major (good for second arg, no copy needed!)
143
+ output = torch._scaled_mm(
144
+ input_fp8,
145
+ weight_fp8.t(),
146
+ scale_a=input_inv,
147
+ scale_b=weight_inv,
148
+ out_dtype=input_2d.dtype,
149
+ # use_fast_accum=True accumulates the dot products in lower precision.
150
+ # Slightly less accurate but measurably faster. Standard practice for
151
+ # the forward pass; we use False in backward for more precise gradients.
152
+ use_fast_accum=True,
153
+ )
154
+ return output
155
+
156
+ @staticmethod
157
+ def backward(ctx, grad_output):
158
+ in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
159
+
160
+ # === GEMM 1: grad_input = grad_output @ weight ===
161
+ # Shapes: [B, N] @ [N, K] -> [B, K]
162
+ # Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
163
+ go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
164
+ # go_fp8 is [B, N] contiguous = row-major, good for first arg
165
+ # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
166
+ w_col = _to_col_major(w_fp8)
167
+ grad_input = torch._scaled_mm(
168
+ go_fp8,
169
+ w_col,
170
+ scale_a=go_inv,
171
+ scale_b=w_inv,
172
+ out_dtype=grad_output.dtype,
173
+ use_fast_accum=False,
174
+ )
175
+
176
+ # === GEMM 2: grad_weight = grad_output.T @ input ===
177
+ # Shapes: [N, B] @ [B, K] -> [N, K]
178
+ # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
179
+ # Transposing gives column-major, but first arg needs row-major,
180
+ # so we must call .contiguous() to physically rearrange the memory.
181
+ go_T = go_fp8.t().contiguous() # [N, B] row-major
182
+ in_col = _to_col_major(in_fp8) # [B, K] column-major
183
+ grad_weight = torch._scaled_mm(
184
+ go_T,
185
+ in_col,
186
+ scale_a=go_inv,
187
+ scale_b=in_inv,
188
+ out_dtype=grad_output.dtype,
189
+ use_fast_accum=False,
190
+ )
191
+
192
+ return grad_input, grad_weight
193
+
194
+
195
+ class Float8Linear(nn.Linear):
196
+ """Drop-in nn.Linear replacement that does FP8 compute.
197
+
198
+ Weights and biases remain in their original precision (e.g. fp32/bf16).
199
+ Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
200
+ """
201
+
202
+ def forward(self, input):
203
+ # Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
204
+ # reduced precision input, and we no longer rely on autocast to do this.
205
+ input = input.to(COMPUTE_DTYPE)
206
+ # _scaled_mm only works on 2D tensors, so flatten batch dimensions
207
+ orig_shape = input.shape
208
+ input_2d = input.reshape(-1, orig_shape[-1])
209
+ output = _Float8Matmul.apply(input_2d, self.weight)
210
+ output = output.reshape(*orig_shape[:-1], output.shape[-1])
211
+ if self.bias is not None:
212
+ output = output + self.bias.to(output.dtype)
213
+ return output
214
+
215
+ @classmethod
216
+ def from_float(cls, mod):
217
+ """Create Float8Linear from nn.Linear, sharing the same weight and bias.
218
+
219
+ Uses meta device to avoid allocating a temporary weight tensor — we
220
+ create the module shell on meta (shapes/dtypes only, no memory), then
221
+ point .weight and .bias to the original module's parameters.
222
+ """
223
+ with torch.device("meta"):
224
+ new_mod = cls(mod.in_features, mod.out_features, bias=False)
225
+ new_mod.weight = mod.weight
226
+ new_mod.bias = mod.bias
227
+ return new_mod
228
+
229
+
230
+ class Float8LinearConfig:
231
+ """Minimal config matching torchao's API. Only tensorwise recipe is supported."""
232
+
233
+ @staticmethod
234
+ def from_recipe_name(recipe_name):
235
+ if recipe_name != "tensorwise":
236
+ raise ValueError(
237
+ f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
238
+ f"Rowwise/axiswise recipes require the full torchao library."
239
+ )
240
+ return Float8LinearConfig()
241
+
242
+
243
+ def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
244
+ """Replace nn.Linear layers with Float8Linear throughout a module.
245
+
246
+ Walks the module tree in post-order (children before parents) and swaps
247
+ each nn.Linear that passes the optional filter. The new Float8Linear shares
248
+ the original weight and bias tensors — no copies, no extra memory.
249
+
250
+ Args:
251
+ module: Root module to convert.
252
+ config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
253
+ module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
254
+ are converted. Common use: skip layers with dims not divisible by 16
255
+ (hardware requirement for FP8 matmuls on H100).
256
+ """
257
+ def _convert(mod, prefix=""):
258
+ for name, child in mod.named_children():
259
+ fqn = f"{prefix}.{name}" if prefix else name
260
+ _convert(child, fqn)
261
+ if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
262
+ if module_filter_fn is None or module_filter_fn(child, fqn):
263
+ setattr(mod, name, Float8Linear.from_float(child))
264
+
265
+ _convert(module)
266
+ return module
nanochat/gpt.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Group-Query Attention (GQA) support for more efficient inference
12
+ - Flash Attention 3 integration
13
+ """
14
+
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
23
+ from nanochat.optim import MuonAdamW, DistMuonAdamW
24
+
25
+ # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
26
+ from nanochat.flash_attention import flash_attn
27
+
28
+ @dataclass
29
+ class GPTConfig:
30
+ sequence_len: int = 2048
31
+ vocab_size: int = 32768
32
+ n_layer: int = 12
33
+ n_head: int = 6 # number of query heads
34
+ n_kv_head: int = 6 # number of key/value heads (GQA)
35
+ n_embd: int = 768
36
+ # Sliding window attention pattern string, tiled across layers. Final layer always L.
37
+ # Characters: L=long (full context), S=short (quarter context)
38
+ # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
39
+ window_pattern: str = "SSSL"
40
+
41
+
42
+ def norm(x):
43
+ return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
44
+
45
+ class Linear(nn.Linear):
46
+ """nn.Linear that casts weights to match input dtype in forward.
47
+ Replaces autocast: master weights stay fp32 for optimizer precision,
48
+ but matmuls run in the activation dtype (typically bf16 from embeddings)."""
49
+ def forward(self, x):
50
+ return F.linear(x, self.weight.to(dtype=x.dtype))
51
+
52
+
53
+ def has_ve(layer_idx, n_layer):
54
+ """Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
55
+ return layer_idx % 2 == (n_layer - 1) % 2
56
+
57
+ def apply_rotary_emb(x, cos, sin):
58
+ assert x.ndim == 4 # multihead attention
59
+ d = x.shape[3] // 2
60
+ x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
61
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
62
+ y2 = x1 * (-sin) + x2 * cos
63
+ return torch.cat([y1, y2], 3)
64
+
65
+ class CausalSelfAttention(nn.Module):
66
+ def __init__(self, config, layer_idx):
67
+ super().__init__()
68
+ self.layer_idx = layer_idx
69
+ self.n_head = config.n_head
70
+ self.n_kv_head = config.n_kv_head
71
+ self.n_embd = config.n_embd
72
+ self.head_dim = self.n_embd // self.n_head
73
+ assert self.n_embd % self.n_head == 0
74
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
75
+ self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
76
+ self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
77
+ self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
78
+ self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
79
+ self.ve_gate_channels = 12
80
+ self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
81
+
82
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
83
+ B, T, C = x.size()
84
+
85
+ # Project the input to get queries, keys, and values
86
+ # Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
87
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
88
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
89
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
90
+
91
+ # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
92
+ if ve is not None:
93
+ ve = ve.view(B, T, self.n_kv_head, self.head_dim)
94
+ gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
95
+ v = v + gate.unsqueeze(-1) * ve
96
+
97
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
98
+ cos, sin = cos_sin
99
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
100
+ q, k = norm(q), norm(k) # QK norm
101
+ q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
102
+ k = k * 1.2
103
+
104
+ # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
105
+ # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
106
+ if kv_cache is None:
107
+ # Training: causal attention with optional sliding window
108
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
109
+ else:
110
+ # Inference: use flash_attn_with_kvcache which handles cache management
111
+ k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
112
+ y = flash_attn.flash_attn_with_kvcache(
113
+ q, k_cache, v_cache,
114
+ k=k, v=v,
115
+ cache_seqlens=kv_cache.cache_seqlens,
116
+ causal=True,
117
+ window_size=window_size,
118
+ )
119
+ # Advance position after last layer processes
120
+ if self.layer_idx == kv_cache.n_layers - 1:
121
+ kv_cache.advance(T)
122
+
123
+ # Re-assemble the heads and project back to residual stream
124
+ y = y.contiguous().view(B, T, -1)
125
+ y = self.c_proj(y)
126
+ return y
127
+
128
+
129
+ class MLP(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
133
+ self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
134
+
135
+ def forward(self, x):
136
+ x = self.c_fc(x)
137
+ x = F.relu(x).square()
138
+ x = self.c_proj(x)
139
+ return x
140
+
141
+
142
+ class Block(nn.Module):
143
+ def __init__(self, config, layer_idx):
144
+ super().__init__()
145
+ self.attn = CausalSelfAttention(config, layer_idx)
146
+ self.mlp = MLP(config)
147
+
148
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
149
+ x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
150
+ x = x + self.mlp(norm(x))
151
+ return x
152
+
153
+
154
+ class GPT(nn.Module):
155
+ def __init__(self, config, pad_vocab_size_to=64):
156
+ """
157
+ NOTE a major footgun: this __init__ function runs in meta device context (!!)
158
+ Therefore, any calculations inside here are shapes and dtypes only, no actual data.
159
+ => We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
160
+ """
161
+ super().__init__()
162
+ self.config = config
163
+ # Compute per-layer window sizes for sliding window attention
164
+ # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
165
+ self.window_sizes = self._compute_window_sizes(config)
166
+ # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
167
+ # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
168
+ padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
169
+ if padded_vocab_size != config.vocab_size:
170
+ print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
171
+ self.transformer = nn.ModuleDict({
172
+ "wte": nn.Embedding(padded_vocab_size, config.n_embd),
173
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
174
+ })
175
+ self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
176
+ # Per-layer learnable scalars (inspired by modded-nanogpt)
177
+ # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
178
+ # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
179
+ # Separate parameters so they can have different optimizer treatment
180
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
181
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
182
+ # Smear: mix previous token's embedding into current token (cheap bigram-like info)
183
+ self.smear_gate = Linear(24, 1, bias=False)
184
+ self.smear_lambda = nn.Parameter(torch.zeros(1))
185
+ # Backout: subtract cached mid-layer residual before final norm to remove low-level features
186
+ self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
187
+ # Value embeddings (ResFormer-style): alternating layers, last layer always included
188
+ head_dim = config.n_embd // config.n_head
189
+ kv_dim = config.n_kv_head * head_dim
190
+ self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
191
+ # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
192
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
193
+ # so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
194
+ # In the future we can dynamically grow the cache, for now it's fine.
195
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
196
+ head_dim = config.n_embd // config.n_head
197
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
198
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
199
+ self.register_buffer("sin", sin, persistent=False)
200
+
201
+ @torch.no_grad()
202
+ def init_weights(self):
203
+ """
204
+ Initialize the full model in this one function for maximum clarity.
205
+
206
+ wte (embedding): normal, std=1.0
207
+ lm_head: normal, std=0.001
208
+ for each block:
209
+ attn.c_q: uniform, std=1/sqrt(n_embd)
210
+ attn.c_k: uniform, std=1/sqrt(n_embd)
211
+ attn.c_v: uniform, std=1/sqrt(n_embd)
212
+ attn.c_proj: zeros
213
+ mlp.c_fc: uniform, std=1/sqrt(n_embd)
214
+ mlp.c_proj: zeros
215
+ """
216
+
217
+ # Embedding and unembedding
218
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
219
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
220
+
221
+ # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
222
+ n_embd = self.config.n_embd
223
+ s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
224
+ for block in self.transformer.h:
225
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
226
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
227
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
228
+ torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
229
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
230
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
231
+
232
+ # Per-layer scalars
233
+ # Per-layer resid init: stronger residual at early layers, weaker at deep layers
234
+ n_layer = self.config.n_layer
235
+ for i in range(n_layer):
236
+ self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
237
+ # Decaying x0 init: earlier layers get more input embedding blending
238
+ for i in range(n_layer):
239
+ self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
240
+
241
+ # Value embeddings (init like c_v: uniform with same std)
242
+ for ve in self.value_embeds.values():
243
+ torch.nn.init.uniform_(ve.weight, -s, s)
244
+
245
+ # Gate weights init with small positive values so gates start slightly above neutral
246
+ for block in self.transformer.h:
247
+ if block.attn.ve_gate is not None:
248
+ torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
249
+
250
+ # Rotary embeddings
251
+ head_dim = self.config.n_embd // self.config.n_head
252
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
253
+ self.cos, self.sin = cos, sin
254
+
255
+ # Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
256
+ # embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
257
+ # because GradScaler cannot unscale fp16 gradients.
258
+ if COMPUTE_DTYPE != torch.float16:
259
+ self.transformer.wte.to(dtype=COMPUTE_DTYPE)
260
+ for ve in self.value_embeds.values():
261
+ ve.to(dtype=COMPUTE_DTYPE)
262
+
263
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
264
+ # TODO: bump base theta more? e.g. 100K is more common more recently
265
+ # autodetect the device from model embeddings
266
+ if device is None:
267
+ device = self.transformer.wte.weight.device
268
+ # stride the channels
269
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
270
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
271
+ # stride the time steps
272
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
273
+ # calculate the rotation frequencies at each (time, channel) pair
274
+ freqs = torch.outer(t, inv_freq)
275
+ cos, sin = freqs.cos(), freqs.sin()
276
+ cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
277
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
278
+ return cos, sin
279
+
280
+ def _compute_window_sizes(self, config):
281
+ """
282
+ Compute per-layer window sizes for sliding window attention.
283
+
284
+ Returns list of (left, right) tuples for FA3's window_size parameter:
285
+ - left: how many tokens before current position to attend to (-1 = unlimited)
286
+ - right: how many tokens after current position to attend to (0 for causal)
287
+
288
+ Pattern string is tiled across layers. Final layer always gets L (full context).
289
+ Characters: L=long (full context), S=short (quarter context)
290
+ """
291
+ pattern = config.window_pattern.upper()
292
+ assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
293
+ # Map characters to window sizes
294
+ long_window = config.sequence_len
295
+ short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
296
+ char_to_window = {
297
+ "L": (long_window, 0),
298
+ "S": (short_window, 0),
299
+ }
300
+ # Tile pattern across layers
301
+ window_sizes = []
302
+ for layer_idx in range(config.n_layer):
303
+ char = pattern[layer_idx % len(pattern)]
304
+ window_sizes.append(char_to_window[char])
305
+ # Final layer always gets full context
306
+ window_sizes[-1] = (long_window, 0)
307
+ return window_sizes
308
+
309
+ def get_device(self):
310
+ return self.transformer.wte.weight.device
311
+
312
+ def estimate_flops(self):
313
+ """
314
+ Return the estimated FLOPs per token for the model (forward + backward).
315
+ Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
316
+ Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
317
+ On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
318
+ With sliding windows, effective_seq_len varies per layer (capped by window size).
319
+ Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
320
+ This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
321
+ - Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
322
+ - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
323
+ """
324
+ nparams = sum(p.numel() for p in self.parameters())
325
+ # Exclude non-matmul params: embeddings and per-layer scalars
326
+ value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
327
+ nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
328
+ self.resid_lambdas.numel() + self.x0_lambdas.numel() +
329
+ self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
330
+ h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
331
+ # Sum attention FLOPs per layer, accounting for sliding window
332
+ attn_flops = 0
333
+ for window_size in self.window_sizes:
334
+ window = window_size[0] # (left, right) tuple, we use left
335
+ effective_seq = t if window < 0 else min(window, t)
336
+ attn_flops += 12 * h * q * effective_seq
337
+ num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
338
+ return num_flops_per_token
339
+
340
+ def num_scaling_params(self):
341
+ """
342
+ Return detailed parameter counts for scaling law analysis.
343
+ Different papers use different conventions:
344
+ - Kaplan et al. excluded embedding parameters
345
+ - Chinchilla included all parameters
346
+ Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
347
+ Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
348
+
349
+ Returns a dict with counts for each parameter group, so downstream analysis
350
+ can experiment with which combination gives the cleanest scaling laws.
351
+ """
352
+ # Count each group separately (mirrors the grouping in setup_optimizers)
353
+ wte = sum(p.numel() for p in self.transformer.wte.parameters())
354
+ value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
355
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
356
+ transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
357
+ scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
358
+ total = wte + value_embeds + lm_head + transformer_matrices + scalars
359
+ assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
360
+ return {
361
+ 'wte': wte,
362
+ 'value_embeds': value_embeds,
363
+ 'lm_head': lm_head,
364
+ 'transformer_matrices': transformer_matrices,
365
+ 'scalars': scalars,
366
+ 'total': total,
367
+ }
368
+
369
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
370
+ model_dim = self.config.n_embd
371
+ ddp, rank, local_rank, world_size = get_dist_info()
372
+
373
+ # Separate out all parameters into groups
374
+ matrix_params = list(self.transformer.h.parameters())
375
+ value_embeds_params = list(self.value_embeds.parameters())
376
+ embedding_params = list(self.transformer.wte.parameters())
377
+ lm_head_params = list(self.lm_head.parameters())
378
+ resid_params = [self.resid_lambdas]
379
+ x0_params = [self.x0_lambdas]
380
+ smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
381
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
382
+
383
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
384
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
385
+ print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
386
+
387
+ # Build param_groups with all required fields explicit
388
+ param_groups = [
389
+ # AdamW groups (embeddings, lm_head, scalars)
390
+ dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
391
+ dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
392
+ dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
393
+ dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
394
+ dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
395
+ dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
396
+ ]
397
+ # Muon groups (matrix params, grouped by shape for stacking)
398
+ for shape in sorted({p.shape for p in matrix_params}):
399
+ group_params = [p for p in matrix_params if p.shape == shape]
400
+ param_groups.append(dict(
401
+ kind='muon', params=group_params, lr=matrix_lr,
402
+ momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
403
+ ))
404
+
405
+ Factory = DistMuonAdamW if ddp else MuonAdamW
406
+ optimizer = Factory(param_groups)
407
+ for group in optimizer.param_groups:
408
+ group["initial_lr"] = group["lr"]
409
+ return optimizer
410
+
411
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
412
+ B, T = idx.size()
413
+
414
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
415
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
416
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
417
+ assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
418
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
419
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
420
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
421
+
422
+ # Embed the tokens
423
+ x = self.transformer.wte(idx) # embed current token
424
+ x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
425
+ x = norm(x)
426
+
427
+ # Smear: mix previous token's embedding into current position (cheap bigram info)
428
+ if kv_cache is None:
429
+ # Training / naive generate: full sequence available, use fast slice
430
+ assert T > 1, "Training forward pass should have T > 1"
431
+ gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
432
+ x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
433
+ else:
434
+ # KV cache inference: read prev embedding from cache, store current for next step
435
+ x_pre_smear = kv_cache.prev_embedding
436
+ kv_cache.prev_embedding = x[:, -1:, :]
437
+ if T > 1:
438
+ # Prefill: apply smear to positions 1+, same as training
439
+ gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
440
+ x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
441
+ elif x_pre_smear is not None:
442
+ # Decode: single token, use cached prev embedding
443
+ gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
444
+ x = x + gate * x_pre_smear
445
+
446
+ # Forward the trunk of the Transformer
447
+ x0 = x # save initial normalized embedding for x0 residual
448
+ n_layer = self.config.n_layer
449
+ backout_layer = n_layer // 2 # cache at halfway point
450
+ x_backout = None
451
+ for i, block in enumerate(self.transformer.h):
452
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
453
+ ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
454
+ x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
455
+ if i == backout_layer:
456
+ x_backout = x
457
+ # Subtract mid-layer residual to remove low-level features before logit projection
458
+ if x_backout is not None:
459
+ x = x - self.backout_lambda.to(x.dtype) * x_backout
460
+ x = norm(x)
461
+
462
+ # Forward the lm_head (compute logits)
463
+ softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
464
+ logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
465
+ logits = logits[..., :self.config.vocab_size] # slice to remove padding
466
+ logits = logits.float() # switch to fp32 for logit softcap and loss computation
467
+ logits = softcap * torch.tanh(logits / softcap) # squash the logits
468
+
469
+ if targets is not None:
470
+ # training: given the targets, compute and return the loss
471
+ # TODO experiment with chunked cross-entropy?
472
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
473
+ return loss
474
+ else:
475
+ # inference: just return the logits directly
476
+ return logits
477
+
478
+ @torch.inference_mode()
479
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
480
+ """
481
+ Naive autoregressive streaming inference.
482
+ To make it super simple, let's assume:
483
+ - batch size is 1
484
+ - ids and the yielded tokens are simple Python lists and ints
485
+ """
486
+ assert isinstance(tokens, list)
487
+ device = self.get_device()
488
+ rng = None
489
+ if temperature > 0:
490
+ rng = torch.Generator(device=device)
491
+ rng.manual_seed(seed)
492
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
493
+ for _ in range(max_tokens):
494
+ logits = self.forward(ids) # (B, T, vocab_size)
495
+ logits = logits[:, -1, :] # (B, vocab_size)
496
+ if top_k is not None and top_k > 0:
497
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
498
+ logits[logits < v[:, [-1]]] = -float('Inf')
499
+ if temperature > 0:
500
+ logits = logits / temperature
501
+ probs = F.softmax(logits, dim=-1)
502
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
503
+ else:
504
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
505
+ ids = torch.cat((ids, next_ids), dim=1)
506
+ token = next_ids.item()
507
+ yield token
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-independent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and independently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ if total_bytes == 0:
63
+ return float('inf')
64
+ bpb = total_nats / (math.log(2) * total_bytes)
65
+ return bpb
nanochat/optim.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A nice and efficient mixed AdamW/Muon Combined Optimizer.
3
+ Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
4
+ Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
5
+
6
+ Addapted from: https://github.com/KellerJordan/modded-nanogpt
7
+ Further contributions from @karpathy and @chrisjmccormick.
8
+ """
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import Tensor
13
+
14
+ # -----------------------------------------------------------------------------
15
+ """
16
+ Good old AdamW optimizer, fused kernel.
17
+ https://arxiv.org/abs/1711.05101
18
+ """
19
+
20
+ @torch.compile(dynamic=False, fullgraph=True)
21
+ def adamw_step_fused(
22
+ p: Tensor, # (32768, 768) - parameter tensor
23
+ grad: Tensor, # (32768, 768) - gradient, same shape as p
24
+ exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
25
+ exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
26
+ step_t: Tensor, # () - 0-D CPU tensor, step count
27
+ lr_t: Tensor, # () - 0-D CPU tensor, learning rate
28
+ beta1_t: Tensor, # () - 0-D CPU tensor, beta1
29
+ beta2_t: Tensor, # () - 0-D CPU tensor, beta2
30
+ eps_t: Tensor, # () - 0-D CPU tensor, epsilon
31
+ wd_t: Tensor, # () - 0-D CPU tensor, weight decay
32
+ ) -> None:
33
+ """
34
+ Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
35
+ All in one compiled graph to eliminate Python overhead between ops.
36
+ The 0-D CPU tensors avoid recompilation when hyperparameter values change.
37
+ """
38
+ # Weight decay (decoupled, applied before the update)
39
+ p.mul_(1 - lr_t * wd_t)
40
+ # Update running averages (lerp_ is cleaner and fuses well)
41
+ exp_avg.lerp_(grad, 1 - beta1_t)
42
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
43
+ # Bias corrections
44
+ bias1 = 1 - beta1_t ** step_t
45
+ bias2 = 1 - beta2_t ** step_t
46
+ # Compute update and apply
47
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
48
+ step_size = lr_t / bias1
49
+ p.add_(exp_avg / denom, alpha=-step_size)
50
+
51
+ # -----------------------------------------------------------------------------
52
+ """
53
+ Muon optimizer adapted and simplified from modded-nanogpt.
54
+ https://github.com/KellerJordan/modded-nanogpt
55
+
56
+ Background:
57
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
58
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
59
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
60
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
61
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
62
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
63
+ performance at all relative to UV^T, where USV^T = G is the SVD.
64
+
65
+ Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
66
+ Polar Express Sign Method for orthogonalization.
67
+ https://arxiv.org/pdf/2505.16932
68
+ by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
69
+
70
+ NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
71
+ update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
72
+ https://arxiv.org/pdf/2510.05491
73
+
74
+ Some of the changes in nanochat implementation:
75
+ - Uses a simpler, more general approach to parameter grouping and stacking
76
+ - Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
77
+ - Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
78
+ """
79
+
80
+ # Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
81
+ # From https://arxiv.org/pdf/2505.16932
82
+ polar_express_coeffs = [
83
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
84
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
85
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
86
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
87
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
88
+ ]
89
+
90
+ @torch.compile(dynamic=False, fullgraph=True)
91
+ def muon_step_fused(
92
+ stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
93
+ stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
94
+ momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
95
+ second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
96
+ momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
97
+ lr_t: Tensor, # () - 0-D CPU tensor, learning rate
98
+ wd_t: Tensor, # () - 0-D CPU tensor, weight decay
99
+ beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
100
+ ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
101
+ red_dim: int, # -1 or -2 - reduction dimension for variance
102
+ ) -> None:
103
+ """
104
+ Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
105
+ All in one compiled graph to eliminate Python overhead between ops.
106
+ Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
107
+ """
108
+
109
+ # Nesterov momentum
110
+ momentum = momentum_t.to(stacked_grads.dtype)
111
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
112
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
113
+
114
+ # Polar express
115
+ X = g.bfloat16()
116
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
117
+ if g.size(-2) > g.size(-1): # Tall matrix
118
+ for a, b, c in polar_express_coeffs[:ns_steps]:
119
+ A = X.mT @ X
120
+ B = b * A + c * (A @ A)
121
+ X = a * X + X @ B
122
+ else: # Wide matrix (original math)
123
+ for a, b, c in polar_express_coeffs[:ns_steps]:
124
+ A = X @ X.mT
125
+ B = b * A + c * (A @ A)
126
+ X = a * X + B @ X
127
+ g = X
128
+
129
+ # Variance reduction
130
+ beta2 = beta2_t.to(g.dtype)
131
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
132
+ red_dim_size = g.size(red_dim)
133
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
134
+ v_norm = v_norm_sq.sqrt()
135
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
136
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
137
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
138
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
139
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
140
+ g = g * final_scale.to(g.dtype)
141
+
142
+ # Cautious weight decay + parameter update
143
+ lr = lr_t.to(g.dtype)
144
+ wd = wd_t.to(g.dtype)
145
+ mask = (g * stacked_params) >= 0
146
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
147
+
148
+ # -----------------------------------------------------------------------------
149
+ # Single GPU version of the MuonAdamW optimizer.
150
+ # Used mostly for reference, debugging and testing.
151
+
152
+ class MuonAdamW(torch.optim.Optimizer):
153
+ """
154
+ Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
155
+
156
+ AdamW - Fused AdamW optimizer step.
157
+
158
+ Muon - MomentUm Orthogonalized by Newton-schulz
159
+ https://kellerjordan.github.io/posts/muon/
160
+
161
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
162
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
163
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
164
+ the advantage that it can be stably run in bfloat16 on the GPU.
165
+
166
+ Some warnings:
167
+ - The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
168
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
169
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
170
+
171
+ Arguments:
172
+ param_groups: List of dicts, each containing:
173
+ - 'params': List of parameters
174
+ - 'kind': 'adamw' or 'muon'
175
+ - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
176
+ - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
177
+ """
178
+ def __init__(self, param_groups: list[dict]):
179
+ super().__init__(param_groups, defaults={})
180
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
181
+ # AdamW tensors
182
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
183
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
184
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
185
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
186
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
187
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
188
+ # Muon tensors
189
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
190
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
191
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
192
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
193
+
194
+ def _step_adamw(self, group: dict) -> None:
195
+ """
196
+ AdamW update for each param in the group individually.
197
+ Lazy init the state, fill in all 0-D tensors, call the fused kernel.
198
+ """
199
+ for p in group['params']:
200
+ if p.grad is None:
201
+ continue
202
+ grad = p.grad
203
+ state = self.state[p]
204
+
205
+ # State init
206
+ if not state:
207
+ state['step'] = 0
208
+ state['exp_avg'] = torch.zeros_like(p)
209
+ state['exp_avg_sq'] = torch.zeros_like(p)
210
+ exp_avg = state['exp_avg']
211
+ exp_avg_sq = state['exp_avg_sq']
212
+ state['step'] += 1
213
+
214
+ # Fill 0-D tensors with current values
215
+ self._adamw_step_t.fill_(state['step'])
216
+ self._adamw_lr_t.fill_(group['lr'])
217
+ self._adamw_beta1_t.fill_(group['betas'][0])
218
+ self._adamw_beta2_t.fill_(group['betas'][1])
219
+ self._adamw_eps_t.fill_(group['eps'])
220
+ self._adamw_wd_t.fill_(group['weight_decay'])
221
+
222
+ # Fused update: weight_decay -> momentum -> bias_correction -> param_update
223
+ adamw_step_fused(
224
+ p, grad, exp_avg, exp_avg_sq,
225
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
226
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
227
+ )
228
+
229
+ def _step_muon(self, group: dict) -> None:
230
+ """
231
+ Muon update for all params in the group (stacked for efficiency).
232
+ Lazy init the state, fill in all 0-D tensors, call the fused kernel.
233
+ """
234
+ params: list[Tensor] = group['params']
235
+ if not params:
236
+ return
237
+
238
+ # Get or create group-level buffers (stored in first param's state for convenience)
239
+ p = params[0]
240
+ state = self.state[p]
241
+ num_params = len(params)
242
+ shape, device, dtype = p.shape, p.device, p.dtype
243
+
244
+ # Momentum for every individual parameter
245
+ if "momentum_buffer" not in state:
246
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
247
+ momentum_buffer = state["momentum_buffer"]
248
+
249
+ # Second momentum buffer is factored, either per-row or per-column
250
+ if "second_momentum_buffer" not in state:
251
+ state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
252
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
253
+ second_momentum_buffer = state["second_momentum_buffer"]
254
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
255
+
256
+ # Stack grads and params (NOTE: this assumes all params have the same shape)
257
+ stacked_grads = torch.stack([p.grad for p in params])
258
+ stacked_params = torch.stack(params)
259
+
260
+ # Fill all the 0-D tensors with current values
261
+ self._muon_momentum_t.fill_(group["momentum"])
262
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
263
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
264
+ self._muon_wd_t.fill_(group["weight_decay"])
265
+
266
+ # Single fused kernel: momentum -> polar_express -> variance_reduction -> update
267
+ muon_step_fused(
268
+ stacked_grads,
269
+ stacked_params,
270
+ momentum_buffer,
271
+ second_momentum_buffer,
272
+ self._muon_momentum_t,
273
+ self._muon_lr_t,
274
+ self._muon_wd_t,
275
+ self._muon_beta2_t,
276
+ group["ns_steps"],
277
+ red_dim,
278
+ )
279
+
280
+ # Copy back to original params
281
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
282
+
283
+ @torch.no_grad()
284
+ def step(self):
285
+ for group in self.param_groups:
286
+ if group['kind'] == 'adamw':
287
+ self._step_adamw(group)
288
+ elif group['kind'] == 'muon':
289
+ self._step_muon(group)
290
+ else:
291
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
292
+
293
+ # -----------------------------------------------------------------------------
294
+ # Distributed version of the MuonAdamW optimizer.
295
+ # Used for training on multiple GPUs.
296
+
297
+ class DistMuonAdamW(torch.optim.Optimizer):
298
+ """
299
+ Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
300
+
301
+ See MuonAdamW for the algorithmic details of each optimizer. This class adds
302
+ distributed communication to enable multi-GPU training without PyTorch DDP.
303
+
304
+ Design Goals:
305
+ - Overlap communication with computation (async ops)
306
+ - Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
307
+ - Batch small tensors into single comm ops where possible
308
+
309
+ Communication Pattern (3-phase async):
310
+ We use a 3-phase structure to maximize overlap between communication and compute:
311
+
312
+ Phase 1: Launch all async reduce ops
313
+ - Kick off all reduce_scatter/all_reduce operations
314
+ - Don't wait - let them run in background while we continue
315
+
316
+ Phase 2: Wait for reduces, compute updates, launch gathers
317
+ - For each group: wait for its reduce, compute the update, launch gather
318
+ - By processing groups in order, earlier gathers run while later computes happen
319
+
320
+ Phase 3: Wait for gathers, copy back
321
+ - Wait for all gathers to complete
322
+ - Copy updated params back to original tensors (Muon only)
323
+
324
+ AdamW Communication (ZeRO-2 style):
325
+ - Small params (<1024 elements): all_reduce gradients, update full param on each rank.
326
+ Optimizer state is replicated but these params are tiny (scalars, biases).
327
+ - Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
328
+ only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
329
+ exp_avg_sq) is sharded - each rank only stores state for its slice.
330
+ Requires param.shape[0] divisible by world_size.
331
+
332
+ Muon Communication (stacked + chunked):
333
+ - All params in a Muon group must have the same shape (caller's responsibility).
334
+ - Stack all K params into a single (K, *shape) tensor for efficient comm.
335
+ - Divide K params across N ranks: each rank "owns" ceil(K/N) params.
336
+ - reduce_scatter the stacked grads so each rank gets its chunk.
337
+ - Each rank computes Muon update only for params it owns.
338
+ - all_gather the updated params back to all ranks.
339
+ - Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
340
+ - Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
341
+ then ignore the padding when copying back.
342
+
343
+ Buffer Reuse:
344
+ - For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
345
+ same buffer as the output for all_gather (stacked_params). This saves memory
346
+ since we don't need both buffers simultaneously.
347
+
348
+ Arguments:
349
+ param_groups: List of dicts, each containing:
350
+ - 'params': List of parameters
351
+ - 'kind': 'adamw' or 'muon'
352
+ - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
353
+ - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
354
+ """
355
+ def __init__(self, param_groups: list[dict]):
356
+ super().__init__(param_groups, defaults={})
357
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
358
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
359
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
360
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
361
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
362
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
363
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
364
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
365
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
366
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
367
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
368
+
369
+ def _reduce_adamw(self, group: dict, world_size: int) -> dict:
370
+ """Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
371
+ param_infos = {}
372
+ for p in group['params']:
373
+ grad = p.grad
374
+ if p.numel() < 1024:
375
+ # Small params: all_reduce (no scatter/gather needed)
376
+ future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
377
+ param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
378
+ else:
379
+ # Large params: reduce_scatter
380
+ assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
381
+ rank_size = grad.shape[0] // world_size
382
+ grad_slice = torch.empty_like(grad[:rank_size])
383
+ future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
384
+ param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
385
+ return dict(param_infos=param_infos)
386
+
387
+ def _reduce_muon(self, group: dict, world_size: int) -> dict:
388
+ """Launch async reduce op for Muon group. Returns info dict."""
389
+ params = group['params']
390
+ chunk_size = (len(params) + world_size - 1) // world_size
391
+ padded_num_params = chunk_size * world_size
392
+ p = params[0]
393
+ shape, device, dtype = p.shape, p.device, p.dtype
394
+
395
+ # Stack grads and zero-pad to padded_num_params
396
+ grad_stack = torch.stack([p.grad for p in params])
397
+ stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
398
+ stacked_grads[:len(params)].copy_(grad_stack)
399
+ if len(params) < padded_num_params:
400
+ stacked_grads[len(params):].zero_()
401
+
402
+ # Reduce_scatter to get this rank's chunk
403
+ grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
404
+ future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
405
+
406
+ return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
407
+
408
+ def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
409
+ """Wait for reduce, compute AdamW updates, launch gathers for large params."""
410
+ param_infos = info['param_infos']
411
+ for p in group['params']:
412
+ pinfo = param_infos[p]
413
+ pinfo['future'].wait()
414
+ grad_slice = pinfo['grad_slice']
415
+ state = self.state[p]
416
+
417
+ # For small params, operate on full param; for large, operate on slice
418
+ if pinfo['is_small']:
419
+ p_slice = p
420
+ else:
421
+ rank_size = p.shape[0] // world_size
422
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
423
+
424
+ # State init
425
+ if not state:
426
+ state['step'] = 0
427
+ state['exp_avg'] = torch.zeros_like(p_slice)
428
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
429
+ state['step'] += 1
430
+
431
+ # Fill 0-D tensors and run fused kernel
432
+ self._adamw_step_t.fill_(state['step'])
433
+ self._adamw_lr_t.fill_(group['lr'])
434
+ self._adamw_beta1_t.fill_(group['betas'][0])
435
+ self._adamw_beta2_t.fill_(group['betas'][1])
436
+ self._adamw_eps_t.fill_(group['eps'])
437
+ self._adamw_wd_t.fill_(group['weight_decay'])
438
+ adamw_step_fused(
439
+ p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
440
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
441
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
442
+ )
443
+
444
+ # Large params need all_gather
445
+ if not pinfo['is_small']:
446
+ future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
447
+ gather_list.append(dict(future=future, params=None))
448
+
449
+ def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
450
+ """Wait for reduce, compute Muon updates, launch gather."""
451
+ info['future'].wait()
452
+ params = group['params']
453
+ chunk_size = info['chunk_size']
454
+ grad_chunk = info['grad_chunk']
455
+ p = params[0]
456
+ shape, device, dtype = p.shape, p.device, p.dtype
457
+
458
+ # How many params does this rank own?
459
+ start_idx = rank * chunk_size
460
+ num_owned = min(chunk_size, max(0, len(params) - start_idx))
461
+
462
+ # Get or create group-level state
463
+ state = self.state[p]
464
+ if "momentum_buffer" not in state:
465
+ state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
466
+ if "second_momentum_buffer" not in state:
467
+ state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
468
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
469
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
470
+
471
+ # Build output buffer for all_gather
472
+ updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
473
+
474
+ if num_owned > 0:
475
+ owned_params = [params[start_idx + i] for i in range(num_owned)]
476
+ stacked_owned = torch.stack(owned_params)
477
+
478
+ # Fill 0-D tensors and run fused kernel
479
+ self._muon_momentum_t.fill_(group["momentum"])
480
+ self._muon_beta2_t.fill_(group["beta2"])
481
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
482
+ self._muon_wd_t.fill_(group["weight_decay"])
483
+ muon_step_fused(
484
+ grad_chunk[:num_owned], stacked_owned,
485
+ state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
486
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
487
+ group["ns_steps"], red_dim,
488
+ )
489
+ updated_params[:num_owned].copy_(stacked_owned)
490
+
491
+ if num_owned < chunk_size:
492
+ updated_params[num_owned:].zero_()
493
+
494
+ # Reuse stacked_grads buffer for all_gather output
495
+ stacked_params = info["stacked_grads"]
496
+ future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
497
+ gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
498
+
499
+ def _finish_gathers(self, gather_list: list) -> None:
500
+ """Wait for all gathers and copy Muon params back."""
501
+ for info in gather_list:
502
+ info["future"].wait()
503
+ if info["params"] is not None:
504
+ # Muon: copy from stacked buffer back to individual params
505
+ torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
506
+
507
+ @torch.no_grad()
508
+ def step(self):
509
+ rank = dist.get_rank()
510
+ world_size = dist.get_world_size()
511
+
512
+ # Phase 1: launch all async reduce ops
513
+ reduce_infos: list[dict] = []
514
+ for group in self.param_groups:
515
+ if group['kind'] == 'adamw':
516
+ reduce_infos.append(self._reduce_adamw(group, world_size))
517
+ elif group['kind'] == 'muon':
518
+ reduce_infos.append(self._reduce_muon(group, world_size))
519
+ else:
520
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
521
+
522
+ # Phase 2: wait for reduces, compute updates, launch gathers
523
+ gather_list: list[dict] = []
524
+ for group, info in zip(self.param_groups, reduce_infos):
525
+ if group['kind'] == 'adamw':
526
+ self._compute_adamw(group, info, gather_list, rank, world_size)
527
+ elif group['kind'] == 'muon':
528
+ self._compute_muon(group, info, gather_list, rank)
529
+ else:
530
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
531
+
532
+ # Phase 3: wait for gathers, copy back
533
+ self._finish_gathers(gather_list)
nanochat/report.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ # Return stdout if we got output (even if some files in xargs failed)
20
+ if result.stdout.strip():
21
+ return result.stdout.strip()
22
+ if result.returncode == 0:
23
+ return ""
24
+ return None
25
+ except:
26
+ return None
27
+
28
+ def get_git_info():
29
+ """Get current git commit, branch, and dirty status."""
30
+ info = {}
31
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
32
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
33
+
34
+ # Check if repo is dirty (has uncommitted changes)
35
+ status = run_command("git status --porcelain")
36
+ info['dirty'] = bool(status) if status is not None else False
37
+
38
+ # Get commit message
39
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
40
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
41
+
42
+ return info
43
+
44
+ def get_gpu_info():
45
+ """Get GPU information."""
46
+ if not torch.cuda.is_available():
47
+ return {"available": False}
48
+
49
+ num_devices = torch.cuda.device_count()
50
+ info = {
51
+ "available": True,
52
+ "count": num_devices,
53
+ "names": [],
54
+ "memory_gb": []
55
+ }
56
+
57
+ for i in range(num_devices):
58
+ props = torch.cuda.get_device_properties(i)
59
+ info["names"].append(props.name)
60
+ info["memory_gb"].append(props.total_memory / (1024**3))
61
+
62
+ # Get CUDA version
63
+ info["cuda_version"] = torch.version.cuda or "unknown"
64
+
65
+ return info
66
+
67
+ def get_system_info():
68
+ """Get system information."""
69
+ info = {}
70
+
71
+ # Basic system info
72
+ info['hostname'] = socket.gethostname()
73
+ info['platform'] = platform.system()
74
+ info['python_version'] = platform.python_version()
75
+ info['torch_version'] = torch.__version__
76
+
77
+ # CPU and memory
78
+ info['cpu_count'] = psutil.cpu_count(logical=False)
79
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
80
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
81
+
82
+ # User and environment
83
+ info['user'] = os.environ.get('USER', 'unknown')
84
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
85
+ info['working_dir'] = os.getcwd()
86
+
87
+ return info
88
+
89
+ def estimate_cost(gpu_info, runtime_hours=None):
90
+ """Estimate training cost based on GPU type and runtime."""
91
+
92
+ # Rough pricing, from Lambda Cloud
93
+ default_rate = 2.0
94
+ gpu_hourly_rates = {
95
+ "H100": 3.00,
96
+ "A100": 1.79,
97
+ "V100": 0.55,
98
+ }
99
+
100
+ if not gpu_info.get("available"):
101
+ return None
102
+
103
+ # Try to identify GPU type from name
104
+ hourly_rate = None
105
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
106
+ for gpu_type, rate in gpu_hourly_rates.items():
107
+ if gpu_type in gpu_name:
108
+ hourly_rate = rate * gpu_info["count"]
109
+ break
110
+
111
+ if hourly_rate is None:
112
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
113
+
114
+ return {
115
+ "hourly_rate": hourly_rate,
116
+ "gpu_type": gpu_name,
117
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
118
+ }
119
+
120
+ def generate_header():
121
+ """Generate the header for a training report."""
122
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
123
+
124
+ git_info = get_git_info()
125
+ gpu_info = get_gpu_info()
126
+ sys_info = get_system_info()
127
+ cost_info = estimate_cost(gpu_info)
128
+
129
+ header = f"""# nanochat training report
130
+
131
+ Generated: {timestamp}
132
+
133
+ ## Environment
134
+
135
+ ### Git Information
136
+ - Branch: {git_info['branch']}
137
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
138
+ - Message: {git_info['message']}
139
+
140
+ ### Hardware
141
+ - Platform: {sys_info['platform']}
142
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
143
+ - Memory: {sys_info['memory_gb']:.1f} GB
144
+ """
145
+
146
+ if gpu_info.get("available"):
147
+ gpu_names = ", ".join(set(gpu_info["names"]))
148
+ total_vram = sum(gpu_info["memory_gb"])
149
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
150
+ - GPU Memory: {total_vram:.1f} GB total
151
+ - CUDA Version: {gpu_info['cuda_version']}
152
+ """
153
+ else:
154
+ header += "- GPUs: None available\n"
155
+
156
+ if cost_info and cost_info["hourly_rate"] > 0:
157
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
158
+
159
+ header += f"""
160
+ ### Software
161
+ - Python: {sys_info['python_version']}
162
+ - PyTorch: {sys_info['torch_version']}
163
+
164
+ """
165
+
166
+ # bloat metrics: count lines/chars in git-tracked source files only
167
+ extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
168
+ git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
169
+ files_output = run_command(f"git ls-files -- {git_patterns}")
170
+ file_list = [f for f in (files_output or '').split('\n') if f]
171
+ num_files = len(file_list)
172
+ num_lines = 0
173
+ num_chars = 0
174
+ if num_files > 0:
175
+ wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
176
+ if wc_output:
177
+ total_line = wc_output.strip().split('\n')[-1]
178
+ parts = total_line.split()
179
+ if len(parts) >= 2:
180
+ num_lines = int(parts[0])
181
+ num_chars = int(parts[1])
182
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
183
+
184
+ # count dependencies via uv.lock
185
+ uv_lock_lines = 0
186
+ if os.path.exists('uv.lock'):
187
+ with open('uv.lock', 'r', encoding='utf-8') as f:
188
+ uv_lock_lines = len(f.readlines())
189
+
190
+ header += f"""
191
+ ### Bloat
192
+ - Characters: {num_chars:,}
193
+ - Lines: {num_lines:,}
194
+ - Files: {num_files:,}
195
+ - Tokens (approx): {num_tokens:,}
196
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
197
+
198
+ """
199
+ return header
200
+
201
+ # -----------------------------------------------------------------------------
202
+
203
+ def slugify(text):
204
+ """Slugify a text string."""
205
+ return text.lower().replace(" ", "-")
206
+
207
+ # the expected files and their order
208
+ EXPECTED_FILES = [
209
+ "tokenizer-training.md",
210
+ "tokenizer-evaluation.md",
211
+ "base-model-training.md",
212
+ "base-model-loss.md",
213
+ "base-model-evaluation.md",
214
+ "chat-sft.md",
215
+ "chat-evaluation-sft.md",
216
+ "chat-rl.md",
217
+ "chat-evaluation-rl.md",
218
+ ]
219
+ # the metrics we're currently interested in
220
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
221
+
222
+ def extract(section, keys):
223
+ """simple def to extract a single key from a section"""
224
+ if not isinstance(keys, list):
225
+ keys = [keys] # convenience
226
+ out = {}
227
+ for line in section.split("\n"):
228
+ for key in keys:
229
+ if key in line:
230
+ out[key] = line.split(":")[1].strip()
231
+ return out
232
+
233
+ def extract_timestamp(content, prefix):
234
+ """Extract timestamp from content with given prefix."""
235
+ for line in content.split('\n'):
236
+ if line.startswith(prefix):
237
+ time_str = line.split(":", 1)[1].strip()
238
+ try:
239
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
240
+ except:
241
+ pass
242
+ return None
243
+
244
+ class Report:
245
+ """Maintains a bunch of logs, generates a final markdown report."""
246
+
247
+ def __init__(self, report_dir):
248
+ os.makedirs(report_dir, exist_ok=True)
249
+ self.report_dir = report_dir
250
+
251
+ def log(self, section, data):
252
+ """Log a section of data to the report."""
253
+ slug = slugify(section)
254
+ file_name = f"{slug}.md"
255
+ file_path = os.path.join(self.report_dir, file_name)
256
+ with open(file_path, "w", encoding="utf-8") as f:
257
+ f.write(f"## {section}\n")
258
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
259
+ for item in data:
260
+ if not item:
261
+ # skip falsy values like None or empty dict etc.
262
+ continue
263
+ if isinstance(item, str):
264
+ # directly write the string
265
+ f.write(item)
266
+ else:
267
+ # render a dict
268
+ for k, v in item.items():
269
+ if isinstance(v, float):
270
+ vstr = f"{v:.4f}"
271
+ elif isinstance(v, int) and v >= 10000:
272
+ vstr = f"{v:,.0f}"
273
+ else:
274
+ vstr = str(v)
275
+ f.write(f"- {k}: {vstr}\n")
276
+ f.write("\n")
277
+ return file_path
278
+
279
+ def generate(self):
280
+ """Generate the final report."""
281
+ report_dir = self.report_dir
282
+ report_file = os.path.join(report_dir, "report.md")
283
+ print(f"Generating report to {report_file}")
284
+ final_metrics = {} # the most important final metrics we'll add as table at the end
285
+ start_time = None
286
+ end_time = None
287
+ with open(report_file, "w", encoding="utf-8") as out_file:
288
+ # write the header first
289
+ header_file = os.path.join(report_dir, "header.md")
290
+ if os.path.exists(header_file):
291
+ with open(header_file, "r", encoding="utf-8") as f:
292
+ header_content = f.read()
293
+ out_file.write(header_content)
294
+ start_time = extract_timestamp(header_content, "Run started:")
295
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
296
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
297
+ bloat_data = bloat_data.group(1) if bloat_data else ""
298
+ else:
299
+ start_time = None # will cause us to not write the total wall clock time
300
+ bloat_data = "[bloat data missing]"
301
+ print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
302
+ # process all the individual sections
303
+ for file_name in EXPECTED_FILES:
304
+ section_file = os.path.join(report_dir, file_name)
305
+ if not os.path.exists(section_file):
306
+ print(f"Warning: {section_file} does not exist, skipping")
307
+ continue
308
+ with open(section_file, "r", encoding="utf-8") as in_file:
309
+ section = in_file.read()
310
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
311
+ if "rl" not in file_name:
312
+ # Skip RL sections for end_time calculation because RL is experimental
313
+ end_time = extract_timestamp(section, "timestamp:")
314
+ # extract the most important metrics from the sections
315
+ if file_name == "base-model-evaluation.md":
316
+ final_metrics["base"] = extract(section, "CORE")
317
+ if file_name == "chat-evaluation-sft.md":
318
+ final_metrics["sft"] = extract(section, chat_metrics)
319
+ if file_name == "chat-evaluation-rl.md":
320
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
321
+ # append this section of the report
322
+ out_file.write(section)
323
+ out_file.write("\n")
324
+ # add the final metrics table
325
+ out_file.write("## Summary\n\n")
326
+ # Copy over the bloat metrics from the header
327
+ out_file.write(bloat_data)
328
+ out_file.write("\n\n")
329
+ # Collect all unique metric names
330
+ all_metrics = set()
331
+ for stage_metrics in final_metrics.values():
332
+ all_metrics.update(stage_metrics.keys())
333
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
334
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
335
+ # Fixed column widths
336
+ stages = ["base", "sft", "rl"]
337
+ metric_width = 15
338
+ value_width = 8
339
+ # Write table header
340
+ header = f"| {'Metric'.ljust(metric_width)} |"
341
+ for stage in stages:
342
+ header += f" {stage.upper().ljust(value_width)} |"
343
+ out_file.write(header + "\n")
344
+ # Write separator
345
+ separator = f"|{'-' * (metric_width + 2)}|"
346
+ for stage in stages:
347
+ separator += f"{'-' * (value_width + 2)}|"
348
+ out_file.write(separator + "\n")
349
+ # Write table rows
350
+ for metric in all_metrics:
351
+ row = f"| {metric.ljust(metric_width)} |"
352
+ for stage in stages:
353
+ value = final_metrics.get(stage, {}).get(metric, "-")
354
+ row += f" {str(value).ljust(value_width)} |"
355
+ out_file.write(row + "\n")
356
+ out_file.write("\n")
357
+ # Calculate and write total wall clock time
358
+ if start_time and end_time:
359
+ duration = end_time - start_time
360
+ total_seconds = int(duration.total_seconds())
361
+ hours = total_seconds // 3600
362
+ minutes = (total_seconds % 3600) // 60
363
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
364
+ else:
365
+ out_file.write("Total wall clock time: unknown\n")
366
+ # also cp the report.md file to current directory
367
+ print(f"Copying report.md to current directory for convenience")
368
+ shutil.copy(report_file, "report.md")
369
+ return report_file
370
+
371
+ def reset(self):
372
+ """Reset the report."""
373
+ # Remove section files
374
+ for file_name in EXPECTED_FILES:
375
+ file_path = os.path.join(self.report_dir, file_name)
376
+ if os.path.exists(file_path):
377
+ os.remove(file_path)
378
+ # Remove report.md if it exists
379
+ report_file = os.path.join(self.report_dir, "report.md")
380
+ if os.path.exists(report_file):
381
+ os.remove(report_file)
382
+ # Generate and write the header section with start timestamp
383
+ header_file = os.path.join(self.report_dir, "header.md")
384
+ header = generate_header()
385
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
386
+ with open(header_file, "w", encoding="utf-8") as f:
387
+ f.write(header)
388
+ f.write(f"Run started: {start_time}\n\n---\n\n")
389
+ print(f"Reset report and wrote header to {header_file}")
390
+
391
+ # -----------------------------------------------------------------------------
392
+ # nanochat-specific convenience functions
393
+
394
+ class DummyReport:
395
+ def log(self, *args, **kwargs):
396
+ pass
397
+ def reset(self, *args, **kwargs):
398
+ pass
399
+
400
+ def get_report():
401
+ # just for convenience, only rank 0 logs to report
402
+ from nanochat.common import get_base_dir, get_dist_info
403
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
404
+ if ddp_rank == 0:
405
+ report_dir = os.path.join(get_base_dir(), "report")
406
+ return Report(report_dir)
407
+ else:
408
+ return DummyReport()
409
+
410
+ if __name__ == "__main__":
411
+ import argparse
412
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
413
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
414
+ args = parser.parse_args()
415
+ if args.command == "generate":
416
+ get_report().generate()
417
+ elif args.command == "reset":
418
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None, num_threads=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ # num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
110
+ assert isinstance(text, str)
111
+ ids = []
112
+ if prepend is not None:
113
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
114
+ ids.append(prepend_id)
115
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
116
+ if append is not None:
117
+ append_id = append if isinstance(append, int) else self.encode_special(append)
118
+ ids.append(append_id)
119
+ return ids
120
+
121
+ def encode_special(self, text):
122
+ # encode a single special token via exact match
123
+ return self.tokenizer.token_to_id(text)
124
+
125
+ def get_bos_token_id(self):
126
+ # Different HuggingFace models use different BOS tokens and there is little consistency
127
+ # 1) attempt to find a <|bos|> token
128
+ bos = self.encode_special("<|bos|>")
129
+ # 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
130
+ if bos is None:
131
+ bos = self.encode_special("<|endoftext|>")
132
+ # 3) if these fail, it's better to crash than to silently return None
133
+ assert bos is not None, "Failed to find BOS token in tokenizer"
134
+ return bos
135
+
136
+ def encode(self, text, *args, **kwargs):
137
+ if isinstance(text, str):
138
+ return self._encode_one(text, *args, **kwargs)
139
+ elif isinstance(text, list):
140
+ return [self._encode_one(t, *args, **kwargs) for t in text]
141
+ else:
142
+ raise ValueError(f"Invalid input type: {type(text)}")
143
+
144
+ def __call__(self, *args, **kwargs):
145
+ return self.encode(*args, **kwargs)
146
+
147
+ def decode(self, ids):
148
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
149
+
150
+ def save(self, tokenizer_dir):
151
+ # save the tokenizer to disk
152
+ os.makedirs(tokenizer_dir, exist_ok=True)
153
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
154
+ self.tokenizer.save(tokenizer_path)
155
+ print(f"Saved tokenizer to {tokenizer_path}")
156
+
157
+ # -----------------------------------------------------------------------------
158
+ # Tokenizer based on rustbpe + tiktoken combo
159
+ import pickle
160
+ import rustbpe
161
+ import tiktoken
162
+
163
+ class RustBPETokenizer:
164
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
165
+
166
+ def __init__(self, enc, bos_token):
167
+ self.enc = enc
168
+ self.bos_token_id = self.encode_special(bos_token)
169
+
170
+ @classmethod
171
+ def train_from_iterator(cls, text_iterator, vocab_size):
172
+ # 1) train using rustbpe
173
+ tokenizer = rustbpe.Tokenizer()
174
+ # the special tokens are inserted later in __init__, we don't train them here
175
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
176
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
177
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
178
+ # 2) construct the associated tiktoken encoding for inference
179
+ pattern = tokenizer.get_pattern()
180
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
181
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
182
+ tokens_offset = len(mergeable_ranks)
183
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
184
+ enc = tiktoken.Encoding(
185
+ name="rustbpe",
186
+ pat_str=pattern,
187
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
188
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
189
+ )
190
+ return cls(enc, "<|bos|>")
191
+
192
+ @classmethod
193
+ def from_directory(cls, tokenizer_dir):
194
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
195
+ with open(pickle_path, "rb") as f:
196
+ enc = pickle.load(f)
197
+ return cls(enc, "<|bos|>")
198
+
199
+ @classmethod
200
+ def from_pretrained(cls, tiktoken_name):
201
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
202
+ enc = tiktoken.get_encoding(tiktoken_name)
203
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
204
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
205
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
206
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
207
+ return cls(enc, "<|endoftext|>")
208
+
209
+ def get_vocab_size(self):
210
+ return self.enc.n_vocab
211
+
212
+ def get_special_tokens(self):
213
+ return self.enc.special_tokens_set
214
+
215
+ def id_to_token(self, id):
216
+ return self.enc.decode([id])
217
+
218
+ @lru_cache(maxsize=32)
219
+ def encode_special(self, text):
220
+ return self.enc.encode_single_token(text)
221
+
222
+ def get_bos_token_id(self):
223
+ return self.bos_token_id
224
+
225
+ def encode(self, text, prepend=None, append=None, num_threads=8):
226
+ # text can be either a string or a list of strings
227
+
228
+ if prepend is not None:
229
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
230
+ if append is not None:
231
+ append_id = append if isinstance(append, int) else self.encode_special(append)
232
+
233
+ if isinstance(text, str):
234
+ ids = self.enc.encode_ordinary(text)
235
+ if prepend is not None:
236
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
237
+ if append is not None:
238
+ ids.append(append_id)
239
+ elif isinstance(text, list):
240
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
241
+ if prepend is not None:
242
+ for ids_row in ids:
243
+ ids_row.insert(0, prepend_id) # TODO: same
244
+ if append is not None:
245
+ for ids_row in ids:
246
+ ids_row.append(append_id)
247
+ else:
248
+ raise ValueError(f"Invalid input type: {type(text)}")
249
+
250
+ return ids
251
+
252
+ def __call__(self, *args, **kwargs):
253
+ return self.encode(*args, **kwargs)
254
+
255
+ def decode(self, ids):
256
+ return self.enc.decode(ids)
257
+
258
+ def save(self, tokenizer_dir):
259
+ # save the encoding object to disk
260
+ os.makedirs(tokenizer_dir, exist_ok=True)
261
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
262
+ with open(pickle_path, "wb") as f:
263
+ pickle.dump(self.enc, f)
264
+ print(f"Saved tokenizer encoding to {pickle_path}")
265
+
266
+ def render_conversation(self, conversation, max_tokens=2048):
267
+ """
268
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
269
+ Returns:
270
+ - ids: list[int] is a list of token ids of this rendered conversation
271
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
272
+ """
273
+ # ids, masks that we will return and a helper function to help build them up.
274
+ ids, mask = [], []
275
+ def add_tokens(token_ids, mask_val):
276
+ if isinstance(token_ids, int):
277
+ token_ids = [token_ids]
278
+ ids.extend(token_ids)
279
+ mask.extend([mask_val] * len(token_ids))
280
+
281
+ # sometimes the first message is a system message...
282
+ # => just merge it with the second (user) message
283
+ if conversation["messages"][0]["role"] == "system":
284
+ # some conversation surgery is necessary here for now...
285
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
286
+ messages = conversation["messages"]
287
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
288
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
289
+ messages = messages[1:]
290
+ else:
291
+ messages = conversation["messages"]
292
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
293
+
294
+ # fetch all the special tokens we need
295
+ bos = self.get_bos_token_id()
296
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
297
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
298
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
299
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
300
+
301
+ # now we can tokenize the conversation
302
+ add_tokens(bos, 0)
303
+ for i, message in enumerate(messages):
304
+
305
+ # some sanity checking here around assumptions, to prevent footguns
306
+ must_be_from = "user" if i % 2 == 0 else "assistant"
307
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
308
+
309
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
310
+ content = message["content"]
311
+
312
+ if message["role"] == "user":
313
+ assert isinstance(content, str), "User messages are simply expected to be strings"
314
+ value_ids = self.encode(content)
315
+ add_tokens(user_start, 0)
316
+ add_tokens(value_ids, 0)
317
+ add_tokens(user_end, 0)
318
+ elif message["role"] == "assistant":
319
+ add_tokens(assistant_start, 0)
320
+ if isinstance(content, str):
321
+ # simple string => simply add the tokens
322
+ value_ids = self.encode(content)
323
+ add_tokens(value_ids, 1)
324
+ elif isinstance(content, list):
325
+ for part in content:
326
+ value_ids = self.encode(part["text"])
327
+ if part["type"] == "text":
328
+ # string part => simply add the tokens
329
+ add_tokens(value_ids, 1)
330
+ elif part["type"] == "python":
331
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
332
+ add_tokens(python_start, 1)
333
+ add_tokens(value_ids, 1)
334
+ add_tokens(python_end, 1)
335
+ elif part["type"] == "python_output":
336
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
337
+ # none of these tokens are supervised because the tokens come from Python at test time
338
+ add_tokens(output_start, 0)
339
+ add_tokens(value_ids, 0)
340
+ add_tokens(output_end, 0)
341
+ else:
342
+ raise ValueError(f"Unknown part type: {part['type']}")
343
+ else:
344
+ raise ValueError(f"Unknown content type: {type(content)}")
345
+ add_tokens(assistant_end, 1)
346
+
347
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
348
+ ids = ids[:max_tokens]
349
+ mask = mask[:max_tokens]
350
+ return ids, mask
351
+
352
+ def visualize_tokenization(self, ids, mask, with_token_id=False):
353
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
354
+ RED = '\033[91m'
355
+ GREEN = '\033[92m'
356
+ RESET = '\033[0m'
357
+ GRAY = '\033[90m'
358
+ tokens = []
359
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
360
+ token_str = self.decode([token_id])
361
+ color = GREEN if mask_val == 1 else RED
362
+ tokens.append(f"{color}{token_str}{RESET}")
363
+ if with_token_id:
364
+ tokens.append(f"{GRAY}({token_id}){RESET}")
365
+ return '|'.join(tokens)
366
+
367
+ def render_for_completion(self, conversation):
368
+ """
369
+ Used during Reinforcement Learning. In that setting, we want to
370
+ render the conversation priming the Assistant for a completion.
371
+ Unlike the Chat SFT case, we don't need to return the mask.
372
+ """
373
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
374
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
375
+ messages = conversation["messages"]
376
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
377
+ messages.pop() # remove the last message (of the Assistant) inplace
378
+
379
+ # Now tokenize the conversation
380
+ ids, mask = self.render_conversation(conversation)
381
+
382
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
383
+ assistant_start = self.encode_special("<|assistant_start|>")
384
+ ids.append(assistant_start)
385
+ return ids
386
+
387
+ # -----------------------------------------------------------------------------
388
+ # nanochat-specific convenience functions
389
+
390
+ def get_tokenizer():
391
+ from nanochat.common import get_base_dir
392
+ base_dir = get_base_dir()
393
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
394
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
395
+ return RustBPETokenizer.from_directory(tokenizer_dir)
396
+
397
+ def get_token_bytes(device="cpu"):
398
+ import torch
399
+ from nanochat.common import get_base_dir
400
+ base_dir = get_base_dir()
401
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
402
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
403
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
404
+ with open(token_bytes_path, "rb") as f:
405
+ token_bytes = torch.load(f, map_location=device)
406
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ html, body{
18
+ height: 100%;
19
+ margin: 0;
20
+ }
21
+
22
+ body {
23
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
24
+ background-color: #ffffff;
25
+ color: #111827;
26
+ min-height: 100dvh;
27
+ margin: 0;
28
+ display: flex;
29
+ flex-direction: column;
30
+ }
31
+
32
+ .header {
33
+ background-color: #ffffff;
34
+ padding: 1.25rem 1.5rem;
35
+ }
36
+
37
+ .header-left {
38
+ display: flex;
39
+ align-items: center;
40
+ gap: 0.75rem;
41
+ }
42
+
43
+ .header-logo {
44
+ height: 32px;
45
+ width: auto;
46
+ }
47
+
48
+ .header h1 {
49
+ font-size: 1.25rem;
50
+ font-weight: 600;
51
+ margin: 0;
52
+ color: #111827;
53
+ }
54
+
55
+ .new-conversation-btn {
56
+ width: 32px;
57
+ height: 32px;
58
+ padding: 0;
59
+ border: 1px solid #e5e7eb;
60
+ border-radius: 0.5rem;
61
+ background-color: #ffffff;
62
+ color: #6b7280;
63
+ cursor: pointer;
64
+ display: flex;
65
+ align-items: center;
66
+ justify-content: center;
67
+ transition: all 0.2s ease;
68
+ }
69
+
70
+ .new-conversation-btn:hover {
71
+ background-color: #f3f4f6;
72
+ border-color: #d1d5db;
73
+ color: #374151;
74
+ }
75
+
76
+ .chat-container {
77
+ flex: 1;
78
+ overflow-y: auto;
79
+ background-color: #ffffff;
80
+ }
81
+
82
+ .chat-wrapper {
83
+ max-width: 48rem;
84
+ margin: 0 auto;
85
+ padding: 2rem 1.5rem 3rem;
86
+ display: flex;
87
+ flex-direction: column;
88
+ gap: 0.75rem;
89
+ }
90
+
91
+ .message {
92
+ display: flex;
93
+ justify-content: flex-start;
94
+ margin-bottom: 0.5rem;
95
+ color: #0d0d0d;
96
+ }
97
+
98
+ .message.assistant {
99
+ justify-content: flex-start;
100
+ }
101
+
102
+ .message.user {
103
+ justify-content: flex-end;
104
+ }
105
+
106
+ .message-content {
107
+ white-space: pre-wrap;
108
+ line-height: 1.6;
109
+ max-width: 100%;
110
+ }
111
+
112
+ .message.assistant .message-content {
113
+ background: transparent;
114
+ border: none;
115
+ cursor: pointer;
116
+ border-radius: 0.5rem;
117
+ padding: 0.5rem;
118
+ margin-left: -0.5rem;
119
+ transition: background-color 0.2s ease;
120
+ }
121
+
122
+ .message.assistant .message-content:hover {
123
+ background-color: #f9fafb;
124
+ }
125
+
126
+ .message.user .message-content {
127
+ background-color: #f3f4f6;
128
+ border-radius: 1.25rem;
129
+ padding: 0.8rem 1rem;
130
+ max-width: 65%;
131
+ cursor: pointer;
132
+ transition: background-color 0.2s ease;
133
+ }
134
+
135
+ .message.user .message-content:hover {
136
+ background-color: #e5e7eb;
137
+ }
138
+
139
+ .message.console .message-content {
140
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
141
+ font-size: 0.875rem;
142
+ background-color: #fafafa;
143
+ padding: 0.75rem 1rem;
144
+ color: #374151;
145
+ max-width: 80%;
146
+ }
147
+
148
+ .input-container {
149
+ background-color: #ffffff;
150
+ padding: 1rem;
151
+ padding-bottom: calc(1rem + env(safe-area-inset-bottom))
152
+ }
153
+
154
+ .input-wrapper {
155
+ max-width: 48rem;
156
+ margin: 0 auto;
157
+ display: flex;
158
+ gap: 0.75rem;
159
+ align-items: flex-end;
160
+ }
161
+
162
+ .chat-input {
163
+ flex: 1;
164
+ padding: 0.8rem 1rem;
165
+ border: 1px solid #d1d5db;
166
+ border-radius: 0.75rem;
167
+ background-color: #ffffff;
168
+ color: #111827;
169
+ font-size: 1rem;
170
+ line-height: 1.5;
171
+ resize: none;
172
+ outline: none;
173
+ min-height: 54px;
174
+ max-height: 200px;
175
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
176
+ }
177
+
178
+ .chat-input::placeholder {
179
+ color: #9ca3af;
180
+ }
181
+
182
+ .chat-input:focus {
183
+ border-color: #2563eb;
184
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
185
+ }
186
+
187
+ .send-button {
188
+ flex-shrink: 0;
189
+ padding: 0;
190
+ width: 54px;
191
+ height: 54px;
192
+ border: 1px solid #111827;
193
+ border-radius: 0.75rem;
194
+ background-color: #111827;
195
+ color: #ffffff;
196
+ display: flex;
197
+ align-items: center;
198
+ justify-content: center;
199
+ cursor: pointer;
200
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
201
+ }
202
+
203
+ .send-button:hover:not(:disabled) {
204
+ background-color: #2563eb;
205
+ border-color: #2563eb;
206
+ }
207
+
208
+ .send-button:disabled {
209
+ cursor: not-allowed;
210
+ border-color: #d1d5db;
211
+ background-color: #e5e7eb;
212
+ color: #9ca3af;
213
+ }
214
+
215
+ .typing-indicator {
216
+ display: inline-block;
217
+ color: #6b7280;
218
+ letter-spacing: 0.15em;
219
+ }
220
+
221
+ .typing-indicator::after {
222
+ content: '···';
223
+ animation: typing 1.4s infinite;
224
+ }
225
+
226
+ @keyframes typing {
227
+ 0%, 60%, 100% { opacity: 0.2; }
228
+ 30% { opacity: 1; }
229
+ }
230
+
231
+ .error-message {
232
+ background-color: #fee2e2;
233
+ border: 1px solid #fecaca;
234
+ color: #b91c1c;
235
+ padding: 0.75rem 1rem;
236
+ border-radius: 0.75rem;
237
+ margin-top: 0.5rem;
238
+ }
239
+ </style>
240
+ </head>
241
+ <body>
242
+ <div class="header">
243
+ <div class="header-left">
244
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
245
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
246
+ <path d="M12 5v14"></path>
247
+ <path d="M5 12h14"></path>
248
+ </svg>
249
+ </button>
250
+ <h1>nanochat</h1>
251
+ </div>
252
+ </div>
253
+
254
+ <div class="chat-container" id="chatContainer">
255
+ <div class="chat-wrapper" id="chatWrapper">
256
+ <!-- Messages will be added here -->
257
+ </div>
258
+ </div>
259
+
260
+ <div class="input-container">
261
+ <div class="input-wrapper">
262
+ <textarea
263
+ id="chatInput"
264
+ class="chat-input"
265
+ placeholder="Ask anything"
266
+ rows="1"
267
+ onkeydown="handleKeyDown(event)"
268
+ ></textarea>
269
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
270
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
271
+ <path d="M22 2L11 13"></path>
272
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
273
+ </svg>
274
+ </button>
275
+ </div>
276
+ </div>
277
+
278
+ <script>
279
+ const API_URL = '';
280
+ const chatContainer = document.getElementById('chatContainer');
281
+ const chatWrapper = document.getElementById('chatWrapper');
282
+ const chatInput = document.getElementById('chatInput');
283
+ const sendButton = document.getElementById('sendButton');
284
+
285
+ let messages = [];
286
+ let isGenerating = false;
287
+ let currentTemperature = 0.8;
288
+ let currentTopK = 50;
289
+
290
+ chatInput.addEventListener('input', function() {
291
+ this.style.height = 'auto';
292
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
293
+ sendButton.disabled = !this.value.trim() || isGenerating;
294
+ });
295
+
296
+ function handleKeyDown(event) {
297
+ if (event.key === 'Enter' && !event.shiftKey) {
298
+ event.preventDefault();
299
+ sendMessage();
300
+ }
301
+ }
302
+
303
+ document.addEventListener('keydown', function(event) {
304
+ // Ctrl+Shift+N for new conversation
305
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
306
+ event.preventDefault();
307
+ if (!isGenerating) {
308
+ newConversation();
309
+ }
310
+ }
311
+ });
312
+
313
+ function newConversation() {
314
+ messages = [];
315
+ chatWrapper.innerHTML = '';
316
+ chatInput.value = '';
317
+ chatInput.style.height = 'auto';
318
+ sendButton.disabled = false;
319
+ isGenerating = false;
320
+ chatInput.focus();
321
+ }
322
+
323
+ function addMessage(role, content, messageIndex = null) {
324
+ const messageDiv = document.createElement('div');
325
+ messageDiv.className = `message ${role}`;
326
+
327
+ const contentDiv = document.createElement('div');
328
+ contentDiv.className = 'message-content';
329
+ contentDiv.textContent = content;
330
+
331
+ // Add click handler for user messages to enable editing
332
+ if (role === 'user' && messageIndex !== null) {
333
+ contentDiv.setAttribute('data-message-index', messageIndex);
334
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
335
+ contentDiv.addEventListener('click', function() {
336
+ if (!isGenerating) {
337
+ editMessage(messageIndex);
338
+ }
339
+ });
340
+ }
341
+
342
+ // Add click handler for assistant messages to enable regeneration
343
+ if (role === 'assistant' && messageIndex !== null) {
344
+ contentDiv.setAttribute('data-message-index', messageIndex);
345
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
346
+ contentDiv.addEventListener('click', function() {
347
+ if (!isGenerating) {
348
+ regenerateMessage(messageIndex);
349
+ }
350
+ });
351
+ }
352
+
353
+ messageDiv.appendChild(contentDiv);
354
+ chatWrapper.appendChild(messageDiv);
355
+
356
+ chatContainer.scrollTop = chatContainer.scrollHeight;
357
+ return contentDiv;
358
+ }
359
+
360
+ function editMessage(messageIndex) {
361
+ // Find the message in the messages array
362
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
363
+
364
+ const messageToEdit = messages[messageIndex];
365
+ if (messageToEdit.role !== 'user') return;
366
+
367
+ // Copy message content to input
368
+ chatInput.value = messageToEdit.content;
369
+ chatInput.style.height = 'auto';
370
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
371
+
372
+ // Remove this message and all subsequent messages from the array
373
+ messages = messages.slice(0, messageIndex);
374
+
375
+ // Remove message elements from DOM starting from messageIndex
376
+ const allMessages = chatWrapper.querySelectorAll('.message');
377
+ for (let i = messageIndex; i < allMessages.length; i++) {
378
+ allMessages[i].remove();
379
+ }
380
+
381
+ // Enable send button and focus input
382
+ sendButton.disabled = false;
383
+ chatInput.focus();
384
+ }
385
+
386
+ async function generateAssistantResponse() {
387
+ isGenerating = true;
388
+ sendButton.disabled = true;
389
+
390
+ const assistantContent = addMessage('assistant', '');
391
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
392
+
393
+ try {
394
+ const response = await fetch(`${API_URL}/chat/completions`, {
395
+ method: 'POST',
396
+ headers: {
397
+ 'Content-Type': 'application/json',
398
+ },
399
+ body: JSON.stringify({
400
+ messages: messages,
401
+ temperature: currentTemperature,
402
+ top_k: currentTopK,
403
+ max_tokens: 512
404
+ }),
405
+ });
406
+
407
+ if (!response.ok) {
408
+ throw new Error(`HTTP error! status: ${response.status}`);
409
+ }
410
+
411
+ const reader = response.body.getReader();
412
+ const decoder = new TextDecoder();
413
+ let fullResponse = '';
414
+ assistantContent.textContent = '';
415
+
416
+ while (true) {
417
+ const { done, value } = await reader.read();
418
+ if (done) break;
419
+
420
+ const chunk = decoder.decode(value);
421
+ const lines = chunk.split('\n');
422
+
423
+ for (const line of lines) {
424
+ if (line.startsWith('data: ')) {
425
+ try {
426
+ const data = JSON.parse(line.slice(6));
427
+ if (data.token) {
428
+ fullResponse += data.token;
429
+ assistantContent.textContent = fullResponse;
430
+ chatContainer.scrollTop = chatContainer.scrollHeight;
431
+ }
432
+ } catch (e) {
433
+ }
434
+ }
435
+ }
436
+ }
437
+
438
+ const assistantMessageIndex = messages.length;
439
+ messages.push({ role: 'assistant', content: fullResponse });
440
+
441
+ // Add click handler to regenerate this assistant message
442
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
443
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
444
+ assistantContent.addEventListener('click', function() {
445
+ if (!isGenerating) {
446
+ regenerateMessage(assistantMessageIndex);
447
+ }
448
+ });
449
+
450
+ } catch (error) {
451
+ console.error('Error:', error);
452
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
453
+ } finally {
454
+ isGenerating = false;
455
+ sendButton.disabled = !chatInput.value.trim();
456
+ }
457
+ }
458
+
459
+ async function regenerateMessage(messageIndex) {
460
+ // Find the message in the messages array
461
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
462
+
463
+ const messageToRegenerate = messages[messageIndex];
464
+ if (messageToRegenerate.role !== 'assistant') return;
465
+
466
+ // Remove this message and all subsequent messages from the array
467
+ messages = messages.slice(0, messageIndex);
468
+
469
+ // Remove message elements from DOM starting from messageIndex
470
+ const allMessages = chatWrapper.querySelectorAll('.message');
471
+ for (let i = messageIndex; i < allMessages.length; i++) {
472
+ allMessages[i].remove();
473
+ }
474
+
475
+ // Regenerate the assistant response
476
+ await generateAssistantResponse();
477
+ }
478
+
479
+ function handleSlashCommand(command) {
480
+ const parts = command.trim().split(/\s+/);
481
+ const cmd = parts[0].toLowerCase();
482
+ const arg = parts[1];
483
+
484
+ if (cmd === '/temperature') {
485
+ if (arg === undefined) {
486
+ addMessage('console', `Current temperature: ${currentTemperature}`);
487
+ } else {
488
+ const temp = parseFloat(arg);
489
+ if (isNaN(temp) || temp < 0 || temp > 2) {
490
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
491
+ } else {
492
+ currentTemperature = temp;
493
+ addMessage('console', `Temperature set to ${currentTemperature}`);
494
+ }
495
+ }
496
+ return true;
497
+ } else if (cmd === '/topk') {
498
+ if (arg === undefined) {
499
+ addMessage('console', `Current top-k: ${currentTopK}`);
500
+ } else {
501
+ const topk = parseInt(arg);
502
+ if (isNaN(topk) || topk < 1 || topk > 200) {
503
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
504
+ } else {
505
+ currentTopK = topk;
506
+ addMessage('console', `Top-k set to ${currentTopK}`);
507
+ }
508
+ }
509
+ return true;
510
+ } else if (cmd === '/clear') {
511
+ newConversation();
512
+ return true;
513
+ } else if (cmd === '/help') {
514
+ addMessage('console',
515
+ 'Available commands:\n' +
516
+ '/temperature - Show current temperature\n' +
517
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
518
+ '/topk - Show current top-k\n' +
519
+ '/topk <value> - Set top-k (1-200)\n' +
520
+ '/clear - Clear conversation\n' +
521
+ '/help - Show this help message'
522
+ );
523
+ return true;
524
+ }
525
+ return false;
526
+ }
527
+
528
+ async function sendMessage() {
529
+ const message = chatInput.value.trim();
530
+ if (!message || isGenerating) return;
531
+
532
+ // Handle slash commands
533
+ if (message.startsWith('/')) {
534
+ chatInput.value = '';
535
+ chatInput.style.height = 'auto';
536
+ handleSlashCommand(message);
537
+ return;
538
+ }
539
+
540
+ chatInput.value = '';
541
+ chatInput.style.height = 'auto';
542
+
543
+ const userMessageIndex = messages.length;
544
+ messages.push({ role: 'user', content: message });
545
+ addMessage('user', message, userMessageIndex);
546
+
547
+ await generateAssistantResponse();
548
+ }
549
+
550
+ sendButton.disabled = false;
551
+
552
+ // Autofocus the chat input on page load
553
+ chatInput.focus();
554
+
555
+ fetch(`${API_URL}/health`)
556
+ .then(response => response.json())
557
+ .then(data => {
558
+ console.log('Engine status:', data);
559
+ })
560
+ .catch(error => {
561
+ console.error('Engine not available:', error);
562
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
563
+ });
564
+ </script>
565
+ </body>
566
+ </html>
pyproject.toml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "nanochat"
3
+ version = "0.1.0"
4
+ description = "the minimal full-stack ChatGPT clone"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "datasets>=4.0.0",
9
+ "fastapi>=0.117.1",
10
+ "ipykernel>=7.1.0",
11
+ "kernels>=0.11.7",
12
+ "matplotlib>=3.10.8",
13
+ "psutil>=7.1.0",
14
+ "python-dotenv>=1.2.1",
15
+ "regex>=2025.9.1",
16
+ "rustbpe>=0.1.0",
17
+ "scipy>=1.15.3",
18
+ "setuptools>=80.9.0",
19
+ "tabulate>=0.9.0",
20
+ "tiktoken>=0.11.0",
21
+ "tokenizers>=0.22.0",
22
+ "torch==2.9.1",
23
+ "transformers>=4.57.3",
24
+ "uvicorn>=0.36.0",
25
+ "wandb>=0.21.3",
26
+ "zstandard>=0.25.0",
27
+ ]
28
+
29
+ [dependency-groups]
30
+ dev = [
31
+ "pytest>=8.0.0",
32
+ ]
33
+
34
+ [tool.pytest.ini_options]
35
+ markers = [
36
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
37
+ ]
38
+ testpaths = ["tests"]
39
+ python_files = ["test_*.py"]
40
+ python_classes = ["Test*"]
41
+ python_functions = ["test_*"]
42
+
43
+ # target torch to cuda 12.8 or CPU
44
+ [tool.uv.sources]
45
+ torch = [
46
+ { index = "pytorch-cpu", extra = "cpu" },
47
+ { index = "pytorch-cu128", extra = "gpu" },
48
+ ]
49
+
50
+ [[tool.uv.index]]
51
+ name = "pytorch-cpu"
52
+ url = "https://download.pytorch.org/whl/cpu"
53
+ explicit = true
54
+
55
+ [[tool.uv.index]]
56
+ name = "pytorch-cu128"
57
+ url = "https://download.pytorch.org/whl/cu128"
58
+ explicit = true
59
+
60
+ [project.optional-dependencies]
61
+ cpu = [
62
+ "torch==2.9.1",
63
+ ]
64
+ gpu = [
65
+ "torch==2.9.1",
66
+ ]
67
+
68
+ [tool.uv]
69
+ conflicts = [
70
+ [
71
+ { extra = "cpu" },
72
+ { extra = "gpu" },
73
+ ],
74
+ ]
scripts/__pycache__/base_eval.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
scripts/__pycache__/base_train.cpython-310.pyc ADDED
Binary file (17.9 kB). View file
 
scripts/__pycache__/chat_eval.cpython-310.pyc ADDED
Binary file (7.52 kB). View file
 
scripts/__pycache__/chat_sft.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
scripts/__pycache__/tok_eval.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
scripts/__pycache__/tok_train.cpython-310.pyc ADDED
Binary file (2.98 kB). View file
 
scripts/base_eval.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluation script for base models.
3
+
4
+ Supports three evaluation modes (comma-separated):
5
+ --eval core : CORE metric (accuracy on ICL tasks)
6
+ --eval bpb : Bits per byte on train/val splits
7
+ --eval sample : Generate samples from the model
8
+
9
+ Default is all three: --eval core,bpb,sample
10
+
11
+ Examples:
12
+
13
+ # Evaluate a HuggingFace model (e.g. GPT-2 124M) using 8 GPUs
14
+ torchrun --nproc_per_node=8 -m scripts.base_eval --hf-path openai-community/gpt2
15
+
16
+ # Evaluate a nanochat model (e.g. d24) using 8 GPUs
17
+ torchrun --nproc_per_node=8 -m scripts.base_eval --model-tag d24 --device-batch-size=16
18
+
19
+ # Quick/approximate evaluation using a single GPU
20
+ python -m scripts.base_eval --model-tag d24 --device-batch-size=16 --max-per-task=100 --split-tokens=524288
21
+ """
22
+ import os
23
+ import csv
24
+ import time
25
+ import json
26
+ import yaml
27
+ import shutil
28
+ import random
29
+ import zipfile
30
+ import tempfile
31
+ import argparse
32
+ import torch
33
+
34
+ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
35
+ from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
36
+ from nanochat.checkpoint_manager import load_model
37
+ from nanochat.core_eval import evaluate_task
38
+ from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
39
+ from nanochat.loss_eval import evaluate_bpb
40
+ from nanochat.engine import Engine
41
+
42
+ # -----------------------------------------------------------------------------
43
+ # HuggingFace loading utilities
44
+
45
+ class ModelWrapper:
46
+ """Lightweight wrapper to give HuggingFace models a nanochat-compatible interface."""
47
+ def __init__(self, model, max_seq_len=None):
48
+ self.model = model
49
+ self.max_seq_len = max_seq_len
50
+
51
+ def __call__(self, input_ids, targets=None, loss_reduction='mean'):
52
+ logits = self.model(input_ids).logits
53
+ if targets is None:
54
+ return logits
55
+ loss = torch.nn.functional.cross_entropy(
56
+ logits.view(-1, logits.size(-1)),
57
+ targets.view(-1),
58
+ ignore_index=-1,
59
+ reduction=loss_reduction
60
+ )
61
+ return loss
62
+
63
+ def get_device(self):
64
+ return next(self.model.parameters()).device
65
+
66
+
67
+ def load_hf_model(hf_path: str, device):
68
+ """Load a HuggingFace model and tokenizer."""
69
+ print0(f"Loading HuggingFace model from: {hf_path}")
70
+ from transformers import AutoModelForCausalLM
71
+ model = AutoModelForCausalLM.from_pretrained(hf_path)
72
+ model.to(device)
73
+ model.eval()
74
+ max_seq_len = 1024 if "gpt2" in hf_path else None
75
+ model = ModelWrapper(model, max_seq_len=max_seq_len)
76
+ tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
77
+ return model, tokenizer
78
+
79
+
80
+ def get_hf_token_bytes(tokenizer, device="cpu"):
81
+ """Compute token_bytes tensor for a HuggingFace tokenizer."""
82
+ vocab_size = tokenizer.tokenizer.get_vocab_size()
83
+ token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
84
+ for token_id in range(vocab_size):
85
+ token_str = tokenizer.tokenizer.decode([token_id])
86
+ token_bytes[token_id] = len(token_str.encode('utf-8'))
87
+ return token_bytes
88
+
89
+ # -----------------------------------------------------------------------------
90
+ # CORE evaluation
91
+
92
+ EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
93
+
94
+
95
+ def place_eval_bundle(file_path):
96
+ """Unzip eval_bundle.zip and place it in the base directory."""
97
+ base_dir = get_base_dir()
98
+ eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
99
+ with tempfile.TemporaryDirectory() as tmpdir:
100
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
101
+ zip_ref.extractall(tmpdir)
102
+ extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
103
+ shutil.move(extracted_bundle_dir, eval_bundle_dir)
104
+ print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
105
+
106
+
107
+ def evaluate_core(model, tokenizer, device, max_per_task=-1):
108
+ """
109
+ Evaluate a base model on the CORE benchmark.
110
+ Returns dict with results, centered_results, and core_metric.
111
+ """
112
+ base_dir = get_base_dir()
113
+ eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
114
+ # Download the eval bundle if needed
115
+ if not os.path.exists(eval_bundle_dir):
116
+ download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
117
+
118
+ config_path = os.path.join(eval_bundle_dir, "core.yaml")
119
+ data_base_path = os.path.join(eval_bundle_dir, "eval_data")
120
+ eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
121
+
122
+ with open(config_path, 'r', encoding='utf-8') as f:
123
+ config = yaml.safe_load(f)
124
+ tasks = config['icl_tasks']
125
+
126
+ # Load random baseline values
127
+ random_baselines = {}
128
+ with open(eval_meta_data, 'r', encoding='utf-8') as f:
129
+ reader = csv.DictReader(f)
130
+ for row in reader:
131
+ task_name = row['Eval Task']
132
+ random_baseline = row['Random baseline']
133
+ random_baselines[task_name] = float(random_baseline)
134
+
135
+ # Evaluate each task
136
+ results = {}
137
+ centered_results = {}
138
+ for task in tasks:
139
+ start_time = time.time()
140
+ label = task['label']
141
+ task_meta = {
142
+ 'task_type': task['icl_task_type'],
143
+ 'dataset_uri': task['dataset_uri'],
144
+ 'num_fewshot': task['num_fewshot'][0],
145
+ 'continuation_delimiter': task.get('continuation_delimiter', ' ')
146
+ }
147
+ print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
148
+
149
+ data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
150
+ with open(data_path, 'r', encoding='utf-8') as f:
151
+ data = [json.loads(line.strip()) for line in f]
152
+
153
+ # Shuffle for consistent subsampling when using max_per_task
154
+ shuffle_rng = random.Random(1337)
155
+ shuffle_rng.shuffle(data)
156
+ if max_per_task > 0:
157
+ data = data[:max_per_task]
158
+
159
+ accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
160
+ results[label] = accuracy
161
+ random_baseline = random_baselines[label]
162
+ centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
163
+ centered_results[label] = centered_result
164
+ elapsed = time.time() - start_time
165
+ print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {elapsed:.2f}s")
166
+
167
+ core_metric = sum(centered_results.values()) / len(centered_results)
168
+ out = {
169
+ "results": results,
170
+ "centered_results": centered_results,
171
+ "core_metric": core_metric
172
+ }
173
+ return out
174
+
175
+ # -----------------------------------------------------------------------------
176
+ # Main
177
+
178
+ def main():
179
+ parser = argparse.ArgumentParser(description="Base model evaluation")
180
+ parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)')
181
+ parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)')
182
+ parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory')
183
+ parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)')
184
+ parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)')
185
+ parser.add_argument('--device-batch-size', type=int, default=32, help='Per-device batch size for BPB evaluation')
186
+ parser.add_argument('--split-tokens', type=int, default=40*524288, help='Number of tokens to evaluate per split for BPB')
187
+ parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps (empty = autodetect)')
188
+ args = parser.parse_args()
189
+
190
+ # Parse evaluation modes
191
+ eval_modes = set(mode.strip() for mode in args.eval.split(','))
192
+ valid_modes = {'core', 'bpb', 'sample'}
193
+ invalid = eval_modes - valid_modes
194
+ if invalid:
195
+ parser.error(f"Invalid eval modes: {invalid}. Valid: {valid_modes}")
196
+
197
+ # Distributed / precision setup
198
+ device_type = autodetect_device_type() if args.device_type == '' else args.device_type
199
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
200
+ # Load model and tokenizer
201
+ is_hf_model = args.hf_path is not None
202
+ if is_hf_model:
203
+ model, tokenizer = load_hf_model(args.hf_path, device)
204
+ sequence_len = model.max_seq_len or 1024
205
+ token_bytes = get_hf_token_bytes(tokenizer, device=device)
206
+ model_name = args.hf_path
207
+ model_slug = args.hf_path.replace("/", "-")
208
+ else:
209
+ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
210
+ sequence_len = meta["model_config"]["sequence_len"]
211
+ token_bytes = get_token_bytes(device=device)
212
+ model_name = f"base_model (step {meta['step']})"
213
+ model_slug = f"base_model_{meta['step']:06d}"
214
+
215
+ print0(f"Evaluating model: {model_name}")
216
+ print0(f"Eval modes: {', '.join(sorted(eval_modes))}")
217
+
218
+ # Results to log
219
+ core_results = None
220
+ bpb_results = {}
221
+ samples = []
222
+ unconditioned_samples = []
223
+
224
+ # --- Sampling ---
225
+ if 'sample' in eval_modes and not is_hf_model:
226
+ print0("\n" + "="*80)
227
+ print0("Model Samples")
228
+ print0("="*80)
229
+ if ddp_rank == 0:
230
+ prompts = [
231
+ "The capital of France is",
232
+ "The chemical symbol of gold is",
233
+ "If yesterday was Friday, then tomorrow will be",
234
+ "The opposite of hot is",
235
+ "The planets of the solar system are:",
236
+ "My favorite color is",
237
+ "If 5*x + 3 = 13, then x is",
238
+ ]
239
+ engine = Engine(model, tokenizer)
240
+ print0("\nConditioned samples:")
241
+ for prompt in prompts:
242
+ tokens = tokenizer(prompt, prepend="<|bos|>")
243
+ sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
244
+ sample_str = tokenizer.decode(sample[0])
245
+ print0("-" * 80)
246
+ print0(sample_str)
247
+ samples.append(sample_str)
248
+
249
+ print0("\nUnconditioned samples:")
250
+ tokens = tokenizer("", prepend="<|bos|>")
251
+ uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
252
+ for sample in uncond:
253
+ sample_str = tokenizer.decode(sample)
254
+ print0("-" * 80)
255
+ print0(sample_str)
256
+ unconditioned_samples.append(sample_str)
257
+ elif 'sample' in eval_modes and is_hf_model:
258
+ print0("\nSkipping sampling for HuggingFace models (not supported)")
259
+
260
+ # --- BPB evaluation ---
261
+ if 'bpb' in eval_modes:
262
+ print0("\n" + "="*80)
263
+ print0("BPB Evaluation")
264
+ print0("="*80)
265
+ tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
266
+ if args.split_tokens % tokens_per_step != 0:
267
+ # Adjust to nearest multiple
268
+ args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step
269
+ print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
270
+ steps = args.split_tokens // tokens_per_step
271
+
272
+ for split_name in ["train", "val"]:
273
+ loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
274
+ bpb = evaluate_bpb(model, loader, steps, token_bytes)
275
+ bpb_results[split_name] = bpb
276
+ print0(f"{split_name} bpb: {bpb:.6f}")
277
+
278
+ # --- CORE evaluation ---
279
+ if 'core' in eval_modes:
280
+ print0("\n" + "="*80)
281
+ print0("CORE Evaluation")
282
+ print0("="*80)
283
+ core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
284
+
285
+ # Write CSV output
286
+ if ddp_rank == 0:
287
+ base_dir = get_base_dir()
288
+ output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
289
+ os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
290
+ with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
291
+ f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
292
+ for label in core_results["results"]:
293
+ acc = core_results["results"][label]
294
+ centered = core_results["centered_results"][label]
295
+ f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n")
296
+ f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n")
297
+ print0(f"\nResults written to: {output_csv_path}")
298
+ print0(f"CORE metric: {core_results['core_metric']:.4f}")
299
+
300
+ # --- Log to report ---
301
+ from nanochat.report import get_report
302
+ report_data = [{"model": model_name}]
303
+
304
+ if core_results:
305
+ report_data[0]["CORE metric"] = core_results["core_metric"]
306
+ report_data.append(core_results["centered_results"])
307
+
308
+ if bpb_results:
309
+ report_data[0]["train bpb"] = bpb_results.get("train")
310
+ report_data[0]["val bpb"] = bpb_results.get("val")
311
+
312
+ if samples:
313
+ report_data.append({f"sample {i}": s for i, s in enumerate(samples)})
314
+ if unconditioned_samples:
315
+ report_data.append({f"unconditioned {i}": s for i, s in enumerate(unconditioned_samples)})
316
+
317
+ get_report().log(section="Base model evaluation", data=report_data)
318
+
319
+ compute_cleanup()
320
+
321
+
322
+ if __name__ == "__main__":
323
+ main()
scripts/base_train.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train model. From root directory of the project, run as:
3
+
4
+ python -m scripts.base_train
5
+
6
+ or distributed as:
7
+
8
+ torchrun --nproc_per_node=8 -m scripts.base_train
9
+
10
+ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
11
+ python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
12
+ """
13
+
14
+ import os
15
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
16
+ import gc
17
+ import json
18
+ import time
19
+ import math
20
+ import argparse
21
+ from dataclasses import asdict
22
+ from contextlib import contextmanager
23
+
24
+ import wandb
25
+ import torch
26
+ import torch.distributed as dist
27
+
28
+ from nanochat.gpt import GPT, GPTConfig, Linear
29
+ from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
30
+ from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
31
+ from nanochat.tokenizer import get_tokenizer, get_token_bytes
32
+ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
33
+ from nanochat.loss_eval import evaluate_bpb
34
+ from nanochat.engine import Engine
35
+ from nanochat.flash_attention import HAS_FA3
36
+ from scripts.base_eval import evaluate_core
37
+ print_banner()
38
+
39
+ # -----------------------------------------------------------------------------
40
+ # CLI arguments
41
+ parser = argparse.ArgumentParser(description="Pretrain base model")
42
+ # Logging
43
+ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
44
+ # Runtime
45
+ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
46
+ # FP8 training
47
+ parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
48
+ parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
49
+ # Model architecture
50
+ parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
51
+ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
52
+ parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
53
+ parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
54
+ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
55
+ # Training horizon (only one used, in order of precedence)
56
+ parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
57
+ parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
58
+ parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
59
+ # Optimization
60
+ parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
61
+ parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
62
+ parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
63
+ parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
64
+ parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
65
+ parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
66
+ parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
67
+ parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
68
+ parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
69
+ parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
70
+ parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
71
+ # Evaluation
72
+ parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
73
+ parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
74
+ parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
75
+ parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
76
+ parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
77
+ parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
78
+ # Output
79
+ parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
80
+ args = parser.parse_args()
81
+ user_config = vars(args).copy() # for logging
82
+ # -----------------------------------------------------------------------------
83
+ # Compute init and wandb logging
84
+
85
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
86
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
87
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
88
+ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
89
+ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
90
+ if device_type == "cuda":
91
+ gpu_device_name = torch.cuda.get_device_name(0)
92
+ gpu_peak_flops = get_peak_flops(gpu_device_name)
93
+ print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
94
+ else:
95
+ gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
96
+ print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
97
+
98
+ # wandb logging init
99
+ use_dummy_wandb = args.run == "dummy" or not master_process
100
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
101
+
102
+ # Flash Attention status
103
+ from nanochat.flash_attention import USE_FA3
104
+ using_fa3 = USE_FA3
105
+ if using_fa3:
106
+ print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
107
+ else:
108
+ print0("!" * 80)
109
+ if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
110
+ print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
111
+ else:
112
+ print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
113
+ print0("WARNING: Training will be less efficient without FA3")
114
+ if args.window_pattern != "L":
115
+ print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
116
+ print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
117
+ print0("!" * 80)
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # Tokenizer will be useful for evaluation and also we need the vocab size to init the model
121
+ tokenizer = get_tokenizer()
122
+ token_bytes = get_token_bytes(device=device)
123
+ vocab_size = tokenizer.get_vocab_size()
124
+ print0(f"Vocab size: {vocab_size:,}")
125
+
126
+ # -----------------------------------------------------------------------------
127
+ # Initialize the Model
128
+
129
+ def build_model_meta(depth):
130
+ """Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
131
+ # Model dim is nudged up to nearest multiple of head_dim for clean division
132
+ # (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
133
+ base_dim = depth * args.aspect_ratio
134
+ model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
135
+ num_heads = model_dim // args.head_dim
136
+ config = GPTConfig(
137
+ sequence_len=args.max_seq_len, vocab_size=vocab_size,
138
+ n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
139
+ window_pattern=args.window_pattern,
140
+ )
141
+ with torch.device("meta"):
142
+ model_meta = GPT(config)
143
+ return model_meta
144
+
145
+ # Build the model, move to device, init the weights
146
+ model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
147
+ model_config = model.config
148
+ model_config_kwargs = asdict(model_config)
149
+ print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
150
+ model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
151
+ model.init_weights() # 3) All tensors get initialized
152
+
153
+ # If we are resuming, overwrite the model parameters with those of the checkpoint
154
+ base_dir = get_base_dir()
155
+ output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
156
+ checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
157
+ resuming = args.resume_from_step != -1
158
+ if resuming:
159
+ print0(f"Resuming optimization from step {args.resume_from_step}")
160
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
161
+ model.load_state_dict(model_data, strict=True, assign=True)
162
+ del model_data # free up this memory after the copy
163
+
164
+ # -----------------------------------------------------------------------------
165
+ # FP8 training initialization and management (this has to be done before torch.compile)
166
+
167
+ # Convert Linear layers to Float8Linear if --fp8 is set
168
+ if args.fp8:
169
+ if device_type != "cuda":
170
+ print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
171
+ else:
172
+ # our custom fp8 is simpler than torchao, written for exact API compatibility
173
+ from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
174
+ # from torchao.float8 import Float8LinearConfig, convert_to_float8_training
175
+ import torch.nn as nn
176
+
177
+ # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
178
+ def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
179
+ if not isinstance(mod, nn.Linear):
180
+ return False
181
+ if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
182
+ return False
183
+ if min(mod.in_features, mod.out_features) < 128:
184
+ return False
185
+ return True
186
+
187
+ fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
188
+ num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
189
+ convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
190
+ num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
191
+ num_skipped = num_linear - num_fp8
192
+ print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
193
+
194
+ # Context manager to temporarily disable FP8 so that model evaluation remains in BF16
195
+ @contextmanager
196
+ def disable_fp8(model):
197
+ """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
198
+
199
+ CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
200
+ we swap out Float8Linear modules entirely and restore them after.
201
+ """
202
+ import torch.nn as nn
203
+
204
+ # Find all Float8Linear modules and their locations
205
+ fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
206
+ for name, module in model.named_modules():
207
+ if 'Float8' in type(module).__name__:
208
+ if '.' in name:
209
+ parent_name, attr_name = name.rsplit('.', 1)
210
+ parent = model.get_submodule(parent_name)
211
+ else:
212
+ parent = model
213
+ attr_name = name
214
+ fp8_locations.append((parent, attr_name, module))
215
+
216
+ if not fp8_locations:
217
+ yield # No FP8 modules, nothing to do
218
+ return
219
+
220
+ # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
221
+ for parent, attr_name, fp8_module in fp8_locations:
222
+ linear = Linear(
223
+ fp8_module.in_features,
224
+ fp8_module.out_features,
225
+ bias=fp8_module.bias is not None,
226
+ device=fp8_module.weight.device,
227
+ dtype=fp8_module.weight.dtype,
228
+ )
229
+ linear.weight = fp8_module.weight # share, don't copy
230
+ if fp8_module.bias is not None:
231
+ linear.bias = fp8_module.bias
232
+ setattr(parent, attr_name, linear)
233
+
234
+ try:
235
+ yield
236
+ finally:
237
+ # Restore Float8Linear modules
238
+ for parent, attr_name, fp8_module in fp8_locations:
239
+ setattr(parent, attr_name, fp8_module)
240
+
241
+ # -----------------------------------------------------------------------------
242
+ # Compile the model
243
+
244
+ orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
245
+ model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
246
+
247
+ # -----------------------------------------------------------------------------
248
+ # Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
249
+
250
+ # Get the parameter counts of our model
251
+ param_counts = model.num_scaling_params()
252
+ print0(f"Parameter counts:")
253
+ for key, value in param_counts.items():
254
+ print0(f"{key:24s}: {value:,}")
255
+ num_params = param_counts['total']
256
+ num_flops_per_token = model.estimate_flops()
257
+ print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
258
+
259
+ # 1) Use scaling laws to determine the optimal training horizon in tokens
260
+ # The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
261
+ # We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
262
+ def get_scaling_params(m):
263
+ # As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
264
+ params_counts = m.num_scaling_params()
265
+ scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head']
266
+ return scaling_params
267
+ num_scaling_params = get_scaling_params(model)
268
+ target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train
269
+
270
+ # Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
271
+ d12_ref = build_model_meta(12) # creates the model on meta device
272
+ D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically)
273
+ B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
274
+
275
+ # 2) Now that we have the token horizon, we can calculate the optimal batch size
276
+ # We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
277
+ # The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x.
278
+ total_batch_size = args.total_batch_size # user-provided override is possible
279
+ if total_batch_size == -1:
280
+ batch_size_ratio = target_tokens / D_REF
281
+ predicted_batch_size = B_REF * batch_size_ratio ** 0.383
282
+ total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency
283
+ print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens")
284
+
285
+ # 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates)
286
+ batch_lr_scale = 1.0
287
+ batch_ratio = total_batch_size / B_REF # B/B_ref
288
+ if batch_ratio != 1.0:
289
+ # SGD: linear scaling with batch size is standard (not used in nanochat)
290
+ # AdamW: sqrt scaling is standard: η ∝ √(B/B_ref)
291
+ # Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!)
292
+ batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref)
293
+ print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})")
294
+
295
+ # 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling
296
+ # We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698
297
+ # Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant.
298
+ # Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need:
299
+ # λ = λ_ref · √(B/B_ref) · (D_ref/D)
300
+ # Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too.
301
+ weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens)
302
+ if weight_decay_scaled != args.weight_decay:
303
+ print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
304
+
305
+ # -----------------------------------------------------------------------------
306
+ # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
307
+ optimizer = model.setup_optimizer(
308
+ # AdamW hyperparameters
309
+ unembedding_lr=args.unembedding_lr * batch_lr_scale,
310
+ embedding_lr=args.embedding_lr * batch_lr_scale,
311
+ scalar_lr=args.scalar_lr * batch_lr_scale,
312
+ # Muon hyperparameters
313
+ matrix_lr=args.matrix_lr * batch_lr_scale,
314
+ weight_decay=weight_decay_scaled,
315
+ )
316
+
317
+ if resuming:
318
+ optimizer.load_state_dict(optimizer_data)
319
+ del optimizer_data
320
+
321
+ # -----------------------------------------------------------------------------
322
+ # GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
323
+ scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
324
+ if scaler is not None:
325
+ print0("GradScaler enabled for fp16 training")
326
+
327
+ # -----------------------------------------------------------------------------
328
+ # Initialize the DataLoaders for train/val
329
+ dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
330
+ train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
331
+ build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
332
+ x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
333
+
334
+ # -----------------------------------------------------------------------------
335
+ # Calculate the number of iterations we will train for and set up the various schedulers
336
+
337
+ # num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order)
338
+ assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
339
+ if args.num_iterations > 0:
340
+ # Override num_iterations to a specific value if given
341
+ num_iterations = args.num_iterations
342
+ print0(f"Using user-provided number of iterations: {num_iterations:,}")
343
+ elif args.target_flops > 0:
344
+ # Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
345
+ num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size))
346
+ print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
347
+ elif args.target_param_data_ratio > 0:
348
+ # Calculate the number of iterations from the target param data ratio (the most common use case)
349
+ num_iterations = target_tokens // total_batch_size
350
+ print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
351
+ else:
352
+ raise ValueError("No training horizon specified")
353
+ total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for
354
+ print0(f"Total number of training tokens: {total_tokens:,}")
355
+ print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
356
+ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
357
+
358
+ # Learning rate schedule (linear warmup, constant, linear warmdown)
359
+ def get_lr_multiplier(it):
360
+ warmup_iters = args.warmup_steps
361
+ warmdown_iters = round(args.warmdown_ratio * num_iterations)
362
+ if it < warmup_iters:
363
+ return (it + 1) / warmup_iters
364
+ elif it <= num_iterations - warmdown_iters:
365
+ return 1.0
366
+ else:
367
+ progress = (num_iterations - it) / warmdown_iters
368
+ return progress * 1.0 + (1 - progress) * args.final_lr_frac
369
+
370
+ # Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
371
+ def get_muon_momentum(it):
372
+ warmdown_iters = round(args.warmdown_ratio * num_iterations)
373
+ warmdown_start = num_iterations - warmdown_iters
374
+ if it < 400:
375
+ frac = it / 400
376
+ return (1 - frac) * 0.85 + frac * 0.97
377
+ elif it >= warmdown_start:
378
+ progress = (it - warmdown_start) / warmdown_iters
379
+ return 0.97 * (1 - progress) + 0.90 * progress
380
+ else:
381
+ return 0.97
382
+
383
+ # Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
384
+ def get_weight_decay(it):
385
+ return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations))
386
+
387
+ # -----------------------------------------------------------------------------
388
+ # Training loop
389
+
390
+ # Loop state (variables updated by the training loop)
391
+ if not resuming:
392
+ step = 0
393
+ val_bpb = None # will be set if eval_every > 0
394
+ min_val_bpb = float("inf")
395
+ smooth_train_loss = 0 # EMA of training loss
396
+ total_training_time = 0 # total wall-clock time of training
397
+ else:
398
+ step = meta_data["step"]
399
+ loop_state = meta_data["loop_state"]
400
+ val_bpb = meta_data["val_bpb"]
401
+ min_val_bpb = loop_state["min_val_bpb"]
402
+ smooth_train_loss = loop_state["smooth_train_loss"]
403
+ total_training_time = loop_state["total_training_time"]
404
+
405
+ # Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
406
+ tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
407
+ world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
408
+ assert total_batch_size % world_tokens_per_fwdbwd == 0
409
+ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
410
+ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
411
+ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
412
+ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
413
+
414
+ # Go!
415
+ while True:
416
+ last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
417
+ flops_so_far = num_flops_per_token * total_batch_size * step
418
+
419
+ # once in a while: evaluate the val bpb (all ranks participate)
420
+ if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
421
+ model.eval()
422
+ val_loader = build_val_loader()
423
+ eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
424
+ with disable_fp8(model):
425
+ val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
426
+ print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
427
+ if val_bpb < min_val_bpb:
428
+ min_val_bpb = val_bpb
429
+ wandb_run.log({
430
+ "step": step,
431
+ "total_training_flops": flops_so_far,
432
+ "total_training_time": total_training_time,
433
+ "val/bpb": val_bpb,
434
+ })
435
+ model.train()
436
+
437
+ # once in a while: estimate the CORE metric (all ranks participate)
438
+ # use the original uncompiled model because the inputs keep changing shape
439
+ # disable FP8 for evaluation to use BF16 for more consistent/accurate results
440
+ results = {}
441
+ if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
442
+ model.eval()
443
+ with disable_fp8(orig_model):
444
+ results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
445
+ print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
446
+ wandb_run.log({
447
+ "step": step,
448
+ "total_training_flops": flops_so_far,
449
+ "core_metric": results["core_metric"],
450
+ "centered_results": results["centered_results"],
451
+ })
452
+ model.train()
453
+
454
+ # once in a while: sample from the model (only on master process)
455
+ # use the original uncompiled model because the inputs keep changing shape
456
+ if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
457
+ model.eval()
458
+ prompts = [
459
+ "The capital of France is",
460
+ "The chemical symbol of gold is",
461
+ "If yesterday was Friday, then tomorrow will be",
462
+ "The opposite of hot is",
463
+ "The planets of the solar system are:",
464
+ "My favorite color is",
465
+ "If 5*x + 3 = 13, then x is",
466
+ ]
467
+ engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
468
+ for prompt in prompts:
469
+ tokens = tokenizer(prompt, prepend="<|bos|>")
470
+ with disable_fp8(orig_model):
471
+ sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
472
+ print0(tokenizer.decode(sample[0]))
473
+ model.train()
474
+
475
+ # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
476
+ if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
477
+ save_checkpoint(
478
+ checkpoint_dir,
479
+ step,
480
+ orig_model.state_dict(), # model parameters
481
+ optimizer.state_dict(), # optimizer state
482
+ { # metadata saved as json
483
+ "step": step,
484
+ "val_bpb": val_bpb, # loss at last step
485
+ "model_config": model_config_kwargs,
486
+ "user_config": user_config, # inputs to the training script
487
+ "device_batch_size": args.device_batch_size,
488
+ "max_seq_len": args.max_seq_len,
489
+ "total_batch_size": total_batch_size,
490
+ "dataloader_state_dict": dataloader_state_dict,
491
+ "loop_state": { # all loop state (other than step) so that we can resume training
492
+ "min_val_bpb": min_val_bpb,
493
+ "smooth_train_loss": smooth_train_loss,
494
+ "total_training_time": total_training_time,
495
+ },
496
+ },
497
+ rank=ddp_rank,
498
+ )
499
+
500
+ # termination conditions (TODO: possibly also add loss explosions etc.)
501
+ if last_step:
502
+ break
503
+
504
+ # -------------------------------------------------------------------------
505
+ # single training step
506
+ # evaluate the gradient
507
+ synchronize()
508
+ t0 = time.time()
509
+ for micro_step in range(grad_accum_steps):
510
+ loss = model(x, y)
511
+ train_loss = loss.detach() # for logging
512
+ loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
513
+ if scaler is not None:
514
+ scaler.scale(loss).backward()
515
+ else:
516
+ loss.backward()
517
+ x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
518
+ # step the optimizer
519
+ lrm = get_lr_multiplier(step)
520
+ muon_momentum = get_muon_momentum(step)
521
+ muon_weight_decay = get_weight_decay(step)
522
+ for group in optimizer.param_groups:
523
+ group["lr"] = group["initial_lr"] * lrm
524
+ if group['kind'] == 'muon':
525
+ group["momentum"] = muon_momentum
526
+ group["weight_decay"] = muon_weight_decay
527
+ if scaler is not None:
528
+ scaler.unscale_(optimizer)
529
+ # In distributed training, all ranks must agree on whether to skip the step.
530
+ # Each rank may independently encounter inf/nan gradients, so we all-reduce
531
+ # the found_inf flag (MAX = if any rank found inf, all ranks skip).
532
+ if is_ddp_initialized():
533
+ for v in scaler._found_inf_per_device(optimizer).values():
534
+ dist.all_reduce(v, op=dist.ReduceOp.MAX)
535
+ scaler.step(optimizer)
536
+ scaler.update()
537
+ else:
538
+ optimizer.step()
539
+ model.zero_grad(set_to_none=True)
540
+ train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
541
+ synchronize()
542
+ t1 = time.time()
543
+ dt = t1 - t0
544
+ # -------------------------------------------------------------------------
545
+
546
+ # logging (CPU action only)
547
+ ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
548
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
549
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
550
+ pct_done = 100 * step / num_iterations
551
+ tok_per_sec = int(total_batch_size / dt)
552
+ flops_per_sec = num_flops_per_token * total_batch_size / dt
553
+ mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
554
+ if step > 10:
555
+ total_training_time += dt # only count the time after the first 10 steps
556
+ # Calculate ETA based on average time per step (excluding first 10 steps)
557
+ steps_done = step - 10
558
+ if steps_done > 0:
559
+ avg_time_per_step = total_training_time / steps_done
560
+ remaining_steps = num_iterations - step
561
+ eta_seconds = remaining_steps * avg_time_per_step
562
+ eta_str = f" | eta: {eta_seconds/60:.1f}m"
563
+ else:
564
+ eta_str = ""
565
+ epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
566
+ print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
567
+ if step % 100 == 0:
568
+ log_data = {
569
+ "step": step,
570
+ "total_training_flops": flops_so_far,
571
+ "total_training_time": total_training_time,
572
+ "train/loss": debiased_smooth_loss,
573
+ "train/lrm": lrm,
574
+ "train/dt": dt,
575
+ "train/tok_per_sec": tok_per_sec,
576
+ "train/mfu": mfu,
577
+ "train/epoch": epoch,
578
+ }
579
+ wandb_run.log(log_data)
580
+
581
+ # state update
582
+ first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
583
+ step += 1
584
+
585
+ # The garbage collector is sadly a little bit overactive and for some poorly understood reason,
586
+ # it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
587
+ # So we manually manage and help it out here
588
+ if first_step_of_run:
589
+ gc.collect() # manually collect a lot of garbage from setup
590
+ gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC
591
+ gc.disable() # nuclear intervention here: disable GC entirely except:
592
+ elif step % 5000 == 0: # every 5000 steps...
593
+ gc.collect() # manually collect, just to be safe for very, very long runs
594
+
595
+ # print a few more stats
596
+ print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
597
+ print0(f"Total training time: {total_training_time/60:.2f}m")
598
+ if val_bpb is not None:
599
+ print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
600
+
601
+ # Log to report
602
+ from nanochat.report import get_report
603
+ get_report().log(section="Base model training", data=[
604
+ user_config, # CLI args
605
+ { # stats about the training setup
606
+ "Number of parameters": num_params,
607
+ "Number of FLOPs per token": f"{num_flops_per_token:e}",
608
+ "Calculated number of iterations": num_iterations,
609
+ "Number of training tokens": total_tokens,
610
+ "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params,
611
+ "DDP world size": ddp_world_size,
612
+ "warmup_steps": args.warmup_steps,
613
+ "warmdown_ratio": args.warmdown_ratio,
614
+ "final_lr_frac": args.final_lr_frac,
615
+ },
616
+ { # stats about training outcomes
617
+ "Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
618
+ "Final validation bpb": val_bpb,
619
+ "CORE metric estimate": results.get("core_metric", None),
620
+ "MFU %": f"{mfu:.2f}%",
621
+ "Total training flops": f"{flops_so_far:e}",
622
+ "Total training time": f"{total_training_time/60:.2f}m",
623
+ "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
624
+ }
625
+ ])
626
+
627
+ # cleanup
628
+ wandb_run.finish() # wandb run finish
629
+ compute_cleanup()
scripts/chat_cli.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ New and upgraded chat mode because a lot of the code has changed since the last one.
3
+
4
+ Intended to be run single GPU only atm:
5
+ python -m scripts.chat_cli
6
+ """
7
+ import argparse
8
+ import torch
9
+ from nanochat.common import compute_init, autodetect_device_type
10
+ from nanochat.engine import Engine
11
+ from nanochat.checkpoint_manager import load_model
12
+
13
+ parser = argparse.ArgumentParser(description='Chat with the model')
14
+ parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
15
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
16
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
17
+ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
18
+ parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
19
+ parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
20
+ parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
21
+ args = parser.parse_args()
22
+
23
+ # Init the model and tokenizer
24
+
25
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
26
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
27
+ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
28
+
29
+ # Special tokens for the chat state machine
30
+ bos = tokenizer.get_bos_token_id()
31
+ user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
32
+ assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
33
+
34
+ # Create Engine for efficient generation
35
+ engine = Engine(model, tokenizer)
36
+
37
+ print("\nNanoChat Interactive Mode")
38
+ print("-" * 50)
39
+ print("Type 'quit' or 'exit' to end the conversation")
40
+ print("Type 'clear' to start a new conversation")
41
+ print("-" * 50)
42
+
43
+ conversation_tokens = [bos]
44
+
45
+ while True:
46
+
47
+ if args.prompt:
48
+ # Get the prompt from the launch command
49
+ user_input = args.prompt
50
+ else:
51
+ # Get the prompt interactively from the console
52
+ try:
53
+ user_input = input("\nUser: ").strip()
54
+ except (EOFError, KeyboardInterrupt):
55
+ print("\nGoodbye!")
56
+ break
57
+
58
+ # Handle special commands
59
+ if user_input.lower() in ['quit', 'exit']:
60
+ print("Goodbye!")
61
+ break
62
+
63
+ if user_input.lower() == 'clear':
64
+ conversation_tokens = [bos]
65
+ print("Conversation cleared.")
66
+ continue
67
+
68
+ if not user_input:
69
+ continue
70
+
71
+ # Add User message to the conversation
72
+ conversation_tokens.append(user_start)
73
+ conversation_tokens.extend(tokenizer.encode(user_input))
74
+ conversation_tokens.append(user_end)
75
+
76
+ # Kick off the assistant
77
+ conversation_tokens.append(assistant_start)
78
+ generate_kwargs = {
79
+ "num_samples": 1,
80
+ "max_tokens": 256,
81
+ "temperature": args.temperature,
82
+ "top_k": args.top_k,
83
+ }
84
+ response_tokens = []
85
+ print("\nAssistant: ", end="", flush=True)
86
+ for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
87
+ token = token_column[0] # pop the batch dimension (num_samples=1)
88
+ response_tokens.append(token)
89
+ token_text = tokenizer.decode([token])
90
+ print(token_text, end="", flush=True)
91
+ print()
92
+ # we have to ensure that the assistant end token is the last token
93
+ # so even if generation ends due to max tokens, we have to append it to the end
94
+ if response_tokens[-1] != assistant_end:
95
+ response_tokens.append(assistant_end)
96
+ conversation_tokens.extend(response_tokens)
97
+
98
+ # In the prompt mode, we only want a single response and exit
99
+ if args.prompt:
100
+ break
scripts/chat_eval.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate the Chat model.
3
+ All the generic code lives here, and all the evaluation-specific
4
+ code lives in nanochat directory and is imported from here.
5
+
6
+ Example runs:
7
+ python -m scripts.chat_eval -a ARC-Easy
8
+ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
9
+ """
10
+
11
+ import argparse
12
+ from functools import partial
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
17
+ from nanochat.checkpoint_manager import load_model
18
+ from nanochat.engine import Engine
19
+
20
+ from tasks.humaneval import HumanEval
21
+ from tasks.mmlu import MMLU
22
+ from tasks.arc import ARC
23
+ from tasks.gsm8k import GSM8K
24
+ from tasks.spellingbee import SpellingBee
25
+
26
+ # -----------------------------------------------------------------------------
27
+ # Generative evaluation loop (we go one problem at a time, sample, evaluate)
28
+
29
+ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
30
+
31
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
32
+ device = model.get_device()
33
+
34
+ num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
35
+
36
+ # Run the evaluation
37
+ num_passed, total = 0, 0
38
+ for i in range(ddp_rank, num_problems, ddp_world_size):
39
+ conversation = task_object[i]
40
+
41
+ # Tokenize the prompt
42
+ encoded_prompt = tokenizer.render_for_completion(conversation)
43
+ # Get the completions
44
+ results, _ = engine.generate_batch(
45
+ encoded_prompt,
46
+ num_samples=num_samples,
47
+ max_tokens=max_new_tokens,
48
+ temperature=temperature,
49
+ top_k=top_k,
50
+ )
51
+ # Decode the completions as text
52
+ prefix_length = len(encoded_prompt)
53
+ completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
54
+ # Evaluate success criteria
55
+ outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
56
+ passed = any(outcomes)
57
+
58
+ # Keep stats
59
+ total += 1
60
+ num_passed += int(passed)
61
+
62
+ # Logging (overwrite the same line in the console)
63
+ print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
64
+
65
+ # Finish the in-place progress line with a newline before final summary
66
+ print()
67
+
68
+ # Aggregate results across all ranks
69
+ if ddp:
70
+ num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
71
+ total_tensor = torch.tensor([total], dtype=torch.long, device=device)
72
+ dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
73
+ dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
74
+ num_passed = num_passed_tensor.item()
75
+ total = total_tensor.item()
76
+
77
+ print0("=" * 50)
78
+ print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
79
+
80
+ # Return the accuracy
81
+ return num_passed/total
82
+
83
+ # -----------------------------------------------------------------------------
84
+ # Categorical evaluation loop
85
+ # A lot easier because we don't have to sample. Therefore, we can actually go
86
+ # batches at a time and just check the logits for correct answer choices.
87
+
88
+ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
89
+
90
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
91
+ device = model.get_device()
92
+ bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
93
+
94
+ # We'll process batches of independent problems at a time because there is no sampling needed
95
+ num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
96
+ ceil_div = lambda x, y: -(-x // y)
97
+ num_batches = ceil_div(num_problems, batch_size)
98
+
99
+ # Run the evaluation
100
+ letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
101
+ num_passed, total = 0, 0
102
+ for i in range(ddp_rank, num_batches, ddp_world_size):
103
+ i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
104
+
105
+ # Prepare the batch of problems. They might all be of different length, so we pad/collate them.
106
+ conversations = [task_object[ii] for ii in range(i0, i1)]
107
+ prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
108
+ max_length = max(len(ids) for ids in prompt_ids)
109
+ answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
110
+ padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
111
+ prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
112
+
113
+ # Get the logits for the whole batch of conversations in parallel (efficiency win here)
114
+ with torch.no_grad():
115
+ logits = model(prompt_ids) # (B, T, V)
116
+
117
+ # Focus on the available answer on just the letters corresponding to choices
118
+ # Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
119
+ # The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
120
+ # letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
121
+ for idx, conversation in enumerate(conversations):
122
+ # get the token ids of all the available letters of this problem
123
+ letters = conversation['letters']
124
+ letter_ids = []
125
+ for letter in letters:
126
+ if not letter in letter_to_id_cache:
127
+ encoded_letter = tokenizer.encode(letter)
128
+ assert len(encoded_letter) == 1, "Each letter must be a single token"
129
+ letter_to_id_cache[letter] = encoded_letter[0]
130
+ letter_ids.append(letter_to_id_cache[letter])
131
+ # focus logits just down to the answer position and the available letters of the answer
132
+ answer_pos = answer_time_positions[idx]
133
+ focus_logits = logits[idx, answer_pos, letter_ids]
134
+ # get the argmax letter (the predicted answer)
135
+ argmax_letter_id = focus_logits.argmax(dim=-1).item()
136
+ predicted_letter = letters[argmax_letter_id]
137
+ # evaluate the outcome
138
+ outcome = task_object.evaluate(conversation, predicted_letter)
139
+ num_passed += int(outcome)
140
+ total += 1
141
+
142
+ # Aggregate results across all ranks
143
+ if ddp:
144
+ num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
145
+ total_tensor = torch.tensor([total], dtype=torch.long, device=device)
146
+ dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
147
+ dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
148
+ num_passed = num_passed_tensor.item()
149
+ total = total_tensor.item()
150
+
151
+ average = num_passed/total
152
+ print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
153
+ return average
154
+
155
+ # -----------------------------------------------------------------------------
156
+
157
+ def run_chat_eval(task_name, model, tokenizer, engine,
158
+ batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
159
+ max_problems=None):
160
+ # Create the evaluation object
161
+ task_module = {
162
+ 'HumanEval': HumanEval,
163
+ 'MMLU': partial(MMLU, subset="all", split="test"),
164
+ 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
165
+ 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
166
+ 'GSM8K': partial(GSM8K, subset="main", split="test"),
167
+ 'SpellingBee': partial(SpellingBee, size=256, split="test"),
168
+ }[task_name]
169
+ task_object = task_module()
170
+ # Run the evaluation
171
+ if task_object.eval_type == 'generative':
172
+ acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
173
+ elif task_object.eval_type == 'categorical':
174
+ acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
175
+ else:
176
+ raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
177
+ return acc
178
+
179
+ # -----------------------------------------------------------------------------
180
+ if __name__ == "__main__":
181
+
182
+ # Parse command-line arguments
183
+ parser = argparse.ArgumentParser()
184
+ parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
185
+ parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
186
+ parser.add_argument('-t', '--temperature', type=float, default=0.0)
187
+ parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
188
+ parser.add_argument('-n', '--num-samples', type=int, default=1)
189
+ parser.add_argument('-k', '--top-k', type=int, default=50)
190
+ parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
191
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
192
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
193
+ parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
194
+ parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
195
+ args = parser.parse_args()
196
+
197
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
198
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
199
+
200
+ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
201
+ engine = Engine(model, tokenizer)
202
+
203
+ # Get the tasks to evaluate on
204
+ all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
205
+ baseline_accuracies = {
206
+ 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
207
+ 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
208
+ 'MMLU': 0.25, # multiple choice 1 of 4 => 25%
209
+ 'GSM8K': 0.0, # open-ended => 0%
210
+ 'HumanEval': 0.0, # open-ended => 0%
211
+ 'SpellingBee': 0.0, # open-ended => 0%
212
+ }
213
+ task_names = all_tasks if args.task_name is None else args.task_name.split('|')
214
+
215
+ # Run all the task evaluations sequentially
216
+ results = {}
217
+ for task_name in task_names:
218
+ acc = run_chat_eval(
219
+ task_name,
220
+ model, tokenizer, engine,
221
+ batch_size=args.batch_size,
222
+ num_samples=args.num_samples,
223
+ max_new_tokens=args.max_new_tokens,
224
+ temperature=args.temperature,
225
+ top_k=args.top_k,
226
+ max_problems=args.max_problems,
227
+ )
228
+ results[task_name] = acc
229
+ print0(f"{task_name} accuracy: {100 * acc:.2f}%")
230
+
231
+ # Log to report
232
+ from nanochat.report import get_report
233
+ all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
234
+ # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
235
+ # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
236
+ chatcore_metric_dict = {}
237
+ if all_tasks_were_evaluated:
238
+ centered_mean = 0
239
+ for task_name, acc in results.items():
240
+ baseline_acc = baseline_accuracies.get(task_name, 0.0)
241
+ centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc)
242
+ centered_mean += centered_acc
243
+ chatcore_metric = centered_mean / len(results)
244
+ chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
245
+ get_report().log(section="Chat evaluation " + args.source, data=[
246
+ vars(args), # CLI args
247
+ results,
248
+ chatcore_metric_dict,
249
+ ])
250
+
251
+ compute_cleanup()
scripts/chat_rl.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reinforcement learning on GSM8K via "GRPO".
3
+
4
+ I put GRPO in quotes because we actually end up with something a lot
5
+ simpler and more similar to just REINFORCE:
6
+
7
+ 1) Delete trust region, so there is no KL regularization to a reference model
8
+ 2) We are on policy, so there's no need for PPO ratio+clip.
9
+ 3) We use DAPO style normalization that is token-level, not sequence-level.
10
+ 4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
11
+
12
+ 1 GPU:
13
+ python -m scripts.chat_rl
14
+
15
+ 8 GPUs:
16
+ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
17
+ """
18
+
19
+ import argparse
20
+ import os
21
+ import itertools
22
+ import wandb
23
+ import torch
24
+ import torch.distributed as dist
25
+ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
26
+ from nanochat.checkpoint_manager import save_checkpoint, load_model
27
+ from nanochat.engine import Engine
28
+ from tasks.gsm8k import GSM8K
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # CLI arguments
32
+ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
33
+ # Logging
34
+ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
35
+ # Runtime
36
+ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
37
+ # Model loading
38
+ parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
39
+ parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
40
+ # Training horizon
41
+ parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
42
+ # Batch sizes / sampling
43
+ parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
44
+ parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
45
+ parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
46
+ # Generation
47
+ parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
48
+ parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
49
+ parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
50
+ # Optimization
51
+ parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
52
+ parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
53
+ parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
54
+ parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
55
+ parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
56
+ # Evaluation / checkpointing
57
+ parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
58
+ parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
59
+ parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
60
+ args = parser.parse_args()
61
+ user_config = vars(args).copy()
62
+ # -----------------------------------------------------------------------------
63
+
64
+ # Init compute/precision
65
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
66
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
67
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
68
+
69
+ # wandb logging init
70
+ use_dummy_wandb = args.run == "dummy" or not master_process
71
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
72
+
73
+ # Init model and tokenizer
74
+ model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
75
+ engine = Engine(model, tokenizer) # for sampling rollouts
76
+
77
+ # -----------------------------------------------------------------------------
78
+ # Rollout / sampling generator loop that yields batches of examples for training
79
+
80
+ train_task = GSM8K(subset="main", split="train")
81
+ val_task = GSM8K(subset="main", split="test")
82
+ num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
83
+ print0(f"Calculated number of steps: {num_steps}")
84
+
85
+ @torch.no_grad()
86
+ def get_batch():
87
+ assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
88
+ rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
89
+ for example_idx in itertools.cycle(rank_indices):
90
+
91
+ # First get the full conversation of both user and assistant messages
92
+ conversation = train_task[example_idx]
93
+
94
+ # Tokenize the conversation, deleting the last Assistant message and priming the Assistant for a completion instead
95
+ # (i.e. keep the <|assistant_start|>, but delete everything after it)
96
+ tokens = tokenizer.render_for_completion(conversation)
97
+ prefix_length = len(tokens)
98
+
99
+ # Generate num_samples samples using batched generation, use loop to avoid OOMs
100
+ model.eval() # ensure the model is in eval mode
101
+ generated_token_sequences = []
102
+ masks = []
103
+ num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
104
+ for sampling_step in range(num_sampling_steps):
105
+ seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
106
+ generated_token_sequences_batch, masks_batch = engine.generate_batch(
107
+ tokens,
108
+ num_samples=args.device_batch_size,
109
+ max_tokens=args.max_new_tokens,
110
+ temperature=args.temperature,
111
+ top_k=args.top_k,
112
+ seed=seed, # must make sure to change the seed for each sampling step
113
+ )
114
+ generated_token_sequences.extend(generated_token_sequences_batch)
115
+ masks.extend(masks_batch)
116
+
117
+ # Calculate the rewards for each sample
118
+ rewards = []
119
+ for sample_tokens in generated_token_sequences:
120
+ # Get just the generated tokens (after the prompt)
121
+ generated_tokens = sample_tokens[prefix_length:]
122
+ # Decode the generated response
123
+ generated_text = tokenizer.decode(generated_tokens)
124
+ # Calculate the reward
125
+ reward = train_task.reward(conversation, generated_text)
126
+ rewards.append(reward)
127
+
128
+ # Pad the sequences so that their lengths (in time) match
129
+ max_length = max(len(seq) for seq in generated_token_sequences)
130
+ padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
131
+ padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
132
+ # Stack up the sequences and masks into PyTorch tensors
133
+ ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
134
+ mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
135
+ # Generate autoregressive inputs and targets to the Transformer
136
+ inputs = ids[:, :-1]
137
+ targets = ids[:, 1:].clone() # clone to avoid in-place modification:
138
+ targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
139
+ # NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
140
+ # So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
141
+ rewards = torch.tensor(rewards, dtype=torch.float, device=device)
142
+ # Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
143
+ mu = rewards.mean()
144
+ advantages = rewards - mu
145
+ # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
146
+ yield generated_token_sequences, inputs, targets, rewards, advantages
147
+
148
+ # -----------------------------------------------------------------------------
149
+ # Simple evaluation loop for GSM8K pass@k
150
+ def run_gsm8k_eval(task, tokenizer, engine,
151
+ max_examples=None,
152
+ num_samples=1,
153
+ max_completion_tokens=256,
154
+ temperature=0.0,
155
+ top_k=50
156
+ ):
157
+ """
158
+ Evaluates GSM8K task and returns a list of records of evaluation outcomes.
159
+ In a distributed setting, all ranks cooperate but this function will NOT
160
+ do the reduction across ranks. This is the responsibility of the caller.
161
+ Because the evaluation can take a while, this function will yield records one by one.
162
+ """
163
+ max_examples = min(max_examples, len(task)) if max_examples is not None else len(task)
164
+ for idx in range(ddp_rank, max_examples, ddp_world_size):
165
+ conversation = task[idx]
166
+ tokens = tokenizer.render_for_completion(conversation)
167
+ prefix_length = len(tokens)
168
+ # Generate k samples using batched generation inside the Engine
169
+ assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
170
+ generated_token_sequences, masks = engine.generate_batch(
171
+ tokens,
172
+ num_samples=num_samples,
173
+ max_tokens=max_completion_tokens,
174
+ temperature=temperature,
175
+ top_k=top_k
176
+ )
177
+ # Check each sample for correctness
178
+ outcomes = []
179
+ for sample_tokens in generated_token_sequences:
180
+ generated_tokens = sample_tokens[prefix_length:]
181
+ generated_text = tokenizer.decode(generated_tokens)
182
+ is_correct = task.evaluate(conversation, generated_text)
183
+ outcomes.append({
184
+ "is_correct": is_correct
185
+ })
186
+ # A bit bloated because I wanted to do more complex logging at one point.
187
+ record = {
188
+ "idx": idx,
189
+ "outcomes": outcomes,
190
+ }
191
+ yield record
192
+
193
+ # -----------------------------------------------------------------------------
194
+ # Training loop
195
+
196
+ # Init the optimizer
197
+ optimizer = model.setup_optimizer(
198
+ unembedding_lr=args.unembedding_lr,
199
+ embedding_lr=args.embedding_lr,
200
+ matrix_lr=args.matrix_lr,
201
+ weight_decay=args.weight_decay,
202
+ )
203
+
204
+ # Set the initial learning rate as a fraction of the base learning rate
205
+ for group in optimizer.param_groups:
206
+ group["lr"] = group["lr"] * args.init_lr_frac
207
+ group["initial_lr"] = group["lr"]
208
+
209
+ # Learning rate scheduler: simple rampdown to zero over num_steps
210
+ def get_lr_multiplier(it):
211
+ lrm = 1.0 - it / num_steps
212
+ return lrm
213
+
214
+ # Calculate the number of examples each rank handles to achieve the desired examples_per_step
215
+ print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
216
+ assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
217
+ examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
218
+ print0(f"Calculated examples per rank: {examples_per_rank}")
219
+
220
+ # Kick off the training loop
221
+ batch_iterator = get_batch()
222
+ for step in range(num_steps):
223
+
224
+ # Evaluate the model once in a while and log to wandb
225
+ if step % args.eval_every == 0:
226
+ model.eval()
227
+ passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
228
+ records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
229
+ records = list(records_iter) # collect all records
230
+ for k in range(1, args.device_batch_size + 1):
231
+ passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
232
+ num_records = torch.tensor(len(records), dtype=torch.long, device=device)
233
+ if ddp:
234
+ dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
235
+ dist.all_reduce(passk, op=dist.ReduceOp.SUM)
236
+ passk = passk / num_records.item() # normalize by the total number of records
237
+ print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
238
+ print0(f"Step {step} | {', '.join(print_passk)}")
239
+ log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
240
+ wandb_run.log({
241
+ "step": step,
242
+ **log_passk,
243
+ })
244
+
245
+ # Forward/Backward on rollouts over multiple examples in the dataset
246
+ rewards_list = []
247
+ sequence_lengths = []
248
+ for example_step in range(examples_per_rank):
249
+ # Get one batch corresponding to one example in the training dataset
250
+ sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
251
+ # Evaluate the loss and gradients
252
+ model.train() # ensure the model is in train mode
253
+ # We need one more loop because we can never exceed the device_batch_size
254
+ assert inputs_all.size(0) % args.device_batch_size == 0
255
+ num_passes = inputs_all.size(0) // args.device_batch_size
256
+ for pass_idx in range(num_passes):
257
+ # Pluck out the batch for this pass
258
+ b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
259
+ inputs = inputs_all[b0:b1]
260
+ targets = targets_all[b0:b1]
261
+ rewards = rewards_all[b0:b1]
262
+ advantages = advantages_all[b0:b1]
263
+ # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
264
+ logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
265
+ # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
266
+ pg_obj = (logp * advantages.unsqueeze(-1)).sum()
267
+ # normalize by the number of valid tokens, number of passes, and examples_per_rank
268
+ num_valid = (targets >= 0).sum().clamp(min=1)
269
+ pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
270
+ # Note, there is no need to add PPO ratio+clip because we are on policy
271
+ # Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
272
+ loss = -pg_obj
273
+ loss.backward()
274
+ print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
275
+ # For logging
276
+ rewards_list.append(rewards_all.mean().item())
277
+ sequence_lengths.extend(len(seq) for seq in sequences_all)
278
+
279
+ # A bunch of logging for how the rollouts went this step
280
+ mean_reward = sum(rewards_list) / len(rewards_list)
281
+ mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
282
+ if ddp: # aggregate across ranks
283
+ mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
284
+ mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
285
+ dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
286
+ dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
287
+ mean_reward = mean_reward_tensor.item()
288
+ mean_sequence_length = mean_sequence_length_tensor.item()
289
+ print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
290
+ wandb_run.log({
291
+ "step": step,
292
+ "reward": mean_reward,
293
+ "sequence_length": mean_sequence_length,
294
+ })
295
+
296
+ # Update the model parameters
297
+ lrm = get_lr_multiplier(step)
298
+ for group in optimizer.param_groups:
299
+ group["lr"] = group["initial_lr"] * lrm
300
+ optimizer.step()
301
+ model.zero_grad(set_to_none=True)
302
+ wandb_run.log({
303
+ "step": step,
304
+ "lrm": lrm,
305
+ })
306
+
307
+ # Master process saves the model once in a while. Skip first step. Save last step.
308
+ if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
309
+ base_dir = get_base_dir()
310
+ depth = model.config.n_layer
311
+ output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
312
+ checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
313
+ model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
314
+ save_checkpoint(
315
+ checkpoint_dir,
316
+ step,
317
+ model.state_dict(),
318
+ None, # note: we don't bother to save the optimizer state
319
+ {
320
+ "model_config": model_config_kwargs,
321
+ }
322
+ )
323
+ print(f"✅ Saved model checkpoint to {checkpoint_dir}")
324
+
325
+ # Log to report
326
+ from nanochat.report import get_report
327
+ get_report().log(section="Chat RL", data=[
328
+ user_config, # CLI args
329
+ ])
330
+
331
+ wandb_run.finish() # wandb run finish
332
+ compute_cleanup()
scripts/chat_sft.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supervised fine-tuning (SFT) the model.
3
+ Run as:
4
+
5
+ python -m scripts.chat_sft
6
+
7
+ Or torchrun for training:
8
+
9
+ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
10
+ """
11
+
12
+ import gc
13
+ import argparse
14
+ import os
15
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
16
+ import time
17
+ import wandb
18
+ import torch
19
+ from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
20
+ from nanochat.tokenizer import get_token_bytes
21
+ from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
22
+ from nanochat.loss_eval import evaluate_bpb
23
+ import torch.distributed as dist
24
+ from nanochat.flash_attention import HAS_FA3
25
+ from nanochat.engine import Engine
26
+ from scripts.chat_eval import run_chat_eval
27
+
28
+ from tasks.common import TaskMixture
29
+ from tasks.gsm8k import GSM8K
30
+ from tasks.mmlu import MMLU
31
+ from tasks.smoltalk import SmolTalk
32
+ from tasks.customjson import CustomJSON
33
+ from tasks.spellingbee import SimpleSpelling, SpellingBee
34
+
35
+ # -----------------------------------------------------------------------------
36
+ # CLI arguments
37
+ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
38
+ # Logging
39
+ parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
40
+ # Runtime
41
+ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
42
+ # Model loading
43
+ parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
44
+ parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
45
+ parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
46
+ # Training horizon
47
+ parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
48
+ # Batch sizes (default: inherit from pretrained checkpoint)
49
+ parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
50
+ parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
51
+ parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
52
+ # Optimization (default: inherit from pretrained checkpoint)
53
+ parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
54
+ parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
55
+ parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
56
+ parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
57
+ parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
58
+ parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
59
+ parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
60
+ # Evaluation
61
+ parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
62
+ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
63
+ parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
64
+ parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
65
+ parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
66
+ # Data mixture
67
+ parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
68
+ parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
69
+ args = parser.parse_args()
70
+ user_config = vars(args).copy()
71
+ # -----------------------------------------------------------------------------
72
+
73
+ # Compute init
74
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
75
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
76
+ master_process = ddp_rank == 0
77
+ print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
78
+ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
79
+ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
80
+ if device_type == "cuda":
81
+ gpu_device_name = torch.cuda.get_device_name(0)
82
+ gpu_peak_flops = get_peak_flops(gpu_device_name)
83
+ print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
84
+ else:
85
+ gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
86
+
87
+ # wandb logging init
88
+ use_dummy_wandb = args.run == "dummy" or not master_process
89
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
90
+
91
+ # Flash Attention status
92
+ if not HAS_FA3:
93
+ print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
94
+
95
+ # Load the model and tokenizer
96
+ model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
97
+
98
+ # Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
99
+ pretrain_user_config = meta.get("user_config", {})
100
+ for name, fallback, source in [
101
+ ("max_seq_len", 2048, meta),
102
+ ("device_batch_size", 32, meta),
103
+ ("total_batch_size", 524288, meta),
104
+ ("embedding_lr", 0.3, pretrain_user_config),
105
+ ("unembedding_lr", 0.004, pretrain_user_config),
106
+ ("matrix_lr", 0.02, pretrain_user_config),
107
+ ]:
108
+ arg_val = getattr(args, name)
109
+ pretrain_val = source.get(name)
110
+ if arg_val is None:
111
+ resolved = pretrain_val if pretrain_val is not None else fallback
112
+ setattr(args, name, resolved)
113
+ print0(f"Inherited {name}={resolved} from pretrained checkpoint")
114
+ elif pretrain_val is not None and arg_val != pretrain_val:
115
+ print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
116
+ else:
117
+ print0(f"Using {name}={arg_val}")
118
+
119
+ orig_model = model
120
+ model = torch.compile(model, dynamic=False)
121
+ depth = model.config.n_layer
122
+ num_flops_per_token = model.estimate_flops()
123
+ tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
124
+ world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
125
+ assert args.total_batch_size % world_tokens_per_fwdbwd == 0
126
+ grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
127
+ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
128
+ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
129
+ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
130
+ token_bytes = get_token_bytes(device=device)
131
+
132
+ # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
133
+ # Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
134
+ optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
135
+
136
+ # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
137
+ # Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
138
+ # pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
139
+ # restore our fresh SFT LRs after loading.
140
+ base_dir = get_base_dir()
141
+ if args.load_optimizer:
142
+ optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
143
+ if optimizer_data is not None:
144
+ base_lrs = [group["lr"] for group in optimizer.param_groups]
145
+ optimizer.load_state_dict(optimizer_data)
146
+ del optimizer_data
147
+ for group, base_lr in zip(optimizer.param_groups, base_lrs):
148
+ group["lr"] = base_lr
149
+ print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
150
+ else:
151
+ print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
152
+
153
+ # GradScaler for fp16 training (bf16/fp32 don't need it)
154
+ scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
155
+ if scaler is not None:
156
+ print0("GradScaler enabled for fp16 training")
157
+
158
+ # Override the initial learning rate as a fraction of the base learning rate
159
+ for group in optimizer.param_groups:
160
+ group["lr"] = group["lr"] * args.init_lr_frac
161
+ group["initial_lr"] = group["lr"]
162
+
163
+ # SFT data mixture and DataLoader
164
+ identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
165
+ train_tasks = [
166
+ SmolTalk(split="train"), # 460K rows of general conversations
167
+ CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
168
+ CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
169
+ *[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
170
+ *[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
171
+ SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
172
+ SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
173
+ ]
174
+ train_dataset = TaskMixture(train_tasks)
175
+ print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
176
+ val_dataset = TaskMixture([
177
+ SmolTalk(split="test"), # 24K rows in test set
178
+ MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
179
+ GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
180
+ ]) # total: 24K + 14K + 1.32K ~= 39K rows
181
+ # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
182
+ # A big problem is that we don't know the final num_iterations in advance. So we create
183
+ # these two global variables and update them from within the data generator.
184
+ last_step = False # we will toggle this to True when we reach the end of the training dataset
185
+ approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
186
+ current_epoch = 1 # track epoch for logging
187
+ def sft_data_generator_bos_bestfit(split, buffer_size=100):
188
+ """
189
+ BOS-aligned dataloader for SFT with bestfit-pad packing.
190
+
191
+ Each row in the batch starts with BOS (beginning of a conversation).
192
+ Conversations are packed using best-fit algorithm. When no conversation fits,
193
+ the row is padded (instead of cropping) to ensure no tokens are ever discarded.
194
+ Padding positions have targets masked with -1 (ignore_index for cross-entropy).
195
+ """
196
+ global last_step, approx_progress, current_epoch
197
+ assert split in {"train", "val"}, "split must be 'train' or 'val'"
198
+ dataset = train_dataset if split == "train" else val_dataset
199
+ dataset_size = len(dataset)
200
+ assert dataset_size > 0
201
+ row_capacity = args.max_seq_len + 1 # +1 for target at last position
202
+ bos_token = tokenizer.get_bos_token_id()
203
+
204
+ # Conversation buffer: list of (token_ids, loss_mask) tuples
205
+ conv_buffer = []
206
+ cursor = ddp_rank # Each rank processes different conversations (for fetching)
207
+ consumed = ddp_rank # Track actual consumption separately from buffering
208
+ epoch = 1
209
+ it = 0 # iteration counter
210
+
211
+ def refill_buffer():
212
+ nonlocal cursor, epoch
213
+ while len(conv_buffer) < buffer_size:
214
+ conversation = dataset[cursor]
215
+ ids, mask = tokenizer.render_conversation(conversation)
216
+ conv_buffer.append((ids, mask))
217
+ cursor += ddp_world_size
218
+ if cursor >= dataset_size:
219
+ cursor = cursor % dataset_size
220
+ epoch += 1
221
+ # Note: last_step is now triggered based on consumption, not fetching
222
+
223
+ while True:
224
+ rows = []
225
+ mask_rows = []
226
+ row_lengths = [] # Track actual content length (excluding padding) for each row
227
+ for _ in range(args.device_batch_size):
228
+ row = []
229
+ mask_row = []
230
+ padded = False
231
+ while len(row) < row_capacity:
232
+ # Ensure buffer has conversations
233
+ while len(conv_buffer) < buffer_size:
234
+ refill_buffer()
235
+
236
+ remaining = row_capacity - len(row)
237
+
238
+ # Find largest conversation that fits entirely
239
+ best_idx = -1
240
+ best_len = 0
241
+ for i, (conv, _) in enumerate(conv_buffer):
242
+ conv_len = len(conv)
243
+ if conv_len <= remaining and conv_len > best_len:
244
+ best_idx = i
245
+ best_len = conv_len
246
+
247
+ if best_idx >= 0:
248
+ # Found a conversation that fits - use it entirely
249
+ conv, conv_mask = conv_buffer.pop(best_idx)
250
+ row.extend(conv)
251
+ mask_row.extend(conv_mask)
252
+ consumed += ddp_world_size # Track actual consumption
253
+ else:
254
+ # No conversation fits - pad the remainder instead of cropping
255
+ # This ensures we never discard any tokens
256
+ content_len = len(row)
257
+ row.extend([bos_token] * remaining) # Pad with BOS tokens
258
+ mask_row.extend([0] * remaining)
259
+ padded = True
260
+ break # Row is now full (with padding)
261
+
262
+ # Track content length: full row if no padding, otherwise the length before padding
263
+ if padded:
264
+ row_lengths.append(content_len)
265
+ else:
266
+ row_lengths.append(row_capacity)
267
+ rows.append(row[:row_capacity])
268
+ mask_rows.append(mask_row[:row_capacity])
269
+
270
+ # Stopping condition to respect num_iterations, if given
271
+ it += 1
272
+ if 0 < args.num_iterations <= it and split == "train":
273
+ last_step = True
274
+
275
+ # Update progress tracking (based on consumed, not cursor, to account for buffering)
276
+ if split == "train":
277
+ current_epoch = epoch
278
+ if args.num_iterations > 0:
279
+ approx_progress = it / args.num_iterations
280
+ else:
281
+ approx_progress = consumed / dataset_size
282
+ # Trigger last_step when we've consumed enough (instead of when cursor wraps)
283
+ if consumed >= dataset_size:
284
+ last_step = True
285
+
286
+ # Build tensors
287
+ use_cuda = device_type == "cuda"
288
+ batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
289
+ inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
290
+ targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
291
+
292
+ # Apply the loss mask from render_conversation (mask=1 for assistant completions,
293
+ # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
294
+ # with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
295
+ mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
296
+ mask_targets = mask_tensor[:, 1:].to(device=device)
297
+ targets[mask_targets == 0] = -1
298
+
299
+ # Mask out padding positions in targets (set to -1 = ignore_index)
300
+ # For each row, positions >= (content_length - 1) in targets should be masked
301
+ for i, content_len in enumerate(row_lengths):
302
+ if content_len < row_capacity:
303
+ targets[i, content_len-1:] = -1
304
+
305
+ yield inputs, targets
306
+
307
+ train_loader = sft_data_generator_bos_bestfit("train")
308
+ build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
309
+ progress = 0 # will go from 0 to 1 over the course of the epoch
310
+
311
+ # Learning rate schedule (linear warmup, constant, linear warmdown)
312
+ # Same shape as base_train but uses progress (0→1) instead of absolute step counts,
313
+ # because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
314
+ def get_lr_multiplier(progress):
315
+ if progress < args.warmup_ratio:
316
+ return (progress + 1e-8) / args.warmup_ratio
317
+ elif progress <= 1.0 - args.warmdown_ratio:
318
+ return 1.0
319
+ else:
320
+ decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
321
+ return (1 - decay) * 1.0 + decay * args.final_lr_frac
322
+
323
+ # Momentum scheduler for Muon optimizer
324
+ def get_muon_momentum(it):
325
+ frac = min(it / 300, 1)
326
+ momentum = (1 - frac) * 0.85 + frac * 0.95
327
+ return momentum
328
+
329
+ # -----------------------------------------------------------------------------
330
+ # Training loop
331
+ x, y = next(train_loader) # prefetch the very first batch of data
332
+ min_val_bpb = float("inf")
333
+ smooth_train_loss = 0 # EMA of training loss
334
+ ema_beta = 0.9 # EMA decay factor
335
+ total_training_time = 0 # total wall-clock time of training
336
+ step = 0
337
+ while True:
338
+ flops_so_far = num_flops_per_token * args.total_batch_size * step
339
+
340
+ # Synchronize last_step across all ranks to avoid hangs in the distributed setting
341
+ if ddp:
342
+ last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
343
+ dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
344
+ last_step = bool(last_step_tensor.item())
345
+
346
+ # once in a while: evaluate the val bpb (all ranks participate)
347
+ if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
348
+ model.eval()
349
+ val_loader = build_val_loader()
350
+ eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
351
+ val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
352
+ print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
353
+ if val_bpb < min_val_bpb:
354
+ min_val_bpb = val_bpb
355
+ wandb_run.log({
356
+ "step": step,
357
+ "total_training_flops": flops_so_far,
358
+ "total_training_time": total_training_time,
359
+ "val/bpb": val_bpb,
360
+ })
361
+ model.train()
362
+
363
+ # once in a while: estimate the ChatCORE metric (all ranks participate)
364
+ # use the original uncompiled model because the inputs keep changing shape
365
+ chatcore_results = {}
366
+ if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
367
+ model.eval()
368
+ engine = Engine(orig_model, tokenizer)
369
+ all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
370
+ categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
371
+ baseline_accuracies = {
372
+ 'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
373
+ 'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
374
+ }
375
+ task_results = {}
376
+ for task_name in all_tasks:
377
+ limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
378
+ max_problems = None if limit < 0 else limit # -1 means no limit
379
+ acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
380
+ batch_size=args.device_batch_size, max_problems=max_problems)
381
+ task_results[task_name] = acc
382
+ print0(f" {task_name}: {100*acc:.2f}%")
383
+ # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
384
+ def centered_mean(tasks):
385
+ return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
386
+ chatcore = centered_mean(all_tasks)
387
+ chatcore_cat = centered_mean(categorical_tasks)
388
+ print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
389
+ wandb_run.log({
390
+ "step": step,
391
+ "total_training_flops": flops_so_far,
392
+ "chatcore_metric": chatcore,
393
+ "chatcore_cat": chatcore_cat,
394
+ **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
395
+ })
396
+ model.train()
397
+
398
+ # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
399
+ if last_step:
400
+ output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
401
+ checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
402
+ save_checkpoint(
403
+ checkpoint_dir,
404
+ step,
405
+ orig_model.state_dict(),
406
+ optimizer.state_dict(),
407
+ {
408
+ "step": step,
409
+ "val_bpb": val_bpb, # loss at last step
410
+ "model_config": {
411
+ "sequence_len": args.max_seq_len,
412
+ "vocab_size": tokenizer.get_vocab_size(),
413
+ "n_layer": depth,
414
+ "n_head": model.config.n_head,
415
+ "n_kv_head": model.config.n_kv_head,
416
+ "n_embd": model.config.n_embd,
417
+ "window_pattern": model.config.window_pattern,
418
+ },
419
+ "user_config": user_config, # inputs to the training script
420
+ },
421
+ rank=ddp_rank,
422
+ )
423
+
424
+ if last_step:
425
+ break
426
+
427
+ # -------------------------------------------------------------------------
428
+ # single training step
429
+ # evaluate the gradient
430
+ synchronize()
431
+ t0 = time.time()
432
+ for micro_step in range(grad_accum_steps):
433
+ loss = model(x, y)
434
+ train_loss = loss.detach() # for logging
435
+ loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
436
+ if scaler is not None:
437
+ scaler.scale(loss).backward()
438
+ else:
439
+ loss.backward()
440
+ x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
441
+ progress = max(progress, approx_progress) # only increase progress monotonically
442
+ # step the optimizer
443
+ lrm = get_lr_multiplier(progress)
444
+ muon_momentum = get_muon_momentum(step)
445
+ for group in optimizer.param_groups:
446
+ group["lr"] = group["initial_lr"] * lrm
447
+ if group['kind'] == 'muon':
448
+ group["momentum"] = muon_momentum
449
+ if scaler is not None:
450
+ scaler.unscale_(optimizer)
451
+ if is_ddp_initialized():
452
+ for v in scaler._found_inf_per_device(optimizer).values():
453
+ dist.all_reduce(v, op=dist.ReduceOp.MAX)
454
+ scaler.step(optimizer)
455
+ scaler.update()
456
+ else:
457
+ optimizer.step()
458
+ model.zero_grad(set_to_none=True)
459
+ synchronize()
460
+ t1 = time.time()
461
+ dt = t1 - t0
462
+ # -------------------------------------------------------------------------
463
+
464
+ # State
465
+ step += 1
466
+
467
+ # logging
468
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
469
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
470
+ pct_done = 100 * progress
471
+ tok_per_sec = int(args.total_batch_size / dt)
472
+ flops_per_sec = num_flops_per_token * args.total_batch_size / dt
473
+ mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
474
+ if step > 10:
475
+ total_training_time += dt # only count the time after the first 10 steps
476
+ print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
477
+ if step % 10 == 0:
478
+ wandb_run.log({
479
+ "step": step,
480
+ "total_training_flops": flops_so_far,
481
+ "total_training_time": total_training_time,
482
+ "train/loss": debiased_smooth_loss,
483
+ "train/lrm": lrm,
484
+ "train/dt": dt,
485
+ "train/tok_per_sec": tok_per_sec,
486
+ "train/mfu": mfu,
487
+ "train/epoch": current_epoch,
488
+ })
489
+
490
+ # The garbage collector spends ~500ms scanning for cycles quite frequently.
491
+ # We manually manage it to avoid these pauses during training.
492
+ if step == 1:
493
+ gc.collect() # manually collect a lot of garbage from setup
494
+ gc.freeze() # freeze all currently surviving objects and exclude them from GC
495
+ gc.disable() # disable GC entirely except:
496
+ elif step % 5000 == 0: # every 5000 steps...
497
+ gc.collect() # manually collect, just to be safe for very long runs
498
+
499
+ # print a few more stats
500
+ print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
501
+ print0(f"Total training time: {total_training_time/60:.2f}m")
502
+ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
503
+
504
+ # Log to report
505
+ from nanochat.report import get_report
506
+ get_report().log(section="SFT", data=[
507
+ user_config, # CLI args
508
+ { # stats about the training setup
509
+ "Number of iterations": step,
510
+ "DDP world size": ddp_world_size,
511
+ },
512
+ { # stats about training outcomes
513
+ "Minimum validation bpb": min_val_bpb,
514
+ }
515
+ ])
516
+
517
+ # cleanup
518
+ wandb_run.finish() # wandb run finish
519
+ compute_cleanup()