Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +21 -0
- README.md +206 -4
- checkpoints/README.md +4 -0
- checkpoints/meta_005000.json +58 -0
- checkpoints/model_005000.pt +3 -0
- checkpoints/optim_005000_rank0.pt +3 -0
- nanochat/__init__.py +0 -0
- nanochat/__pycache__/__init__.cpython-310.pyc +0 -0
- nanochat/__pycache__/checkpoint_manager.cpython-310.pyc +0 -0
- nanochat/__pycache__/common.cpython-310.pyc +0 -0
- nanochat/__pycache__/core_eval.cpython-310.pyc +0 -0
- nanochat/__pycache__/dataloader.cpython-310.pyc +0 -0
- nanochat/__pycache__/dataset.cpython-310.pyc +0 -0
- nanochat/__pycache__/engine.cpython-310.pyc +0 -0
- nanochat/__pycache__/execution.cpython-310.pyc +0 -0
- nanochat/__pycache__/flash_attention.cpython-310.pyc +0 -0
- nanochat/__pycache__/gpt.cpython-310.pyc +0 -0
- nanochat/__pycache__/loss_eval.cpython-310.pyc +0 -0
- nanochat/__pycache__/optim.cpython-310.pyc +0 -0
- nanochat/__pycache__/report.cpython-310.pyc +0 -0
- nanochat/__pycache__/tokenizer.cpython-310.pyc +0 -0
- nanochat/checkpoint_manager.py +194 -0
- nanochat/common.py +278 -0
- nanochat/core_eval.py +262 -0
- nanochat/dataloader.py +166 -0
- nanochat/dataset.py +160 -0
- nanochat/engine.py +357 -0
- nanochat/execution.py +349 -0
- nanochat/flash_attention.py +187 -0
- nanochat/fp8.py +266 -0
- nanochat/gpt.py +507 -0
- nanochat/logo.svg +8 -0
- nanochat/loss_eval.py +65 -0
- nanochat/optim.py +533 -0
- nanochat/report.py +418 -0
- nanochat/tokenizer.py +406 -0
- nanochat/ui.html +566 -0
- pyproject.toml +74 -0
- scripts/__pycache__/base_eval.cpython-310.pyc +0 -0
- scripts/__pycache__/base_train.cpython-310.pyc +0 -0
- scripts/__pycache__/chat_eval.cpython-310.pyc +0 -0
- scripts/__pycache__/chat_sft.cpython-310.pyc +0 -0
- scripts/__pycache__/tok_eval.cpython-310.pyc +0 -0
- scripts/__pycache__/tok_train.cpython-310.pyc +0 -0
- scripts/base_eval.py +323 -0
- scripts/base_train.py +629 -0
- scripts/chat_cli.py +100 -0
- scripts/chat_eval.py +251 -0
- scripts/chat_rl.py +332 -0
- scripts/chat_sft.py +519 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Andrej Karpathy
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,4 +1,206 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nanochat
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+

|
| 5 |
+
|
| 6 |
+
nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $48 (~2 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$15. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way.
|
| 7 |
+
|
| 8 |
+
For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord.
|
| 9 |
+
|
| 10 |
+
## Time-to-GPT-2 Leaderboard
|
| 11 |
+
|
| 12 |
+
Presently, the main focus of development is on tuning the pretraining stage, which takes the most amount of compute. Inspired by the modded-nanogpt repo and to incentivise progress and community collaboration, nanochat maintains a leaderboard for a "GPT-2 speedrun", which is the wall-clock time required to train a nanochat model to GPT-2 grade capability, as measured by the DCLM CORE score. The [runs/speedrun.sh](runs/speedrun.sh) script always reflects the reference way to train GPT-2 grade model and talk to it. The current leaderboard looks as follows:
|
| 13 |
+
|
| 14 |
+
| # | time | val_bpb | CORE | Description | Date | Commit | Contributors |
|
| 15 |
+
|---|-------------|---------|------|-------------|------|--------|--------------|
|
| 16 |
+
| 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI |
|
| 17 |
+
| 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy |
|
| 18 |
+
| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy |
|
| 19 |
+
| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy |
|
| 20 |
+
| 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy |
|
| 21 |
+
| 5 | 1.80 | 0.71808 | 0.2690 | autoresearch [round 1](https://x.com/karpathy/status/2031135152349524125) | Mar 9 2026 | 6ed7d1d | @karpathy |
|
| 22 |
+
| 5 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy |
|
| 23 |
+
|
| 24 |
+
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48).
|
| 25 |
+
|
| 26 |
+
See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard.
|
| 27 |
+
|
| 28 |
+
## Getting started
|
| 29 |
+
|
| 30 |
+
### Reproduce and talk to GPT-2
|
| 31 |
+
|
| 32 |
+
The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
bash runs/speedrun.sh
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
You may wish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python -m scripts.chat_web
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
<img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
A few more notes:
|
| 53 |
+
|
| 54 |
+
- The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
|
| 55 |
+
- All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
|
| 56 |
+
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
|
| 57 |
+
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges.
|
| 58 |
+
|
| 59 |
+
## Research
|
| 60 |
+
|
| 61 |
+
If you are a researcher and wish to help improve nanochat, two scripts of interest are [runs/scaling_laws.sh](runs/scaling_laws.sh) and [runs/miniseries.sh](runs/miniseries.sh). See [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) for related documentation. For quick experimentation (~5 min pretraining runs) my favorite scale is to train a 12-layer model (GPT-1 sized), e.g. like this:
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
| 65 |
+
--depth=12 \
|
| 66 |
+
--run="d12" \
|
| 67 |
+
--model-tag="d12" \
|
| 68 |
+
--core-metric-every=999999 \
|
| 69 |
+
--sample-every=-1 \
|
| 70 |
+
--save-every=-1 \
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for:
|
| 74 |
+
|
| 75 |
+
1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`.
|
| 76 |
+
2. `core_metric` (the DCLM CORE socre)
|
| 77 |
+
3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput)
|
| 78 |
+
|
| 79 |
+
See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044).
|
| 80 |
+
|
| 81 |
+
The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth.
|
| 82 |
+
|
| 83 |
+
## Running on CPU / MPS
|
| 84 |
+
|
| 85 |
+
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
|
| 86 |
+
|
| 87 |
+
## Precision / dtype
|
| 88 |
+
|
| 89 |
+
nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware:
|
| 90 |
+
|
| 91 |
+
| Hardware | Default dtype | Why |
|
| 92 |
+
|----------|--------------|-----|
|
| 93 |
+
| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores |
|
| 94 |
+
| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) |
|
| 95 |
+
| CPU / MPS | `float32` | No reduced-precision tensor cores |
|
| 96 |
+
|
| 97 |
+
You can override the default with the `NANOCHAT_DTYPE` environment variable:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32
|
| 101 |
+
NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision.
|
| 105 |
+
|
| 106 |
+
Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere.
|
| 107 |
+
|
| 108 |
+
## Guides
|
| 109 |
+
|
| 110 |
+
I've published a number of guides that might contain helpful information, most recent to least recent:
|
| 111 |
+
|
| 112 |
+
- [Feb 1 2026: Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481)
|
| 113 |
+
- [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models.
|
| 114 |
+
- To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
|
| 115 |
+
- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage.
|
| 116 |
+
- [Oct 13 2025: original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master.
|
| 117 |
+
|
| 118 |
+
## File structure
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
.
|
| 122 |
+
├── LICENSE
|
| 123 |
+
├── README.md
|
| 124 |
+
├── dev
|
| 125 |
+
│ ├── gen_synthetic_data.py # Example synthetic data for identity
|
| 126 |
+
│ ├── generate_logo.html
|
| 127 |
+
│ ├── nanochat.png
|
| 128 |
+
│ └── repackage_data_reference.py # Pretraining data shard generation
|
| 129 |
+
├── nanochat
|
| 130 |
+
│ ├── __init__.py # empty
|
| 131 |
+
│ ├── checkpoint_manager.py # Save/Load model checkpoints
|
| 132 |
+
│ ├── common.py # Misc small utilities, quality of life
|
| 133 |
+
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
|
| 134 |
+
│ ├── dataloader.py # Tokenizing Distributed Data Loader
|
| 135 |
+
│ ├── dataset.py # Download/read utils for pretraining data
|
| 136 |
+
│ ├── engine.py # Efficient model inference with KV Cache
|
| 137 |
+
│ ├── execution.py # Allows the LLM to execute Python code as tool
|
| 138 |
+
│ ├── gpt.py # The GPT nn.Module Transformer
|
| 139 |
+
│ ├── logo.svg
|
| 140 |
+
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
|
| 141 |
+
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
|
| 142 |
+
│ ├── report.py # Utilities for writing the nanochat Report
|
| 143 |
+
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
|
| 144 |
+
│ └── ui.html # HTML/CSS/JS for nanochat frontend
|
| 145 |
+
├── pyproject.toml
|
| 146 |
+
├── runs
|
| 147 |
+
│ ├── miniseries.sh # Miniseries training script
|
| 148 |
+
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
|
| 149 |
+
│ ├── scaling_laws.sh # Scaling laws experiments
|
| 150 |
+
│ └── speedrun.sh # Train the ~$100 nanochat d20
|
| 151 |
+
├── scripts
|
| 152 |
+
│ ├── base_eval.py # Base model: CORE score, bits per byte, samples
|
| 153 |
+
│ ├── base_train.py # Base model: train
|
| 154 |
+
│ ├── chat_cli.py # Chat model: talk to over CLI
|
| 155 |
+
│ ├── chat_eval.py # Chat model: eval tasks
|
| 156 |
+
│ ├── chat_rl.py # Chat model: reinforcement learning
|
| 157 |
+
│ ├── chat_sft.py # Chat model: train SFT
|
| 158 |
+
│ ├── chat_web.py # Chat model: talk to over WebUI
|
| 159 |
+
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
|
| 160 |
+
│ └── tok_train.py # Tokenizer: train it
|
| 161 |
+
├── tasks
|
| 162 |
+
│ ├── arc.py # Multiple choice science questions
|
| 163 |
+
│ ├── common.py # TaskMixture | TaskSequence
|
| 164 |
+
│ ├── customjson.py # Make Task from arbitrary jsonl convos
|
| 165 |
+
│ ├── gsm8k.py # 8K Grade School Math questions
|
| 166 |
+
│ ├── humaneval.py # Misnomer; Simple Python coding task
|
| 167 |
+
│ ├── mmlu.py # Multiple choice questions, broad topics
|
| 168 |
+
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
|
| 169 |
+
│ └── spellingbee.py # Task teaching model to spell/count letters
|
| 170 |
+
├── tests
|
| 171 |
+
│ └── test_engine.py
|
| 172 |
+
└── uv.lock
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## Contributing
|
| 176 |
+
|
| 177 |
+
The goal of nanochat is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there are no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a ChatGPT model you can talk to. Currently, the most interesting part personally is speeding up the latency to GPT-2 (i.e. getting a CORE score above 0.256525). Currently this takes ~3 hours, but by improving the pretraining stage we can improve this further.
|
| 178 |
+
|
| 179 |
+
Current AI policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
|
| 180 |
+
|
| 181 |
+
## Acknowledgements
|
| 182 |
+
|
| 183 |
+
- The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining.
|
| 184 |
+
- nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining.
|
| 185 |
+
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
|
| 186 |
+
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
|
| 187 |
+
- Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
|
| 188 |
+
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
|
| 189 |
+
|
| 190 |
+
## Cite
|
| 191 |
+
|
| 192 |
+
If you find nanochat helpful in your research cite simply as:
|
| 193 |
+
|
| 194 |
+
```bibtex
|
| 195 |
+
@misc{nanochat,
|
| 196 |
+
author = {Andrej Karpathy},
|
| 197 |
+
title = {nanochat: The best ChatGPT that \$100 can buy},
|
| 198 |
+
year = {2025},
|
| 199 |
+
publisher = {GitHub},
|
| 200 |
+
url = {https://github.com/karpathy/nanochat}
|
| 201 |
+
}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
## License
|
| 205 |
+
|
| 206 |
+
MIT
|
checkpoints/README.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NanoChat Checkpoint Step 5000
|
| 2 |
+
- Partial training checkpoint (training crashed after this step)
|
| 3 |
+
- Trained on ARC-Easy + ARC-Challenge dataset
|
| 4 |
+
- Loss became NaN afterward, do **not** use for further training without fixing learning rate / FP8 settings
|
checkpoints/meta_005000.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"step": 5000,
|
| 3 |
+
"val_bpb": 3.0855938505925744,
|
| 4 |
+
"model_config": {
|
| 5 |
+
"sequence_len": 512,
|
| 6 |
+
"vocab_size": 32768,
|
| 7 |
+
"n_layer": 6,
|
| 8 |
+
"n_head": 6,
|
| 9 |
+
"n_kv_head": 6,
|
| 10 |
+
"n_embd": 384,
|
| 11 |
+
"window_pattern": "L"
|
| 12 |
+
},
|
| 13 |
+
"user_config": {
|
| 14 |
+
"run": "dummy",
|
| 15 |
+
"device_type": "",
|
| 16 |
+
"fp8": false,
|
| 17 |
+
"fp8_recipe": "tensorwise",
|
| 18 |
+
"depth": 6,
|
| 19 |
+
"aspect_ratio": 64,
|
| 20 |
+
"head_dim": 64,
|
| 21 |
+
"max_seq_len": 512,
|
| 22 |
+
"window_pattern": "L",
|
| 23 |
+
"num_iterations": 5000,
|
| 24 |
+
"target_flops": -1.0,
|
| 25 |
+
"target_param_data_ratio": 10.5,
|
| 26 |
+
"device_batch_size": 32,
|
| 27 |
+
"total_batch_size": 16384,
|
| 28 |
+
"embedding_lr": 0.3,
|
| 29 |
+
"unembedding_lr": 0.008,
|
| 30 |
+
"weight_decay": 0.28,
|
| 31 |
+
"matrix_lr": 0.02,
|
| 32 |
+
"scalar_lr": 0.5,
|
| 33 |
+
"warmup_steps": 40,
|
| 34 |
+
"warmdown_ratio": 0.65,
|
| 35 |
+
"final_lr_frac": 0.05,
|
| 36 |
+
"resume_from_step": -1,
|
| 37 |
+
"eval_every": 100,
|
| 38 |
+
"eval_tokens": 524288,
|
| 39 |
+
"core_metric_every": -1,
|
| 40 |
+
"core_metric_max_per_task": 500,
|
| 41 |
+
"sample_every": 100,
|
| 42 |
+
"save_every": -1,
|
| 43 |
+
"model_tag": null
|
| 44 |
+
},
|
| 45 |
+
"device_batch_size": 32,
|
| 46 |
+
"max_seq_len": 512,
|
| 47 |
+
"total_batch_size": 16384,
|
| 48 |
+
"dataloader_state_dict": {
|
| 49 |
+
"pq_idx": 2,
|
| 50 |
+
"rg_idx": 48,
|
| 51 |
+
"epoch": 1
|
| 52 |
+
},
|
| 53 |
+
"loop_state": {
|
| 54 |
+
"min_val_bpb": 2.689004677760501,
|
| 55 |
+
"smooth_train_loss": 10.054429589944,
|
| 56 |
+
"total_training_time": 16636.2527115345
|
| 57 |
+
}
|
| 58 |
+
}
|
checkpoints/model_005000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ff116a0583967f5e0b0df01bc329993f75b5ffab8664928432de7ba556fd0e6
|
| 3 |
+
size 294145379
|
checkpoints/optim_005000_rank0.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:782a6b507c9468849298071883982d32a83736aabc1cc892cf414e0c21f0dcc9
|
| 3 |
+
size 545905413
|
nanochat/__init__.py
ADDED
|
File without changes
|
nanochat/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
nanochat/__pycache__/checkpoint_manager.cpython-310.pyc
ADDED
|
Binary file (6.96 kB). View file
|
|
|
nanochat/__pycache__/common.cpython-310.pyc
ADDED
|
Binary file (9.61 kB). View file
|
|
|
nanochat/__pycache__/core_eval.cpython-310.pyc
ADDED
|
Binary file (8.46 kB). View file
|
|
|
nanochat/__pycache__/dataloader.cpython-310.pyc
ADDED
|
Binary file (5.38 kB). View file
|
|
|
nanochat/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
nanochat/__pycache__/engine.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
nanochat/__pycache__/execution.cpython-310.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|
nanochat/__pycache__/flash_attention.cpython-310.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
nanochat/__pycache__/gpt.cpython-310.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
nanochat/__pycache__/loss_eval.cpython-310.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
nanochat/__pycache__/optim.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
nanochat/__pycache__/report.cpython-310.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
nanochat/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
nanochat/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for saving and loading model/optim/state checkpoints.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from nanochat.common import get_base_dir
|
| 12 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 13 |
+
from nanochat.tokenizer import get_tokenizer
|
| 14 |
+
from nanochat.common import setup_default_logging
|
| 15 |
+
|
| 16 |
+
# Set up logging
|
| 17 |
+
setup_default_logging()
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
def log0(message):
|
| 20 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
| 21 |
+
logger.info(message)
|
| 22 |
+
|
| 23 |
+
def _patch_missing_config_keys(model_config_kwargs):
|
| 24 |
+
"""Add default values for new config keys missing in old checkpoints."""
|
| 25 |
+
# Old models were trained with full context (no sliding window)
|
| 26 |
+
if "window_pattern" not in model_config_kwargs:
|
| 27 |
+
model_config_kwargs["window_pattern"] = "L"
|
| 28 |
+
log0(f"Patching missing window_pattern in model config to 'L'")
|
| 29 |
+
|
| 30 |
+
def _patch_missing_keys(model_data, model_config):
|
| 31 |
+
"""Add default values for new parameters that may be missing in old checkpoints."""
|
| 32 |
+
n_layer = model_config.n_layer
|
| 33 |
+
# resid_lambdas defaults to 1.0 (identity scaling)
|
| 34 |
+
if "resid_lambdas" not in model_data:
|
| 35 |
+
model_data["resid_lambdas"] = torch.ones(n_layer)
|
| 36 |
+
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
| 37 |
+
# x0_lambdas defaults to 0.0 (disabled)
|
| 38 |
+
if "x0_lambdas" not in model_data:
|
| 39 |
+
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
| 40 |
+
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
| 41 |
+
|
| 42 |
+
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
| 43 |
+
if rank == 0:
|
| 44 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 45 |
+
# Save the model state parameters
|
| 46 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 47 |
+
torch.save(model_data, model_path)
|
| 48 |
+
logger.info(f"Saved model parameters to: {model_path}")
|
| 49 |
+
# Save the metadata dict as json
|
| 50 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 51 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 52 |
+
json.dump(meta_data, f, indent=2)
|
| 53 |
+
logger.info(f"Saved metadata to: {meta_path}")
|
| 54 |
+
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
| 55 |
+
if optimizer_data is not None:
|
| 56 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 57 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 58 |
+
torch.save(optimizer_data, optimizer_path)
|
| 59 |
+
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
| 60 |
+
|
| 61 |
+
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
| 62 |
+
# Load the model state
|
| 63 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 64 |
+
model_data = torch.load(model_path, map_location=device)
|
| 65 |
+
# Load the optimizer state if requested
|
| 66 |
+
optimizer_data = None
|
| 67 |
+
if load_optimizer:
|
| 68 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 69 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 70 |
+
# Load the metadata
|
| 71 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 72 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 73 |
+
meta_data = json.load(f)
|
| 74 |
+
return model_data, optimizer_data, meta_data
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_model(checkpoint_dir, step, device, phase):
|
| 78 |
+
"""
|
| 79 |
+
A bunch of repetitive code to build a model from a given checkpoint.
|
| 80 |
+
Returns:
|
| 81 |
+
- base model - uncompiled, not wrapped in DDP
|
| 82 |
+
- tokenizer
|
| 83 |
+
- meta data saved during base model training
|
| 84 |
+
"""
|
| 85 |
+
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
| 86 |
+
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
| 87 |
+
if device.type in {"cpu", "mps"}:
|
| 88 |
+
# Convert bfloat16 tensors to float for CPU inference
|
| 89 |
+
model_data = {
|
| 90 |
+
k: v.float() if v.dtype == torch.bfloat16 else v
|
| 91 |
+
for k, v in model_data.items()
|
| 92 |
+
}
|
| 93 |
+
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
| 94 |
+
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
| 95 |
+
model_config_kwargs = meta_data["model_config"]
|
| 96 |
+
_patch_missing_config_keys(model_config_kwargs)
|
| 97 |
+
log0(f"Building model with config: {model_config_kwargs}")
|
| 98 |
+
model_config = GPTConfig(**model_config_kwargs)
|
| 99 |
+
_patch_missing_keys(model_data, model_config)
|
| 100 |
+
with torch.device("meta"):
|
| 101 |
+
model = GPT(model_config)
|
| 102 |
+
# Load the model state
|
| 103 |
+
model.to_empty(device=device)
|
| 104 |
+
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
| 105 |
+
model.load_state_dict(model_data, strict=True, assign=True)
|
| 106 |
+
# Put the model in the right training phase / mode
|
| 107 |
+
if phase == "eval":
|
| 108 |
+
model.eval()
|
| 109 |
+
else:
|
| 110 |
+
model.train()
|
| 111 |
+
# Load the Tokenizer
|
| 112 |
+
tokenizer = get_tokenizer()
|
| 113 |
+
# Sanity check: compatibility between model and tokenizer
|
| 114 |
+
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
| 115 |
+
return model, tokenizer, meta_data
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def find_largest_model(checkpoints_dir):
|
| 119 |
+
# attempt to guess the model tag: take the biggest model available
|
| 120 |
+
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
| 121 |
+
if not model_tags:
|
| 122 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
| 123 |
+
# 1) normally all model tags are of the form d<number>, try that first:
|
| 124 |
+
candidates = []
|
| 125 |
+
for model_tag in model_tags:
|
| 126 |
+
match = re.match(r"d(\d+)", model_tag)
|
| 127 |
+
if match:
|
| 128 |
+
model_depth = int(match.group(1))
|
| 129 |
+
candidates.append((model_depth, model_tag))
|
| 130 |
+
if candidates:
|
| 131 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 132 |
+
return candidates[0][1]
|
| 133 |
+
# 2) if that failed, take the most recently updated model:
|
| 134 |
+
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
| 135 |
+
return model_tags[0]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def find_last_step(checkpoint_dir):
|
| 139 |
+
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
| 140 |
+
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
| 141 |
+
if not checkpoint_files:
|
| 142 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 143 |
+
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
| 144 |
+
return last_step
|
| 145 |
+
|
| 146 |
+
# -----------------------------------------------------------------------------
|
| 147 |
+
# convenience functions that take into account nanochat's directory structure
|
| 148 |
+
|
| 149 |
+
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
| 150 |
+
if model_tag is None:
|
| 151 |
+
# guess the model tag by defaulting to the largest model
|
| 152 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 153 |
+
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
| 154 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 155 |
+
if step is None:
|
| 156 |
+
# guess the step by defaulting to the last step
|
| 157 |
+
step = find_last_step(checkpoint_dir)
|
| 158 |
+
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
| 159 |
+
# build the model
|
| 160 |
+
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
| 161 |
+
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
| 162 |
+
return model, tokenizer, meta_data
|
| 163 |
+
|
| 164 |
+
def load_model(source, *args, **kwargs):
|
| 165 |
+
model_dir = {
|
| 166 |
+
"base": "base_checkpoints",
|
| 167 |
+
"sft": "chatsft_checkpoints",
|
| 168 |
+
"rl": "chatrl_checkpoints",
|
| 169 |
+
}[source]
|
| 170 |
+
base_dir = get_base_dir()
|
| 171 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 172 |
+
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
| 173 |
+
|
| 174 |
+
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
| 175 |
+
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
| 176 |
+
model_dir = {
|
| 177 |
+
"base": "base_checkpoints",
|
| 178 |
+
"sft": "chatsft_checkpoints",
|
| 179 |
+
"rl": "chatrl_checkpoints",
|
| 180 |
+
}[source]
|
| 181 |
+
base_dir = get_base_dir()
|
| 182 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 183 |
+
if model_tag is None:
|
| 184 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 185 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 186 |
+
if step is None:
|
| 187 |
+
step = find_last_step(checkpoint_dir)
|
| 188 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 189 |
+
if not os.path.exists(optimizer_path):
|
| 190 |
+
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
| 191 |
+
return None
|
| 192 |
+
log0(f"Loading optimizer state from {optimizer_path}")
|
| 193 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 194 |
+
return optimizer_data
|
nanochat/common.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities for nanochat.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import urllib.request
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from filelock import FileLock
|
| 12 |
+
|
| 13 |
+
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
| 14 |
+
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
| 15 |
+
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
| 16 |
+
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
| 17 |
+
def _detect_compute_dtype():
|
| 18 |
+
env = os.environ.get("NANOCHAT_DTYPE")
|
| 19 |
+
if env is not None:
|
| 20 |
+
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
| 23 |
+
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
| 24 |
+
capability = torch.cuda.get_device_capability()
|
| 25 |
+
if capability >= (8, 0):
|
| 26 |
+
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
| 27 |
+
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
| 28 |
+
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
| 29 |
+
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
| 30 |
+
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
| 31 |
+
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
| 32 |
+
|
| 33 |
+
class ColoredFormatter(logging.Formatter):
|
| 34 |
+
"""Custom formatter that adds colors to log messages."""
|
| 35 |
+
# ANSI color codes
|
| 36 |
+
COLORS = {
|
| 37 |
+
'DEBUG': '\033[36m', # Cyan
|
| 38 |
+
'INFO': '\033[32m', # Green
|
| 39 |
+
'WARNING': '\033[33m', # Yellow
|
| 40 |
+
'ERROR': '\033[31m', # Red
|
| 41 |
+
'CRITICAL': '\033[35m', # Magenta
|
| 42 |
+
}
|
| 43 |
+
RESET = '\033[0m'
|
| 44 |
+
BOLD = '\033[1m'
|
| 45 |
+
def format(self, record):
|
| 46 |
+
# Add color to the level name
|
| 47 |
+
levelname = record.levelname
|
| 48 |
+
if levelname in self.COLORS:
|
| 49 |
+
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
| 50 |
+
# Format the message
|
| 51 |
+
message = super().format(record)
|
| 52 |
+
# Add color to specific parts of the message
|
| 53 |
+
if levelname == 'INFO':
|
| 54 |
+
# Highlight numbers and percentages
|
| 55 |
+
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
| 56 |
+
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
| 57 |
+
return message
|
| 58 |
+
|
| 59 |
+
def setup_default_logging():
|
| 60 |
+
handler = logging.StreamHandler()
|
| 61 |
+
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 62 |
+
logging.basicConfig(
|
| 63 |
+
level=logging.INFO,
|
| 64 |
+
handlers=[handler]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
setup_default_logging()
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
def get_base_dir():
|
| 71 |
+
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
| 72 |
+
if os.environ.get("NANOCHAT_BASE_DIR"):
|
| 73 |
+
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
|
| 74 |
+
else:
|
| 75 |
+
home_dir = os.path.expanduser("~")
|
| 76 |
+
cache_dir = os.path.join(home_dir, ".cache")
|
| 77 |
+
nanochat_dir = os.path.join(cache_dir, "nanochat")
|
| 78 |
+
os.makedirs(nanochat_dir, exist_ok=True)
|
| 79 |
+
return nanochat_dir
|
| 80 |
+
|
| 81 |
+
def download_file_with_lock(url, filename, postprocess_fn=None):
|
| 82 |
+
"""
|
| 83 |
+
Downloads a file from a URL to a local path in the base directory.
|
| 84 |
+
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
| 85 |
+
"""
|
| 86 |
+
base_dir = get_base_dir()
|
| 87 |
+
file_path = os.path.join(base_dir, filename)
|
| 88 |
+
lock_path = file_path + ".lock"
|
| 89 |
+
|
| 90 |
+
if os.path.exists(file_path):
|
| 91 |
+
return file_path
|
| 92 |
+
|
| 93 |
+
with FileLock(lock_path):
|
| 94 |
+
# Only a single rank can acquire this lock
|
| 95 |
+
# All other ranks block until it is released
|
| 96 |
+
|
| 97 |
+
# Recheck after acquiring lock
|
| 98 |
+
if os.path.exists(file_path):
|
| 99 |
+
return file_path
|
| 100 |
+
|
| 101 |
+
# Download the content as bytes
|
| 102 |
+
print(f"Downloading {url}...")
|
| 103 |
+
with urllib.request.urlopen(url) as response:
|
| 104 |
+
content = response.read() # bytes
|
| 105 |
+
|
| 106 |
+
# Write to local file
|
| 107 |
+
with open(file_path, 'wb') as f:
|
| 108 |
+
f.write(content)
|
| 109 |
+
print(f"Downloaded to {file_path}")
|
| 110 |
+
|
| 111 |
+
# Run the postprocess function if provided
|
| 112 |
+
if postprocess_fn is not None:
|
| 113 |
+
postprocess_fn(file_path)
|
| 114 |
+
|
| 115 |
+
return file_path
|
| 116 |
+
|
| 117 |
+
def print0(s="",**kwargs):
|
| 118 |
+
ddp_rank = int(os.environ.get('RANK', 0))
|
| 119 |
+
if ddp_rank == 0:
|
| 120 |
+
print(s, **kwargs)
|
| 121 |
+
|
| 122 |
+
def print_banner():
|
| 123 |
+
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
| 124 |
+
banner = """
|
| 125 |
+
█████ █████
|
| 126 |
+
░░███ ░░███
|
| 127 |
+
████████ ██████ ██��█████ ██████ ██████ ░███████ ██████ ███████
|
| 128 |
+
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
| 129 |
+
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
| 130 |
+
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
| 131 |
+
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
| 132 |
+
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
| 133 |
+
"""
|
| 134 |
+
print0(banner)
|
| 135 |
+
|
| 136 |
+
def is_ddp_requested() -> bool:
|
| 137 |
+
"""
|
| 138 |
+
True if launched by torchrun (env present), even before init.
|
| 139 |
+
Used to decide whether we *should* initialize a PG.
|
| 140 |
+
"""
|
| 141 |
+
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
| 142 |
+
|
| 143 |
+
def is_ddp_initialized() -> bool:
|
| 144 |
+
"""
|
| 145 |
+
True if torch.distributed is available and the process group is initialized.
|
| 146 |
+
Used at cleanup to avoid destroying a non-existent PG.
|
| 147 |
+
"""
|
| 148 |
+
return dist.is_available() and dist.is_initialized()
|
| 149 |
+
|
| 150 |
+
def get_dist_info():
|
| 151 |
+
if is_ddp_requested():
|
| 152 |
+
# We rely on torchrun's env to decide if we SHOULD init.
|
| 153 |
+
# (Initialization itself happens in compute init.)
|
| 154 |
+
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
| 155 |
+
ddp_rank = int(os.environ['RANK'])
|
| 156 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
| 157 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
| 158 |
+
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
| 159 |
+
else:
|
| 160 |
+
return False, 0, 0, 1
|
| 161 |
+
|
| 162 |
+
def autodetect_device_type():
|
| 163 |
+
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
| 164 |
+
if torch.cuda.is_available():
|
| 165 |
+
device_type = "cuda"
|
| 166 |
+
elif torch.backends.mps.is_available():
|
| 167 |
+
device_type = "mps"
|
| 168 |
+
else:
|
| 169 |
+
device_type = "cpu"
|
| 170 |
+
print0(f"Autodetected device type: {device_type}")
|
| 171 |
+
return device_type
|
| 172 |
+
|
| 173 |
+
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
| 174 |
+
"""Basic initialization that we keep doing over and over, so make common."""
|
| 175 |
+
|
| 176 |
+
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
| 177 |
+
if device_type == "cuda":
|
| 178 |
+
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
| 179 |
+
if device_type == "mps":
|
| 180 |
+
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
| 181 |
+
|
| 182 |
+
# Reproducibility
|
| 183 |
+
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
| 184 |
+
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
| 185 |
+
torch.manual_seed(42)
|
| 186 |
+
if device_type == "cuda":
|
| 187 |
+
torch.cuda.manual_seed(42)
|
| 188 |
+
# skipping full reproducibility for now, possibly investigate slowdown later
|
| 189 |
+
# torch.use_deterministic_algorithms(True)
|
| 190 |
+
|
| 191 |
+
# Precision
|
| 192 |
+
if device_type == "cuda":
|
| 193 |
+
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
|
| 194 |
+
|
| 195 |
+
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
| 196 |
+
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 197 |
+
if is_ddp_requested and device_type == "cuda":
|
| 198 |
+
device = torch.device("cuda", ddp_local_rank)
|
| 199 |
+
torch.cuda.set_device(device) # make "cuda" default to this device
|
| 200 |
+
dist.init_process_group(backend="nccl", device_id=device)
|
| 201 |
+
dist.barrier()
|
| 202 |
+
else:
|
| 203 |
+
device = torch.device(device_type) # mps|cpu
|
| 204 |
+
|
| 205 |
+
if ddp_rank == 0:
|
| 206 |
+
logger.info(f"Distributed world size: {ddp_world_size}")
|
| 207 |
+
|
| 208 |
+
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
| 209 |
+
|
| 210 |
+
def compute_cleanup():
|
| 211 |
+
"""Companion function to compute_init, to clean things up before script exit"""
|
| 212 |
+
if is_ddp_initialized():
|
| 213 |
+
dist.destroy_process_group()
|
| 214 |
+
|
| 215 |
+
class DummyWandb:
|
| 216 |
+
"""Useful if we wish to not use wandb but have all the same signatures"""
|
| 217 |
+
def __init__(self):
|
| 218 |
+
pass
|
| 219 |
+
def log(self, *args, **kwargs):
|
| 220 |
+
pass
|
| 221 |
+
def finish(self):
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
# hardcoded BF16 peak flops for various GPUs
|
| 225 |
+
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
| 226 |
+
# and PR: https://github.com/karpathy/nanochat/pull/147
|
| 227 |
+
def get_peak_flops(device_name: str) -> float:
|
| 228 |
+
name = device_name.lower()
|
| 229 |
+
|
| 230 |
+
# Table order matters: more specific patterns first.
|
| 231 |
+
_PEAK_FLOPS_TABLE = (
|
| 232 |
+
# NVIDIA Blackwell
|
| 233 |
+
(["gb200"], 2.5e15),
|
| 234 |
+
(["grace blackwell"], 2.5e15),
|
| 235 |
+
(["b200"], 2.25e15),
|
| 236 |
+
(["b100"], 1.8e15),
|
| 237 |
+
# NVIDIA Hopper
|
| 238 |
+
(["h200", "nvl"], 836e12),
|
| 239 |
+
(["h200", "pcie"], 836e12),
|
| 240 |
+
(["h200"], 989e12),
|
| 241 |
+
(["h100", "nvl"], 835e12),
|
| 242 |
+
(["h100", "pcie"], 756e12),
|
| 243 |
+
(["h100"], 989e12),
|
| 244 |
+
(["h800", "nvl"], 989e12),
|
| 245 |
+
(["h800"], 756e12),
|
| 246 |
+
# NVIDIA Ampere data center
|
| 247 |
+
(["a100"], 312e12),
|
| 248 |
+
(["a800"], 312e12),
|
| 249 |
+
(["a40"], 149.7e12),
|
| 250 |
+
(["a30"], 165e12),
|
| 251 |
+
# NVIDIA Ada data center
|
| 252 |
+
(["l40s"], 362e12),
|
| 253 |
+
(["l40-s"], 362e12),
|
| 254 |
+
(["l40 s"], 362e12),
|
| 255 |
+
(["l4"], 121e12),
|
| 256 |
+
# AMD CDNA accelerators
|
| 257 |
+
(["mi355"], 2.5e15),
|
| 258 |
+
(["mi325"], 1.3074e15),
|
| 259 |
+
(["mi300x"], 1.3074e15),
|
| 260 |
+
(["mi300a"], 980.6e12),
|
| 261 |
+
(["mi250x"], 383e12),
|
| 262 |
+
(["mi250"], 362.1e12),
|
| 263 |
+
# Consumer RTX
|
| 264 |
+
(["5090"], 209.5e12),
|
| 265 |
+
(["4090"], 165.2e12),
|
| 266 |
+
(["3090"], 71e12),
|
| 267 |
+
)
|
| 268 |
+
for patterns, flops in _PEAK_FLOPS_TABLE:
|
| 269 |
+
if all(p in name for p in patterns):
|
| 270 |
+
return flops
|
| 271 |
+
if "data center gpu max 1550" in name:
|
| 272 |
+
# Ponte Vecchio (PVC) - dynamic based on compute units
|
| 273 |
+
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
| 274 |
+
return 512 * max_comp_units * 1300 * 10**6
|
| 275 |
+
|
| 276 |
+
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
| 277 |
+
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
| 278 |
+
return float('inf')
|
nanochat/core_eval.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions for evaluating the CORE metric, as described in the DCLM paper.
|
| 3 |
+
https://arxiv.org/abs/2406.11794
|
| 4 |
+
|
| 5 |
+
TODOs:
|
| 6 |
+
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
|
| 7 |
+
"""
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from jinja2 import Template
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# Prompt rendering utilities
|
| 16 |
+
|
| 17 |
+
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
| 18 |
+
"""Render complete prompts for a multiple choice question"""
|
| 19 |
+
template_str = """
|
| 20 |
+
{%- for example in fewshot_examples -%}
|
| 21 |
+
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
|
| 22 |
+
|
| 23 |
+
{% endfor -%}
|
| 24 |
+
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
|
| 25 |
+
template = Template(template_str)
|
| 26 |
+
fewshot_examples = fewshot_examples or []
|
| 27 |
+
context = {
|
| 28 |
+
'fewshot_examples': fewshot_examples,
|
| 29 |
+
'continuation_delimiter': continuation_delimiter,
|
| 30 |
+
'item': item
|
| 31 |
+
}
|
| 32 |
+
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
|
| 33 |
+
return prompts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
|
| 37 |
+
"""Render complete prompts for a schema question"""
|
| 38 |
+
template_str = """
|
| 39 |
+
{%- for example in fewshot_examples -%}
|
| 40 |
+
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 41 |
+
|
| 42 |
+
{% endfor -%}
|
| 43 |
+
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
|
| 44 |
+
template = Template(template_str)
|
| 45 |
+
fewshot_examples = fewshot_examples or []
|
| 46 |
+
context = {
|
| 47 |
+
'fewshot_examples': fewshot_examples,
|
| 48 |
+
'continuation_delimiter': continuation_delimiter,
|
| 49 |
+
'item': item
|
| 50 |
+
}
|
| 51 |
+
prompts = [template.render(context=context_option, **context)
|
| 52 |
+
for context_option in item['context_options']]
|
| 53 |
+
return prompts
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
| 57 |
+
"""
|
| 58 |
+
Render complete prompt for a language modeling task.
|
| 59 |
+
Notice that we manually trim the context in the template,
|
| 60 |
+
which in some datasets seems to have trailing whitespace (which we don't want).
|
| 61 |
+
"""
|
| 62 |
+
template_str = """
|
| 63 |
+
{%- for example in fewshot_examples -%}
|
| 64 |
+
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 65 |
+
|
| 66 |
+
{% endfor -%}
|
| 67 |
+
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
|
| 68 |
+
template = Template(template_str)
|
| 69 |
+
fewshot_examples = fewshot_examples or []
|
| 70 |
+
context = {
|
| 71 |
+
'fewshot_examples': fewshot_examples,
|
| 72 |
+
'continuation_delimiter': continuation_delimiter,
|
| 73 |
+
'item': item
|
| 74 |
+
}
|
| 75 |
+
# Return two prompts: without and with the continuation
|
| 76 |
+
prompt_without = template.render(include_continuation=False, **context)
|
| 77 |
+
prompt_with = template.render(include_continuation=True, **context)
|
| 78 |
+
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
|
| 79 |
+
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
|
| 80 |
+
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
|
| 81 |
+
# to detect the final continuation. Tokenizers...
|
| 82 |
+
prompt_without = prompt_without.strip()
|
| 83 |
+
return [prompt_without, prompt_with]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def find_common_length(token_sequences, direction='left'):
|
| 87 |
+
"""
|
| 88 |
+
Find the length of the common prefix or suffix across token sequences
|
| 89 |
+
- direction: 'left' for prefix, 'right' for suffix
|
| 90 |
+
"""
|
| 91 |
+
min_len = min(len(seq) for seq in token_sequences)
|
| 92 |
+
indices = {
|
| 93 |
+
'left': range(min_len),
|
| 94 |
+
'right': range(-1, -min_len-1, -1)
|
| 95 |
+
}[direction]
|
| 96 |
+
# Find the first position where the token sequences differ
|
| 97 |
+
for i, idx in enumerate(indices):
|
| 98 |
+
token = token_sequences[0][idx]
|
| 99 |
+
if not all(seq[idx] == token for seq in token_sequences):
|
| 100 |
+
return i
|
| 101 |
+
return min_len
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def stack_sequences(tokens, pad_token_id):
|
| 105 |
+
"""Stack up a list of token sequences, pad to longest on the right"""
|
| 106 |
+
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
|
| 107 |
+
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
|
| 108 |
+
for i, x in enumerate(tokens):
|
| 109 |
+
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
|
| 110 |
+
return input_ids
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def batch_sequences_mc(tokenizer, prompts):
|
| 114 |
+
# In multiple choice, contexts are the same but the continuation is different (common prefix)
|
| 115 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 116 |
+
# figure out the start and end of each continuation
|
| 117 |
+
answer_start_idx = find_common_length(tokens, direction='left')
|
| 118 |
+
start_indices = [answer_start_idx] * len(prompts)
|
| 119 |
+
end_indices = [len(x) for x in tokens]
|
| 120 |
+
return tokens, start_indices, end_indices
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def batch_sequences_schema(tokenizer, prompts):
|
| 124 |
+
# In schema tasks, contexts vary but continuation is the same (common suffix)
|
| 125 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 126 |
+
# figure out the start and end of each context
|
| 127 |
+
suffix_length = find_common_length(tokens, direction='right')
|
| 128 |
+
end_indices = [len(x) for x in tokens]
|
| 129 |
+
start_indices = [ei - suffix_length for ei in end_indices]
|
| 130 |
+
return tokens, start_indices, end_indices
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def batch_sequences_lm(tokenizer, prompts):
|
| 134 |
+
# In LM tasks, we have two prompts: without and with continuation
|
| 135 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 136 |
+
tokens_without, tokens_with = tokens
|
| 137 |
+
start_idx, end_idx = len(tokens_without), len(tokens_with)
|
| 138 |
+
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
|
| 139 |
+
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
|
| 140 |
+
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
|
| 141 |
+
return [tokens_with], [start_idx], [end_idx]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@torch.no_grad()
|
| 145 |
+
def forward_model(model, input_ids):
|
| 146 |
+
"""
|
| 147 |
+
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
| 148 |
+
The last column of losses is set to nan because we don't have autoregressive targets there.
|
| 149 |
+
"""
|
| 150 |
+
batch_size, seq_len = input_ids.size()
|
| 151 |
+
outputs = model(input_ids)
|
| 152 |
+
# Roll the tensor to the left by one position to get the (autoregressive) target ids
|
| 153 |
+
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
| 154 |
+
# Calculate cross entropy at all positions
|
| 155 |
+
losses = torch.nn.functional.cross_entropy(
|
| 156 |
+
outputs.view(batch_size * seq_len, -1),
|
| 157 |
+
target_ids.view(batch_size * seq_len),
|
| 158 |
+
reduction='none'
|
| 159 |
+
).view(batch_size, seq_len)
|
| 160 |
+
# Set the last column to be nan because there is no autoregressive loss there
|
| 161 |
+
losses[:, -1] = float('nan')
|
| 162 |
+
# Get the argmax predictions at each position
|
| 163 |
+
predictions = outputs.argmax(dim=-1)
|
| 164 |
+
return losses, predictions
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
| 169 |
+
"""Evaluate a single example, return True if correct, False otherwise"""
|
| 170 |
+
item = data[idx]
|
| 171 |
+
task_type = task_meta['task_type']
|
| 172 |
+
num_fewshot = task_meta['num_fewshot']
|
| 173 |
+
continuation_delimiter = task_meta['continuation_delimiter']
|
| 174 |
+
|
| 175 |
+
# Sample few-shot examples (excluding current item)
|
| 176 |
+
fewshot_examples = []
|
| 177 |
+
if num_fewshot > 0:
|
| 178 |
+
rng = random.Random(1234 + idx)
|
| 179 |
+
available_indices = [i for i in range(len(data)) if i != idx]
|
| 180 |
+
fewshot_indices = rng.sample(available_indices, num_fewshot)
|
| 181 |
+
fewshot_examples = [data[i] for i in fewshot_indices]
|
| 182 |
+
|
| 183 |
+
# Render prompts and batch sequences based on task type
|
| 184 |
+
if task_type == 'multiple_choice':
|
| 185 |
+
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
|
| 186 |
+
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
|
| 187 |
+
elif task_type == 'schema':
|
| 188 |
+
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
|
| 189 |
+
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
|
| 190 |
+
elif task_type == 'language_modeling':
|
| 191 |
+
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
|
| 192 |
+
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 195 |
+
|
| 196 |
+
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
|
| 197 |
+
# In these cases, we have to truncate sequences to max length and adjust the indices
|
| 198 |
+
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
| 199 |
+
max_tokens = model.max_seq_len
|
| 200 |
+
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
| 201 |
+
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
| 202 |
+
if len(t) > max_tokens:
|
| 203 |
+
num_to_crop = len(t) - max_tokens
|
| 204 |
+
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
| 205 |
+
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
| 206 |
+
new_end_idxs.append(e - num_to_crop)
|
| 207 |
+
assert s - num_to_crop >= 0, "this should never happen right?"
|
| 208 |
+
assert e - num_to_crop >= 0, "this should never happen right?"
|
| 209 |
+
else:
|
| 210 |
+
new_tokens.append(t) # keep unchanged
|
| 211 |
+
new_start_idxs.append(s)
|
| 212 |
+
new_end_idxs.append(e)
|
| 213 |
+
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
| 214 |
+
|
| 215 |
+
# Stack up all the sequences into a batch
|
| 216 |
+
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
| 217 |
+
input_ids = stack_sequences(tokens, pad_token_id)
|
| 218 |
+
input_ids = input_ids.to(device)
|
| 219 |
+
|
| 220 |
+
# Forward the model, get the autoregressive loss and argmax prediction at each token
|
| 221 |
+
losses, predictions = forward_model(model, input_ids)
|
| 222 |
+
|
| 223 |
+
# See if the losses/predictions come out correctly
|
| 224 |
+
if task_type == 'language_modeling':
|
| 225 |
+
# language modeling task is currently always batch size 1
|
| 226 |
+
si = start_idxs[0]
|
| 227 |
+
ei = end_idxs[0]
|
| 228 |
+
# predictions[i] predict input_ids[i+1] autoregressively
|
| 229 |
+
predicted_tokens = predictions[0, si-1:ei-1]
|
| 230 |
+
actual_tokens = input_ids[0, si:ei]
|
| 231 |
+
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
| 232 |
+
elif task_type in ['multiple_choice', 'schema']:
|
| 233 |
+
# For MC/schema: find the option with lowest average loss
|
| 234 |
+
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
| 235 |
+
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
| 236 |
+
pred_idx = mean_losses.index(min(mean_losses))
|
| 237 |
+
is_correct = pred_idx == item['gold']
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 240 |
+
|
| 241 |
+
return is_correct
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def evaluate_task(model, tokenizer, data, device, task_meta):
|
| 245 |
+
"""
|
| 246 |
+
This function is responsible for evaluating one task across many examples.
|
| 247 |
+
It also handles dispatch to all processes if the script is run with torchrun.
|
| 248 |
+
"""
|
| 249 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 250 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 251 |
+
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
| 252 |
+
# stride the examples to each rank
|
| 253 |
+
for idx in range(rank, len(data), world_size):
|
| 254 |
+
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
|
| 255 |
+
correct[idx] = float(is_correct)
|
| 256 |
+
# sync results across all the processes if running distributed
|
| 257 |
+
if world_size > 1:
|
| 258 |
+
dist.barrier()
|
| 259 |
+
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
| 260 |
+
# compute the mean
|
| 261 |
+
mean_correct = correct.mean().item()
|
| 262 |
+
return mean_correct
|
nanochat/dataloader.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distributed dataloaders for pretraining.
|
| 3 |
+
|
| 4 |
+
BOS-aligned bestfit:
|
| 5 |
+
- Every row starts with BOS token
|
| 6 |
+
- Documents packed using best-fit algorithm to minimize cropping
|
| 7 |
+
- When no document fits remaining space, crops a document to fill exactly
|
| 8 |
+
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
| 9 |
+
|
| 10 |
+
Compared to the original tokenizing_distributed_data_loader:
|
| 11 |
+
BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
| 12 |
+
there are fewer "confusing" tokens in the train/val batches as every token can
|
| 13 |
+
now attend back to the BOS token and sees the full context of the document.
|
| 14 |
+
|
| 15 |
+
Fallback to the original if you have very limited data AND long documents:
|
| 16 |
+
https://github.com/karpathy/nanochat/blob/3c3a3d7/nanochat/dataloader.py#L78-L117
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import pyarrow.parquet as pq
|
| 21 |
+
|
| 22 |
+
from nanochat.common import get_dist_info
|
| 23 |
+
from nanochat.dataset import list_parquet_files
|
| 24 |
+
|
| 25 |
+
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
| 26 |
+
"""
|
| 27 |
+
Infinite iterator over document batches (list of text strings) from parquet files.
|
| 28 |
+
|
| 29 |
+
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
| 30 |
+
where text_batch is a list of document strings, indices track position for resumption,
|
| 31 |
+
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
| 32 |
+
"""
|
| 33 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 34 |
+
|
| 35 |
+
warn_on_legacy = ddp_rank == 0 and split == "train" # rank 0 on train split will warn on legacy
|
| 36 |
+
parquet_paths = list_parquet_files(warn_on_legacy=warn_on_legacy)
|
| 37 |
+
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
| 38 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 39 |
+
|
| 40 |
+
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
| 41 |
+
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
| 42 |
+
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
| 43 |
+
first_pass = True
|
| 44 |
+
pq_idx = resume_pq_idx
|
| 45 |
+
epoch = resume_epoch
|
| 46 |
+
|
| 47 |
+
while True: # iterate infinitely (multi-epoch)
|
| 48 |
+
pq_idx = resume_pq_idx if first_pass else 0
|
| 49 |
+
while pq_idx < len(parquet_paths):
|
| 50 |
+
filepath = parquet_paths[pq_idx]
|
| 51 |
+
pf = pq.ParquetFile(filepath)
|
| 52 |
+
# Start from resume point if resuming on same file, otherwise from DDP rank
|
| 53 |
+
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
| 54 |
+
base_idx = resume_rg_idx // ddp_world_size
|
| 55 |
+
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
| 56 |
+
rg_idx = base_idx * ddp_world_size + ddp_rank
|
| 57 |
+
if rg_idx >= pf.num_row_groups:
|
| 58 |
+
pq_idx += 1
|
| 59 |
+
continue
|
| 60 |
+
resume_rg_idx = None # only do this once
|
| 61 |
+
else:
|
| 62 |
+
rg_idx = ddp_rank
|
| 63 |
+
while rg_idx < pf.num_row_groups:
|
| 64 |
+
rg = pf.read_row_group(rg_idx)
|
| 65 |
+
batch = rg.column('text').to_pylist()
|
| 66 |
+
for i in range(0, len(batch), tokenizer_batch_size):
|
| 67 |
+
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
| 68 |
+
rg_idx += ddp_world_size
|
| 69 |
+
pq_idx += 1
|
| 70 |
+
first_pass = False
|
| 71 |
+
epoch += 1
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
| 75 |
+
tokenizer, B, T, split,
|
| 76 |
+
tokenizer_threads=4, tokenizer_batch_size=128,
|
| 77 |
+
device="cuda", resume_state_dict=None,
|
| 78 |
+
buffer_size=1000
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
BOS-aligned dataloader with Best-Fit Cropping.
|
| 82 |
+
|
| 83 |
+
Reduces token waste compared to simple greedy cropping by searching a buffer
|
| 84 |
+
for documents that fit well, while maintaining 100% utilization (no padding).
|
| 85 |
+
|
| 86 |
+
Algorithm for each row:
|
| 87 |
+
1. From buffered docs, pick the LARGEST doc that fits entirely
|
| 88 |
+
2. Repeat until no doc fits
|
| 89 |
+
3. When nothing fits, crop a doc to fill remaining space exactly
|
| 90 |
+
|
| 91 |
+
Key properties:
|
| 92 |
+
- Every row starts with BOS
|
| 93 |
+
- 100% utilization (no padding, every token is trained on)
|
| 94 |
+
- Approximately 35% of all tokens are discarded due to cropping
|
| 95 |
+
"""
|
| 96 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 97 |
+
|
| 98 |
+
row_capacity = T + 1
|
| 99 |
+
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
| 100 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 101 |
+
doc_buffer = []
|
| 102 |
+
pq_idx, rg_idx, epoch = 0, 0, 1
|
| 103 |
+
|
| 104 |
+
def refill_buffer():
|
| 105 |
+
nonlocal pq_idx, rg_idx, epoch
|
| 106 |
+
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
| 107 |
+
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
| 108 |
+
for tokens in token_lists:
|
| 109 |
+
doc_buffer.append(tokens)
|
| 110 |
+
|
| 111 |
+
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
| 112 |
+
# This gives us contiguous views and a single HtoD transfer
|
| 113 |
+
use_cuda = device == "cuda"
|
| 114 |
+
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
|
| 115 |
+
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
| 116 |
+
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
| 117 |
+
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
| 118 |
+
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
| 119 |
+
inputs = gpu_buffer[:B * T].view(B, T)
|
| 120 |
+
targets = gpu_buffer[B * T:].view(B, T)
|
| 121 |
+
|
| 122 |
+
while True:
|
| 123 |
+
for row_idx in range(B):
|
| 124 |
+
pos = 0
|
| 125 |
+
while pos < row_capacity:
|
| 126 |
+
# Ensure buffer has documents
|
| 127 |
+
while len(doc_buffer) < buffer_size:
|
| 128 |
+
refill_buffer()
|
| 129 |
+
|
| 130 |
+
remaining = row_capacity - pos
|
| 131 |
+
|
| 132 |
+
# Find largest doc that fits entirely
|
| 133 |
+
best_idx = -1
|
| 134 |
+
best_len = 0
|
| 135 |
+
for i, doc in enumerate(doc_buffer):
|
| 136 |
+
doc_len = len(doc)
|
| 137 |
+
if doc_len <= remaining and doc_len > best_len:
|
| 138 |
+
best_idx = i
|
| 139 |
+
best_len = doc_len
|
| 140 |
+
|
| 141 |
+
if best_idx >= 0:
|
| 142 |
+
doc = doc_buffer.pop(best_idx)
|
| 143 |
+
doc_len = len(doc)
|
| 144 |
+
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
|
| 145 |
+
pos += doc_len
|
| 146 |
+
else:
|
| 147 |
+
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
| 148 |
+
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
| 149 |
+
doc = doc_buffer.pop(shortest_idx)
|
| 150 |
+
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
| 151 |
+
pos += remaining
|
| 152 |
+
|
| 153 |
+
# Copy to pinned CPU buffer, then single HtoD transfer
|
| 154 |
+
cpu_inputs.copy_(row_buffer[:, :-1])
|
| 155 |
+
cpu_targets.copy_(row_buffer[:, 1:])
|
| 156 |
+
|
| 157 |
+
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
| 158 |
+
|
| 159 |
+
# Single HtoD copy into persistent GPU buffer and yield
|
| 160 |
+
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
| 161 |
+
yield inputs, targets, state_dict
|
| 162 |
+
|
| 163 |
+
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
| 164 |
+
"""Helper that omits state_dict from yields."""
|
| 165 |
+
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
| 166 |
+
yield inputs, targets
|
nanochat/dataset.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The base/pretraining dataset is a set of parquet files.
|
| 3 |
+
This file contains utilities for:
|
| 4 |
+
- iterating over the parquet files and yielding documents from it
|
| 5 |
+
- download the files on demand if they are not on disk
|
| 6 |
+
|
| 7 |
+
For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
import requests
|
| 14 |
+
import pyarrow.parquet as pq
|
| 15 |
+
from multiprocessing import Pool
|
| 16 |
+
|
| 17 |
+
from nanochat.common import get_base_dir
|
| 18 |
+
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
# The specifics of the current pretraining dataset
|
| 21 |
+
|
| 22 |
+
# The URL on the internet where the data is hosted and downloaded from on demand
|
| 23 |
+
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
| 24 |
+
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
| 25 |
+
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
| 26 |
+
base_dir = get_base_dir()
|
| 27 |
+
DATA_DIR = os.path.join(base_dir, "base_data_climbmix")
|
| 28 |
+
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
# These functions are useful utilities to other modules, can/should be imported
|
| 31 |
+
|
| 32 |
+
def list_parquet_files(data_dir=None, warn_on_legacy=False):
|
| 33 |
+
""" Looks into a data dir and returns full paths to all parquet files. """
|
| 34 |
+
data_dir = DATA_DIR if data_dir is None else data_dir
|
| 35 |
+
|
| 36 |
+
# Legacy-supporting code due to the upgrade from FinewebEdu-100B to ClimbMix-400B
|
| 37 |
+
# This code will eventually be deleted.
|
| 38 |
+
if not os.path.exists(data_dir):
|
| 39 |
+
if warn_on_legacy:
|
| 40 |
+
print()
|
| 41 |
+
print("=" * 80)
|
| 42 |
+
print(" WARNING: DATASET UPGRADE REQUIRED")
|
| 43 |
+
print("=" * 80)
|
| 44 |
+
print()
|
| 45 |
+
print(f" Could not find: {data_dir}")
|
| 46 |
+
print()
|
| 47 |
+
print(" nanochat recently switched from FinewebEdu-100B to ClimbMix-400B.")
|
| 48 |
+
print(" Everyone who does `git pull` as of March 4, 2026 is expected to see this message.")
|
| 49 |
+
print(" To upgrade to the new ClimbMix-400B dataset, run these two commands:")
|
| 50 |
+
print()
|
| 51 |
+
print(" python -m nanochat.dataset -n 170 # download ~170 shards, enough for GPT-2, adjust as desired")
|
| 52 |
+
print(" python -m scripts.tok_train # re-train tokenizer on new ClimbMix data")
|
| 53 |
+
print()
|
| 54 |
+
print(" For now, falling back to your old FinewebEdu-100B dataset...")
|
| 55 |
+
print("=" * 80)
|
| 56 |
+
print()
|
| 57 |
+
# attempt a fallback to the legacy data directory
|
| 58 |
+
data_dir = os.path.join(base_dir, "base_data")
|
| 59 |
+
|
| 60 |
+
parquet_files = sorted([
|
| 61 |
+
f for f in os.listdir(data_dir)
|
| 62 |
+
if f.endswith('.parquet') and not f.endswith('.tmp')
|
| 63 |
+
])
|
| 64 |
+
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
| 65 |
+
return parquet_paths
|
| 66 |
+
|
| 67 |
+
def parquets_iter_batched(split, start=0, step=1):
|
| 68 |
+
"""
|
| 69 |
+
Iterate through the dataset, in batches of underlying row_groups for efficiency.
|
| 70 |
+
- split can be "train" or "val". the last parquet file will be val.
|
| 71 |
+
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
|
| 72 |
+
"""
|
| 73 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 74 |
+
parquet_paths = list_parquet_files()
|
| 75 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 76 |
+
for filepath in parquet_paths:
|
| 77 |
+
pf = pq.ParquetFile(filepath)
|
| 78 |
+
for rg_idx in range(start, pf.num_row_groups, step):
|
| 79 |
+
rg = pf.read_row_group(rg_idx)
|
| 80 |
+
texts = rg.column('text').to_pylist()
|
| 81 |
+
yield texts
|
| 82 |
+
|
| 83 |
+
# -----------------------------------------------------------------------------
|
| 84 |
+
def download_single_file(index):
|
| 85 |
+
""" Downloads a single file index, with some backoff """
|
| 86 |
+
|
| 87 |
+
# Construct the local filepath for this file and skip if it already exists
|
| 88 |
+
filename = index_to_filename(index)
|
| 89 |
+
filepath = os.path.join(DATA_DIR, filename)
|
| 90 |
+
if os.path.exists(filepath):
|
| 91 |
+
print(f"Skipping {filepath} (already exists)")
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
# Construct the remote URL for this file
|
| 95 |
+
url = f"{BASE_URL}/{filename}"
|
| 96 |
+
print(f"Downloading {filename}...")
|
| 97 |
+
|
| 98 |
+
# Download with retries
|
| 99 |
+
max_attempts = 5
|
| 100 |
+
for attempt in range(1, max_attempts + 1):
|
| 101 |
+
try:
|
| 102 |
+
response = requests.get(url, stream=True, timeout=30)
|
| 103 |
+
response.raise_for_status()
|
| 104 |
+
# Write to temporary file first
|
| 105 |
+
temp_path = filepath + f".tmp"
|
| 106 |
+
with open(temp_path, 'wb') as f:
|
| 107 |
+
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
| 108 |
+
if chunk:
|
| 109 |
+
f.write(chunk)
|
| 110 |
+
# Move temp file to final location
|
| 111 |
+
os.rename(temp_path, filepath)
|
| 112 |
+
print(f"Successfully downloaded {filename}")
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
except (requests.RequestException, IOError) as e:
|
| 116 |
+
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
| 117 |
+
# Clean up any partial files
|
| 118 |
+
for path in [filepath + f".tmp", filepath]:
|
| 119 |
+
if os.path.exists(path):
|
| 120 |
+
try:
|
| 121 |
+
os.remove(path)
|
| 122 |
+
except:
|
| 123 |
+
pass
|
| 124 |
+
# Try a few times with exponential backoff: 2^attempt seconds
|
| 125 |
+
if attempt < max_attempts:
|
| 126 |
+
wait_time = 2 ** attempt
|
| 127 |
+
print(f"Waiting {wait_time} seconds before retry...")
|
| 128 |
+
time.sleep(wait_time)
|
| 129 |
+
else:
|
| 130 |
+
print(f"Failed to download {filename} after {max_attempts} attempts")
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
parser = argparse.ArgumentParser(description="Download pretraining dataset shards")
|
| 138 |
+
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of train shards to download (default: -1), -1 = disable")
|
| 139 |
+
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
# Prepare the output directory
|
| 143 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 144 |
+
|
| 145 |
+
# The way this works is that the user specifies the number of train shards to download via the -n flag.
|
| 146 |
+
# In addition to that, the validation shard is *always* downloaded and is pinned to be the last shard.
|
| 147 |
+
num_train_shards = MAX_SHARD if args.num_files == -1 else min(args.num_files, MAX_SHARD)
|
| 148 |
+
ids_to_download = list(range(num_train_shards))
|
| 149 |
+
ids_to_download.append(MAX_SHARD) # always download the validation shard
|
| 150 |
+
|
| 151 |
+
# Download the shards
|
| 152 |
+
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
| 153 |
+
print(f"Target directory: {DATA_DIR}")
|
| 154 |
+
print()
|
| 155 |
+
with Pool(processes=args.num_workers) as pool:
|
| 156 |
+
results = pool.map(download_single_file, ids_to_download)
|
| 157 |
+
|
| 158 |
+
# Report results
|
| 159 |
+
successful = sum(1 for success in results if success)
|
| 160 |
+
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
|
nanochat/engine.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engine for efficient inference of our models.
|
| 3 |
+
|
| 4 |
+
Everything works around token sequences:
|
| 5 |
+
- The user can send token sequences to the engine
|
| 6 |
+
- The engine returns the next token
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- The engine knows nothing about tokenization, it's purely token id sequences.
|
| 10 |
+
|
| 11 |
+
The whole thing is made as efficient as possible.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import signal
|
| 17 |
+
import warnings
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from collections import deque
|
| 20 |
+
from nanochat.common import compute_init, autodetect_device_type
|
| 21 |
+
from nanochat.checkpoint_manager import load_model
|
| 22 |
+
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
# Calculator tool helpers
|
| 25 |
+
@contextmanager
|
| 26 |
+
def timeout(duration, formula):
|
| 27 |
+
def timeout_handler(signum, frame):
|
| 28 |
+
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
| 29 |
+
|
| 30 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 31 |
+
signal.alarm(duration)
|
| 32 |
+
yield
|
| 33 |
+
signal.alarm(0)
|
| 34 |
+
|
| 35 |
+
def eval_with_timeout(formula, max_time=3):
|
| 36 |
+
try:
|
| 37 |
+
with timeout(max_time, formula):
|
| 38 |
+
with warnings.catch_warnings():
|
| 39 |
+
warnings.simplefilter("ignore", SyntaxWarning)
|
| 40 |
+
return eval(formula, {"__builtins__": {}}, {})
|
| 41 |
+
except Exception as e:
|
| 42 |
+
signal.alarm(0)
|
| 43 |
+
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def use_calculator(expr):
|
| 47 |
+
"""
|
| 48 |
+
Evaluate a Python expression safely.
|
| 49 |
+
Supports both math expressions and string operations like .count()
|
| 50 |
+
"""
|
| 51 |
+
# Remove commas from numbers
|
| 52 |
+
expr = expr.replace(",", "")
|
| 53 |
+
|
| 54 |
+
# Check if it's a pure math expression (old behavior)
|
| 55 |
+
if all([x in "0123456789*+-/.() " for x in expr]):
|
| 56 |
+
if "**" in expr: # disallow power operator
|
| 57 |
+
return None
|
| 58 |
+
return eval_with_timeout(expr)
|
| 59 |
+
|
| 60 |
+
# Check if it's a string operation we support
|
| 61 |
+
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
| 62 |
+
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
| 63 |
+
if not all([x in allowed_chars for x in expr]):
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
# Disallow dangerous patterns
|
| 67 |
+
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
| 68 |
+
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
| 69 |
+
'getattr', 'setattr', 'delattr', 'hasattr']
|
| 70 |
+
expr_lower = expr.lower()
|
| 71 |
+
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
# Only allow .count() method for now (can expand later)
|
| 75 |
+
if '.count(' not in expr:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
# Evaluate with timeout
|
| 79 |
+
return eval_with_timeout(expr)
|
| 80 |
+
|
| 81 |
+
# -----------------------------------------------------------------------------
|
| 82 |
+
class KVCache:
|
| 83 |
+
"""
|
| 84 |
+
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
| 85 |
+
|
| 86 |
+
Key differences from FA2-style cache:
|
| 87 |
+
- Tensors are (B, T, H, D) not (B, H, T, D)
|
| 88 |
+
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
| 89 |
+
- Position tracked per batch element via cache_seqlens tensor
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
self.max_seq_len = seq_len
|
| 95 |
+
self.n_layers = num_layers
|
| 96 |
+
self.n_heads = num_heads
|
| 97 |
+
self.head_dim = head_dim
|
| 98 |
+
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
| 99 |
+
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 100 |
+
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 101 |
+
# Current sequence length per batch element (FA3 needs int32)
|
| 102 |
+
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
| 103 |
+
# Previous token's normalized embedding for smear (set by model forward pass)
|
| 104 |
+
self.prev_embedding = None
|
| 105 |
+
|
| 106 |
+
def reset(self):
|
| 107 |
+
"""Reset cache to empty state."""
|
| 108 |
+
self.cache_seqlens.zero_()
|
| 109 |
+
self.prev_embedding = None
|
| 110 |
+
|
| 111 |
+
def get_pos(self):
|
| 112 |
+
"""Get current position (assumes all batch elements at same position)."""
|
| 113 |
+
return self.cache_seqlens[0].item()
|
| 114 |
+
|
| 115 |
+
def get_layer_cache(self, layer_idx):
|
| 116 |
+
"""Return (k_cache, v_cache) views for a specific layer."""
|
| 117 |
+
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
| 118 |
+
|
| 119 |
+
def advance(self, num_tokens):
|
| 120 |
+
"""Advance the cache position by num_tokens."""
|
| 121 |
+
self.cache_seqlens += num_tokens
|
| 122 |
+
|
| 123 |
+
def prefill(self, other):
|
| 124 |
+
"""
|
| 125 |
+
Copy cached KV from another cache into this one.
|
| 126 |
+
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
| 127 |
+
"""
|
| 128 |
+
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
| 129 |
+
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
| 130 |
+
assert self.max_seq_len >= other.max_seq_len
|
| 131 |
+
other_pos = other.get_pos()
|
| 132 |
+
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
| 133 |
+
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
| 134 |
+
self.cache_seqlens.fill_(other_pos)
|
| 135 |
+
# Copy smear state: expand batch=1 prev_embedding to num_samples
|
| 136 |
+
if other.prev_embedding is not None:
|
| 137 |
+
self.prev_embedding = other.prev_embedding.expand(self.batch_size, -1, -1).clone()
|
| 138 |
+
|
| 139 |
+
# -----------------------------------------------------------------------------
|
| 140 |
+
@torch.inference_mode()
|
| 141 |
+
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
| 142 |
+
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
| 143 |
+
assert temperature >= 0.0, "temperature must be non-negative"
|
| 144 |
+
if temperature == 0.0:
|
| 145 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 146 |
+
if top_k is not None and top_k > 0:
|
| 147 |
+
k = min(top_k, logits.size(-1))
|
| 148 |
+
vals, idx = torch.topk(logits, k, dim=-1)
|
| 149 |
+
vals = vals / temperature
|
| 150 |
+
probs = F.softmax(vals, dim=-1)
|
| 151 |
+
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 152 |
+
return idx.gather(1, choice)
|
| 153 |
+
else:
|
| 154 |
+
logits = logits / temperature
|
| 155 |
+
probs = F.softmax(logits, dim=-1)
|
| 156 |
+
return torch.multinomial(probs, num_samples=1, generator=rng)
|
| 157 |
+
|
| 158 |
+
# -----------------------------------------------------------------------------
|
| 159 |
+
|
| 160 |
+
class RowState:
|
| 161 |
+
# Per-row state tracking during generation
|
| 162 |
+
def __init__(self, current_tokens=None):
|
| 163 |
+
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
| 164 |
+
self.forced_tokens = deque() # Queue of tokens to force inject
|
| 165 |
+
self.in_python_block = False # Whether we are inside a python block
|
| 166 |
+
self.python_expr_tokens = [] # Tokens of the current python expression
|
| 167 |
+
self.completed = False # Whether this row has completed generation
|
| 168 |
+
|
| 169 |
+
class Engine:
|
| 170 |
+
|
| 171 |
+
def __init__(self, model, tokenizer):
|
| 172 |
+
self.model = model
|
| 173 |
+
self.tokenizer = tokenizer # needed for tool use
|
| 174 |
+
|
| 175 |
+
@torch.inference_mode()
|
| 176 |
+
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
| 177 |
+
"""Same as generate, but does single prefill and then clones the KV cache."""
|
| 178 |
+
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
| 179 |
+
device = self.model.get_device()
|
| 180 |
+
# NOTE: setting the dtype here and in this way is an ugly hack.
|
| 181 |
+
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
| 182 |
+
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
| 183 |
+
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
| 184 |
+
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
| 185 |
+
# In particular, the KVCache should allocate its tensors lazily
|
| 186 |
+
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 187 |
+
rng = torch.Generator(device=device)
|
| 188 |
+
rng.manual_seed(seed)
|
| 189 |
+
|
| 190 |
+
# Get the special tokens we need to coordinate the tool use state machine
|
| 191 |
+
get_special = lambda s: self.tokenizer.encode_special(s)
|
| 192 |
+
python_start = get_special("<|python_start|>")
|
| 193 |
+
python_end = get_special("<|python_end|>")
|
| 194 |
+
output_start = get_special("<|output_start|>")
|
| 195 |
+
output_end = get_special("<|output_end|>")
|
| 196 |
+
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
| 197 |
+
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
| 198 |
+
|
| 199 |
+
# 1) Run a batch 1 prefill of the prompt tokens
|
| 200 |
+
m = self.model.config
|
| 201 |
+
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
| 202 |
+
kv_cache_prefill = KVCache(
|
| 203 |
+
batch_size=1,
|
| 204 |
+
seq_len=len(tokens),
|
| 205 |
+
device=device,
|
| 206 |
+
dtype=dtype,
|
| 207 |
+
**kv_model_kwargs,
|
| 208 |
+
)
|
| 209 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 210 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
| 211 |
+
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
| 212 |
+
|
| 213 |
+
# 2) Replicate the KV cache for each sample/row
|
| 214 |
+
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
| 215 |
+
kv_cache_decode = KVCache(
|
| 216 |
+
batch_size=num_samples,
|
| 217 |
+
seq_len=kv_length_hint,
|
| 218 |
+
device=device,
|
| 219 |
+
dtype=dtype,
|
| 220 |
+
**kv_model_kwargs,
|
| 221 |
+
)
|
| 222 |
+
kv_cache_decode.prefill(kv_cache_prefill)
|
| 223 |
+
del kv_cache_prefill # no need to keep this memory around
|
| 224 |
+
|
| 225 |
+
# 3) Initialize states for each sample
|
| 226 |
+
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
| 227 |
+
|
| 228 |
+
# 4) Main generation loop
|
| 229 |
+
num_generated = 0
|
| 230 |
+
while True:
|
| 231 |
+
# Stop condition: we've reached max tokens
|
| 232 |
+
if max_tokens is not None and num_generated >= max_tokens:
|
| 233 |
+
break
|
| 234 |
+
# Stop condition: all rows are completed
|
| 235 |
+
if all(state.completed for state in row_states):
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
# Sample the next token for each row
|
| 239 |
+
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
| 240 |
+
sampled_tokens = next_ids[:, 0].tolist()
|
| 241 |
+
|
| 242 |
+
# Process each row: choose the next token, update state, optional tool use
|
| 243 |
+
token_column = [] # contains the next token id along each row
|
| 244 |
+
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
| 245 |
+
for i, state in enumerate(row_states):
|
| 246 |
+
# Select the next token in this row
|
| 247 |
+
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
| 248 |
+
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
| 249 |
+
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
| 250 |
+
token_column.append(next_token)
|
| 251 |
+
# Update the state of this row to include the next token
|
| 252 |
+
state.current_tokens.append(next_token)
|
| 253 |
+
# On <|assistant_end|> or <|bos|>, mark the row as completed
|
| 254 |
+
if next_token == assistant_end or next_token == bos:
|
| 255 |
+
state.completed = True
|
| 256 |
+
# Handle tool logic
|
| 257 |
+
if next_token == python_start:
|
| 258 |
+
state.in_python_block = True
|
| 259 |
+
state.python_expr_tokens = []
|
| 260 |
+
elif next_token == python_end and state.in_python_block:
|
| 261 |
+
state.in_python_block = False
|
| 262 |
+
if state.python_expr_tokens:
|
| 263 |
+
expr = self.tokenizer.decode(state.python_expr_tokens)
|
| 264 |
+
result = use_calculator(expr)
|
| 265 |
+
if result is not None:
|
| 266 |
+
result_tokens = self.tokenizer.encode(str(result))
|
| 267 |
+
state.forced_tokens.append(output_start)
|
| 268 |
+
state.forced_tokens.extend(result_tokens)
|
| 269 |
+
state.forced_tokens.append(output_end)
|
| 270 |
+
state.python_expr_tokens = []
|
| 271 |
+
elif state.in_python_block:
|
| 272 |
+
state.python_expr_tokens.append(next_token)
|
| 273 |
+
|
| 274 |
+
# Yield the token column
|
| 275 |
+
yield token_column, token_masks
|
| 276 |
+
num_generated += 1
|
| 277 |
+
|
| 278 |
+
# Prepare logits for next iteration
|
| 279 |
+
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
| 280 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
| 281 |
+
|
| 282 |
+
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
| 283 |
+
"""
|
| 284 |
+
Non-streaming batch generation that just returns the final token sequences.
|
| 285 |
+
Returns a list of token sequences (list of lists of ints).
|
| 286 |
+
Terminal tokens (assistant_end, bos) are not included in the results.
|
| 287 |
+
"""
|
| 288 |
+
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
| 289 |
+
bos = self.tokenizer.get_bos_token_id()
|
| 290 |
+
results = [tokens.copy() for _ in range(num_samples)]
|
| 291 |
+
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
| 292 |
+
completed = [False] * num_samples
|
| 293 |
+
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
| 294 |
+
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
| 295 |
+
if not completed[i]:
|
| 296 |
+
if token == assistant_end or token == bos:
|
| 297 |
+
completed[i] = True
|
| 298 |
+
else:
|
| 299 |
+
results[i].append(token)
|
| 300 |
+
masks[i].append(mask)
|
| 301 |
+
# Stop if all rows are completed
|
| 302 |
+
if all(completed):
|
| 303 |
+
break
|
| 304 |
+
return results, masks
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
"""
|
| 309 |
+
Quick inline test to make sure that the naive/slow model.generate function
|
| 310 |
+
is equivalent to the faster Engine.generate function here.
|
| 311 |
+
"""
|
| 312 |
+
import time
|
| 313 |
+
# init compute
|
| 314 |
+
device_type = autodetect_device_type()
|
| 315 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 316 |
+
# load the model and tokenizer
|
| 317 |
+
model, tokenizer, meta = load_model("base", device, phase="eval")
|
| 318 |
+
bos_token_id = tokenizer.get_bos_token_id()
|
| 319 |
+
# common hyperparameters
|
| 320 |
+
kwargs = dict(max_tokens=64, temperature=0.0)
|
| 321 |
+
# set the starting prompt
|
| 322 |
+
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
| 323 |
+
# generate the reference sequence using the model.generate() function
|
| 324 |
+
generated_tokens = []
|
| 325 |
+
torch.cuda.synchronize()
|
| 326 |
+
t0 = time.time()
|
| 327 |
+
stream = model.generate(prompt_tokens, **kwargs)
|
| 328 |
+
for token in stream:
|
| 329 |
+
generated_tokens.append(token)
|
| 330 |
+
chunk = tokenizer.decode([token])
|
| 331 |
+
print(chunk, end="", flush=True)
|
| 332 |
+
print()
|
| 333 |
+
torch.cuda.synchronize()
|
| 334 |
+
t1 = time.time()
|
| 335 |
+
print(f"Reference time: {t1 - t0:.2f}s")
|
| 336 |
+
reference_ids = generated_tokens
|
| 337 |
+
# generate tokens with Engine
|
| 338 |
+
generated_tokens = []
|
| 339 |
+
engine = Engine(model, tokenizer)
|
| 340 |
+
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
| 341 |
+
torch.cuda.synchronize()
|
| 342 |
+
t0 = time.time()
|
| 343 |
+
for token_column, token_masks in stream:
|
| 344 |
+
token = token_column[0] # only print out the first row
|
| 345 |
+
generated_tokens.append(token)
|
| 346 |
+
chunk = tokenizer.decode([token])
|
| 347 |
+
print(chunk, end="", flush=True)
|
| 348 |
+
print()
|
| 349 |
+
torch.cuda.synchronize()
|
| 350 |
+
t1 = time.time()
|
| 351 |
+
print(f"Engine time: {t1 - t0:.2f}s")
|
| 352 |
+
# compare the two sequences
|
| 353 |
+
for i in range(len(reference_ids)):
|
| 354 |
+
if reference_ids[i] != generated_tokens[i]:
|
| 355 |
+
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
| 356 |
+
break
|
| 357 |
+
print(f"Match: {reference_ids == generated_tokens}")
|
nanochat/execution.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sandboxed execution utilities for running Python code that comes out of an LLM.
|
| 3 |
+
Adapted from OpenAI HumanEval code:
|
| 4 |
+
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
|
| 5 |
+
|
| 6 |
+
What is covered:
|
| 7 |
+
- Each execution runs in its own process (can be killed if it hangs or crashes)
|
| 8 |
+
- Execution is limited by a timeout to stop infinite loops
|
| 9 |
+
- Memory limits are enforced by default (256MB)
|
| 10 |
+
- stdout and stderr are captured and returned
|
| 11 |
+
- Code runs in a temporary directory that is deleted afterwards
|
| 12 |
+
- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
|
| 13 |
+
|
| 14 |
+
What is not covered:
|
| 15 |
+
- Not a true security sandbox
|
| 16 |
+
- Network access is not blocked (e.g. sockets could be opened)
|
| 17 |
+
- Python's dynamic features (e.g. ctypes) could bypass restrictions
|
| 18 |
+
- No kernel-level isolation (no seccomp, no containers, no virtualization)
|
| 19 |
+
|
| 20 |
+
Overall this sandbox is good for evaluation of generated code and protects against
|
| 21 |
+
accidental destructive behavior, but it is not safe against malicious adversarial code.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import contextlib
|
| 25 |
+
import faulthandler
|
| 26 |
+
import io
|
| 27 |
+
import multiprocessing
|
| 28 |
+
import os
|
| 29 |
+
import platform
|
| 30 |
+
import signal
|
| 31 |
+
import tempfile
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from typing import Optional
|
| 34 |
+
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ExecutionResult:
|
| 39 |
+
"""Result of executing Python code in a sandbox."""
|
| 40 |
+
success: bool
|
| 41 |
+
stdout: str
|
| 42 |
+
stderr: str
|
| 43 |
+
error: Optional[str] = None
|
| 44 |
+
timeout: bool = False
|
| 45 |
+
memory_exceeded: bool = False
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
parts = []
|
| 49 |
+
parts.append(f"ExecutionResult(success={self.success}")
|
| 50 |
+
if self.timeout:
|
| 51 |
+
parts.append(", timeout=True")
|
| 52 |
+
if self.memory_exceeded:
|
| 53 |
+
parts.append(", memory_exceeded=True")
|
| 54 |
+
if self.error:
|
| 55 |
+
parts.append(f", error={self.error!r}")
|
| 56 |
+
if self.stdout:
|
| 57 |
+
parts.append(f", stdout={self.stdout!r}")
|
| 58 |
+
if self.stderr:
|
| 59 |
+
parts.append(f", stderr={self.stderr!r}")
|
| 60 |
+
parts.append(")")
|
| 61 |
+
return "".join(parts)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@contextlib.contextmanager
|
| 65 |
+
def time_limit(seconds: float):
|
| 66 |
+
def signal_handler(signum, frame):
|
| 67 |
+
raise TimeoutException("Timed out!")
|
| 68 |
+
|
| 69 |
+
signal.setitimer(signal.ITIMER_REAL, seconds)
|
| 70 |
+
signal.signal(signal.SIGALRM, signal_handler)
|
| 71 |
+
try:
|
| 72 |
+
yield
|
| 73 |
+
finally:
|
| 74 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@contextlib.contextmanager
|
| 78 |
+
def capture_io():
|
| 79 |
+
"""Capture stdout and stderr, and disable stdin."""
|
| 80 |
+
stdout_capture = io.StringIO()
|
| 81 |
+
stderr_capture = io.StringIO()
|
| 82 |
+
stdin_block = WriteOnlyStringIO()
|
| 83 |
+
with contextlib.redirect_stdout(stdout_capture):
|
| 84 |
+
with contextlib.redirect_stderr(stderr_capture):
|
| 85 |
+
with redirect_stdin(stdin_block):
|
| 86 |
+
yield stdout_capture, stderr_capture
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@contextlib.contextmanager
|
| 90 |
+
def create_tempdir():
|
| 91 |
+
with tempfile.TemporaryDirectory() as dirname:
|
| 92 |
+
with chdir(dirname):
|
| 93 |
+
yield dirname
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TimeoutException(Exception):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class WriteOnlyStringIO(io.StringIO):
|
| 101 |
+
"""StringIO that throws an exception when it's read from"""
|
| 102 |
+
|
| 103 |
+
def read(self, *args, **kwargs):
|
| 104 |
+
raise IOError
|
| 105 |
+
|
| 106 |
+
def readline(self, *args, **kwargs):
|
| 107 |
+
raise IOError
|
| 108 |
+
|
| 109 |
+
def readlines(self, *args, **kwargs):
|
| 110 |
+
raise IOError
|
| 111 |
+
|
| 112 |
+
def readable(self, *args, **kwargs):
|
| 113 |
+
"""Returns True if the IO object can be read."""
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class redirect_stdin(contextlib._RedirectStream): # type: ignore
|
| 118 |
+
_stream = "stdin"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@contextlib.contextmanager
|
| 122 |
+
def chdir(root):
|
| 123 |
+
if root == ".":
|
| 124 |
+
yield
|
| 125 |
+
return
|
| 126 |
+
cwd = os.getcwd()
|
| 127 |
+
os.chdir(root)
|
| 128 |
+
try:
|
| 129 |
+
yield
|
| 130 |
+
finally:
|
| 131 |
+
os.chdir(cwd)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
| 135 |
+
"""
|
| 136 |
+
This disables various destructive functions and prevents the generated code
|
| 137 |
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
| 138 |
+
removing filesystem files, etc.)
|
| 139 |
+
|
| 140 |
+
WARNING
|
| 141 |
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
| 142 |
+
generated code, should not be blindly executed outside of one. See the
|
| 143 |
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
| 144 |
+
with caution.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
if platform.uname().system != "Darwin":
|
| 148 |
+
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
| 149 |
+
import resource
|
| 150 |
+
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
| 151 |
+
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
| 152 |
+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
| 153 |
+
|
| 154 |
+
faulthandler.disable()
|
| 155 |
+
|
| 156 |
+
import builtins
|
| 157 |
+
|
| 158 |
+
builtins.exit = None
|
| 159 |
+
builtins.quit = None
|
| 160 |
+
|
| 161 |
+
import os
|
| 162 |
+
|
| 163 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 164 |
+
|
| 165 |
+
os.kill = None
|
| 166 |
+
os.system = None
|
| 167 |
+
os.putenv = None
|
| 168 |
+
os.remove = None
|
| 169 |
+
os.removedirs = None
|
| 170 |
+
os.rmdir = None
|
| 171 |
+
os.fchdir = None
|
| 172 |
+
os.setuid = None
|
| 173 |
+
os.fork = None
|
| 174 |
+
os.forkpty = None
|
| 175 |
+
os.killpg = None
|
| 176 |
+
os.rename = None
|
| 177 |
+
os.renames = None
|
| 178 |
+
os.truncate = None
|
| 179 |
+
os.replace = None
|
| 180 |
+
os.unlink = None
|
| 181 |
+
os.fchmod = None
|
| 182 |
+
os.fchown = None
|
| 183 |
+
os.chmod = None
|
| 184 |
+
os.chown = None
|
| 185 |
+
os.chroot = None
|
| 186 |
+
os.fchdir = None
|
| 187 |
+
os.lchflags = None
|
| 188 |
+
os.lchmod = None
|
| 189 |
+
os.lchown = None
|
| 190 |
+
os.getcwd = None
|
| 191 |
+
os.chdir = None
|
| 192 |
+
|
| 193 |
+
import shutil
|
| 194 |
+
|
| 195 |
+
shutil.rmtree = None
|
| 196 |
+
shutil.move = None
|
| 197 |
+
shutil.chown = None
|
| 198 |
+
|
| 199 |
+
import subprocess
|
| 200 |
+
|
| 201 |
+
subprocess.Popen = None # type: ignore
|
| 202 |
+
|
| 203 |
+
__builtins__["help"] = None
|
| 204 |
+
|
| 205 |
+
import sys
|
| 206 |
+
|
| 207 |
+
sys.modules["ipdb"] = None
|
| 208 |
+
sys.modules["joblib"] = None
|
| 209 |
+
sys.modules["resource"] = None
|
| 210 |
+
sys.modules["psutil"] = None
|
| 211 |
+
sys.modules["tkinter"] = None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
|
| 215 |
+
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
|
| 216 |
+
with create_tempdir():
|
| 217 |
+
|
| 218 |
+
# These system calls are needed when cleaning up tempdir.
|
| 219 |
+
import os
|
| 220 |
+
import shutil
|
| 221 |
+
|
| 222 |
+
rmtree = shutil.rmtree
|
| 223 |
+
rmdir = os.rmdir
|
| 224 |
+
chdir = os.chdir
|
| 225 |
+
unlink = os.unlink
|
| 226 |
+
|
| 227 |
+
# Disable functionalities that can make destructive changes to the test.
|
| 228 |
+
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
| 229 |
+
|
| 230 |
+
# Default to failure
|
| 231 |
+
result_dict.update({
|
| 232 |
+
"success": False,
|
| 233 |
+
"stdout": "",
|
| 234 |
+
"stderr": "",
|
| 235 |
+
"timeout": False,
|
| 236 |
+
"memory_exceeded": False,
|
| 237 |
+
"error": None,
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
exec_globals = {}
|
| 242 |
+
with capture_io() as (stdout_capture, stderr_capture):
|
| 243 |
+
with time_limit(timeout):
|
| 244 |
+
# WARNING
|
| 245 |
+
# This program exists to execute untrusted model-generated code. Although
|
| 246 |
+
# it is highly unlikely that model-generated code will do something overtly
|
| 247 |
+
# malicious in response to this test suite, model-generated code may act
|
| 248 |
+
# destructively due to a lack of model capability or alignment.
|
| 249 |
+
# Users are strongly encouraged to sandbox this evaluation suite so that it
|
| 250 |
+
# does not perform destructive actions on their host or network. For more
|
| 251 |
+
# information on how OpenAI sandboxes its code, see the accompanying paper.
|
| 252 |
+
# Once you have read this disclaimer and taken appropriate precautions,
|
| 253 |
+
# uncomment the following line and proceed at your own risk:
|
| 254 |
+
exec(code, exec_globals)
|
| 255 |
+
|
| 256 |
+
result_dict.update({
|
| 257 |
+
"success": True,
|
| 258 |
+
"stdout": stdout_capture.getvalue(),
|
| 259 |
+
"stderr": stderr_capture.getvalue(),
|
| 260 |
+
})
|
| 261 |
+
|
| 262 |
+
except TimeoutException:
|
| 263 |
+
result_dict.update({
|
| 264 |
+
"timeout": True,
|
| 265 |
+
"error": "Execution timed out",
|
| 266 |
+
})
|
| 267 |
+
|
| 268 |
+
except MemoryError as e:
|
| 269 |
+
result_dict.update({
|
| 270 |
+
"memory_exceeded": True,
|
| 271 |
+
"error": f"Memory limit exceeded: {e}",
|
| 272 |
+
})
|
| 273 |
+
|
| 274 |
+
except BaseException as e:
|
| 275 |
+
result_dict.update({
|
| 276 |
+
"error": f"{type(e).__name__}: {e}",
|
| 277 |
+
})
|
| 278 |
+
|
| 279 |
+
# Needed for cleaning up.
|
| 280 |
+
shutil.rmtree = rmtree
|
| 281 |
+
os.rmdir = rmdir
|
| 282 |
+
os.chdir = chdir
|
| 283 |
+
os.unlink = unlink
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def execute_code(
|
| 287 |
+
code: str,
|
| 288 |
+
timeout: float = 5.0, # 5 seconds default
|
| 289 |
+
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
|
| 290 |
+
) -> ExecutionResult:
|
| 291 |
+
"""
|
| 292 |
+
Execute Python code in a sandboxed environment.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
code: Python code to execute as a string
|
| 296 |
+
timeout: Maximum execution time in seconds (default: 5.0)
|
| 297 |
+
maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
ExecutionResult with success status, stdout/stderr, and error information
|
| 301 |
+
|
| 302 |
+
Example:
|
| 303 |
+
>>> result = execute_code("print('hello world')")
|
| 304 |
+
>>> result.success
|
| 305 |
+
True
|
| 306 |
+
>>> result.stdout
|
| 307 |
+
'hello world\\n'
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
manager = multiprocessing.Manager()
|
| 311 |
+
result_dict = manager.dict()
|
| 312 |
+
|
| 313 |
+
p = multiprocessing.Process(
|
| 314 |
+
target=_unsafe_execute,
|
| 315 |
+
args=(code, timeout, maximum_memory_bytes, result_dict)
|
| 316 |
+
)
|
| 317 |
+
p.start()
|
| 318 |
+
p.join(timeout=timeout + 1)
|
| 319 |
+
|
| 320 |
+
if p.is_alive():
|
| 321 |
+
p.kill()
|
| 322 |
+
return ExecutionResult(
|
| 323 |
+
success=False,
|
| 324 |
+
stdout="",
|
| 325 |
+
stderr="",
|
| 326 |
+
error="Execution timed out (process killed)",
|
| 327 |
+
timeout=True,
|
| 328 |
+
memory_exceeded=False,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if not result_dict:
|
| 332 |
+
return ExecutionResult(
|
| 333 |
+
success=False,
|
| 334 |
+
stdout="",
|
| 335 |
+
stderr="",
|
| 336 |
+
error="Execution failed (no result returned)",
|
| 337 |
+
timeout=True,
|
| 338 |
+
memory_exceeded=False,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return ExecutionResult(
|
| 342 |
+
success=result_dict["success"],
|
| 343 |
+
stdout=result_dict["stdout"],
|
| 344 |
+
stderr=result_dict["stderr"],
|
| 345 |
+
error=result_dict["error"],
|
| 346 |
+
timeout=result_dict["timeout"],
|
| 347 |
+
memory_exceeded=result_dict["memory_exceeded"],
|
| 348 |
+
)
|
| 349 |
+
|
nanochat/flash_attention.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
| 3 |
+
|
| 4 |
+
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
| 5 |
+
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
| 6 |
+
|
| 7 |
+
Usage (drop-in replacement for FA3):
|
| 8 |
+
from nanochat.flash_attention import flash_attn
|
| 9 |
+
|
| 10 |
+
# Training (no KV cache)
|
| 11 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 12 |
+
|
| 13 |
+
# Inference (with KV cache)
|
| 14 |
+
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
| 15 |
+
"""
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# Detection: Try to load FA3 on Hopper+ GPUs
|
| 22 |
+
# =============================================================================
|
| 23 |
+
def _load_flash_attention_3():
|
| 24 |
+
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
| 25 |
+
if not torch.cuda.is_available():
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
major, _ = torch.cuda.get_device_capability()
|
| 29 |
+
# FA3 kernels are compiled for Hopper (sm90) only
|
| 30 |
+
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
| 31 |
+
if major != 9:
|
| 32 |
+
return None
|
| 33 |
+
import os
|
| 34 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 35 |
+
from kernels import get_kernel
|
| 36 |
+
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
| 37 |
+
except Exception:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_fa3 = _load_flash_attention_3()
|
| 42 |
+
HAS_FA3 = _fa3 is not None
|
| 43 |
+
|
| 44 |
+
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
| 45 |
+
_override_impl = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _resolve_use_fa3():
|
| 49 |
+
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
| 50 |
+
if _override_impl == 'fa3':
|
| 51 |
+
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
| 52 |
+
return True
|
| 53 |
+
if _override_impl == 'sdpa':
|
| 54 |
+
return False
|
| 55 |
+
if HAS_FA3:
|
| 56 |
+
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
| 57 |
+
from nanochat.common import COMPUTE_DTYPE
|
| 58 |
+
if COMPUTE_DTYPE == torch.bfloat16:
|
| 59 |
+
return True
|
| 60 |
+
return False
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
USE_FA3 = _resolve_use_fa3()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# =============================================================================
|
| 67 |
+
# SDPA helpers
|
| 68 |
+
# =============================================================================
|
| 69 |
+
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
| 70 |
+
"""
|
| 71 |
+
SDPA attention with sliding window support.
|
| 72 |
+
q, k, v are (B, H, T, D) format.
|
| 73 |
+
"""
|
| 74 |
+
Tq = q.size(2)
|
| 75 |
+
Tk = k.size(2)
|
| 76 |
+
window = window_size[0]
|
| 77 |
+
|
| 78 |
+
# Full context, same length
|
| 79 |
+
if (window < 0 or window >= Tq) and Tq == Tk:
|
| 80 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
| 81 |
+
|
| 82 |
+
# Single token generation
|
| 83 |
+
if Tq == 1:
|
| 84 |
+
if window >= 0 and window < Tk:
|
| 85 |
+
# window is "left" tokens we need to include (window + 1) keys total
|
| 86 |
+
start = max(0, Tk - (window + 1))
|
| 87 |
+
k = k[:, :, start:, :]
|
| 88 |
+
v = v[:, :, start:, :]
|
| 89 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
| 90 |
+
|
| 91 |
+
# Need explicit mask for sliding window/chunk inference
|
| 92 |
+
device = q.device
|
| 93 |
+
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
| 94 |
+
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
| 95 |
+
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
| 96 |
+
mask = col_idx <= row_idx
|
| 97 |
+
|
| 98 |
+
# sliding window (left)
|
| 99 |
+
if window >= 0 and window < Tk:
|
| 100 |
+
mask = mask & ((row_idx - col_idx) <= window)
|
| 101 |
+
|
| 102 |
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# Public API: Same interface as FA3
|
| 106 |
+
# =============================================================================
|
| 107 |
+
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
| 108 |
+
"""
|
| 109 |
+
Flash Attention for training (no KV cache).
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
q, k, v: Tensors of shape (B, T, H, D)
|
| 113 |
+
causal: Whether to use causal masking
|
| 114 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Output tensor of shape (B, T, H, D)
|
| 118 |
+
"""
|
| 119 |
+
if USE_FA3:
|
| 120 |
+
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
| 121 |
+
|
| 122 |
+
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
| 123 |
+
q = q.transpose(1, 2)
|
| 124 |
+
k = k.transpose(1, 2)
|
| 125 |
+
v = v.transpose(1, 2)
|
| 126 |
+
enable_gqa = q.size(1) != k.size(1)
|
| 127 |
+
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
| 128 |
+
return y.transpose(1, 2) # back to (B, T, H, D)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
| 132 |
+
causal=False, window_size=(-1, -1)):
|
| 133 |
+
"""
|
| 134 |
+
Flash Attention with KV cache for inference.
|
| 135 |
+
|
| 136 |
+
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
q: Queries, shape (B, T_new, H, D)
|
| 140 |
+
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
| 141 |
+
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
| 142 |
+
cache_seqlens: Current position in cache, shape (B,) int32
|
| 143 |
+
causal: Whether to use causal masking
|
| 144 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Output tensor of shape (B, T_new, H, D)
|
| 148 |
+
"""
|
| 149 |
+
if USE_FA3:
|
| 150 |
+
return _fa3.flash_attn_with_kvcache(
|
| 151 |
+
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
| 152 |
+
causal=causal, window_size=window_size
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# SDPA fallback: manually manage KV cache
|
| 156 |
+
B, T_new, H, D = q.shape
|
| 157 |
+
pos = cache_seqlens[0].item() # assume uniform position across batch
|
| 158 |
+
|
| 159 |
+
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
| 160 |
+
if k is not None and v is not None:
|
| 161 |
+
k_cache[:, pos:pos+T_new, :, :] = k
|
| 162 |
+
v_cache[:, pos:pos+T_new, :, :] = v
|
| 163 |
+
|
| 164 |
+
# Get full cache up to current position + new tokens
|
| 165 |
+
end_pos = pos + T_new
|
| 166 |
+
k_full = k_cache[:, :end_pos, :, :]
|
| 167 |
+
v_full = v_cache[:, :end_pos, :, :]
|
| 168 |
+
|
| 169 |
+
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
| 170 |
+
q_sdpa = q.transpose(1, 2)
|
| 171 |
+
k_sdpa = k_full.transpose(1, 2)
|
| 172 |
+
v_sdpa = v_full.transpose(1, 2)
|
| 173 |
+
|
| 174 |
+
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
| 175 |
+
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
| 176 |
+
|
| 177 |
+
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# =============================================================================
|
| 181 |
+
# Export: flash_attn module interface (drop-in replacement for FA3)
|
| 182 |
+
# =============================================================================
|
| 183 |
+
from types import SimpleNamespace
|
| 184 |
+
flash_attn = SimpleNamespace(
|
| 185 |
+
flash_attn_func=flash_attn_func,
|
| 186 |
+
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
| 187 |
+
)
|
nanochat/fp8.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
|
| 2 |
+
|
| 3 |
+
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
|
| 4 |
+
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
|
| 5 |
+
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
|
| 6 |
+
subclass dispatch tables, etc.)
|
| 7 |
+
|
| 8 |
+
How FP8 training works
|
| 9 |
+
======================
|
| 10 |
+
A standard Linear layer does one matmul in forward and two in backward:
|
| 11 |
+
forward: output = input @ weight.T
|
| 12 |
+
backward: grad_input = grad_output @ weight
|
| 13 |
+
grad_weight= grad_output.T @ input
|
| 14 |
+
|
| 15 |
+
FP8 training wraps each of these three matmuls with:
|
| 16 |
+
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
|
| 17 |
+
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
|
| 18 |
+
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
|
| 19 |
+
4. Dequantize: _scaled_mm handles this internally using the inverse scales
|
| 20 |
+
|
| 21 |
+
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
|
| 22 |
+
torchao is just orchestration around these primitives. We can call them directly.
|
| 23 |
+
|
| 24 |
+
FP8 dtype choice
|
| 25 |
+
================
|
| 26 |
+
There are two FP8 formats. We use both, following the standard convention:
|
| 27 |
+
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
|
| 28 |
+
Higher precision (more mantissa bits), used for input and weight.
|
| 29 |
+
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
|
| 30 |
+
Wider range (more exponent bits), used for gradients which can be large.
|
| 31 |
+
|
| 32 |
+
torch._scaled_mm layout requirements
|
| 33 |
+
=====================================
|
| 34 |
+
The cuBLAS FP8 kernel requires specific memory layouts:
|
| 35 |
+
- First argument (A): must be row-major (contiguous)
|
| 36 |
+
- Second argument (B): must be column-major (B.t().contiguous().t())
|
| 37 |
+
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
|
| 38 |
+
already column-major — no copy needed. Otherwise we use _to_col_major().
|
| 39 |
+
|
| 40 |
+
How this differs from torchao's approach
|
| 41 |
+
========================================
|
| 42 |
+
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
|
| 43 |
+
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
|
| 44 |
+
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
|
| 45 |
+
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
|
| 46 |
+
output = input @ weight.T
|
| 47 |
+
the @ operator dispatches to aten.mm, which gets intercepted and routed to
|
| 48 |
+
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
|
| 49 |
+
a handler for every tensor operation that might touch an FP8 tensor.
|
| 50 |
+
|
| 51 |
+
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
|
| 52 |
+
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
|
| 53 |
+
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
|
| 54 |
+
opaque node rather than trying to trace inside.
|
| 55 |
+
|
| 56 |
+
The trade-off is in how torch.compile sees the two approaches:
|
| 57 |
+
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
|
| 58 |
+
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
|
| 59 |
+
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
|
| 60 |
+
amax computation with the preceding layer's activation function).
|
| 61 |
+
- ours: compile sees a single opaque call. It can optimize everything around
|
| 62 |
+
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
|
| 63 |
+
|
| 64 |
+
Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical.
|
| 65 |
+
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
|
| 66 |
+
compared to the matmul. In practice this means our version is slightly faster
|
| 67 |
+
(less compilation overhead, no tensor subclass dispatch cost) but can produce
|
| 68 |
+
subtly different floating-point rounding paths under torch.compile, since Inductor
|
| 69 |
+
generates a different graph. Numerics are bitwise identical in eager mode.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
import torch
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
|
| 75 |
+
from nanochat.common import COMPUTE_DTYPE
|
| 76 |
+
|
| 77 |
+
# Avoid division by zero when computing scale from an all-zeros tensor
|
| 78 |
+
EPS = 1e-12
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def _to_fp8(x, fp8_dtype):
|
| 83 |
+
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
|
| 84 |
+
|
| 85 |
+
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
|
| 86 |
+
"rowwise" which computes a separate scale per row). Tensorwise is faster
|
| 87 |
+
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
|
| 88 |
+
|
| 89 |
+
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
|
| 90 |
+
"""
|
| 91 |
+
fp8_max = torch.finfo(fp8_dtype).max
|
| 92 |
+
# Compute the max absolute value across the entire tensor
|
| 93 |
+
amax = x.float().abs().max()
|
| 94 |
+
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
|
| 95 |
+
# ensure consistent numerics between torch.compile and eager mode.
|
| 96 |
+
# (torchao does the same upcast — without it, compile/eager can diverge)
|
| 97 |
+
scale = fp8_max / amax.double().clamp(min=EPS)
|
| 98 |
+
scale = scale.float()
|
| 99 |
+
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
|
| 100 |
+
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
|
| 101 |
+
x_scaled = x.float() * scale
|
| 102 |
+
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
|
| 103 |
+
x_fp8 = x_clamped.to(fp8_dtype)
|
| 104 |
+
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
|
| 105 |
+
# convert FP8 values back to the original range during the matmul)
|
| 106 |
+
inv_scale = scale.reciprocal()
|
| 107 |
+
return x_fp8, inv_scale
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _to_col_major(x):
|
| 111 |
+
"""Rearrange a 2D tensor's memory to column-major layout.
|
| 112 |
+
|
| 113 |
+
torch._scaled_mm requires its second operand in column-major layout.
|
| 114 |
+
The trick: transpose -> contiguous (forces a copy in transposed order)
|
| 115 |
+
-> transpose back. The result has the same logical shape but column-major
|
| 116 |
+
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
|
| 117 |
+
"""
|
| 118 |
+
return x.t().contiguous().t()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# allow_in_graph tells torch.compile to treat this as an opaque operation —
|
| 122 |
+
# dynamo won't try to decompose it into smaller ops. See the module docstring
|
| 123 |
+
# for how this differs from torchao's tensor subclass approach.
|
| 124 |
+
@torch._dynamo.allow_in_graph
|
| 125 |
+
class _Float8Matmul(torch.autograd.Function):
|
| 126 |
+
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
| 127 |
+
|
| 128 |
+
The forward quantizes input and weight to FP8 and saves
|
| 129 |
+
the quantized tensors + scales for backward.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def forward(ctx, input_2d, weight):
|
| 134 |
+
# Quantize both operands to e4m3 (higher precision format)
|
| 135 |
+
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
| 136 |
+
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
| 137 |
+
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
|
| 138 |
+
|
| 139 |
+
# output = input @ weight.T
|
| 140 |
+
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
| 141 |
+
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
|
| 142 |
+
# strides (1, K) = column-major (good for second arg, no copy needed!)
|
| 143 |
+
output = torch._scaled_mm(
|
| 144 |
+
input_fp8,
|
| 145 |
+
weight_fp8.t(),
|
| 146 |
+
scale_a=input_inv,
|
| 147 |
+
scale_b=weight_inv,
|
| 148 |
+
out_dtype=input_2d.dtype,
|
| 149 |
+
# use_fast_accum=True accumulates the dot products in lower precision.
|
| 150 |
+
# Slightly less accurate but measurably faster. Standard practice for
|
| 151 |
+
# the forward pass; we use False in backward for more precise gradients.
|
| 152 |
+
use_fast_accum=True,
|
| 153 |
+
)
|
| 154 |
+
return output
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def backward(ctx, grad_output):
|
| 158 |
+
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
|
| 159 |
+
|
| 160 |
+
# === GEMM 1: grad_input = grad_output @ weight ===
|
| 161 |
+
# Shapes: [B, N] @ [N, K] -> [B, K]
|
| 162 |
+
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
| 163 |
+
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
| 164 |
+
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
| 165 |
+
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
| 166 |
+
w_col = _to_col_major(w_fp8)
|
| 167 |
+
grad_input = torch._scaled_mm(
|
| 168 |
+
go_fp8,
|
| 169 |
+
w_col,
|
| 170 |
+
scale_a=go_inv,
|
| 171 |
+
scale_b=w_inv,
|
| 172 |
+
out_dtype=grad_output.dtype,
|
| 173 |
+
use_fast_accum=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
| 177 |
+
# Shapes: [N, B] @ [B, K] -> [N, K]
|
| 178 |
+
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
| 179 |
+
# Transposing gives column-major, but first arg needs row-major,
|
| 180 |
+
# so we must call .contiguous() to physically rearrange the memory.
|
| 181 |
+
go_T = go_fp8.t().contiguous() # [N, B] row-major
|
| 182 |
+
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
| 183 |
+
grad_weight = torch._scaled_mm(
|
| 184 |
+
go_T,
|
| 185 |
+
in_col,
|
| 186 |
+
scale_a=go_inv,
|
| 187 |
+
scale_b=in_inv,
|
| 188 |
+
out_dtype=grad_output.dtype,
|
| 189 |
+
use_fast_accum=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return grad_input, grad_weight
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Float8Linear(nn.Linear):
|
| 196 |
+
"""Drop-in nn.Linear replacement that does FP8 compute.
|
| 197 |
+
|
| 198 |
+
Weights and biases remain in their original precision (e.g. fp32/bf16).
|
| 199 |
+
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def forward(self, input):
|
| 203 |
+
# Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
|
| 204 |
+
# reduced precision input, and we no longer rely on autocast to do this.
|
| 205 |
+
input = input.to(COMPUTE_DTYPE)
|
| 206 |
+
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
| 207 |
+
orig_shape = input.shape
|
| 208 |
+
input_2d = input.reshape(-1, orig_shape[-1])
|
| 209 |
+
output = _Float8Matmul.apply(input_2d, self.weight)
|
| 210 |
+
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
| 211 |
+
if self.bias is not None:
|
| 212 |
+
output = output + self.bias.to(output.dtype)
|
| 213 |
+
return output
|
| 214 |
+
|
| 215 |
+
@classmethod
|
| 216 |
+
def from_float(cls, mod):
|
| 217 |
+
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
|
| 218 |
+
|
| 219 |
+
Uses meta device to avoid allocating a temporary weight tensor — we
|
| 220 |
+
create the module shell on meta (shapes/dtypes only, no memory), then
|
| 221 |
+
point .weight and .bias to the original module's parameters.
|
| 222 |
+
"""
|
| 223 |
+
with torch.device("meta"):
|
| 224 |
+
new_mod = cls(mod.in_features, mod.out_features, bias=False)
|
| 225 |
+
new_mod.weight = mod.weight
|
| 226 |
+
new_mod.bias = mod.bias
|
| 227 |
+
return new_mod
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class Float8LinearConfig:
|
| 231 |
+
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
|
| 232 |
+
|
| 233 |
+
@staticmethod
|
| 234 |
+
def from_recipe_name(recipe_name):
|
| 235 |
+
if recipe_name != "tensorwise":
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
|
| 238 |
+
f"Rowwise/axiswise recipes require the full torchao library."
|
| 239 |
+
)
|
| 240 |
+
return Float8LinearConfig()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
|
| 244 |
+
"""Replace nn.Linear layers with Float8Linear throughout a module.
|
| 245 |
+
|
| 246 |
+
Walks the module tree in post-order (children before parents) and swaps
|
| 247 |
+
each nn.Linear that passes the optional filter. The new Float8Linear shares
|
| 248 |
+
the original weight and bias tensors — no copies, no extra memory.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
module: Root module to convert.
|
| 252 |
+
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
|
| 253 |
+
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
|
| 254 |
+
are converted. Common use: skip layers with dims not divisible by 16
|
| 255 |
+
(hardware requirement for FP8 matmuls on H100).
|
| 256 |
+
"""
|
| 257 |
+
def _convert(mod, prefix=""):
|
| 258 |
+
for name, child in mod.named_children():
|
| 259 |
+
fqn = f"{prefix}.{name}" if prefix else name
|
| 260 |
+
_convert(child, fqn)
|
| 261 |
+
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
|
| 262 |
+
if module_filter_fn is None or module_filter_fn(child, fqn):
|
| 263 |
+
setattr(mod, name, Float8Linear.from_float(child))
|
| 264 |
+
|
| 265 |
+
_convert(module)
|
| 266 |
+
return module
|
nanochat/gpt.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT model (rewrite, a lot simpler)
|
| 3 |
+
Notable features:
|
| 4 |
+
- rotary embeddings (and no positional embeddings)
|
| 5 |
+
- QK norm
|
| 6 |
+
- untied weights for token embedding and lm_head
|
| 7 |
+
- relu^2 activation in MLP
|
| 8 |
+
- norm after token embedding
|
| 9 |
+
- no learnable params in rmsnorm
|
| 10 |
+
- no bias in linear layers
|
| 11 |
+
- Group-Query Attention (GQA) support for more efficient inference
|
| 12 |
+
- Flash Attention 3 integration
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from functools import partial
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
| 23 |
+
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
| 24 |
+
|
| 25 |
+
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
| 26 |
+
from nanochat.flash_attention import flash_attn
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class GPTConfig:
|
| 30 |
+
sequence_len: int = 2048
|
| 31 |
+
vocab_size: int = 32768
|
| 32 |
+
n_layer: int = 12
|
| 33 |
+
n_head: int = 6 # number of query heads
|
| 34 |
+
n_kv_head: int = 6 # number of key/value heads (GQA)
|
| 35 |
+
n_embd: int = 768
|
| 36 |
+
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
| 37 |
+
# Characters: L=long (full context), S=short (quarter context)
|
| 38 |
+
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
| 39 |
+
window_pattern: str = "SSSL"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def norm(x):
|
| 43 |
+
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
| 44 |
+
|
| 45 |
+
class Linear(nn.Linear):
|
| 46 |
+
"""nn.Linear that casts weights to match input dtype in forward.
|
| 47 |
+
Replaces autocast: master weights stay fp32 for optimizer precision,
|
| 48 |
+
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return F.linear(x, self.weight.to(dtype=x.dtype))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def has_ve(layer_idx, n_layer):
|
| 54 |
+
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
| 55 |
+
return layer_idx % 2 == (n_layer - 1) % 2
|
| 56 |
+
|
| 57 |
+
def apply_rotary_emb(x, cos, sin):
|
| 58 |
+
assert x.ndim == 4 # multihead attention
|
| 59 |
+
d = x.shape[3] // 2
|
| 60 |
+
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
| 61 |
+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
| 62 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 63 |
+
return torch.cat([y1, y2], 3)
|
| 64 |
+
|
| 65 |
+
class CausalSelfAttention(nn.Module):
|
| 66 |
+
def __init__(self, config, layer_idx):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.layer_idx = layer_idx
|
| 69 |
+
self.n_head = config.n_head
|
| 70 |
+
self.n_kv_head = config.n_kv_head
|
| 71 |
+
self.n_embd = config.n_embd
|
| 72 |
+
self.head_dim = self.n_embd // self.n_head
|
| 73 |
+
assert self.n_embd % self.n_head == 0
|
| 74 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
| 75 |
+
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
| 76 |
+
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 77 |
+
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 78 |
+
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
| 79 |
+
self.ve_gate_channels = 12
|
| 80 |
+
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
| 81 |
+
|
| 82 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 83 |
+
B, T, C = x.size()
|
| 84 |
+
|
| 85 |
+
# Project the input to get queries, keys, and values
|
| 86 |
+
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
| 87 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 88 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 89 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 90 |
+
|
| 91 |
+
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
| 92 |
+
if ve is not None:
|
| 93 |
+
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
| 94 |
+
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
|
| 95 |
+
v = v + gate.unsqueeze(-1) * ve
|
| 96 |
+
|
| 97 |
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 98 |
+
cos, sin = cos_sin
|
| 99 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
| 100 |
+
q, k = norm(q), norm(k) # QK norm
|
| 101 |
+
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
|
| 102 |
+
k = k * 1.2
|
| 103 |
+
|
| 104 |
+
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
| 105 |
+
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
| 106 |
+
if kv_cache is None:
|
| 107 |
+
# Training: causal attention with optional sliding window
|
| 108 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 109 |
+
else:
|
| 110 |
+
# Inference: use flash_attn_with_kvcache which handles cache management
|
| 111 |
+
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
| 112 |
+
y = flash_attn.flash_attn_with_kvcache(
|
| 113 |
+
q, k_cache, v_cache,
|
| 114 |
+
k=k, v=v,
|
| 115 |
+
cache_seqlens=kv_cache.cache_seqlens,
|
| 116 |
+
causal=True,
|
| 117 |
+
window_size=window_size,
|
| 118 |
+
)
|
| 119 |
+
# Advance position after last layer processes
|
| 120 |
+
if self.layer_idx == kv_cache.n_layers - 1:
|
| 121 |
+
kv_cache.advance(T)
|
| 122 |
+
|
| 123 |
+
# Re-assemble the heads and project back to residual stream
|
| 124 |
+
y = y.contiguous().view(B, T, -1)
|
| 125 |
+
y = self.c_proj(y)
|
| 126 |
+
return y
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class MLP(nn.Module):
|
| 130 |
+
def __init__(self, config):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 133 |
+
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x = self.c_fc(x)
|
| 137 |
+
x = F.relu(x).square()
|
| 138 |
+
x = self.c_proj(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Block(nn.Module):
|
| 143 |
+
def __init__(self, config, layer_idx):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
| 146 |
+
self.mlp = MLP(config)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 149 |
+
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
| 150 |
+
x = x + self.mlp(norm(x))
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class GPT(nn.Module):
|
| 155 |
+
def __init__(self, config, pad_vocab_size_to=64):
|
| 156 |
+
"""
|
| 157 |
+
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
| 158 |
+
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
| 159 |
+
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
| 160 |
+
"""
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
# Compute per-layer window sizes for sliding window attention
|
| 164 |
+
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
| 165 |
+
self.window_sizes = self._compute_window_sizes(config)
|
| 166 |
+
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
| 167 |
+
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
| 168 |
+
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
| 169 |
+
if padded_vocab_size != config.vocab_size:
|
| 170 |
+
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
| 171 |
+
self.transformer = nn.ModuleDict({
|
| 172 |
+
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
| 173 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 174 |
+
})
|
| 175 |
+
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
| 176 |
+
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
| 177 |
+
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
| 178 |
+
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
| 179 |
+
# Separate parameters so they can have different optimizer treatment
|
| 180 |
+
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
| 181 |
+
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
| 182 |
+
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
|
| 183 |
+
self.smear_gate = Linear(24, 1, bias=False)
|
| 184 |
+
self.smear_lambda = nn.Parameter(torch.zeros(1))
|
| 185 |
+
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
|
| 186 |
+
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
|
| 187 |
+
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
| 188 |
+
head_dim = config.n_embd // config.n_head
|
| 189 |
+
kv_dim = config.n_kv_head * head_dim
|
| 190 |
+
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
| 191 |
+
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
| 192 |
+
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
| 193 |
+
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
| 194 |
+
# In the future we can dynamically grow the cache, for now it's fine.
|
| 195 |
+
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
| 196 |
+
head_dim = config.n_embd // config.n_head
|
| 197 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 198 |
+
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
| 199 |
+
self.register_buffer("sin", sin, persistent=False)
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def init_weights(self):
|
| 203 |
+
"""
|
| 204 |
+
Initialize the full model in this one function for maximum clarity.
|
| 205 |
+
|
| 206 |
+
wte (embedding): normal, std=1.0
|
| 207 |
+
lm_head: normal, std=0.001
|
| 208 |
+
for each block:
|
| 209 |
+
attn.c_q: uniform, std=1/sqrt(n_embd)
|
| 210 |
+
attn.c_k: uniform, std=1/sqrt(n_embd)
|
| 211 |
+
attn.c_v: uniform, std=1/sqrt(n_embd)
|
| 212 |
+
attn.c_proj: zeros
|
| 213 |
+
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
| 214 |
+
mlp.c_proj: zeros
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
# Embedding and unembedding
|
| 218 |
+
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
| 219 |
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
| 220 |
+
|
| 221 |
+
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
| 222 |
+
n_embd = self.config.n_embd
|
| 223 |
+
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
| 224 |
+
for block in self.transformer.h:
|
| 225 |
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
| 226 |
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
| 227 |
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
| 228 |
+
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
| 229 |
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
|
| 230 |
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
| 231 |
+
|
| 232 |
+
# Per-layer scalars
|
| 233 |
+
# Per-layer resid init: stronger residual at early layers, weaker at deep layers
|
| 234 |
+
n_layer = self.config.n_layer
|
| 235 |
+
for i in range(n_layer):
|
| 236 |
+
self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
|
| 237 |
+
# Decaying x0 init: earlier layers get more input embedding blending
|
| 238 |
+
for i in range(n_layer):
|
| 239 |
+
self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
|
| 240 |
+
|
| 241 |
+
# Value embeddings (init like c_v: uniform with same std)
|
| 242 |
+
for ve in self.value_embeds.values():
|
| 243 |
+
torch.nn.init.uniform_(ve.weight, -s, s)
|
| 244 |
+
|
| 245 |
+
# Gate weights init with small positive values so gates start slightly above neutral
|
| 246 |
+
for block in self.transformer.h:
|
| 247 |
+
if block.attn.ve_gate is not None:
|
| 248 |
+
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
|
| 249 |
+
|
| 250 |
+
# Rotary embeddings
|
| 251 |
+
head_dim = self.config.n_embd // self.config.n_head
|
| 252 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 253 |
+
self.cos, self.sin = cos, sin
|
| 254 |
+
|
| 255 |
+
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
| 256 |
+
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
| 257 |
+
# because GradScaler cannot unscale fp16 gradients.
|
| 258 |
+
if COMPUTE_DTYPE != torch.float16:
|
| 259 |
+
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
| 260 |
+
for ve in self.value_embeds.values():
|
| 261 |
+
ve.to(dtype=COMPUTE_DTYPE)
|
| 262 |
+
|
| 263 |
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
| 264 |
+
# TODO: bump base theta more? e.g. 100K is more common more recently
|
| 265 |
+
# autodetect the device from model embeddings
|
| 266 |
+
if device is None:
|
| 267 |
+
device = self.transformer.wte.weight.device
|
| 268 |
+
# stride the channels
|
| 269 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 270 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
| 271 |
+
# stride the time steps
|
| 272 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
| 273 |
+
# calculate the rotation frequencies at each (time, channel) pair
|
| 274 |
+
freqs = torch.outer(t, inv_freq)
|
| 275 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 276 |
+
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
| 277 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
| 278 |
+
return cos, sin
|
| 279 |
+
|
| 280 |
+
def _compute_window_sizes(self, config):
|
| 281 |
+
"""
|
| 282 |
+
Compute per-layer window sizes for sliding window attention.
|
| 283 |
+
|
| 284 |
+
Returns list of (left, right) tuples for FA3's window_size parameter:
|
| 285 |
+
- left: how many tokens before current position to attend to (-1 = unlimited)
|
| 286 |
+
- right: how many tokens after current position to attend to (0 for causal)
|
| 287 |
+
|
| 288 |
+
Pattern string is tiled across layers. Final layer always gets L (full context).
|
| 289 |
+
Characters: L=long (full context), S=short (quarter context)
|
| 290 |
+
"""
|
| 291 |
+
pattern = config.window_pattern.upper()
|
| 292 |
+
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
| 293 |
+
# Map characters to window sizes
|
| 294 |
+
long_window = config.sequence_len
|
| 295 |
+
short_window = -(-long_window // 4 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
| 296 |
+
char_to_window = {
|
| 297 |
+
"L": (long_window, 0),
|
| 298 |
+
"S": (short_window, 0),
|
| 299 |
+
}
|
| 300 |
+
# Tile pattern across layers
|
| 301 |
+
window_sizes = []
|
| 302 |
+
for layer_idx in range(config.n_layer):
|
| 303 |
+
char = pattern[layer_idx % len(pattern)]
|
| 304 |
+
window_sizes.append(char_to_window[char])
|
| 305 |
+
# Final layer always gets full context
|
| 306 |
+
window_sizes[-1] = (long_window, 0)
|
| 307 |
+
return window_sizes
|
| 308 |
+
|
| 309 |
+
def get_device(self):
|
| 310 |
+
return self.transformer.wte.weight.device
|
| 311 |
+
|
| 312 |
+
def estimate_flops(self):
|
| 313 |
+
"""
|
| 314 |
+
Return the estimated FLOPs per token for the model (forward + backward).
|
| 315 |
+
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
| 316 |
+
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
| 317 |
+
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
| 318 |
+
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
| 319 |
+
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
| 320 |
+
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
| 321 |
+
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
| 322 |
+
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
| 323 |
+
"""
|
| 324 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 325 |
+
# Exclude non-matmul params: embeddings and per-layer scalars
|
| 326 |
+
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
| 327 |
+
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
| 328 |
+
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
| 329 |
+
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
| 330 |
+
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
| 331 |
+
# Sum attention FLOPs per layer, accounting for sliding window
|
| 332 |
+
attn_flops = 0
|
| 333 |
+
for window_size in self.window_sizes:
|
| 334 |
+
window = window_size[0] # (left, right) tuple, we use left
|
| 335 |
+
effective_seq = t if window < 0 else min(window, t)
|
| 336 |
+
attn_flops += 12 * h * q * effective_seq
|
| 337 |
+
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
| 338 |
+
return num_flops_per_token
|
| 339 |
+
|
| 340 |
+
def num_scaling_params(self):
|
| 341 |
+
"""
|
| 342 |
+
Return detailed parameter counts for scaling law analysis.
|
| 343 |
+
Different papers use different conventions:
|
| 344 |
+
- Kaplan et al. excluded embedding parameters
|
| 345 |
+
- Chinchilla included all parameters
|
| 346 |
+
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
| 347 |
+
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
| 348 |
+
|
| 349 |
+
Returns a dict with counts for each parameter group, so downstream analysis
|
| 350 |
+
can experiment with which combination gives the cleanest scaling laws.
|
| 351 |
+
"""
|
| 352 |
+
# Count each group separately (mirrors the grouping in setup_optimizers)
|
| 353 |
+
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
| 354 |
+
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
| 355 |
+
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 356 |
+
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
| 357 |
+
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
| 358 |
+
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
| 359 |
+
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
| 360 |
+
return {
|
| 361 |
+
'wte': wte,
|
| 362 |
+
'value_embeds': value_embeds,
|
| 363 |
+
'lm_head': lm_head,
|
| 364 |
+
'transformer_matrices': transformer_matrices,
|
| 365 |
+
'scalars': scalars,
|
| 366 |
+
'total': total,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
| 370 |
+
model_dim = self.config.n_embd
|
| 371 |
+
ddp, rank, local_rank, world_size = get_dist_info()
|
| 372 |
+
|
| 373 |
+
# Separate out all parameters into groups
|
| 374 |
+
matrix_params = list(self.transformer.h.parameters())
|
| 375 |
+
value_embeds_params = list(self.value_embeds.parameters())
|
| 376 |
+
embedding_params = list(self.transformer.wte.parameters())
|
| 377 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 378 |
+
resid_params = [self.resid_lambdas]
|
| 379 |
+
x0_params = [self.x0_lambdas]
|
| 380 |
+
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
| 381 |
+
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
| 382 |
+
|
| 383 |
+
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
| 384 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 385 |
+
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 386 |
+
|
| 387 |
+
# Build param_groups with all required fields explicit
|
| 388 |
+
param_groups = [
|
| 389 |
+
# AdamW groups (embeddings, lm_head, scalars)
|
| 390 |
+
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
| 391 |
+
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
| 392 |
+
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
| 393 |
+
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
| 394 |
+
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
| 395 |
+
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
| 396 |
+
]
|
| 397 |
+
# Muon groups (matrix params, grouped by shape for stacking)
|
| 398 |
+
for shape in sorted({p.shape for p in matrix_params}):
|
| 399 |
+
group_params = [p for p in matrix_params if p.shape == shape]
|
| 400 |
+
param_groups.append(dict(
|
| 401 |
+
kind='muon', params=group_params, lr=matrix_lr,
|
| 402 |
+
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
| 403 |
+
))
|
| 404 |
+
|
| 405 |
+
Factory = DistMuonAdamW if ddp else MuonAdamW
|
| 406 |
+
optimizer = Factory(param_groups)
|
| 407 |
+
for group in optimizer.param_groups:
|
| 408 |
+
group["initial_lr"] = group["lr"]
|
| 409 |
+
return optimizer
|
| 410 |
+
|
| 411 |
+
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
| 412 |
+
B, T = idx.size()
|
| 413 |
+
|
| 414 |
+
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
| 415 |
+
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
| 416 |
+
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
| 417 |
+
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
| 418 |
+
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
| 419 |
+
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
| 420 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
| 421 |
+
|
| 422 |
+
# Embed the tokens
|
| 423 |
+
x = self.transformer.wte(idx) # embed current token
|
| 424 |
+
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
| 425 |
+
x = norm(x)
|
| 426 |
+
|
| 427 |
+
# Smear: mix previous token's embedding into current position (cheap bigram info)
|
| 428 |
+
if kv_cache is None:
|
| 429 |
+
# Training / naive generate: full sequence available, use fast slice
|
| 430 |
+
assert T > 1, "Training forward pass should have T > 1"
|
| 431 |
+
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
| 432 |
+
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
| 433 |
+
else:
|
| 434 |
+
# KV cache inference: read prev embedding from cache, store current for next step
|
| 435 |
+
x_pre_smear = kv_cache.prev_embedding
|
| 436 |
+
kv_cache.prev_embedding = x[:, -1:, :]
|
| 437 |
+
if T > 1:
|
| 438 |
+
# Prefill: apply smear to positions 1+, same as training
|
| 439 |
+
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
|
| 440 |
+
x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
|
| 441 |
+
elif x_pre_smear is not None:
|
| 442 |
+
# Decode: single token, use cached prev embedding
|
| 443 |
+
gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
|
| 444 |
+
x = x + gate * x_pre_smear
|
| 445 |
+
|
| 446 |
+
# Forward the trunk of the Transformer
|
| 447 |
+
x0 = x # save initial normalized embedding for x0 residual
|
| 448 |
+
n_layer = self.config.n_layer
|
| 449 |
+
backout_layer = n_layer // 2 # cache at halfway point
|
| 450 |
+
x_backout = None
|
| 451 |
+
for i, block in enumerate(self.transformer.h):
|
| 452 |
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
| 453 |
+
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
| 454 |
+
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
| 455 |
+
if i == backout_layer:
|
| 456 |
+
x_backout = x
|
| 457 |
+
# Subtract mid-layer residual to remove low-level features before logit projection
|
| 458 |
+
if x_backout is not None:
|
| 459 |
+
x = x - self.backout_lambda.to(x.dtype) * x_backout
|
| 460 |
+
x = norm(x)
|
| 461 |
+
|
| 462 |
+
# Forward the lm_head (compute logits)
|
| 463 |
+
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
| 464 |
+
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
| 465 |
+
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
| 466 |
+
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
| 467 |
+
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
| 468 |
+
|
| 469 |
+
if targets is not None:
|
| 470 |
+
# training: given the targets, compute and return the loss
|
| 471 |
+
# TODO experiment with chunked cross-entropy?
|
| 472 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
| 473 |
+
return loss
|
| 474 |
+
else:
|
| 475 |
+
# inference: just return the logits directly
|
| 476 |
+
return logits
|
| 477 |
+
|
| 478 |
+
@torch.inference_mode()
|
| 479 |
+
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
| 480 |
+
"""
|
| 481 |
+
Naive autoregressive streaming inference.
|
| 482 |
+
To make it super simple, let's assume:
|
| 483 |
+
- batch size is 1
|
| 484 |
+
- ids and the yielded tokens are simple Python lists and ints
|
| 485 |
+
"""
|
| 486 |
+
assert isinstance(tokens, list)
|
| 487 |
+
device = self.get_device()
|
| 488 |
+
rng = None
|
| 489 |
+
if temperature > 0:
|
| 490 |
+
rng = torch.Generator(device=device)
|
| 491 |
+
rng.manual_seed(seed)
|
| 492 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
| 493 |
+
for _ in range(max_tokens):
|
| 494 |
+
logits = self.forward(ids) # (B, T, vocab_size)
|
| 495 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
| 496 |
+
if top_k is not None and top_k > 0:
|
| 497 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 498 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 499 |
+
if temperature > 0:
|
| 500 |
+
logits = logits / temperature
|
| 501 |
+
probs = F.softmax(logits, dim=-1)
|
| 502 |
+
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 503 |
+
else:
|
| 504 |
+
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
| 505 |
+
ids = torch.cat((ids, next_ids), dim=1)
|
| 506 |
+
token = next_ids.item()
|
| 507 |
+
yield token
|
nanochat/logo.svg
ADDED
|
|
nanochat/loss_eval.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A number of functions that help with evaluating a base model.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def evaluate_bpb(model, batches, steps, token_bytes):
|
| 10 |
+
"""
|
| 11 |
+
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
| 12 |
+
which is a tokenization vocab size-independent metric, meaning you are still comparing
|
| 13 |
+
apples:apples if you change the vocab size. The way this works is that instead of just
|
| 14 |
+
calculating the average loss as usual, you calculate the sum loss, and independently
|
| 15 |
+
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
| 16 |
+
the number of bytes that the target tokens represent.
|
| 17 |
+
|
| 18 |
+
The added complexity is so that:
|
| 19 |
+
1) All "normal" tokens are normalized by the length of the token in bytes
|
| 20 |
+
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
|
| 21 |
+
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
|
| 22 |
+
|
| 23 |
+
In addition to evaluate_loss, we need the token_bytes tensor:
|
| 24 |
+
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
|
| 25 |
+
each token id, or 0 if the token is to not be counted (e.g. special tokens).
|
| 26 |
+
"""
|
| 27 |
+
# record the losses
|
| 28 |
+
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
| 29 |
+
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
| 30 |
+
batch_iter = iter(batches)
|
| 31 |
+
for _ in range(steps):
|
| 32 |
+
x, y = next(batch_iter)
|
| 33 |
+
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
| 34 |
+
loss2d = loss2d.view(-1) # flatten
|
| 35 |
+
y = y.view(-1) # flatten
|
| 36 |
+
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
| 37 |
+
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
| 38 |
+
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
| 39 |
+
valid = y >= 0
|
| 40 |
+
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
| 41 |
+
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
| 42 |
+
num_bytes2d = torch.where(
|
| 43 |
+
valid,
|
| 44 |
+
token_bytes[y_safe],
|
| 45 |
+
torch.zeros_like(y, dtype=token_bytes.dtype)
|
| 46 |
+
)
|
| 47 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 48 |
+
total_bytes += num_bytes2d.sum()
|
| 49 |
+
else:
|
| 50 |
+
# fast path: no ignored targets, safe to index directly
|
| 51 |
+
num_bytes2d = token_bytes[y]
|
| 52 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 53 |
+
total_bytes += num_bytes2d.sum()
|
| 54 |
+
# sum reduce across all ranks
|
| 55 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 56 |
+
if world_size > 1:
|
| 57 |
+
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
| 58 |
+
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
| 59 |
+
# move both to cpu, calculate bpb and return
|
| 60 |
+
total_nats = total_nats.item()
|
| 61 |
+
total_bytes = total_bytes.item()
|
| 62 |
+
if total_bytes == 0:
|
| 63 |
+
return float('inf')
|
| 64 |
+
bpb = total_nats / (math.log(2) * total_bytes)
|
| 65 |
+
return bpb
|
nanochat/optim.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
| 3 |
+
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
| 4 |
+
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
| 5 |
+
|
| 6 |
+
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
| 7 |
+
Further contributions from @karpathy and @chrisjmccormick.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
"""
|
| 16 |
+
Good old AdamW optimizer, fused kernel.
|
| 17 |
+
https://arxiv.org/abs/1711.05101
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 21 |
+
def adamw_step_fused(
|
| 22 |
+
p: Tensor, # (32768, 768) - parameter tensor
|
| 23 |
+
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
| 24 |
+
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
|
| 25 |
+
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
|
| 26 |
+
step_t: Tensor, # () - 0-D CPU tensor, step count
|
| 27 |
+
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
| 28 |
+
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
|
| 29 |
+
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
|
| 30 |
+
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
|
| 31 |
+
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
| 32 |
+
) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
| 35 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 36 |
+
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
| 37 |
+
"""
|
| 38 |
+
# Weight decay (decoupled, applied before the update)
|
| 39 |
+
p.mul_(1 - lr_t * wd_t)
|
| 40 |
+
# Update running averages (lerp_ is cleaner and fuses well)
|
| 41 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 42 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 43 |
+
# Bias corrections
|
| 44 |
+
bias1 = 1 - beta1_t ** step_t
|
| 45 |
+
bias2 = 1 - beta2_t ** step_t
|
| 46 |
+
# Compute update and apply
|
| 47 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 48 |
+
step_size = lr_t / bias1
|
| 49 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 50 |
+
|
| 51 |
+
# -----------------------------------------------------------------------------
|
| 52 |
+
"""
|
| 53 |
+
Muon optimizer adapted and simplified from modded-nanogpt.
|
| 54 |
+
https://github.com/KellerJordan/modded-nanogpt
|
| 55 |
+
|
| 56 |
+
Background:
|
| 57 |
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
| 58 |
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
| 59 |
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
| 60 |
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
| 61 |
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
| 62 |
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
| 63 |
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 64 |
+
|
| 65 |
+
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
| 66 |
+
Polar Express Sign Method for orthogonalization.
|
| 67 |
+
https://arxiv.org/pdf/2505.16932
|
| 68 |
+
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
| 69 |
+
|
| 70 |
+
NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
|
| 71 |
+
update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
|
| 72 |
+
https://arxiv.org/pdf/2510.05491
|
| 73 |
+
|
| 74 |
+
Some of the changes in nanochat implementation:
|
| 75 |
+
- Uses a simpler, more general approach to parameter grouping and stacking
|
| 76 |
+
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
| 77 |
+
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
| 81 |
+
# From https://arxiv.org/pdf/2505.16932
|
| 82 |
+
polar_express_coeffs = [
|
| 83 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 84 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 85 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 86 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 87 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 91 |
+
def muon_step_fused(
|
| 92 |
+
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
| 93 |
+
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
| 94 |
+
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
|
| 95 |
+
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
|
| 96 |
+
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
| 97 |
+
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
| 98 |
+
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
| 99 |
+
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
| 100 |
+
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
| 101 |
+
red_dim: int, # -1 or -2 - reduction dimension for variance
|
| 102 |
+
) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
| 105 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 106 |
+
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# Nesterov momentum
|
| 110 |
+
momentum = momentum_t.to(stacked_grads.dtype)
|
| 111 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 112 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 113 |
+
|
| 114 |
+
# Polar express
|
| 115 |
+
X = g.bfloat16()
|
| 116 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
| 117 |
+
if g.size(-2) > g.size(-1): # Tall matrix
|
| 118 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 119 |
+
A = X.mT @ X
|
| 120 |
+
B = b * A + c * (A @ A)
|
| 121 |
+
X = a * X + X @ B
|
| 122 |
+
else: # Wide matrix (original math)
|
| 123 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 124 |
+
A = X @ X.mT
|
| 125 |
+
B = b * A + c * (A @ A)
|
| 126 |
+
X = a * X + B @ X
|
| 127 |
+
g = X
|
| 128 |
+
|
| 129 |
+
# Variance reduction
|
| 130 |
+
beta2 = beta2_t.to(g.dtype)
|
| 131 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 132 |
+
red_dim_size = g.size(red_dim)
|
| 133 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 134 |
+
v_norm = v_norm_sq.sqrt()
|
| 135 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 136 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 137 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 138 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 139 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 140 |
+
g = g * final_scale.to(g.dtype)
|
| 141 |
+
|
| 142 |
+
# Cautious weight decay + parameter update
|
| 143 |
+
lr = lr_t.to(g.dtype)
|
| 144 |
+
wd = wd_t.to(g.dtype)
|
| 145 |
+
mask = (g * stacked_params) >= 0
|
| 146 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 147 |
+
|
| 148 |
+
# -----------------------------------------------------------------------------
|
| 149 |
+
# Single GPU version of the MuonAdamW optimizer.
|
| 150 |
+
# Used mostly for reference, debugging and testing.
|
| 151 |
+
|
| 152 |
+
class MuonAdamW(torch.optim.Optimizer):
|
| 153 |
+
"""
|
| 154 |
+
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
|
| 155 |
+
|
| 156 |
+
AdamW - Fused AdamW optimizer step.
|
| 157 |
+
|
| 158 |
+
Muon - MomentUm Orthogonalized by Newton-schulz
|
| 159 |
+
https://kellerjordan.github.io/posts/muon/
|
| 160 |
+
|
| 161 |
+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
| 162 |
+
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
| 163 |
+
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
| 164 |
+
the advantage that it can be stably run in bfloat16 on the GPU.
|
| 165 |
+
|
| 166 |
+
Some warnings:
|
| 167 |
+
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
| 168 |
+
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
| 169 |
+
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
| 170 |
+
|
| 171 |
+
Arguments:
|
| 172 |
+
param_groups: List of dicts, each containing:
|
| 173 |
+
- 'params': List of parameters
|
| 174 |
+
- 'kind': 'adamw' or 'muon'
|
| 175 |
+
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
| 176 |
+
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
| 177 |
+
"""
|
| 178 |
+
def __init__(self, param_groups: list[dict]):
|
| 179 |
+
super().__init__(param_groups, defaults={})
|
| 180 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 181 |
+
# AdamW tensors
|
| 182 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 183 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 184 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 185 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 186 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 187 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 188 |
+
# Muon tensors
|
| 189 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 190 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 191 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 192 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 193 |
+
|
| 194 |
+
def _step_adamw(self, group: dict) -> None:
|
| 195 |
+
"""
|
| 196 |
+
AdamW update for each param in the group individually.
|
| 197 |
+
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
| 198 |
+
"""
|
| 199 |
+
for p in group['params']:
|
| 200 |
+
if p.grad is None:
|
| 201 |
+
continue
|
| 202 |
+
grad = p.grad
|
| 203 |
+
state = self.state[p]
|
| 204 |
+
|
| 205 |
+
# State init
|
| 206 |
+
if not state:
|
| 207 |
+
state['step'] = 0
|
| 208 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 209 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 210 |
+
exp_avg = state['exp_avg']
|
| 211 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 212 |
+
state['step'] += 1
|
| 213 |
+
|
| 214 |
+
# Fill 0-D tensors with current values
|
| 215 |
+
self._adamw_step_t.fill_(state['step'])
|
| 216 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 217 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 218 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 219 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 220 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 221 |
+
|
| 222 |
+
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
| 223 |
+
adamw_step_fused(
|
| 224 |
+
p, grad, exp_avg, exp_avg_sq,
|
| 225 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 226 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _step_muon(self, group: dict) -> None:
|
| 230 |
+
"""
|
| 231 |
+
Muon update for all params in the group (stacked for efficiency).
|
| 232 |
+
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
| 233 |
+
"""
|
| 234 |
+
params: list[Tensor] = group['params']
|
| 235 |
+
if not params:
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
# Get or create group-level buffers (stored in first param's state for convenience)
|
| 239 |
+
p = params[0]
|
| 240 |
+
state = self.state[p]
|
| 241 |
+
num_params = len(params)
|
| 242 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 243 |
+
|
| 244 |
+
# Momentum for every individual parameter
|
| 245 |
+
if "momentum_buffer" not in state:
|
| 246 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 247 |
+
momentum_buffer = state["momentum_buffer"]
|
| 248 |
+
|
| 249 |
+
# Second momentum buffer is factored, either per-row or per-column
|
| 250 |
+
if "second_momentum_buffer" not in state:
|
| 251 |
+
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
| 252 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 253 |
+
second_momentum_buffer = state["second_momentum_buffer"]
|
| 254 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 255 |
+
|
| 256 |
+
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
| 257 |
+
stacked_grads = torch.stack([p.grad for p in params])
|
| 258 |
+
stacked_params = torch.stack(params)
|
| 259 |
+
|
| 260 |
+
# Fill all the 0-D tensors with current values
|
| 261 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 262 |
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 263 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 264 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 265 |
+
|
| 266 |
+
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
| 267 |
+
muon_step_fused(
|
| 268 |
+
stacked_grads,
|
| 269 |
+
stacked_params,
|
| 270 |
+
momentum_buffer,
|
| 271 |
+
second_momentum_buffer,
|
| 272 |
+
self._muon_momentum_t,
|
| 273 |
+
self._muon_lr_t,
|
| 274 |
+
self._muon_wd_t,
|
| 275 |
+
self._muon_beta2_t,
|
| 276 |
+
group["ns_steps"],
|
| 277 |
+
red_dim,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Copy back to original params
|
| 281 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 282 |
+
|
| 283 |
+
@torch.no_grad()
|
| 284 |
+
def step(self):
|
| 285 |
+
for group in self.param_groups:
|
| 286 |
+
if group['kind'] == 'adamw':
|
| 287 |
+
self._step_adamw(group)
|
| 288 |
+
elif group['kind'] == 'muon':
|
| 289 |
+
self._step_muon(group)
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 292 |
+
|
| 293 |
+
# -----------------------------------------------------------------------------
|
| 294 |
+
# Distributed version of the MuonAdamW optimizer.
|
| 295 |
+
# Used for training on multiple GPUs.
|
| 296 |
+
|
| 297 |
+
class DistMuonAdamW(torch.optim.Optimizer):
|
| 298 |
+
"""
|
| 299 |
+
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
| 300 |
+
|
| 301 |
+
See MuonAdamW for the algorithmic details of each optimizer. This class adds
|
| 302 |
+
distributed communication to enable multi-GPU training without PyTorch DDP.
|
| 303 |
+
|
| 304 |
+
Design Goals:
|
| 305 |
+
- Overlap communication with computation (async ops)
|
| 306 |
+
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
|
| 307 |
+
- Batch small tensors into single comm ops where possible
|
| 308 |
+
|
| 309 |
+
Communication Pattern (3-phase async):
|
| 310 |
+
We use a 3-phase structure to maximize overlap between communication and compute:
|
| 311 |
+
|
| 312 |
+
Phase 1: Launch all async reduce ops
|
| 313 |
+
- Kick off all reduce_scatter/all_reduce operations
|
| 314 |
+
- Don't wait - let them run in background while we continue
|
| 315 |
+
|
| 316 |
+
Phase 2: Wait for reduces, compute updates, launch gathers
|
| 317 |
+
- For each group: wait for its reduce, compute the update, launch gather
|
| 318 |
+
- By processing groups in order, earlier gathers run while later computes happen
|
| 319 |
+
|
| 320 |
+
Phase 3: Wait for gathers, copy back
|
| 321 |
+
- Wait for all gathers to complete
|
| 322 |
+
- Copy updated params back to original tensors (Muon only)
|
| 323 |
+
|
| 324 |
+
AdamW Communication (ZeRO-2 style):
|
| 325 |
+
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
|
| 326 |
+
Optimizer state is replicated but these params are tiny (scalars, biases).
|
| 327 |
+
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
|
| 328 |
+
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
|
| 329 |
+
exp_avg_sq) is sharded - each rank only stores state for its slice.
|
| 330 |
+
Requires param.shape[0] divisible by world_size.
|
| 331 |
+
|
| 332 |
+
Muon Communication (stacked + chunked):
|
| 333 |
+
- All params in a Muon group must have the same shape (caller's responsibility).
|
| 334 |
+
- Stack all K params into a single (K, *shape) tensor for efficient comm.
|
| 335 |
+
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
|
| 336 |
+
- reduce_scatter the stacked grads so each rank gets its chunk.
|
| 337 |
+
- Each rank computes Muon update only for params it owns.
|
| 338 |
+
- all_gather the updated params back to all ranks.
|
| 339 |
+
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
|
| 340 |
+
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
|
| 341 |
+
then ignore the padding when copying back.
|
| 342 |
+
|
| 343 |
+
Buffer Reuse:
|
| 344 |
+
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
|
| 345 |
+
same buffer as the output for all_gather (stacked_params). This saves memory
|
| 346 |
+
since we don't need both buffers simultaneously.
|
| 347 |
+
|
| 348 |
+
Arguments:
|
| 349 |
+
param_groups: List of dicts, each containing:
|
| 350 |
+
- 'params': List of parameters
|
| 351 |
+
- 'kind': 'adamw' or 'muon'
|
| 352 |
+
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
| 353 |
+
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, param_groups: list[dict]):
|
| 356 |
+
super().__init__(param_groups, defaults={})
|
| 357 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 358 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 359 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 360 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 361 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 362 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 363 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 364 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 365 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 366 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 367 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 368 |
+
|
| 369 |
+
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
| 370 |
+
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
| 371 |
+
param_infos = {}
|
| 372 |
+
for p in group['params']:
|
| 373 |
+
grad = p.grad
|
| 374 |
+
if p.numel() < 1024:
|
| 375 |
+
# Small params: all_reduce (no scatter/gather needed)
|
| 376 |
+
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 377 |
+
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
| 378 |
+
else:
|
| 379 |
+
# Large params: reduce_scatter
|
| 380 |
+
assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
|
| 381 |
+
rank_size = grad.shape[0] // world_size
|
| 382 |
+
grad_slice = torch.empty_like(grad[:rank_size])
|
| 383 |
+
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 384 |
+
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
|
| 385 |
+
return dict(param_infos=param_infos)
|
| 386 |
+
|
| 387 |
+
def _reduce_muon(self, group: dict, world_size: int) -> dict:
|
| 388 |
+
"""Launch async reduce op for Muon group. Returns info dict."""
|
| 389 |
+
params = group['params']
|
| 390 |
+
chunk_size = (len(params) + world_size - 1) // world_size
|
| 391 |
+
padded_num_params = chunk_size * world_size
|
| 392 |
+
p = params[0]
|
| 393 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 394 |
+
|
| 395 |
+
# Stack grads and zero-pad to padded_num_params
|
| 396 |
+
grad_stack = torch.stack([p.grad for p in params])
|
| 397 |
+
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
| 398 |
+
stacked_grads[:len(params)].copy_(grad_stack)
|
| 399 |
+
if len(params) < padded_num_params:
|
| 400 |
+
stacked_grads[len(params):].zero_()
|
| 401 |
+
|
| 402 |
+
# Reduce_scatter to get this rank's chunk
|
| 403 |
+
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 404 |
+
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 405 |
+
|
| 406 |
+
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
| 407 |
+
|
| 408 |
+
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
| 409 |
+
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
| 410 |
+
param_infos = info['param_infos']
|
| 411 |
+
for p in group['params']:
|
| 412 |
+
pinfo = param_infos[p]
|
| 413 |
+
pinfo['future'].wait()
|
| 414 |
+
grad_slice = pinfo['grad_slice']
|
| 415 |
+
state = self.state[p]
|
| 416 |
+
|
| 417 |
+
# For small params, operate on full param; for large, operate on slice
|
| 418 |
+
if pinfo['is_small']:
|
| 419 |
+
p_slice = p
|
| 420 |
+
else:
|
| 421 |
+
rank_size = p.shape[0] // world_size
|
| 422 |
+
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
| 423 |
+
|
| 424 |
+
# State init
|
| 425 |
+
if not state:
|
| 426 |
+
state['step'] = 0
|
| 427 |
+
state['exp_avg'] = torch.zeros_like(p_slice)
|
| 428 |
+
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
| 429 |
+
state['step'] += 1
|
| 430 |
+
|
| 431 |
+
# Fill 0-D tensors and run fused kernel
|
| 432 |
+
self._adamw_step_t.fill_(state['step'])
|
| 433 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 434 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 435 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 436 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 437 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 438 |
+
adamw_step_fused(
|
| 439 |
+
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
| 440 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 441 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Large params need all_gather
|
| 445 |
+
if not pinfo['is_small']:
|
| 446 |
+
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
| 447 |
+
gather_list.append(dict(future=future, params=None))
|
| 448 |
+
|
| 449 |
+
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
| 450 |
+
"""Wait for reduce, compute Muon updates, launch gather."""
|
| 451 |
+
info['future'].wait()
|
| 452 |
+
params = group['params']
|
| 453 |
+
chunk_size = info['chunk_size']
|
| 454 |
+
grad_chunk = info['grad_chunk']
|
| 455 |
+
p = params[0]
|
| 456 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 457 |
+
|
| 458 |
+
# How many params does this rank own?
|
| 459 |
+
start_idx = rank * chunk_size
|
| 460 |
+
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
| 461 |
+
|
| 462 |
+
# Get or create group-level state
|
| 463 |
+
state = self.state[p]
|
| 464 |
+
if "momentum_buffer" not in state:
|
| 465 |
+
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
| 466 |
+
if "second_momentum_buffer" not in state:
|
| 467 |
+
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
| 468 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 469 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 470 |
+
|
| 471 |
+
# Build output buffer for all_gather
|
| 472 |
+
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 473 |
+
|
| 474 |
+
if num_owned > 0:
|
| 475 |
+
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
| 476 |
+
stacked_owned = torch.stack(owned_params)
|
| 477 |
+
|
| 478 |
+
# Fill 0-D tensors and run fused kernel
|
| 479 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 480 |
+
self._muon_beta2_t.fill_(group["beta2"])
|
| 481 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 482 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 483 |
+
muon_step_fused(
|
| 484 |
+
grad_chunk[:num_owned], stacked_owned,
|
| 485 |
+
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
| 486 |
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
| 487 |
+
group["ns_steps"], red_dim,
|
| 488 |
+
)
|
| 489 |
+
updated_params[:num_owned].copy_(stacked_owned)
|
| 490 |
+
|
| 491 |
+
if num_owned < chunk_size:
|
| 492 |
+
updated_params[num_owned:].zero_()
|
| 493 |
+
|
| 494 |
+
# Reuse stacked_grads buffer for all_gather output
|
| 495 |
+
stacked_params = info["stacked_grads"]
|
| 496 |
+
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
| 497 |
+
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
| 498 |
+
|
| 499 |
+
def _finish_gathers(self, gather_list: list) -> None:
|
| 500 |
+
"""Wait for all gathers and copy Muon params back."""
|
| 501 |
+
for info in gather_list:
|
| 502 |
+
info["future"].wait()
|
| 503 |
+
if info["params"] is not None:
|
| 504 |
+
# Muon: copy from stacked buffer back to individual params
|
| 505 |
+
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
|
| 506 |
+
|
| 507 |
+
@torch.no_grad()
|
| 508 |
+
def step(self):
|
| 509 |
+
rank = dist.get_rank()
|
| 510 |
+
world_size = dist.get_world_size()
|
| 511 |
+
|
| 512 |
+
# Phase 1: launch all async reduce ops
|
| 513 |
+
reduce_infos: list[dict] = []
|
| 514 |
+
for group in self.param_groups:
|
| 515 |
+
if group['kind'] == 'adamw':
|
| 516 |
+
reduce_infos.append(self._reduce_adamw(group, world_size))
|
| 517 |
+
elif group['kind'] == 'muon':
|
| 518 |
+
reduce_infos.append(self._reduce_muon(group, world_size))
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 521 |
+
|
| 522 |
+
# Phase 2: wait for reduces, compute updates, launch gathers
|
| 523 |
+
gather_list: list[dict] = []
|
| 524 |
+
for group, info in zip(self.param_groups, reduce_infos):
|
| 525 |
+
if group['kind'] == 'adamw':
|
| 526 |
+
self._compute_adamw(group, info, gather_list, rank, world_size)
|
| 527 |
+
elif group['kind'] == 'muon':
|
| 528 |
+
self._compute_muon(group, info, gather_list, rank)
|
| 529 |
+
else:
|
| 530 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 531 |
+
|
| 532 |
+
# Phase 3: wait for gathers, copy back
|
| 533 |
+
self._finish_gathers(gather_list)
|
nanochat/report.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for generating training report cards. More messy code than usual, will fix.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import subprocess
|
| 9 |
+
import socket
|
| 10 |
+
import datetime
|
| 11 |
+
import platform
|
| 12 |
+
import psutil
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
def run_command(cmd):
|
| 16 |
+
"""Run a shell command and return output, or None if it fails."""
|
| 17 |
+
try:
|
| 18 |
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
| 19 |
+
# Return stdout if we got output (even if some files in xargs failed)
|
| 20 |
+
if result.stdout.strip():
|
| 21 |
+
return result.stdout.strip()
|
| 22 |
+
if result.returncode == 0:
|
| 23 |
+
return ""
|
| 24 |
+
return None
|
| 25 |
+
except:
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def get_git_info():
|
| 29 |
+
"""Get current git commit, branch, and dirty status."""
|
| 30 |
+
info = {}
|
| 31 |
+
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
|
| 32 |
+
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
|
| 33 |
+
|
| 34 |
+
# Check if repo is dirty (has uncommitted changes)
|
| 35 |
+
status = run_command("git status --porcelain")
|
| 36 |
+
info['dirty'] = bool(status) if status is not None else False
|
| 37 |
+
|
| 38 |
+
# Get commit message
|
| 39 |
+
info['message'] = run_command("git log -1 --pretty=%B") or ""
|
| 40 |
+
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
|
| 41 |
+
|
| 42 |
+
return info
|
| 43 |
+
|
| 44 |
+
def get_gpu_info():
|
| 45 |
+
"""Get GPU information."""
|
| 46 |
+
if not torch.cuda.is_available():
|
| 47 |
+
return {"available": False}
|
| 48 |
+
|
| 49 |
+
num_devices = torch.cuda.device_count()
|
| 50 |
+
info = {
|
| 51 |
+
"available": True,
|
| 52 |
+
"count": num_devices,
|
| 53 |
+
"names": [],
|
| 54 |
+
"memory_gb": []
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
for i in range(num_devices):
|
| 58 |
+
props = torch.cuda.get_device_properties(i)
|
| 59 |
+
info["names"].append(props.name)
|
| 60 |
+
info["memory_gb"].append(props.total_memory / (1024**3))
|
| 61 |
+
|
| 62 |
+
# Get CUDA version
|
| 63 |
+
info["cuda_version"] = torch.version.cuda or "unknown"
|
| 64 |
+
|
| 65 |
+
return info
|
| 66 |
+
|
| 67 |
+
def get_system_info():
|
| 68 |
+
"""Get system information."""
|
| 69 |
+
info = {}
|
| 70 |
+
|
| 71 |
+
# Basic system info
|
| 72 |
+
info['hostname'] = socket.gethostname()
|
| 73 |
+
info['platform'] = platform.system()
|
| 74 |
+
info['python_version'] = platform.python_version()
|
| 75 |
+
info['torch_version'] = torch.__version__
|
| 76 |
+
|
| 77 |
+
# CPU and memory
|
| 78 |
+
info['cpu_count'] = psutil.cpu_count(logical=False)
|
| 79 |
+
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
|
| 80 |
+
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
|
| 81 |
+
|
| 82 |
+
# User and environment
|
| 83 |
+
info['user'] = os.environ.get('USER', 'unknown')
|
| 84 |
+
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
|
| 85 |
+
info['working_dir'] = os.getcwd()
|
| 86 |
+
|
| 87 |
+
return info
|
| 88 |
+
|
| 89 |
+
def estimate_cost(gpu_info, runtime_hours=None):
|
| 90 |
+
"""Estimate training cost based on GPU type and runtime."""
|
| 91 |
+
|
| 92 |
+
# Rough pricing, from Lambda Cloud
|
| 93 |
+
default_rate = 2.0
|
| 94 |
+
gpu_hourly_rates = {
|
| 95 |
+
"H100": 3.00,
|
| 96 |
+
"A100": 1.79,
|
| 97 |
+
"V100": 0.55,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if not gpu_info.get("available"):
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Try to identify GPU type from name
|
| 104 |
+
hourly_rate = None
|
| 105 |
+
gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
|
| 106 |
+
for gpu_type, rate in gpu_hourly_rates.items():
|
| 107 |
+
if gpu_type in gpu_name:
|
| 108 |
+
hourly_rate = rate * gpu_info["count"]
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if hourly_rate is None:
|
| 112 |
+
hourly_rate = default_rate * gpu_info["count"] # Default estimate
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"hourly_rate": hourly_rate,
|
| 116 |
+
"gpu_type": gpu_name,
|
| 117 |
+
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def generate_header():
|
| 121 |
+
"""Generate the header for a training report."""
|
| 122 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 123 |
+
|
| 124 |
+
git_info = get_git_info()
|
| 125 |
+
gpu_info = get_gpu_info()
|
| 126 |
+
sys_info = get_system_info()
|
| 127 |
+
cost_info = estimate_cost(gpu_info)
|
| 128 |
+
|
| 129 |
+
header = f"""# nanochat training report
|
| 130 |
+
|
| 131 |
+
Generated: {timestamp}
|
| 132 |
+
|
| 133 |
+
## Environment
|
| 134 |
+
|
| 135 |
+
### Git Information
|
| 136 |
+
- Branch: {git_info['branch']}
|
| 137 |
+
- Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
|
| 138 |
+
- Message: {git_info['message']}
|
| 139 |
+
|
| 140 |
+
### Hardware
|
| 141 |
+
- Platform: {sys_info['platform']}
|
| 142 |
+
- CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
|
| 143 |
+
- Memory: {sys_info['memory_gb']:.1f} GB
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
if gpu_info.get("available"):
|
| 147 |
+
gpu_names = ", ".join(set(gpu_info["names"]))
|
| 148 |
+
total_vram = sum(gpu_info["memory_gb"])
|
| 149 |
+
header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
|
| 150 |
+
- GPU Memory: {total_vram:.1f} GB total
|
| 151 |
+
- CUDA Version: {gpu_info['cuda_version']}
|
| 152 |
+
"""
|
| 153 |
+
else:
|
| 154 |
+
header += "- GPUs: None available\n"
|
| 155 |
+
|
| 156 |
+
if cost_info and cost_info["hourly_rate"] > 0:
|
| 157 |
+
header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
|
| 158 |
+
|
| 159 |
+
header += f"""
|
| 160 |
+
### Software
|
| 161 |
+
- Python: {sys_info['python_version']}
|
| 162 |
+
- PyTorch: {sys_info['torch_version']}
|
| 163 |
+
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# bloat metrics: count lines/chars in git-tracked source files only
|
| 167 |
+
extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
|
| 168 |
+
git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
|
| 169 |
+
files_output = run_command(f"git ls-files -- {git_patterns}")
|
| 170 |
+
file_list = [f for f in (files_output or '').split('\n') if f]
|
| 171 |
+
num_files = len(file_list)
|
| 172 |
+
num_lines = 0
|
| 173 |
+
num_chars = 0
|
| 174 |
+
if num_files > 0:
|
| 175 |
+
wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
|
| 176 |
+
if wc_output:
|
| 177 |
+
total_line = wc_output.strip().split('\n')[-1]
|
| 178 |
+
parts = total_line.split()
|
| 179 |
+
if len(parts) >= 2:
|
| 180 |
+
num_lines = int(parts[0])
|
| 181 |
+
num_chars = int(parts[1])
|
| 182 |
+
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
| 183 |
+
|
| 184 |
+
# count dependencies via uv.lock
|
| 185 |
+
uv_lock_lines = 0
|
| 186 |
+
if os.path.exists('uv.lock'):
|
| 187 |
+
with open('uv.lock', 'r', encoding='utf-8') as f:
|
| 188 |
+
uv_lock_lines = len(f.readlines())
|
| 189 |
+
|
| 190 |
+
header += f"""
|
| 191 |
+
### Bloat
|
| 192 |
+
- Characters: {num_chars:,}
|
| 193 |
+
- Lines: {num_lines:,}
|
| 194 |
+
- Files: {num_files:,}
|
| 195 |
+
- Tokens (approx): {num_tokens:,}
|
| 196 |
+
- Dependencies (uv.lock lines): {uv_lock_lines:,}
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
return header
|
| 200 |
+
|
| 201 |
+
# -----------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def slugify(text):
|
| 204 |
+
"""Slugify a text string."""
|
| 205 |
+
return text.lower().replace(" ", "-")
|
| 206 |
+
|
| 207 |
+
# the expected files and their order
|
| 208 |
+
EXPECTED_FILES = [
|
| 209 |
+
"tokenizer-training.md",
|
| 210 |
+
"tokenizer-evaluation.md",
|
| 211 |
+
"base-model-training.md",
|
| 212 |
+
"base-model-loss.md",
|
| 213 |
+
"base-model-evaluation.md",
|
| 214 |
+
"chat-sft.md",
|
| 215 |
+
"chat-evaluation-sft.md",
|
| 216 |
+
"chat-rl.md",
|
| 217 |
+
"chat-evaluation-rl.md",
|
| 218 |
+
]
|
| 219 |
+
# the metrics we're currently interested in
|
| 220 |
+
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
|
| 221 |
+
|
| 222 |
+
def extract(section, keys):
|
| 223 |
+
"""simple def to extract a single key from a section"""
|
| 224 |
+
if not isinstance(keys, list):
|
| 225 |
+
keys = [keys] # convenience
|
| 226 |
+
out = {}
|
| 227 |
+
for line in section.split("\n"):
|
| 228 |
+
for key in keys:
|
| 229 |
+
if key in line:
|
| 230 |
+
out[key] = line.split(":")[1].strip()
|
| 231 |
+
return out
|
| 232 |
+
|
| 233 |
+
def extract_timestamp(content, prefix):
|
| 234 |
+
"""Extract timestamp from content with given prefix."""
|
| 235 |
+
for line in content.split('\n'):
|
| 236 |
+
if line.startswith(prefix):
|
| 237 |
+
time_str = line.split(":", 1)[1].strip()
|
| 238 |
+
try:
|
| 239 |
+
return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
| 240 |
+
except:
|
| 241 |
+
pass
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
class Report:
|
| 245 |
+
"""Maintains a bunch of logs, generates a final markdown report."""
|
| 246 |
+
|
| 247 |
+
def __init__(self, report_dir):
|
| 248 |
+
os.makedirs(report_dir, exist_ok=True)
|
| 249 |
+
self.report_dir = report_dir
|
| 250 |
+
|
| 251 |
+
def log(self, section, data):
|
| 252 |
+
"""Log a section of data to the report."""
|
| 253 |
+
slug = slugify(section)
|
| 254 |
+
file_name = f"{slug}.md"
|
| 255 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 256 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 257 |
+
f.write(f"## {section}\n")
|
| 258 |
+
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 259 |
+
for item in data:
|
| 260 |
+
if not item:
|
| 261 |
+
# skip falsy values like None or empty dict etc.
|
| 262 |
+
continue
|
| 263 |
+
if isinstance(item, str):
|
| 264 |
+
# directly write the string
|
| 265 |
+
f.write(item)
|
| 266 |
+
else:
|
| 267 |
+
# render a dict
|
| 268 |
+
for k, v in item.items():
|
| 269 |
+
if isinstance(v, float):
|
| 270 |
+
vstr = f"{v:.4f}"
|
| 271 |
+
elif isinstance(v, int) and v >= 10000:
|
| 272 |
+
vstr = f"{v:,.0f}"
|
| 273 |
+
else:
|
| 274 |
+
vstr = str(v)
|
| 275 |
+
f.write(f"- {k}: {vstr}\n")
|
| 276 |
+
f.write("\n")
|
| 277 |
+
return file_path
|
| 278 |
+
|
| 279 |
+
def generate(self):
|
| 280 |
+
"""Generate the final report."""
|
| 281 |
+
report_dir = self.report_dir
|
| 282 |
+
report_file = os.path.join(report_dir, "report.md")
|
| 283 |
+
print(f"Generating report to {report_file}")
|
| 284 |
+
final_metrics = {} # the most important final metrics we'll add as table at the end
|
| 285 |
+
start_time = None
|
| 286 |
+
end_time = None
|
| 287 |
+
with open(report_file, "w", encoding="utf-8") as out_file:
|
| 288 |
+
# write the header first
|
| 289 |
+
header_file = os.path.join(report_dir, "header.md")
|
| 290 |
+
if os.path.exists(header_file):
|
| 291 |
+
with open(header_file, "r", encoding="utf-8") as f:
|
| 292 |
+
header_content = f.read()
|
| 293 |
+
out_file.write(header_content)
|
| 294 |
+
start_time = extract_timestamp(header_content, "Run started:")
|
| 295 |
+
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
| 296 |
+
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
| 297 |
+
bloat_data = bloat_data.group(1) if bloat_data else ""
|
| 298 |
+
else:
|
| 299 |
+
start_time = None # will cause us to not write the total wall clock time
|
| 300 |
+
bloat_data = "[bloat data missing]"
|
| 301 |
+
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
| 302 |
+
# process all the individual sections
|
| 303 |
+
for file_name in EXPECTED_FILES:
|
| 304 |
+
section_file = os.path.join(report_dir, file_name)
|
| 305 |
+
if not os.path.exists(section_file):
|
| 306 |
+
print(f"Warning: {section_file} does not exist, skipping")
|
| 307 |
+
continue
|
| 308 |
+
with open(section_file, "r", encoding="utf-8") as in_file:
|
| 309 |
+
section = in_file.read()
|
| 310 |
+
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
| 311 |
+
if "rl" not in file_name:
|
| 312 |
+
# Skip RL sections for end_time calculation because RL is experimental
|
| 313 |
+
end_time = extract_timestamp(section, "timestamp:")
|
| 314 |
+
# extract the most important metrics from the sections
|
| 315 |
+
if file_name == "base-model-evaluation.md":
|
| 316 |
+
final_metrics["base"] = extract(section, "CORE")
|
| 317 |
+
if file_name == "chat-evaluation-sft.md":
|
| 318 |
+
final_metrics["sft"] = extract(section, chat_metrics)
|
| 319 |
+
if file_name == "chat-evaluation-rl.md":
|
| 320 |
+
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
| 321 |
+
# append this section of the report
|
| 322 |
+
out_file.write(section)
|
| 323 |
+
out_file.write("\n")
|
| 324 |
+
# add the final metrics table
|
| 325 |
+
out_file.write("## Summary\n\n")
|
| 326 |
+
# Copy over the bloat metrics from the header
|
| 327 |
+
out_file.write(bloat_data)
|
| 328 |
+
out_file.write("\n\n")
|
| 329 |
+
# Collect all unique metric names
|
| 330 |
+
all_metrics = set()
|
| 331 |
+
for stage_metrics in final_metrics.values():
|
| 332 |
+
all_metrics.update(stage_metrics.keys())
|
| 333 |
+
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
| 334 |
+
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
| 335 |
+
# Fixed column widths
|
| 336 |
+
stages = ["base", "sft", "rl"]
|
| 337 |
+
metric_width = 15
|
| 338 |
+
value_width = 8
|
| 339 |
+
# Write table header
|
| 340 |
+
header = f"| {'Metric'.ljust(metric_width)} |"
|
| 341 |
+
for stage in stages:
|
| 342 |
+
header += f" {stage.upper().ljust(value_width)} |"
|
| 343 |
+
out_file.write(header + "\n")
|
| 344 |
+
# Write separator
|
| 345 |
+
separator = f"|{'-' * (metric_width + 2)}|"
|
| 346 |
+
for stage in stages:
|
| 347 |
+
separator += f"{'-' * (value_width + 2)}|"
|
| 348 |
+
out_file.write(separator + "\n")
|
| 349 |
+
# Write table rows
|
| 350 |
+
for metric in all_metrics:
|
| 351 |
+
row = f"| {metric.ljust(metric_width)} |"
|
| 352 |
+
for stage in stages:
|
| 353 |
+
value = final_metrics.get(stage, {}).get(metric, "-")
|
| 354 |
+
row += f" {str(value).ljust(value_width)} |"
|
| 355 |
+
out_file.write(row + "\n")
|
| 356 |
+
out_file.write("\n")
|
| 357 |
+
# Calculate and write total wall clock time
|
| 358 |
+
if start_time and end_time:
|
| 359 |
+
duration = end_time - start_time
|
| 360 |
+
total_seconds = int(duration.total_seconds())
|
| 361 |
+
hours = total_seconds // 3600
|
| 362 |
+
minutes = (total_seconds % 3600) // 60
|
| 363 |
+
out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
|
| 364 |
+
else:
|
| 365 |
+
out_file.write("Total wall clock time: unknown\n")
|
| 366 |
+
# also cp the report.md file to current directory
|
| 367 |
+
print(f"Copying report.md to current directory for convenience")
|
| 368 |
+
shutil.copy(report_file, "report.md")
|
| 369 |
+
return report_file
|
| 370 |
+
|
| 371 |
+
def reset(self):
|
| 372 |
+
"""Reset the report."""
|
| 373 |
+
# Remove section files
|
| 374 |
+
for file_name in EXPECTED_FILES:
|
| 375 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 376 |
+
if os.path.exists(file_path):
|
| 377 |
+
os.remove(file_path)
|
| 378 |
+
# Remove report.md if it exists
|
| 379 |
+
report_file = os.path.join(self.report_dir, "report.md")
|
| 380 |
+
if os.path.exists(report_file):
|
| 381 |
+
os.remove(report_file)
|
| 382 |
+
# Generate and write the header section with start timestamp
|
| 383 |
+
header_file = os.path.join(self.report_dir, "header.md")
|
| 384 |
+
header = generate_header()
|
| 385 |
+
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 386 |
+
with open(header_file, "w", encoding="utf-8") as f:
|
| 387 |
+
f.write(header)
|
| 388 |
+
f.write(f"Run started: {start_time}\n\n---\n\n")
|
| 389 |
+
print(f"Reset report and wrote header to {header_file}")
|
| 390 |
+
|
| 391 |
+
# -----------------------------------------------------------------------------
|
| 392 |
+
# nanochat-specific convenience functions
|
| 393 |
+
|
| 394 |
+
class DummyReport:
|
| 395 |
+
def log(self, *args, **kwargs):
|
| 396 |
+
pass
|
| 397 |
+
def reset(self, *args, **kwargs):
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
def get_report():
|
| 401 |
+
# just for convenience, only rank 0 logs to report
|
| 402 |
+
from nanochat.common import get_base_dir, get_dist_info
|
| 403 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 404 |
+
if ddp_rank == 0:
|
| 405 |
+
report_dir = os.path.join(get_base_dir(), "report")
|
| 406 |
+
return Report(report_dir)
|
| 407 |
+
else:
|
| 408 |
+
return DummyReport()
|
| 409 |
+
|
| 410 |
+
if __name__ == "__main__":
|
| 411 |
+
import argparse
|
| 412 |
+
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
|
| 413 |
+
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
|
| 414 |
+
args = parser.parse_args()
|
| 415 |
+
if args.command == "generate":
|
| 416 |
+
get_report().generate()
|
| 417 |
+
elif args.command == "reset":
|
| 418 |
+
get_report().reset()
|
nanochat/tokenizer.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BPE Tokenizer in the style of GPT-4.
|
| 3 |
+
|
| 4 |
+
Two implementations are available:
|
| 5 |
+
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
|
| 6 |
+
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import copy
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
|
| 13 |
+
SPECIAL_TOKENS = [
|
| 14 |
+
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
| 15 |
+
"<|bos|>",
|
| 16 |
+
# tokens below are only used during finetuning to render Conversations into token ids
|
| 17 |
+
"<|user_start|>", # user messages
|
| 18 |
+
"<|user_end|>",
|
| 19 |
+
"<|assistant_start|>", # assistant messages
|
| 20 |
+
"<|assistant_end|>",
|
| 21 |
+
"<|python_start|>", # assistant invokes python REPL tool
|
| 22 |
+
"<|python_end|>",
|
| 23 |
+
"<|output_start|>", # python REPL outputs back to assistant
|
| 24 |
+
"<|output_end|>",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
| 28 |
+
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
| 29 |
+
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
| 30 |
+
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
|
| 34 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 35 |
+
from tokenizers import pre_tokenizers, decoders, Regex
|
| 36 |
+
from tokenizers.models import BPE
|
| 37 |
+
from tokenizers.trainers import BpeTrainer
|
| 38 |
+
|
| 39 |
+
class HuggingFaceTokenizer:
|
| 40 |
+
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, tokenizer):
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(cls, hf_path):
|
| 47 |
+
# init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
|
| 48 |
+
tokenizer = HFTokenizer.from_pretrained(hf_path)
|
| 49 |
+
return cls(tokenizer)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_directory(cls, tokenizer_dir):
|
| 53 |
+
# init from a local directory on disk (e.g. "out/tokenizer")
|
| 54 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 55 |
+
tokenizer = HFTokenizer.from_file(tokenizer_path)
|
| 56 |
+
return cls(tokenizer)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 60 |
+
# train from an iterator of text
|
| 61 |
+
# Configure the HuggingFace Tokenizer
|
| 62 |
+
tokenizer = HFTokenizer(BPE(
|
| 63 |
+
byte_fallback=True, # needed!
|
| 64 |
+
unk_token=None,
|
| 65 |
+
fuse_unk=False,
|
| 66 |
+
))
|
| 67 |
+
# Normalizer: None
|
| 68 |
+
tokenizer.normalizer = None
|
| 69 |
+
# Pre-tokenizer: GPT-4 style
|
| 70 |
+
# the regex pattern used by GPT-4 to split text into groups before BPE
|
| 71 |
+
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
|
| 72 |
+
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
|
| 73 |
+
# (but I haven't validated this! TODO)
|
| 74 |
+
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
| 75 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
| 76 |
+
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
| 77 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
| 78 |
+
])
|
| 79 |
+
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
| 80 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 81 |
+
# Post-processor: None
|
| 82 |
+
tokenizer.post_processor = None
|
| 83 |
+
# Trainer: BPE
|
| 84 |
+
trainer = BpeTrainer(
|
| 85 |
+
vocab_size=vocab_size,
|
| 86 |
+
show_progress=True,
|
| 87 |
+
min_frequency=0, # no minimum frequency
|
| 88 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 89 |
+
special_tokens=SPECIAL_TOKENS,
|
| 90 |
+
)
|
| 91 |
+
# Kick off the training
|
| 92 |
+
tokenizer.train_from_iterator(text_iterator, trainer)
|
| 93 |
+
return cls(tokenizer)
|
| 94 |
+
|
| 95 |
+
def get_vocab_size(self):
|
| 96 |
+
return self.tokenizer.get_vocab_size()
|
| 97 |
+
|
| 98 |
+
def get_special_tokens(self):
|
| 99 |
+
special_tokens_map = self.tokenizer.get_added_tokens_decoder()
|
| 100 |
+
special_tokens = [w.content for w in special_tokens_map.values()]
|
| 101 |
+
return special_tokens
|
| 102 |
+
|
| 103 |
+
def id_to_token(self, id):
|
| 104 |
+
return self.tokenizer.id_to_token(id)
|
| 105 |
+
|
| 106 |
+
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
| 107 |
+
# encode a single string
|
| 108 |
+
# prepend/append can be either a string of a special token or a token id directly.
|
| 109 |
+
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
| 110 |
+
assert isinstance(text, str)
|
| 111 |
+
ids = []
|
| 112 |
+
if prepend is not None:
|
| 113 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 114 |
+
ids.append(prepend_id)
|
| 115 |
+
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
|
| 116 |
+
if append is not None:
|
| 117 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 118 |
+
ids.append(append_id)
|
| 119 |
+
return ids
|
| 120 |
+
|
| 121 |
+
def encode_special(self, text):
|
| 122 |
+
# encode a single special token via exact match
|
| 123 |
+
return self.tokenizer.token_to_id(text)
|
| 124 |
+
|
| 125 |
+
def get_bos_token_id(self):
|
| 126 |
+
# Different HuggingFace models use different BOS tokens and there is little consistency
|
| 127 |
+
# 1) attempt to find a <|bos|> token
|
| 128 |
+
bos = self.encode_special("<|bos|>")
|
| 129 |
+
# 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
|
| 130 |
+
if bos is None:
|
| 131 |
+
bos = self.encode_special("<|endoftext|>")
|
| 132 |
+
# 3) if these fail, it's better to crash than to silently return None
|
| 133 |
+
assert bos is not None, "Failed to find BOS token in tokenizer"
|
| 134 |
+
return bos
|
| 135 |
+
|
| 136 |
+
def encode(self, text, *args, **kwargs):
|
| 137 |
+
if isinstance(text, str):
|
| 138 |
+
return self._encode_one(text, *args, **kwargs)
|
| 139 |
+
elif isinstance(text, list):
|
| 140 |
+
return [self._encode_one(t, *args, **kwargs) for t in text]
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 143 |
+
|
| 144 |
+
def __call__(self, *args, **kwargs):
|
| 145 |
+
return self.encode(*args, **kwargs)
|
| 146 |
+
|
| 147 |
+
def decode(self, ids):
|
| 148 |
+
return self.tokenizer.decode(ids, skip_special_tokens=False)
|
| 149 |
+
|
| 150 |
+
def save(self, tokenizer_dir):
|
| 151 |
+
# save the tokenizer to disk
|
| 152 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 153 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 154 |
+
self.tokenizer.save(tokenizer_path)
|
| 155 |
+
print(f"Saved tokenizer to {tokenizer_path}")
|
| 156 |
+
|
| 157 |
+
# -----------------------------------------------------------------------------
|
| 158 |
+
# Tokenizer based on rustbpe + tiktoken combo
|
| 159 |
+
import pickle
|
| 160 |
+
import rustbpe
|
| 161 |
+
import tiktoken
|
| 162 |
+
|
| 163 |
+
class RustBPETokenizer:
|
| 164 |
+
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, enc, bos_token):
|
| 167 |
+
self.enc = enc
|
| 168 |
+
self.bos_token_id = self.encode_special(bos_token)
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 172 |
+
# 1) train using rustbpe
|
| 173 |
+
tokenizer = rustbpe.Tokenizer()
|
| 174 |
+
# the special tokens are inserted later in __init__, we don't train them here
|
| 175 |
+
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
|
| 176 |
+
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
|
| 177 |
+
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 178 |
+
# 2) construct the associated tiktoken encoding for inference
|
| 179 |
+
pattern = tokenizer.get_pattern()
|
| 180 |
+
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
|
| 181 |
+
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
| 182 |
+
tokens_offset = len(mergeable_ranks)
|
| 183 |
+
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
| 184 |
+
enc = tiktoken.Encoding(
|
| 185 |
+
name="rustbpe",
|
| 186 |
+
pat_str=pattern,
|
| 187 |
+
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
|
| 188 |
+
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
|
| 189 |
+
)
|
| 190 |
+
return cls(enc, "<|bos|>")
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def from_directory(cls, tokenizer_dir):
|
| 194 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 195 |
+
with open(pickle_path, "rb") as f:
|
| 196 |
+
enc = pickle.load(f)
|
| 197 |
+
return cls(enc, "<|bos|>")
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def from_pretrained(cls, tiktoken_name):
|
| 201 |
+
# https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
|
| 202 |
+
enc = tiktoken.get_encoding(tiktoken_name)
|
| 203 |
+
# tiktoken calls the special document delimiter token "<|endoftext|>"
|
| 204 |
+
# yes this is confusing because this token is almost always PREPENDED to the beginning of the document
|
| 205 |
+
# it most often is used to signal the start of a new sequence to the LLM during inference etc.
|
| 206 |
+
# so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
|
| 207 |
+
return cls(enc, "<|endoftext|>")
|
| 208 |
+
|
| 209 |
+
def get_vocab_size(self):
|
| 210 |
+
return self.enc.n_vocab
|
| 211 |
+
|
| 212 |
+
def get_special_tokens(self):
|
| 213 |
+
return self.enc.special_tokens_set
|
| 214 |
+
|
| 215 |
+
def id_to_token(self, id):
|
| 216 |
+
return self.enc.decode([id])
|
| 217 |
+
|
| 218 |
+
@lru_cache(maxsize=32)
|
| 219 |
+
def encode_special(self, text):
|
| 220 |
+
return self.enc.encode_single_token(text)
|
| 221 |
+
|
| 222 |
+
def get_bos_token_id(self):
|
| 223 |
+
return self.bos_token_id
|
| 224 |
+
|
| 225 |
+
def encode(self, text, prepend=None, append=None, num_threads=8):
|
| 226 |
+
# text can be either a string or a list of strings
|
| 227 |
+
|
| 228 |
+
if prepend is not None:
|
| 229 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 230 |
+
if append is not None:
|
| 231 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 232 |
+
|
| 233 |
+
if isinstance(text, str):
|
| 234 |
+
ids = self.enc.encode_ordinary(text)
|
| 235 |
+
if prepend is not None:
|
| 236 |
+
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
| 237 |
+
if append is not None:
|
| 238 |
+
ids.append(append_id)
|
| 239 |
+
elif isinstance(text, list):
|
| 240 |
+
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
| 241 |
+
if prepend is not None:
|
| 242 |
+
for ids_row in ids:
|
| 243 |
+
ids_row.insert(0, prepend_id) # TODO: same
|
| 244 |
+
if append is not None:
|
| 245 |
+
for ids_row in ids:
|
| 246 |
+
ids_row.append(append_id)
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 249 |
+
|
| 250 |
+
return ids
|
| 251 |
+
|
| 252 |
+
def __call__(self, *args, **kwargs):
|
| 253 |
+
return self.encode(*args, **kwargs)
|
| 254 |
+
|
| 255 |
+
def decode(self, ids):
|
| 256 |
+
return self.enc.decode(ids)
|
| 257 |
+
|
| 258 |
+
def save(self, tokenizer_dir):
|
| 259 |
+
# save the encoding object to disk
|
| 260 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 261 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 262 |
+
with open(pickle_path, "wb") as f:
|
| 263 |
+
pickle.dump(self.enc, f)
|
| 264 |
+
print(f"Saved tokenizer encoding to {pickle_path}")
|
| 265 |
+
|
| 266 |
+
def render_conversation(self, conversation, max_tokens=2048):
|
| 267 |
+
"""
|
| 268 |
+
Tokenize a single Chat conversation (which we call a "doc" or "document" here).
|
| 269 |
+
Returns:
|
| 270 |
+
- ids: list[int] is a list of token ids of this rendered conversation
|
| 271 |
+
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
|
| 272 |
+
"""
|
| 273 |
+
# ids, masks that we will return and a helper function to help build them up.
|
| 274 |
+
ids, mask = [], []
|
| 275 |
+
def add_tokens(token_ids, mask_val):
|
| 276 |
+
if isinstance(token_ids, int):
|
| 277 |
+
token_ids = [token_ids]
|
| 278 |
+
ids.extend(token_ids)
|
| 279 |
+
mask.extend([mask_val] * len(token_ids))
|
| 280 |
+
|
| 281 |
+
# sometimes the first message is a system message...
|
| 282 |
+
# => just merge it with the second (user) message
|
| 283 |
+
if conversation["messages"][0]["role"] == "system":
|
| 284 |
+
# some conversation surgery is necessary here for now...
|
| 285 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 286 |
+
messages = conversation["messages"]
|
| 287 |
+
assert messages[1]["role"] == "user", "System message must be followed by a user message"
|
| 288 |
+
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
|
| 289 |
+
messages = messages[1:]
|
| 290 |
+
else:
|
| 291 |
+
messages = conversation["messages"]
|
| 292 |
+
assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
|
| 293 |
+
|
| 294 |
+
# fetch all the special tokens we need
|
| 295 |
+
bos = self.get_bos_token_id()
|
| 296 |
+
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
|
| 297 |
+
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
|
| 298 |
+
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
|
| 299 |
+
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
|
| 300 |
+
|
| 301 |
+
# now we can tokenize the conversation
|
| 302 |
+
add_tokens(bos, 0)
|
| 303 |
+
for i, message in enumerate(messages):
|
| 304 |
+
|
| 305 |
+
# some sanity checking here around assumptions, to prevent footguns
|
| 306 |
+
must_be_from = "user" if i % 2 == 0 else "assistant"
|
| 307 |
+
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
| 308 |
+
|
| 309 |
+
# content can be either a simple string or a list of parts (e.g. containing tool calls)
|
| 310 |
+
content = message["content"]
|
| 311 |
+
|
| 312 |
+
if message["role"] == "user":
|
| 313 |
+
assert isinstance(content, str), "User messages are simply expected to be strings"
|
| 314 |
+
value_ids = self.encode(content)
|
| 315 |
+
add_tokens(user_start, 0)
|
| 316 |
+
add_tokens(value_ids, 0)
|
| 317 |
+
add_tokens(user_end, 0)
|
| 318 |
+
elif message["role"] == "assistant":
|
| 319 |
+
add_tokens(assistant_start, 0)
|
| 320 |
+
if isinstance(content, str):
|
| 321 |
+
# simple string => simply add the tokens
|
| 322 |
+
value_ids = self.encode(content)
|
| 323 |
+
add_tokens(value_ids, 1)
|
| 324 |
+
elif isinstance(content, list):
|
| 325 |
+
for part in content:
|
| 326 |
+
value_ids = self.encode(part["text"])
|
| 327 |
+
if part["type"] == "text":
|
| 328 |
+
# string part => simply add the tokens
|
| 329 |
+
add_tokens(value_ids, 1)
|
| 330 |
+
elif part["type"] == "python":
|
| 331 |
+
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
|
| 332 |
+
add_tokens(python_start, 1)
|
| 333 |
+
add_tokens(value_ids, 1)
|
| 334 |
+
add_tokens(python_end, 1)
|
| 335 |
+
elif part["type"] == "python_output":
|
| 336 |
+
# python output => add the tokens inside <|output_start|> and <|output_end|>
|
| 337 |
+
# none of these tokens are supervised because the tokens come from Python at test time
|
| 338 |
+
add_tokens(output_start, 0)
|
| 339 |
+
add_tokens(value_ids, 0)
|
| 340 |
+
add_tokens(output_end, 0)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError(f"Unknown part type: {part['type']}")
|
| 343 |
+
else:
|
| 344 |
+
raise ValueError(f"Unknown content type: {type(content)}")
|
| 345 |
+
add_tokens(assistant_end, 1)
|
| 346 |
+
|
| 347 |
+
# truncate to max_tokens tokens MAX (helps prevent OOMs)
|
| 348 |
+
ids = ids[:max_tokens]
|
| 349 |
+
mask = mask[:max_tokens]
|
| 350 |
+
return ids, mask
|
| 351 |
+
|
| 352 |
+
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
| 353 |
+
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
| 354 |
+
RED = '\033[91m'
|
| 355 |
+
GREEN = '\033[92m'
|
| 356 |
+
RESET = '\033[0m'
|
| 357 |
+
GRAY = '\033[90m'
|
| 358 |
+
tokens = []
|
| 359 |
+
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
| 360 |
+
token_str = self.decode([token_id])
|
| 361 |
+
color = GREEN if mask_val == 1 else RED
|
| 362 |
+
tokens.append(f"{color}{token_str}{RESET}")
|
| 363 |
+
if with_token_id:
|
| 364 |
+
tokens.append(f"{GRAY}({token_id}){RESET}")
|
| 365 |
+
return '|'.join(tokens)
|
| 366 |
+
|
| 367 |
+
def render_for_completion(self, conversation):
|
| 368 |
+
"""
|
| 369 |
+
Used during Reinforcement Learning. In that setting, we want to
|
| 370 |
+
render the conversation priming the Assistant for a completion.
|
| 371 |
+
Unlike the Chat SFT case, we don't need to return the mask.
|
| 372 |
+
"""
|
| 373 |
+
# We have some surgery to do: we need to pop the last message (of the Assistant)
|
| 374 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 375 |
+
messages = conversation["messages"]
|
| 376 |
+
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
|
| 377 |
+
messages.pop() # remove the last message (of the Assistant) inplace
|
| 378 |
+
|
| 379 |
+
# Now tokenize the conversation
|
| 380 |
+
ids, mask = self.render_conversation(conversation)
|
| 381 |
+
|
| 382 |
+
# Finally, to prime the Assistant for a completion, append the Assistant start token
|
| 383 |
+
assistant_start = self.encode_special("<|assistant_start|>")
|
| 384 |
+
ids.append(assistant_start)
|
| 385 |
+
return ids
|
| 386 |
+
|
| 387 |
+
# -----------------------------------------------------------------------------
|
| 388 |
+
# nanochat-specific convenience functions
|
| 389 |
+
|
| 390 |
+
def get_tokenizer():
|
| 391 |
+
from nanochat.common import get_base_dir
|
| 392 |
+
base_dir = get_base_dir()
|
| 393 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 394 |
+
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
|
| 395 |
+
return RustBPETokenizer.from_directory(tokenizer_dir)
|
| 396 |
+
|
| 397 |
+
def get_token_bytes(device="cpu"):
|
| 398 |
+
import torch
|
| 399 |
+
from nanochat.common import get_base_dir
|
| 400 |
+
base_dir = get_base_dir()
|
| 401 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 402 |
+
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
| 403 |
+
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
| 404 |
+
with open(token_bytes_path, "rb") as f:
|
| 405 |
+
token_bytes = torch.load(f, map_location=device)
|
| 406 |
+
return token_bytes
|
nanochat/ui.html
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
| 6 |
+
<title>NanoChat</title>
|
| 7 |
+
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
color-scheme: light;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
* {
|
| 14 |
+
box-sizing: border-box;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
html, body{
|
| 18 |
+
height: 100%;
|
| 19 |
+
margin: 0;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
body {
|
| 23 |
+
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
| 24 |
+
background-color: #ffffff;
|
| 25 |
+
color: #111827;
|
| 26 |
+
min-height: 100dvh;
|
| 27 |
+
margin: 0;
|
| 28 |
+
display: flex;
|
| 29 |
+
flex-direction: column;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.header {
|
| 33 |
+
background-color: #ffffff;
|
| 34 |
+
padding: 1.25rem 1.5rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.header-left {
|
| 38 |
+
display: flex;
|
| 39 |
+
align-items: center;
|
| 40 |
+
gap: 0.75rem;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.header-logo {
|
| 44 |
+
height: 32px;
|
| 45 |
+
width: auto;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.header h1 {
|
| 49 |
+
font-size: 1.25rem;
|
| 50 |
+
font-weight: 600;
|
| 51 |
+
margin: 0;
|
| 52 |
+
color: #111827;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.new-conversation-btn {
|
| 56 |
+
width: 32px;
|
| 57 |
+
height: 32px;
|
| 58 |
+
padding: 0;
|
| 59 |
+
border: 1px solid #e5e7eb;
|
| 60 |
+
border-radius: 0.5rem;
|
| 61 |
+
background-color: #ffffff;
|
| 62 |
+
color: #6b7280;
|
| 63 |
+
cursor: pointer;
|
| 64 |
+
display: flex;
|
| 65 |
+
align-items: center;
|
| 66 |
+
justify-content: center;
|
| 67 |
+
transition: all 0.2s ease;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.new-conversation-btn:hover {
|
| 71 |
+
background-color: #f3f4f6;
|
| 72 |
+
border-color: #d1d5db;
|
| 73 |
+
color: #374151;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.chat-container {
|
| 77 |
+
flex: 1;
|
| 78 |
+
overflow-y: auto;
|
| 79 |
+
background-color: #ffffff;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.chat-wrapper {
|
| 83 |
+
max-width: 48rem;
|
| 84 |
+
margin: 0 auto;
|
| 85 |
+
padding: 2rem 1.5rem 3rem;
|
| 86 |
+
display: flex;
|
| 87 |
+
flex-direction: column;
|
| 88 |
+
gap: 0.75rem;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.message {
|
| 92 |
+
display: flex;
|
| 93 |
+
justify-content: flex-start;
|
| 94 |
+
margin-bottom: 0.5rem;
|
| 95 |
+
color: #0d0d0d;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.message.assistant {
|
| 99 |
+
justify-content: flex-start;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.message.user {
|
| 103 |
+
justify-content: flex-end;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.message-content {
|
| 107 |
+
white-space: pre-wrap;
|
| 108 |
+
line-height: 1.6;
|
| 109 |
+
max-width: 100%;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.message.assistant .message-content {
|
| 113 |
+
background: transparent;
|
| 114 |
+
border: none;
|
| 115 |
+
cursor: pointer;
|
| 116 |
+
border-radius: 0.5rem;
|
| 117 |
+
padding: 0.5rem;
|
| 118 |
+
margin-left: -0.5rem;
|
| 119 |
+
transition: background-color 0.2s ease;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.message.assistant .message-content:hover {
|
| 123 |
+
background-color: #f9fafb;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.message.user .message-content {
|
| 127 |
+
background-color: #f3f4f6;
|
| 128 |
+
border-radius: 1.25rem;
|
| 129 |
+
padding: 0.8rem 1rem;
|
| 130 |
+
max-width: 65%;
|
| 131 |
+
cursor: pointer;
|
| 132 |
+
transition: background-color 0.2s ease;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.message.user .message-content:hover {
|
| 136 |
+
background-color: #e5e7eb;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.message.console .message-content {
|
| 140 |
+
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
|
| 141 |
+
font-size: 0.875rem;
|
| 142 |
+
background-color: #fafafa;
|
| 143 |
+
padding: 0.75rem 1rem;
|
| 144 |
+
color: #374151;
|
| 145 |
+
max-width: 80%;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.input-container {
|
| 149 |
+
background-color: #ffffff;
|
| 150 |
+
padding: 1rem;
|
| 151 |
+
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.input-wrapper {
|
| 155 |
+
max-width: 48rem;
|
| 156 |
+
margin: 0 auto;
|
| 157 |
+
display: flex;
|
| 158 |
+
gap: 0.75rem;
|
| 159 |
+
align-items: flex-end;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.chat-input {
|
| 163 |
+
flex: 1;
|
| 164 |
+
padding: 0.8rem 1rem;
|
| 165 |
+
border: 1px solid #d1d5db;
|
| 166 |
+
border-radius: 0.75rem;
|
| 167 |
+
background-color: #ffffff;
|
| 168 |
+
color: #111827;
|
| 169 |
+
font-size: 1rem;
|
| 170 |
+
line-height: 1.5;
|
| 171 |
+
resize: none;
|
| 172 |
+
outline: none;
|
| 173 |
+
min-height: 54px;
|
| 174 |
+
max-height: 200px;
|
| 175 |
+
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.chat-input::placeholder {
|
| 179 |
+
color: #9ca3af;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.chat-input:focus {
|
| 183 |
+
border-color: #2563eb;
|
| 184 |
+
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.send-button {
|
| 188 |
+
flex-shrink: 0;
|
| 189 |
+
padding: 0;
|
| 190 |
+
width: 54px;
|
| 191 |
+
height: 54px;
|
| 192 |
+
border: 1px solid #111827;
|
| 193 |
+
border-radius: 0.75rem;
|
| 194 |
+
background-color: #111827;
|
| 195 |
+
color: #ffffff;
|
| 196 |
+
display: flex;
|
| 197 |
+
align-items: center;
|
| 198 |
+
justify-content: center;
|
| 199 |
+
cursor: pointer;
|
| 200 |
+
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.send-button:hover:not(:disabled) {
|
| 204 |
+
background-color: #2563eb;
|
| 205 |
+
border-color: #2563eb;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.send-button:disabled {
|
| 209 |
+
cursor: not-allowed;
|
| 210 |
+
border-color: #d1d5db;
|
| 211 |
+
background-color: #e5e7eb;
|
| 212 |
+
color: #9ca3af;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.typing-indicator {
|
| 216 |
+
display: inline-block;
|
| 217 |
+
color: #6b7280;
|
| 218 |
+
letter-spacing: 0.15em;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.typing-indicator::after {
|
| 222 |
+
content: '···';
|
| 223 |
+
animation: typing 1.4s infinite;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
@keyframes typing {
|
| 227 |
+
0%, 60%, 100% { opacity: 0.2; }
|
| 228 |
+
30% { opacity: 1; }
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
.error-message {
|
| 232 |
+
background-color: #fee2e2;
|
| 233 |
+
border: 1px solid #fecaca;
|
| 234 |
+
color: #b91c1c;
|
| 235 |
+
padding: 0.75rem 1rem;
|
| 236 |
+
border-radius: 0.75rem;
|
| 237 |
+
margin-top: 0.5rem;
|
| 238 |
+
}
|
| 239 |
+
</style>
|
| 240 |
+
</head>
|
| 241 |
+
<body>
|
| 242 |
+
<div class="header">
|
| 243 |
+
<div class="header-left">
|
| 244 |
+
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
|
| 245 |
+
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 246 |
+
<path d="M12 5v14"></path>
|
| 247 |
+
<path d="M5 12h14"></path>
|
| 248 |
+
</svg>
|
| 249 |
+
</button>
|
| 250 |
+
<h1>nanochat</h1>
|
| 251 |
+
</div>
|
| 252 |
+
</div>
|
| 253 |
+
|
| 254 |
+
<div class="chat-container" id="chatContainer">
|
| 255 |
+
<div class="chat-wrapper" id="chatWrapper">
|
| 256 |
+
<!-- Messages will be added here -->
|
| 257 |
+
</div>
|
| 258 |
+
</div>
|
| 259 |
+
|
| 260 |
+
<div class="input-container">
|
| 261 |
+
<div class="input-wrapper">
|
| 262 |
+
<textarea
|
| 263 |
+
id="chatInput"
|
| 264 |
+
class="chat-input"
|
| 265 |
+
placeholder="Ask anything"
|
| 266 |
+
rows="1"
|
| 267 |
+
onkeydown="handleKeyDown(event)"
|
| 268 |
+
></textarea>
|
| 269 |
+
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
|
| 270 |
+
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 271 |
+
<path d="M22 2L11 13"></path>
|
| 272 |
+
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
|
| 273 |
+
</svg>
|
| 274 |
+
</button>
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
|
| 278 |
+
<script>
|
| 279 |
+
const API_URL = '';
|
| 280 |
+
const chatContainer = document.getElementById('chatContainer');
|
| 281 |
+
const chatWrapper = document.getElementById('chatWrapper');
|
| 282 |
+
const chatInput = document.getElementById('chatInput');
|
| 283 |
+
const sendButton = document.getElementById('sendButton');
|
| 284 |
+
|
| 285 |
+
let messages = [];
|
| 286 |
+
let isGenerating = false;
|
| 287 |
+
let currentTemperature = 0.8;
|
| 288 |
+
let currentTopK = 50;
|
| 289 |
+
|
| 290 |
+
chatInput.addEventListener('input', function() {
|
| 291 |
+
this.style.height = 'auto';
|
| 292 |
+
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
|
| 293 |
+
sendButton.disabled = !this.value.trim() || isGenerating;
|
| 294 |
+
});
|
| 295 |
+
|
| 296 |
+
function handleKeyDown(event) {
|
| 297 |
+
if (event.key === 'Enter' && !event.shiftKey) {
|
| 298 |
+
event.preventDefault();
|
| 299 |
+
sendMessage();
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
document.addEventListener('keydown', function(event) {
|
| 304 |
+
// Ctrl+Shift+N for new conversation
|
| 305 |
+
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
|
| 306 |
+
event.preventDefault();
|
| 307 |
+
if (!isGenerating) {
|
| 308 |
+
newConversation();
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
});
|
| 312 |
+
|
| 313 |
+
function newConversation() {
|
| 314 |
+
messages = [];
|
| 315 |
+
chatWrapper.innerHTML = '';
|
| 316 |
+
chatInput.value = '';
|
| 317 |
+
chatInput.style.height = 'auto';
|
| 318 |
+
sendButton.disabled = false;
|
| 319 |
+
isGenerating = false;
|
| 320 |
+
chatInput.focus();
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
function addMessage(role, content, messageIndex = null) {
|
| 324 |
+
const messageDiv = document.createElement('div');
|
| 325 |
+
messageDiv.className = `message ${role}`;
|
| 326 |
+
|
| 327 |
+
const contentDiv = document.createElement('div');
|
| 328 |
+
contentDiv.className = 'message-content';
|
| 329 |
+
contentDiv.textContent = content;
|
| 330 |
+
|
| 331 |
+
// Add click handler for user messages to enable editing
|
| 332 |
+
if (role === 'user' && messageIndex !== null) {
|
| 333 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 334 |
+
contentDiv.setAttribute('title', 'Click to edit and restart from here');
|
| 335 |
+
contentDiv.addEventListener('click', function() {
|
| 336 |
+
if (!isGenerating) {
|
| 337 |
+
editMessage(messageIndex);
|
| 338 |
+
}
|
| 339 |
+
});
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Add click handler for assistant messages to enable regeneration
|
| 343 |
+
if (role === 'assistant' && messageIndex !== null) {
|
| 344 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 345 |
+
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
| 346 |
+
contentDiv.addEventListener('click', function() {
|
| 347 |
+
if (!isGenerating) {
|
| 348 |
+
regenerateMessage(messageIndex);
|
| 349 |
+
}
|
| 350 |
+
});
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
messageDiv.appendChild(contentDiv);
|
| 354 |
+
chatWrapper.appendChild(messageDiv);
|
| 355 |
+
|
| 356 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 357 |
+
return contentDiv;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
function editMessage(messageIndex) {
|
| 361 |
+
// Find the message in the messages array
|
| 362 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 363 |
+
|
| 364 |
+
const messageToEdit = messages[messageIndex];
|
| 365 |
+
if (messageToEdit.role !== 'user') return;
|
| 366 |
+
|
| 367 |
+
// Copy message content to input
|
| 368 |
+
chatInput.value = messageToEdit.content;
|
| 369 |
+
chatInput.style.height = 'auto';
|
| 370 |
+
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
|
| 371 |
+
|
| 372 |
+
// Remove this message and all subsequent messages from the array
|
| 373 |
+
messages = messages.slice(0, messageIndex);
|
| 374 |
+
|
| 375 |
+
// Remove message elements from DOM starting from messageIndex
|
| 376 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 377 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 378 |
+
allMessages[i].remove();
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
// Enable send button and focus input
|
| 382 |
+
sendButton.disabled = false;
|
| 383 |
+
chatInput.focus();
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
async function generateAssistantResponse() {
|
| 387 |
+
isGenerating = true;
|
| 388 |
+
sendButton.disabled = true;
|
| 389 |
+
|
| 390 |
+
const assistantContent = addMessage('assistant', '');
|
| 391 |
+
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
| 392 |
+
|
| 393 |
+
try {
|
| 394 |
+
const response = await fetch(`${API_URL}/chat/completions`, {
|
| 395 |
+
method: 'POST',
|
| 396 |
+
headers: {
|
| 397 |
+
'Content-Type': 'application/json',
|
| 398 |
+
},
|
| 399 |
+
body: JSON.stringify({
|
| 400 |
+
messages: messages,
|
| 401 |
+
temperature: currentTemperature,
|
| 402 |
+
top_k: currentTopK,
|
| 403 |
+
max_tokens: 512
|
| 404 |
+
}),
|
| 405 |
+
});
|
| 406 |
+
|
| 407 |
+
if (!response.ok) {
|
| 408 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
const reader = response.body.getReader();
|
| 412 |
+
const decoder = new TextDecoder();
|
| 413 |
+
let fullResponse = '';
|
| 414 |
+
assistantContent.textContent = '';
|
| 415 |
+
|
| 416 |
+
while (true) {
|
| 417 |
+
const { done, value } = await reader.read();
|
| 418 |
+
if (done) break;
|
| 419 |
+
|
| 420 |
+
const chunk = decoder.decode(value);
|
| 421 |
+
const lines = chunk.split('\n');
|
| 422 |
+
|
| 423 |
+
for (const line of lines) {
|
| 424 |
+
if (line.startsWith('data: ')) {
|
| 425 |
+
try {
|
| 426 |
+
const data = JSON.parse(line.slice(6));
|
| 427 |
+
if (data.token) {
|
| 428 |
+
fullResponse += data.token;
|
| 429 |
+
assistantContent.textContent = fullResponse;
|
| 430 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 431 |
+
}
|
| 432 |
+
} catch (e) {
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
const assistantMessageIndex = messages.length;
|
| 439 |
+
messages.push({ role: 'assistant', content: fullResponse });
|
| 440 |
+
|
| 441 |
+
// Add click handler to regenerate this assistant message
|
| 442 |
+
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
| 443 |
+
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
| 444 |
+
assistantContent.addEventListener('click', function() {
|
| 445 |
+
if (!isGenerating) {
|
| 446 |
+
regenerateMessage(assistantMessageIndex);
|
| 447 |
+
}
|
| 448 |
+
});
|
| 449 |
+
|
| 450 |
+
} catch (error) {
|
| 451 |
+
console.error('Error:', error);
|
| 452 |
+
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
| 453 |
+
} finally {
|
| 454 |
+
isGenerating = false;
|
| 455 |
+
sendButton.disabled = !chatInput.value.trim();
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
async function regenerateMessage(messageIndex) {
|
| 460 |
+
// Find the message in the messages array
|
| 461 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 462 |
+
|
| 463 |
+
const messageToRegenerate = messages[messageIndex];
|
| 464 |
+
if (messageToRegenerate.role !== 'assistant') return;
|
| 465 |
+
|
| 466 |
+
// Remove this message and all subsequent messages from the array
|
| 467 |
+
messages = messages.slice(0, messageIndex);
|
| 468 |
+
|
| 469 |
+
// Remove message elements from DOM starting from messageIndex
|
| 470 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 471 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 472 |
+
allMessages[i].remove();
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
// Regenerate the assistant response
|
| 476 |
+
await generateAssistantResponse();
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
function handleSlashCommand(command) {
|
| 480 |
+
const parts = command.trim().split(/\s+/);
|
| 481 |
+
const cmd = parts[0].toLowerCase();
|
| 482 |
+
const arg = parts[1];
|
| 483 |
+
|
| 484 |
+
if (cmd === '/temperature') {
|
| 485 |
+
if (arg === undefined) {
|
| 486 |
+
addMessage('console', `Current temperature: ${currentTemperature}`);
|
| 487 |
+
} else {
|
| 488 |
+
const temp = parseFloat(arg);
|
| 489 |
+
if (isNaN(temp) || temp < 0 || temp > 2) {
|
| 490 |
+
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
|
| 491 |
+
} else {
|
| 492 |
+
currentTemperature = temp;
|
| 493 |
+
addMessage('console', `Temperature set to ${currentTemperature}`);
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
return true;
|
| 497 |
+
} else if (cmd === '/topk') {
|
| 498 |
+
if (arg === undefined) {
|
| 499 |
+
addMessage('console', `Current top-k: ${currentTopK}`);
|
| 500 |
+
} else {
|
| 501 |
+
const topk = parseInt(arg);
|
| 502 |
+
if (isNaN(topk) || topk < 1 || topk > 200) {
|
| 503 |
+
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
|
| 504 |
+
} else {
|
| 505 |
+
currentTopK = topk;
|
| 506 |
+
addMessage('console', `Top-k set to ${currentTopK}`);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
return true;
|
| 510 |
+
} else if (cmd === '/clear') {
|
| 511 |
+
newConversation();
|
| 512 |
+
return true;
|
| 513 |
+
} else if (cmd === '/help') {
|
| 514 |
+
addMessage('console',
|
| 515 |
+
'Available commands:\n' +
|
| 516 |
+
'/temperature - Show current temperature\n' +
|
| 517 |
+
'/temperature <value> - Set temperature (0.0-2.0)\n' +
|
| 518 |
+
'/topk - Show current top-k\n' +
|
| 519 |
+
'/topk <value> - Set top-k (1-200)\n' +
|
| 520 |
+
'/clear - Clear conversation\n' +
|
| 521 |
+
'/help - Show this help message'
|
| 522 |
+
);
|
| 523 |
+
return true;
|
| 524 |
+
}
|
| 525 |
+
return false;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
async function sendMessage() {
|
| 529 |
+
const message = chatInput.value.trim();
|
| 530 |
+
if (!message || isGenerating) return;
|
| 531 |
+
|
| 532 |
+
// Handle slash commands
|
| 533 |
+
if (message.startsWith('/')) {
|
| 534 |
+
chatInput.value = '';
|
| 535 |
+
chatInput.style.height = 'auto';
|
| 536 |
+
handleSlashCommand(message);
|
| 537 |
+
return;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
chatInput.value = '';
|
| 541 |
+
chatInput.style.height = 'auto';
|
| 542 |
+
|
| 543 |
+
const userMessageIndex = messages.length;
|
| 544 |
+
messages.push({ role: 'user', content: message });
|
| 545 |
+
addMessage('user', message, userMessageIndex);
|
| 546 |
+
|
| 547 |
+
await generateAssistantResponse();
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
sendButton.disabled = false;
|
| 551 |
+
|
| 552 |
+
// Autofocus the chat input on page load
|
| 553 |
+
chatInput.focus();
|
| 554 |
+
|
| 555 |
+
fetch(`${API_URL}/health`)
|
| 556 |
+
.then(response => response.json())
|
| 557 |
+
.then(data => {
|
| 558 |
+
console.log('Engine status:', data);
|
| 559 |
+
})
|
| 560 |
+
.catch(error => {
|
| 561 |
+
console.error('Engine not available:', error);
|
| 562 |
+
chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
|
| 563 |
+
});
|
| 564 |
+
</script>
|
| 565 |
+
</body>
|
| 566 |
+
</html>
|
pyproject.toml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "nanochat"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "the minimal full-stack ChatGPT clone"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"datasets>=4.0.0",
|
| 9 |
+
"fastapi>=0.117.1",
|
| 10 |
+
"ipykernel>=7.1.0",
|
| 11 |
+
"kernels>=0.11.7",
|
| 12 |
+
"matplotlib>=3.10.8",
|
| 13 |
+
"psutil>=7.1.0",
|
| 14 |
+
"python-dotenv>=1.2.1",
|
| 15 |
+
"regex>=2025.9.1",
|
| 16 |
+
"rustbpe>=0.1.0",
|
| 17 |
+
"scipy>=1.15.3",
|
| 18 |
+
"setuptools>=80.9.0",
|
| 19 |
+
"tabulate>=0.9.0",
|
| 20 |
+
"tiktoken>=0.11.0",
|
| 21 |
+
"tokenizers>=0.22.0",
|
| 22 |
+
"torch==2.9.1",
|
| 23 |
+
"transformers>=4.57.3",
|
| 24 |
+
"uvicorn>=0.36.0",
|
| 25 |
+
"wandb>=0.21.3",
|
| 26 |
+
"zstandard>=0.25.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[dependency-groups]
|
| 30 |
+
dev = [
|
| 31 |
+
"pytest>=8.0.0",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[tool.pytest.ini_options]
|
| 35 |
+
markers = [
|
| 36 |
+
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
| 37 |
+
]
|
| 38 |
+
testpaths = ["tests"]
|
| 39 |
+
python_files = ["test_*.py"]
|
| 40 |
+
python_classes = ["Test*"]
|
| 41 |
+
python_functions = ["test_*"]
|
| 42 |
+
|
| 43 |
+
# target torch to cuda 12.8 or CPU
|
| 44 |
+
[tool.uv.sources]
|
| 45 |
+
torch = [
|
| 46 |
+
{ index = "pytorch-cpu", extra = "cpu" },
|
| 47 |
+
{ index = "pytorch-cu128", extra = "gpu" },
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
[[tool.uv.index]]
|
| 51 |
+
name = "pytorch-cpu"
|
| 52 |
+
url = "https://download.pytorch.org/whl/cpu"
|
| 53 |
+
explicit = true
|
| 54 |
+
|
| 55 |
+
[[tool.uv.index]]
|
| 56 |
+
name = "pytorch-cu128"
|
| 57 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 58 |
+
explicit = true
|
| 59 |
+
|
| 60 |
+
[project.optional-dependencies]
|
| 61 |
+
cpu = [
|
| 62 |
+
"torch==2.9.1",
|
| 63 |
+
]
|
| 64 |
+
gpu = [
|
| 65 |
+
"torch==2.9.1",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
[tool.uv]
|
| 69 |
+
conflicts = [
|
| 70 |
+
[
|
| 71 |
+
{ extra = "cpu" },
|
| 72 |
+
{ extra = "gpu" },
|
| 73 |
+
],
|
| 74 |
+
]
|
scripts/__pycache__/base_eval.cpython-310.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
scripts/__pycache__/base_train.cpython-310.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
scripts/__pycache__/chat_eval.cpython-310.pyc
ADDED
|
Binary file (7.52 kB). View file
|
|
|
scripts/__pycache__/chat_sft.cpython-310.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
scripts/__pycache__/tok_eval.cpython-310.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
scripts/__pycache__/tok_train.cpython-310.pyc
ADDED
|
Binary file (2.98 kB). View file
|
|
|
scripts/base_eval.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified evaluation script for base models.
|
| 3 |
+
|
| 4 |
+
Supports three evaluation modes (comma-separated):
|
| 5 |
+
--eval core : CORE metric (accuracy on ICL tasks)
|
| 6 |
+
--eval bpb : Bits per byte on train/val splits
|
| 7 |
+
--eval sample : Generate samples from the model
|
| 8 |
+
|
| 9 |
+
Default is all three: --eval core,bpb,sample
|
| 10 |
+
|
| 11 |
+
Examples:
|
| 12 |
+
|
| 13 |
+
# Evaluate a HuggingFace model (e.g. GPT-2 124M) using 8 GPUs
|
| 14 |
+
torchrun --nproc_per_node=8 -m scripts.base_eval --hf-path openai-community/gpt2
|
| 15 |
+
|
| 16 |
+
# Evaluate a nanochat model (e.g. d24) using 8 GPUs
|
| 17 |
+
torchrun --nproc_per_node=8 -m scripts.base_eval --model-tag d24 --device-batch-size=16
|
| 18 |
+
|
| 19 |
+
# Quick/approximate evaluation using a single GPU
|
| 20 |
+
python -m scripts.base_eval --model-tag d24 --device-batch-size=16 --max-per-task=100 --split-tokens=524288
|
| 21 |
+
"""
|
| 22 |
+
import os
|
| 23 |
+
import csv
|
| 24 |
+
import time
|
| 25 |
+
import json
|
| 26 |
+
import yaml
|
| 27 |
+
import shutil
|
| 28 |
+
import random
|
| 29 |
+
import zipfile
|
| 30 |
+
import tempfile
|
| 31 |
+
import argparse
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
| 35 |
+
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
|
| 36 |
+
from nanochat.checkpoint_manager import load_model
|
| 37 |
+
from nanochat.core_eval import evaluate_task
|
| 38 |
+
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
| 39 |
+
from nanochat.loss_eval import evaluate_bpb
|
| 40 |
+
from nanochat.engine import Engine
|
| 41 |
+
|
| 42 |
+
# -----------------------------------------------------------------------------
|
| 43 |
+
# HuggingFace loading utilities
|
| 44 |
+
|
| 45 |
+
class ModelWrapper:
|
| 46 |
+
"""Lightweight wrapper to give HuggingFace models a nanochat-compatible interface."""
|
| 47 |
+
def __init__(self, model, max_seq_len=None):
|
| 48 |
+
self.model = model
|
| 49 |
+
self.max_seq_len = max_seq_len
|
| 50 |
+
|
| 51 |
+
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
|
| 52 |
+
logits = self.model(input_ids).logits
|
| 53 |
+
if targets is None:
|
| 54 |
+
return logits
|
| 55 |
+
loss = torch.nn.functional.cross_entropy(
|
| 56 |
+
logits.view(-1, logits.size(-1)),
|
| 57 |
+
targets.view(-1),
|
| 58 |
+
ignore_index=-1,
|
| 59 |
+
reduction=loss_reduction
|
| 60 |
+
)
|
| 61 |
+
return loss
|
| 62 |
+
|
| 63 |
+
def get_device(self):
|
| 64 |
+
return next(self.model.parameters()).device
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_hf_model(hf_path: str, device):
|
| 68 |
+
"""Load a HuggingFace model and tokenizer."""
|
| 69 |
+
print0(f"Loading HuggingFace model from: {hf_path}")
|
| 70 |
+
from transformers import AutoModelForCausalLM
|
| 71 |
+
model = AutoModelForCausalLM.from_pretrained(hf_path)
|
| 72 |
+
model.to(device)
|
| 73 |
+
model.eval()
|
| 74 |
+
max_seq_len = 1024 if "gpt2" in hf_path else None
|
| 75 |
+
model = ModelWrapper(model, max_seq_len=max_seq_len)
|
| 76 |
+
tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
|
| 77 |
+
return model, tokenizer
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_hf_token_bytes(tokenizer, device="cpu"):
|
| 81 |
+
"""Compute token_bytes tensor for a HuggingFace tokenizer."""
|
| 82 |
+
vocab_size = tokenizer.tokenizer.get_vocab_size()
|
| 83 |
+
token_bytes = torch.zeros(vocab_size, dtype=torch.int64, device=device)
|
| 84 |
+
for token_id in range(vocab_size):
|
| 85 |
+
token_str = tokenizer.tokenizer.decode([token_id])
|
| 86 |
+
token_bytes[token_id] = len(token_str.encode('utf-8'))
|
| 87 |
+
return token_bytes
|
| 88 |
+
|
| 89 |
+
# -----------------------------------------------------------------------------
|
| 90 |
+
# CORE evaluation
|
| 91 |
+
|
| 92 |
+
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def place_eval_bundle(file_path):
|
| 96 |
+
"""Unzip eval_bundle.zip and place it in the base directory."""
|
| 97 |
+
base_dir = get_base_dir()
|
| 98 |
+
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
| 99 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 100 |
+
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
| 101 |
+
zip_ref.extractall(tmpdir)
|
| 102 |
+
extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
|
| 103 |
+
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
| 104 |
+
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def evaluate_core(model, tokenizer, device, max_per_task=-1):
|
| 108 |
+
"""
|
| 109 |
+
Evaluate a base model on the CORE benchmark.
|
| 110 |
+
Returns dict with results, centered_results, and core_metric.
|
| 111 |
+
"""
|
| 112 |
+
base_dir = get_base_dir()
|
| 113 |
+
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
|
| 114 |
+
# Download the eval bundle if needed
|
| 115 |
+
if not os.path.exists(eval_bundle_dir):
|
| 116 |
+
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
|
| 117 |
+
|
| 118 |
+
config_path = os.path.join(eval_bundle_dir, "core.yaml")
|
| 119 |
+
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
|
| 120 |
+
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
|
| 121 |
+
|
| 122 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 123 |
+
config = yaml.safe_load(f)
|
| 124 |
+
tasks = config['icl_tasks']
|
| 125 |
+
|
| 126 |
+
# Load random baseline values
|
| 127 |
+
random_baselines = {}
|
| 128 |
+
with open(eval_meta_data, 'r', encoding='utf-8') as f:
|
| 129 |
+
reader = csv.DictReader(f)
|
| 130 |
+
for row in reader:
|
| 131 |
+
task_name = row['Eval Task']
|
| 132 |
+
random_baseline = row['Random baseline']
|
| 133 |
+
random_baselines[task_name] = float(random_baseline)
|
| 134 |
+
|
| 135 |
+
# Evaluate each task
|
| 136 |
+
results = {}
|
| 137 |
+
centered_results = {}
|
| 138 |
+
for task in tasks:
|
| 139 |
+
start_time = time.time()
|
| 140 |
+
label = task['label']
|
| 141 |
+
task_meta = {
|
| 142 |
+
'task_type': task['icl_task_type'],
|
| 143 |
+
'dataset_uri': task['dataset_uri'],
|
| 144 |
+
'num_fewshot': task['num_fewshot'][0],
|
| 145 |
+
'continuation_delimiter': task.get('continuation_delimiter', ' ')
|
| 146 |
+
}
|
| 147 |
+
print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
|
| 148 |
+
|
| 149 |
+
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
|
| 150 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
| 151 |
+
data = [json.loads(line.strip()) for line in f]
|
| 152 |
+
|
| 153 |
+
# Shuffle for consistent subsampling when using max_per_task
|
| 154 |
+
shuffle_rng = random.Random(1337)
|
| 155 |
+
shuffle_rng.shuffle(data)
|
| 156 |
+
if max_per_task > 0:
|
| 157 |
+
data = data[:max_per_task]
|
| 158 |
+
|
| 159 |
+
accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
|
| 160 |
+
results[label] = accuracy
|
| 161 |
+
random_baseline = random_baselines[label]
|
| 162 |
+
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
| 163 |
+
centered_results[label] = centered_result
|
| 164 |
+
elapsed = time.time() - start_time
|
| 165 |
+
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {elapsed:.2f}s")
|
| 166 |
+
|
| 167 |
+
core_metric = sum(centered_results.values()) / len(centered_results)
|
| 168 |
+
out = {
|
| 169 |
+
"results": results,
|
| 170 |
+
"centered_results": centered_results,
|
| 171 |
+
"core_metric": core_metric
|
| 172 |
+
}
|
| 173 |
+
return out
|
| 174 |
+
|
| 175 |
+
# -----------------------------------------------------------------------------
|
| 176 |
+
# Main
|
| 177 |
+
|
| 178 |
+
def main():
|
| 179 |
+
parser = argparse.ArgumentParser(description="Base model evaluation")
|
| 180 |
+
parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)')
|
| 181 |
+
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)')
|
| 182 |
+
parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory')
|
| 183 |
+
parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)')
|
| 184 |
+
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)')
|
| 185 |
+
parser.add_argument('--device-batch-size', type=int, default=32, help='Per-device batch size for BPB evaluation')
|
| 186 |
+
parser.add_argument('--split-tokens', type=int, default=40*524288, help='Number of tokens to evaluate per split for BPB')
|
| 187 |
+
parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps (empty = autodetect)')
|
| 188 |
+
args = parser.parse_args()
|
| 189 |
+
|
| 190 |
+
# Parse evaluation modes
|
| 191 |
+
eval_modes = set(mode.strip() for mode in args.eval.split(','))
|
| 192 |
+
valid_modes = {'core', 'bpb', 'sample'}
|
| 193 |
+
invalid = eval_modes - valid_modes
|
| 194 |
+
if invalid:
|
| 195 |
+
parser.error(f"Invalid eval modes: {invalid}. Valid: {valid_modes}")
|
| 196 |
+
|
| 197 |
+
# Distributed / precision setup
|
| 198 |
+
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
|
| 199 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 200 |
+
# Load model and tokenizer
|
| 201 |
+
is_hf_model = args.hf_path is not None
|
| 202 |
+
if is_hf_model:
|
| 203 |
+
model, tokenizer = load_hf_model(args.hf_path, device)
|
| 204 |
+
sequence_len = model.max_seq_len or 1024
|
| 205 |
+
token_bytes = get_hf_token_bytes(tokenizer, device=device)
|
| 206 |
+
model_name = args.hf_path
|
| 207 |
+
model_slug = args.hf_path.replace("/", "-")
|
| 208 |
+
else:
|
| 209 |
+
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
|
| 210 |
+
sequence_len = meta["model_config"]["sequence_len"]
|
| 211 |
+
token_bytes = get_token_bytes(device=device)
|
| 212 |
+
model_name = f"base_model (step {meta['step']})"
|
| 213 |
+
model_slug = f"base_model_{meta['step']:06d}"
|
| 214 |
+
|
| 215 |
+
print0(f"Evaluating model: {model_name}")
|
| 216 |
+
print0(f"Eval modes: {', '.join(sorted(eval_modes))}")
|
| 217 |
+
|
| 218 |
+
# Results to log
|
| 219 |
+
core_results = None
|
| 220 |
+
bpb_results = {}
|
| 221 |
+
samples = []
|
| 222 |
+
unconditioned_samples = []
|
| 223 |
+
|
| 224 |
+
# --- Sampling ---
|
| 225 |
+
if 'sample' in eval_modes and not is_hf_model:
|
| 226 |
+
print0("\n" + "="*80)
|
| 227 |
+
print0("Model Samples")
|
| 228 |
+
print0("="*80)
|
| 229 |
+
if ddp_rank == 0:
|
| 230 |
+
prompts = [
|
| 231 |
+
"The capital of France is",
|
| 232 |
+
"The chemical symbol of gold is",
|
| 233 |
+
"If yesterday was Friday, then tomorrow will be",
|
| 234 |
+
"The opposite of hot is",
|
| 235 |
+
"The planets of the solar system are:",
|
| 236 |
+
"My favorite color is",
|
| 237 |
+
"If 5*x + 3 = 13, then x is",
|
| 238 |
+
]
|
| 239 |
+
engine = Engine(model, tokenizer)
|
| 240 |
+
print0("\nConditioned samples:")
|
| 241 |
+
for prompt in prompts:
|
| 242 |
+
tokens = tokenizer(prompt, prepend="<|bos|>")
|
| 243 |
+
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
| 244 |
+
sample_str = tokenizer.decode(sample[0])
|
| 245 |
+
print0("-" * 80)
|
| 246 |
+
print0(sample_str)
|
| 247 |
+
samples.append(sample_str)
|
| 248 |
+
|
| 249 |
+
print0("\nUnconditioned samples:")
|
| 250 |
+
tokens = tokenizer("", prepend="<|bos|>")
|
| 251 |
+
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
| 252 |
+
for sample in uncond:
|
| 253 |
+
sample_str = tokenizer.decode(sample)
|
| 254 |
+
print0("-" * 80)
|
| 255 |
+
print0(sample_str)
|
| 256 |
+
unconditioned_samples.append(sample_str)
|
| 257 |
+
elif 'sample' in eval_modes and is_hf_model:
|
| 258 |
+
print0("\nSkipping sampling for HuggingFace models (not supported)")
|
| 259 |
+
|
| 260 |
+
# --- BPB evaluation ---
|
| 261 |
+
if 'bpb' in eval_modes:
|
| 262 |
+
print0("\n" + "="*80)
|
| 263 |
+
print0("BPB Evaluation")
|
| 264 |
+
print0("="*80)
|
| 265 |
+
tokens_per_step = args.device_batch_size * sequence_len * ddp_world_size
|
| 266 |
+
if args.split_tokens % tokens_per_step != 0:
|
| 267 |
+
# Adjust to nearest multiple
|
| 268 |
+
args.split_tokens = (args.split_tokens // tokens_per_step) * tokens_per_step
|
| 269 |
+
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
|
| 270 |
+
steps = args.split_tokens // tokens_per_step
|
| 271 |
+
|
| 272 |
+
for split_name in ["train", "val"]:
|
| 273 |
+
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
| 274 |
+
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
| 275 |
+
bpb_results[split_name] = bpb
|
| 276 |
+
print0(f"{split_name} bpb: {bpb:.6f}")
|
| 277 |
+
|
| 278 |
+
# --- CORE evaluation ---
|
| 279 |
+
if 'core' in eval_modes:
|
| 280 |
+
print0("\n" + "="*80)
|
| 281 |
+
print0("CORE Evaluation")
|
| 282 |
+
print0("="*80)
|
| 283 |
+
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
| 284 |
+
|
| 285 |
+
# Write CSV output
|
| 286 |
+
if ddp_rank == 0:
|
| 287 |
+
base_dir = get_base_dir()
|
| 288 |
+
output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
|
| 289 |
+
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
| 290 |
+
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
|
| 291 |
+
f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
|
| 292 |
+
for label in core_results["results"]:
|
| 293 |
+
acc = core_results["results"][label]
|
| 294 |
+
centered = core_results["centered_results"][label]
|
| 295 |
+
f.write(f"{label:<35}, {acc:<10.6f}, {centered:<10.6f}\n")
|
| 296 |
+
f.write(f"{'CORE':<35}, {'':<10}, {core_results['core_metric']:<10.6f}\n")
|
| 297 |
+
print0(f"\nResults written to: {output_csv_path}")
|
| 298 |
+
print0(f"CORE metric: {core_results['core_metric']:.4f}")
|
| 299 |
+
|
| 300 |
+
# --- Log to report ---
|
| 301 |
+
from nanochat.report import get_report
|
| 302 |
+
report_data = [{"model": model_name}]
|
| 303 |
+
|
| 304 |
+
if core_results:
|
| 305 |
+
report_data[0]["CORE metric"] = core_results["core_metric"]
|
| 306 |
+
report_data.append(core_results["centered_results"])
|
| 307 |
+
|
| 308 |
+
if bpb_results:
|
| 309 |
+
report_data[0]["train bpb"] = bpb_results.get("train")
|
| 310 |
+
report_data[0]["val bpb"] = bpb_results.get("val")
|
| 311 |
+
|
| 312 |
+
if samples:
|
| 313 |
+
report_data.append({f"sample {i}": s for i, s in enumerate(samples)})
|
| 314 |
+
if unconditioned_samples:
|
| 315 |
+
report_data.append({f"unconditioned {i}": s for i, s in enumerate(unconditioned_samples)})
|
| 316 |
+
|
| 317 |
+
get_report().log(section="Base model evaluation", data=report_data)
|
| 318 |
+
|
| 319 |
+
compute_cleanup()
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
main()
|
scripts/base_train.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train model. From root directory of the project, run as:
|
| 3 |
+
|
| 4 |
+
python -m scripts.base_train
|
| 5 |
+
|
| 6 |
+
or distributed as:
|
| 7 |
+
|
| 8 |
+
torchrun --nproc_per_node=8 -m scripts.base_train
|
| 9 |
+
|
| 10 |
+
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
| 11 |
+
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| 16 |
+
import gc
|
| 17 |
+
import json
|
| 18 |
+
import time
|
| 19 |
+
import math
|
| 20 |
+
import argparse
|
| 21 |
+
from dataclasses import asdict
|
| 22 |
+
from contextlib import contextmanager
|
| 23 |
+
|
| 24 |
+
import wandb
|
| 25 |
+
import torch
|
| 26 |
+
import torch.distributed as dist
|
| 27 |
+
|
| 28 |
+
from nanochat.gpt import GPT, GPTConfig, Linear
|
| 29 |
+
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
| 30 |
+
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
| 31 |
+
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
| 32 |
+
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
| 33 |
+
from nanochat.loss_eval import evaluate_bpb
|
| 34 |
+
from nanochat.engine import Engine
|
| 35 |
+
from nanochat.flash_attention import HAS_FA3
|
| 36 |
+
from scripts.base_eval import evaluate_core
|
| 37 |
+
print_banner()
|
| 38 |
+
|
| 39 |
+
# -----------------------------------------------------------------------------
|
| 40 |
+
# CLI arguments
|
| 41 |
+
parser = argparse.ArgumentParser(description="Pretrain base model")
|
| 42 |
+
# Logging
|
| 43 |
+
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
| 44 |
+
# Runtime
|
| 45 |
+
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
| 46 |
+
# FP8 training
|
| 47 |
+
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
| 48 |
+
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
| 49 |
+
# Model architecture
|
| 50 |
+
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
| 51 |
+
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
| 52 |
+
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
| 53 |
+
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
| 54 |
+
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
| 55 |
+
# Training horizon (only one used, in order of precedence)
|
| 56 |
+
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
| 57 |
+
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
| 58 |
+
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
| 59 |
+
# Optimization
|
| 60 |
+
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
| 61 |
+
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
| 62 |
+
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
| 63 |
+
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
|
| 64 |
+
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
| 65 |
+
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
| 66 |
+
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
| 67 |
+
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
| 68 |
+
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
| 69 |
+
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
| 70 |
+
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
| 71 |
+
# Evaluation
|
| 72 |
+
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
| 73 |
+
parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on")
|
| 74 |
+
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
| 75 |
+
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
| 76 |
+
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
| 77 |
+
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
| 78 |
+
# Output
|
| 79 |
+
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
user_config = vars(args).copy() # for logging
|
| 82 |
+
# -----------------------------------------------------------------------------
|
| 83 |
+
# Compute init and wandb logging
|
| 84 |
+
|
| 85 |
+
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 86 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 87 |
+
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
| 88 |
+
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
| 89 |
+
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
| 90 |
+
if device_type == "cuda":
|
| 91 |
+
gpu_device_name = torch.cuda.get_device_name(0)
|
| 92 |
+
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
| 93 |
+
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
| 94 |
+
else:
|
| 95 |
+
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
| 96 |
+
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
| 97 |
+
|
| 98 |
+
# wandb logging init
|
| 99 |
+
use_dummy_wandb = args.run == "dummy" or not master_process
|
| 100 |
+
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
| 101 |
+
|
| 102 |
+
# Flash Attention status
|
| 103 |
+
from nanochat.flash_attention import USE_FA3
|
| 104 |
+
using_fa3 = USE_FA3
|
| 105 |
+
if using_fa3:
|
| 106 |
+
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
| 107 |
+
else:
|
| 108 |
+
print0("!" * 80)
|
| 109 |
+
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
| 110 |
+
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
| 111 |
+
else:
|
| 112 |
+
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
| 113 |
+
print0("WARNING: Training will be less efficient without FA3")
|
| 114 |
+
if args.window_pattern != "L":
|
| 115 |
+
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
| 116 |
+
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
| 117 |
+
print0("!" * 80)
|
| 118 |
+
|
| 119 |
+
# -----------------------------------------------------------------------------
|
| 120 |
+
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
|
| 121 |
+
tokenizer = get_tokenizer()
|
| 122 |
+
token_bytes = get_token_bytes(device=device)
|
| 123 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 124 |
+
print0(f"Vocab size: {vocab_size:,}")
|
| 125 |
+
|
| 126 |
+
# -----------------------------------------------------------------------------
|
| 127 |
+
# Initialize the Model
|
| 128 |
+
|
| 129 |
+
def build_model_meta(depth):
|
| 130 |
+
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
|
| 131 |
+
# Model dim is nudged up to nearest multiple of head_dim for clean division
|
| 132 |
+
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
| 133 |
+
base_dim = depth * args.aspect_ratio
|
| 134 |
+
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
| 135 |
+
num_heads = model_dim // args.head_dim
|
| 136 |
+
config = GPTConfig(
|
| 137 |
+
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
| 138 |
+
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
| 139 |
+
window_pattern=args.window_pattern,
|
| 140 |
+
)
|
| 141 |
+
with torch.device("meta"):
|
| 142 |
+
model_meta = GPT(config)
|
| 143 |
+
return model_meta
|
| 144 |
+
|
| 145 |
+
# Build the model, move to device, init the weights
|
| 146 |
+
model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
|
| 147 |
+
model_config = model.config
|
| 148 |
+
model_config_kwargs = asdict(model_config)
|
| 149 |
+
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
|
| 150 |
+
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
|
| 151 |
+
model.init_weights() # 3) All tensors get initialized
|
| 152 |
+
|
| 153 |
+
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
| 154 |
+
base_dir = get_base_dir()
|
| 155 |
+
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
|
| 156 |
+
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
| 157 |
+
resuming = args.resume_from_step != -1
|
| 158 |
+
if resuming:
|
| 159 |
+
print0(f"Resuming optimization from step {args.resume_from_step}")
|
| 160 |
+
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
| 161 |
+
model.load_state_dict(model_data, strict=True, assign=True)
|
| 162 |
+
del model_data # free up this memory after the copy
|
| 163 |
+
|
| 164 |
+
# -----------------------------------------------------------------------------
|
| 165 |
+
# FP8 training initialization and management (this has to be done before torch.compile)
|
| 166 |
+
|
| 167 |
+
# Convert Linear layers to Float8Linear if --fp8 is set
|
| 168 |
+
if args.fp8:
|
| 169 |
+
if device_type != "cuda":
|
| 170 |
+
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
|
| 171 |
+
else:
|
| 172 |
+
# our custom fp8 is simpler than torchao, written for exact API compatibility
|
| 173 |
+
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
|
| 174 |
+
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
| 175 |
+
import torch.nn as nn
|
| 176 |
+
|
| 177 |
+
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
| 178 |
+
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
| 179 |
+
if not isinstance(mod, nn.Linear):
|
| 180 |
+
return False
|
| 181 |
+
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
| 182 |
+
return False
|
| 183 |
+
if min(mod.in_features, mod.out_features) < 128:
|
| 184 |
+
return False
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
| 188 |
+
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
| 189 |
+
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
| 190 |
+
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
| 191 |
+
num_skipped = num_linear - num_fp8
|
| 192 |
+
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
| 193 |
+
|
| 194 |
+
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
| 195 |
+
@contextmanager
|
| 196 |
+
def disable_fp8(model):
|
| 197 |
+
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
|
| 198 |
+
|
| 199 |
+
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
|
| 200 |
+
we swap out Float8Linear modules entirely and restore them after.
|
| 201 |
+
"""
|
| 202 |
+
import torch.nn as nn
|
| 203 |
+
|
| 204 |
+
# Find all Float8Linear modules and their locations
|
| 205 |
+
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
|
| 206 |
+
for name, module in model.named_modules():
|
| 207 |
+
if 'Float8' in type(module).__name__:
|
| 208 |
+
if '.' in name:
|
| 209 |
+
parent_name, attr_name = name.rsplit('.', 1)
|
| 210 |
+
parent = model.get_submodule(parent_name)
|
| 211 |
+
else:
|
| 212 |
+
parent = model
|
| 213 |
+
attr_name = name
|
| 214 |
+
fp8_locations.append((parent, attr_name, module))
|
| 215 |
+
|
| 216 |
+
if not fp8_locations:
|
| 217 |
+
yield # No FP8 modules, nothing to do
|
| 218 |
+
return
|
| 219 |
+
|
| 220 |
+
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
| 221 |
+
for parent, attr_name, fp8_module in fp8_locations:
|
| 222 |
+
linear = Linear(
|
| 223 |
+
fp8_module.in_features,
|
| 224 |
+
fp8_module.out_features,
|
| 225 |
+
bias=fp8_module.bias is not None,
|
| 226 |
+
device=fp8_module.weight.device,
|
| 227 |
+
dtype=fp8_module.weight.dtype,
|
| 228 |
+
)
|
| 229 |
+
linear.weight = fp8_module.weight # share, don't copy
|
| 230 |
+
if fp8_module.bias is not None:
|
| 231 |
+
linear.bias = fp8_module.bias
|
| 232 |
+
setattr(parent, attr_name, linear)
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
yield
|
| 236 |
+
finally:
|
| 237 |
+
# Restore Float8Linear modules
|
| 238 |
+
for parent, attr_name, fp8_module in fp8_locations:
|
| 239 |
+
setattr(parent, attr_name, fp8_module)
|
| 240 |
+
|
| 241 |
+
# -----------------------------------------------------------------------------
|
| 242 |
+
# Compile the model
|
| 243 |
+
|
| 244 |
+
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
| 245 |
+
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
| 246 |
+
|
| 247 |
+
# -----------------------------------------------------------------------------
|
| 248 |
+
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
|
| 249 |
+
|
| 250 |
+
# Get the parameter counts of our model
|
| 251 |
+
param_counts = model.num_scaling_params()
|
| 252 |
+
print0(f"Parameter counts:")
|
| 253 |
+
for key, value in param_counts.items():
|
| 254 |
+
print0(f"{key:24s}: {value:,}")
|
| 255 |
+
num_params = param_counts['total']
|
| 256 |
+
num_flops_per_token = model.estimate_flops()
|
| 257 |
+
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
| 258 |
+
|
| 259 |
+
# 1) Use scaling laws to determine the optimal training horizon in tokens
|
| 260 |
+
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
|
| 261 |
+
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
| 262 |
+
def get_scaling_params(m):
|
| 263 |
+
# As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
|
| 264 |
+
params_counts = m.num_scaling_params()
|
| 265 |
+
scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head']
|
| 266 |
+
return scaling_params
|
| 267 |
+
num_scaling_params = get_scaling_params(model)
|
| 268 |
+
target_tokens = int(args.target_param_data_ratio * num_scaling_params) # optimal tokens for the model we are about to train
|
| 269 |
+
|
| 270 |
+
# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
|
| 271 |
+
d12_ref = build_model_meta(12) # creates the model on meta device
|
| 272 |
+
D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) # compute-optimal d12 training horizon in tokens (measured empirically)
|
| 273 |
+
B_REF = 2**19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
|
| 274 |
+
|
| 275 |
+
# 2) Now that we have the token horizon, we can calculate the optimal batch size
|
| 276 |
+
# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
| 277 |
+
# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x.
|
| 278 |
+
total_batch_size = args.total_batch_size # user-provided override is possible
|
| 279 |
+
if total_batch_size == -1:
|
| 280 |
+
batch_size_ratio = target_tokens / D_REF
|
| 281 |
+
predicted_batch_size = B_REF * batch_size_ratio ** 0.383
|
| 282 |
+
total_batch_size = 2 ** round(math.log2(predicted_batch_size)) # clamp to nearest power of 2 for efficiency
|
| 283 |
+
print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens")
|
| 284 |
+
|
| 285 |
+
# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates)
|
| 286 |
+
batch_lr_scale = 1.0
|
| 287 |
+
batch_ratio = total_batch_size / B_REF # B/B_ref
|
| 288 |
+
if batch_ratio != 1.0:
|
| 289 |
+
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
| 290 |
+
# AdamW: sqrt scaling is standard: η ∝ √(B/B_ref)
|
| 291 |
+
# Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!)
|
| 292 |
+
batch_lr_scale = batch_ratio ** 0.5 # η ∝ √(B/B_ref)
|
| 293 |
+
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})")
|
| 294 |
+
|
| 295 |
+
# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling
|
| 296 |
+
# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698
|
| 297 |
+
# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant.
|
| 298 |
+
# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need:
|
| 299 |
+
# λ = λ_ref · √(B/B_ref) · (D_ref/D)
|
| 300 |
+
# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too.
|
| 301 |
+
weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens)
|
| 302 |
+
if weight_decay_scaled != args.weight_decay:
|
| 303 |
+
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
| 304 |
+
|
| 305 |
+
# -----------------------------------------------------------------------------
|
| 306 |
+
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
| 307 |
+
optimizer = model.setup_optimizer(
|
| 308 |
+
# AdamW hyperparameters
|
| 309 |
+
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
| 310 |
+
embedding_lr=args.embedding_lr * batch_lr_scale,
|
| 311 |
+
scalar_lr=args.scalar_lr * batch_lr_scale,
|
| 312 |
+
# Muon hyperparameters
|
| 313 |
+
matrix_lr=args.matrix_lr * batch_lr_scale,
|
| 314 |
+
weight_decay=weight_decay_scaled,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if resuming:
|
| 318 |
+
optimizer.load_state_dict(optimizer_data)
|
| 319 |
+
del optimizer_data
|
| 320 |
+
|
| 321 |
+
# -----------------------------------------------------------------------------
|
| 322 |
+
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
|
| 323 |
+
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
| 324 |
+
if scaler is not None:
|
| 325 |
+
print0("GradScaler enabled for fp16 training")
|
| 326 |
+
|
| 327 |
+
# -----------------------------------------------------------------------------
|
| 328 |
+
# Initialize the DataLoaders for train/val
|
| 329 |
+
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
| 330 |
+
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
| 331 |
+
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
| 332 |
+
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
| 333 |
+
|
| 334 |
+
# -----------------------------------------------------------------------------
|
| 335 |
+
# Calculate the number of iterations we will train for and set up the various schedulers
|
| 336 |
+
|
| 337 |
+
# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order)
|
| 338 |
+
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
| 339 |
+
if args.num_iterations > 0:
|
| 340 |
+
# Override num_iterations to a specific value if given
|
| 341 |
+
num_iterations = args.num_iterations
|
| 342 |
+
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
| 343 |
+
elif args.target_flops > 0:
|
| 344 |
+
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
|
| 345 |
+
num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size))
|
| 346 |
+
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
| 347 |
+
elif args.target_param_data_ratio > 0:
|
| 348 |
+
# Calculate the number of iterations from the target param data ratio (the most common use case)
|
| 349 |
+
num_iterations = target_tokens // total_batch_size
|
| 350 |
+
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
| 351 |
+
else:
|
| 352 |
+
raise ValueError("No training horizon specified")
|
| 353 |
+
total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for
|
| 354 |
+
print0(f"Total number of training tokens: {total_tokens:,}")
|
| 355 |
+
print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") # e.g. Chinchilla was ~20
|
| 356 |
+
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
| 357 |
+
|
| 358 |
+
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
| 359 |
+
def get_lr_multiplier(it):
|
| 360 |
+
warmup_iters = args.warmup_steps
|
| 361 |
+
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
| 362 |
+
if it < warmup_iters:
|
| 363 |
+
return (it + 1) / warmup_iters
|
| 364 |
+
elif it <= num_iterations - warmdown_iters:
|
| 365 |
+
return 1.0
|
| 366 |
+
else:
|
| 367 |
+
progress = (num_iterations - it) / warmdown_iters
|
| 368 |
+
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
| 369 |
+
|
| 370 |
+
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
|
| 371 |
+
def get_muon_momentum(it):
|
| 372 |
+
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
| 373 |
+
warmdown_start = num_iterations - warmdown_iters
|
| 374 |
+
if it < 400:
|
| 375 |
+
frac = it / 400
|
| 376 |
+
return (1 - frac) * 0.85 + frac * 0.97
|
| 377 |
+
elif it >= warmdown_start:
|
| 378 |
+
progress = (it - warmdown_start) / warmdown_iters
|
| 379 |
+
return 0.97 * (1 - progress) + 0.90 * progress
|
| 380 |
+
else:
|
| 381 |
+
return 0.97
|
| 382 |
+
|
| 383 |
+
# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
|
| 384 |
+
def get_weight_decay(it):
|
| 385 |
+
return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations))
|
| 386 |
+
|
| 387 |
+
# -----------------------------------------------------------------------------
|
| 388 |
+
# Training loop
|
| 389 |
+
|
| 390 |
+
# Loop state (variables updated by the training loop)
|
| 391 |
+
if not resuming:
|
| 392 |
+
step = 0
|
| 393 |
+
val_bpb = None # will be set if eval_every > 0
|
| 394 |
+
min_val_bpb = float("inf")
|
| 395 |
+
smooth_train_loss = 0 # EMA of training loss
|
| 396 |
+
total_training_time = 0 # total wall-clock time of training
|
| 397 |
+
else:
|
| 398 |
+
step = meta_data["step"]
|
| 399 |
+
loop_state = meta_data["loop_state"]
|
| 400 |
+
val_bpb = meta_data["val_bpb"]
|
| 401 |
+
min_val_bpb = loop_state["min_val_bpb"]
|
| 402 |
+
smooth_train_loss = loop_state["smooth_train_loss"]
|
| 403 |
+
total_training_time = loop_state["total_training_time"]
|
| 404 |
+
|
| 405 |
+
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
|
| 406 |
+
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
| 407 |
+
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
| 408 |
+
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
| 409 |
+
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
| 410 |
+
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
| 411 |
+
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
| 412 |
+
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
| 413 |
+
|
| 414 |
+
# Go!
|
| 415 |
+
while True:
|
| 416 |
+
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
| 417 |
+
flops_so_far = num_flops_per_token * total_batch_size * step
|
| 418 |
+
|
| 419 |
+
# once in a while: evaluate the val bpb (all ranks participate)
|
| 420 |
+
if args.eval_every > 0 and (last_step or step % args.eval_every == 0):
|
| 421 |
+
model.eval()
|
| 422 |
+
val_loader = build_val_loader()
|
| 423 |
+
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
| 424 |
+
with disable_fp8(model):
|
| 425 |
+
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
| 426 |
+
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
| 427 |
+
if val_bpb < min_val_bpb:
|
| 428 |
+
min_val_bpb = val_bpb
|
| 429 |
+
wandb_run.log({
|
| 430 |
+
"step": step,
|
| 431 |
+
"total_training_flops": flops_so_far,
|
| 432 |
+
"total_training_time": total_training_time,
|
| 433 |
+
"val/bpb": val_bpb,
|
| 434 |
+
})
|
| 435 |
+
model.train()
|
| 436 |
+
|
| 437 |
+
# once in a while: estimate the CORE metric (all ranks participate)
|
| 438 |
+
# use the original uncompiled model because the inputs keep changing shape
|
| 439 |
+
# disable FP8 for evaluation to use BF16 for more consistent/accurate results
|
| 440 |
+
results = {}
|
| 441 |
+
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
| 442 |
+
model.eval()
|
| 443 |
+
with disable_fp8(orig_model):
|
| 444 |
+
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
| 445 |
+
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
| 446 |
+
wandb_run.log({
|
| 447 |
+
"step": step,
|
| 448 |
+
"total_training_flops": flops_so_far,
|
| 449 |
+
"core_metric": results["core_metric"],
|
| 450 |
+
"centered_results": results["centered_results"],
|
| 451 |
+
})
|
| 452 |
+
model.train()
|
| 453 |
+
|
| 454 |
+
# once in a while: sample from the model (only on master process)
|
| 455 |
+
# use the original uncompiled model because the inputs keep changing shape
|
| 456 |
+
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
|
| 457 |
+
model.eval()
|
| 458 |
+
prompts = [
|
| 459 |
+
"The capital of France is",
|
| 460 |
+
"The chemical symbol of gold is",
|
| 461 |
+
"If yesterday was Friday, then tomorrow will be",
|
| 462 |
+
"The opposite of hot is",
|
| 463 |
+
"The planets of the solar system are:",
|
| 464 |
+
"My favorite color is",
|
| 465 |
+
"If 5*x + 3 = 13, then x is",
|
| 466 |
+
]
|
| 467 |
+
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
| 468 |
+
for prompt in prompts:
|
| 469 |
+
tokens = tokenizer(prompt, prepend="<|bos|>")
|
| 470 |
+
with disable_fp8(orig_model):
|
| 471 |
+
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
| 472 |
+
print0(tokenizer.decode(sample[0]))
|
| 473 |
+
model.train()
|
| 474 |
+
|
| 475 |
+
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
| 476 |
+
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
|
| 477 |
+
save_checkpoint(
|
| 478 |
+
checkpoint_dir,
|
| 479 |
+
step,
|
| 480 |
+
orig_model.state_dict(), # model parameters
|
| 481 |
+
optimizer.state_dict(), # optimizer state
|
| 482 |
+
{ # metadata saved as json
|
| 483 |
+
"step": step,
|
| 484 |
+
"val_bpb": val_bpb, # loss at last step
|
| 485 |
+
"model_config": model_config_kwargs,
|
| 486 |
+
"user_config": user_config, # inputs to the training script
|
| 487 |
+
"device_batch_size": args.device_batch_size,
|
| 488 |
+
"max_seq_len": args.max_seq_len,
|
| 489 |
+
"total_batch_size": total_batch_size,
|
| 490 |
+
"dataloader_state_dict": dataloader_state_dict,
|
| 491 |
+
"loop_state": { # all loop state (other than step) so that we can resume training
|
| 492 |
+
"min_val_bpb": min_val_bpb,
|
| 493 |
+
"smooth_train_loss": smooth_train_loss,
|
| 494 |
+
"total_training_time": total_training_time,
|
| 495 |
+
},
|
| 496 |
+
},
|
| 497 |
+
rank=ddp_rank,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# termination conditions (TODO: possibly also add loss explosions etc.)
|
| 501 |
+
if last_step:
|
| 502 |
+
break
|
| 503 |
+
|
| 504 |
+
# -------------------------------------------------------------------------
|
| 505 |
+
# single training step
|
| 506 |
+
# evaluate the gradient
|
| 507 |
+
synchronize()
|
| 508 |
+
t0 = time.time()
|
| 509 |
+
for micro_step in range(grad_accum_steps):
|
| 510 |
+
loss = model(x, y)
|
| 511 |
+
train_loss = loss.detach() # for logging
|
| 512 |
+
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
| 513 |
+
if scaler is not None:
|
| 514 |
+
scaler.scale(loss).backward()
|
| 515 |
+
else:
|
| 516 |
+
loss.backward()
|
| 517 |
+
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
| 518 |
+
# step the optimizer
|
| 519 |
+
lrm = get_lr_multiplier(step)
|
| 520 |
+
muon_momentum = get_muon_momentum(step)
|
| 521 |
+
muon_weight_decay = get_weight_decay(step)
|
| 522 |
+
for group in optimizer.param_groups:
|
| 523 |
+
group["lr"] = group["initial_lr"] * lrm
|
| 524 |
+
if group['kind'] == 'muon':
|
| 525 |
+
group["momentum"] = muon_momentum
|
| 526 |
+
group["weight_decay"] = muon_weight_decay
|
| 527 |
+
if scaler is not None:
|
| 528 |
+
scaler.unscale_(optimizer)
|
| 529 |
+
# In distributed training, all ranks must agree on whether to skip the step.
|
| 530 |
+
# Each rank may independently encounter inf/nan gradients, so we all-reduce
|
| 531 |
+
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
|
| 532 |
+
if is_ddp_initialized():
|
| 533 |
+
for v in scaler._found_inf_per_device(optimizer).values():
|
| 534 |
+
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
| 535 |
+
scaler.step(optimizer)
|
| 536 |
+
scaler.update()
|
| 537 |
+
else:
|
| 538 |
+
optimizer.step()
|
| 539 |
+
model.zero_grad(set_to_none=True)
|
| 540 |
+
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
| 541 |
+
synchronize()
|
| 542 |
+
t1 = time.time()
|
| 543 |
+
dt = t1 - t0
|
| 544 |
+
# -------------------------------------------------------------------------
|
| 545 |
+
|
| 546 |
+
# logging (CPU action only)
|
| 547 |
+
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
| 548 |
+
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f # EMA the training loss
|
| 549 |
+
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
| 550 |
+
pct_done = 100 * step / num_iterations
|
| 551 |
+
tok_per_sec = int(total_batch_size / dt)
|
| 552 |
+
flops_per_sec = num_flops_per_token * total_batch_size / dt
|
| 553 |
+
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
| 554 |
+
if step > 10:
|
| 555 |
+
total_training_time += dt # only count the time after the first 10 steps
|
| 556 |
+
# Calculate ETA based on average time per step (excluding first 10 steps)
|
| 557 |
+
steps_done = step - 10
|
| 558 |
+
if steps_done > 0:
|
| 559 |
+
avg_time_per_step = total_training_time / steps_done
|
| 560 |
+
remaining_steps = num_iterations - step
|
| 561 |
+
eta_seconds = remaining_steps * avg_time_per_step
|
| 562 |
+
eta_str = f" | eta: {eta_seconds/60:.1f}m"
|
| 563 |
+
else:
|
| 564 |
+
eta_str = ""
|
| 565 |
+
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
| 566 |
+
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
| 567 |
+
if step % 100 == 0:
|
| 568 |
+
log_data = {
|
| 569 |
+
"step": step,
|
| 570 |
+
"total_training_flops": flops_so_far,
|
| 571 |
+
"total_training_time": total_training_time,
|
| 572 |
+
"train/loss": debiased_smooth_loss,
|
| 573 |
+
"train/lrm": lrm,
|
| 574 |
+
"train/dt": dt,
|
| 575 |
+
"train/tok_per_sec": tok_per_sec,
|
| 576 |
+
"train/mfu": mfu,
|
| 577 |
+
"train/epoch": epoch,
|
| 578 |
+
}
|
| 579 |
+
wandb_run.log(log_data)
|
| 580 |
+
|
| 581 |
+
# state update
|
| 582 |
+
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
| 583 |
+
step += 1
|
| 584 |
+
|
| 585 |
+
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
|
| 586 |
+
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
|
| 587 |
+
# So we manually manage and help it out here
|
| 588 |
+
if first_step_of_run:
|
| 589 |
+
gc.collect() # manually collect a lot of garbage from setup
|
| 590 |
+
gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC
|
| 591 |
+
gc.disable() # nuclear intervention here: disable GC entirely except:
|
| 592 |
+
elif step % 5000 == 0: # every 5000 steps...
|
| 593 |
+
gc.collect() # manually collect, just to be safe for very, very long runs
|
| 594 |
+
|
| 595 |
+
# print a few more stats
|
| 596 |
+
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
| 597 |
+
print0(f"Total training time: {total_training_time/60:.2f}m")
|
| 598 |
+
if val_bpb is not None:
|
| 599 |
+
print0(f"Minimum validation bpb: {min_val_bpb:.6f}")
|
| 600 |
+
|
| 601 |
+
# Log to report
|
| 602 |
+
from nanochat.report import get_report
|
| 603 |
+
get_report().log(section="Base model training", data=[
|
| 604 |
+
user_config, # CLI args
|
| 605 |
+
{ # stats about the training setup
|
| 606 |
+
"Number of parameters": num_params,
|
| 607 |
+
"Number of FLOPs per token": f"{num_flops_per_token:e}",
|
| 608 |
+
"Calculated number of iterations": num_iterations,
|
| 609 |
+
"Number of training tokens": total_tokens,
|
| 610 |
+
"Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params,
|
| 611 |
+
"DDP world size": ddp_world_size,
|
| 612 |
+
"warmup_steps": args.warmup_steps,
|
| 613 |
+
"warmdown_ratio": args.warmdown_ratio,
|
| 614 |
+
"final_lr_frac": args.final_lr_frac,
|
| 615 |
+
},
|
| 616 |
+
{ # stats about training outcomes
|
| 617 |
+
"Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
|
| 618 |
+
"Final validation bpb": val_bpb,
|
| 619 |
+
"CORE metric estimate": results.get("core_metric", None),
|
| 620 |
+
"MFU %": f"{mfu:.2f}%",
|
| 621 |
+
"Total training flops": f"{flops_so_far:e}",
|
| 622 |
+
"Total training time": f"{total_training_time/60:.2f}m",
|
| 623 |
+
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
| 624 |
+
}
|
| 625 |
+
])
|
| 626 |
+
|
| 627 |
+
# cleanup
|
| 628 |
+
wandb_run.finish() # wandb run finish
|
| 629 |
+
compute_cleanup()
|
scripts/chat_cli.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
New and upgraded chat mode because a lot of the code has changed since the last one.
|
| 3 |
+
|
| 4 |
+
Intended to be run single GPU only atm:
|
| 5 |
+
python -m scripts.chat_cli
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import torch
|
| 9 |
+
from nanochat.common import compute_init, autodetect_device_type
|
| 10 |
+
from nanochat.engine import Engine
|
| 11 |
+
from nanochat.checkpoint_manager import load_model
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser(description='Chat with the model')
|
| 14 |
+
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
| 15 |
+
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
| 16 |
+
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
| 17 |
+
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
| 18 |
+
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
| 19 |
+
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
| 20 |
+
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
# Init the model and tokenizer
|
| 24 |
+
|
| 25 |
+
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 26 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 27 |
+
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
| 28 |
+
|
| 29 |
+
# Special tokens for the chat state machine
|
| 30 |
+
bos = tokenizer.get_bos_token_id()
|
| 31 |
+
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
|
| 32 |
+
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
|
| 33 |
+
|
| 34 |
+
# Create Engine for efficient generation
|
| 35 |
+
engine = Engine(model, tokenizer)
|
| 36 |
+
|
| 37 |
+
print("\nNanoChat Interactive Mode")
|
| 38 |
+
print("-" * 50)
|
| 39 |
+
print("Type 'quit' or 'exit' to end the conversation")
|
| 40 |
+
print("Type 'clear' to start a new conversation")
|
| 41 |
+
print("-" * 50)
|
| 42 |
+
|
| 43 |
+
conversation_tokens = [bos]
|
| 44 |
+
|
| 45 |
+
while True:
|
| 46 |
+
|
| 47 |
+
if args.prompt:
|
| 48 |
+
# Get the prompt from the launch command
|
| 49 |
+
user_input = args.prompt
|
| 50 |
+
else:
|
| 51 |
+
# Get the prompt interactively from the console
|
| 52 |
+
try:
|
| 53 |
+
user_input = input("\nUser: ").strip()
|
| 54 |
+
except (EOFError, KeyboardInterrupt):
|
| 55 |
+
print("\nGoodbye!")
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
# Handle special commands
|
| 59 |
+
if user_input.lower() in ['quit', 'exit']:
|
| 60 |
+
print("Goodbye!")
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
if user_input.lower() == 'clear':
|
| 64 |
+
conversation_tokens = [bos]
|
| 65 |
+
print("Conversation cleared.")
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
if not user_input:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
# Add User message to the conversation
|
| 72 |
+
conversation_tokens.append(user_start)
|
| 73 |
+
conversation_tokens.extend(tokenizer.encode(user_input))
|
| 74 |
+
conversation_tokens.append(user_end)
|
| 75 |
+
|
| 76 |
+
# Kick off the assistant
|
| 77 |
+
conversation_tokens.append(assistant_start)
|
| 78 |
+
generate_kwargs = {
|
| 79 |
+
"num_samples": 1,
|
| 80 |
+
"max_tokens": 256,
|
| 81 |
+
"temperature": args.temperature,
|
| 82 |
+
"top_k": args.top_k,
|
| 83 |
+
}
|
| 84 |
+
response_tokens = []
|
| 85 |
+
print("\nAssistant: ", end="", flush=True)
|
| 86 |
+
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
| 87 |
+
token = token_column[0] # pop the batch dimension (num_samples=1)
|
| 88 |
+
response_tokens.append(token)
|
| 89 |
+
token_text = tokenizer.decode([token])
|
| 90 |
+
print(token_text, end="", flush=True)
|
| 91 |
+
print()
|
| 92 |
+
# we have to ensure that the assistant end token is the last token
|
| 93 |
+
# so even if generation ends due to max tokens, we have to append it to the end
|
| 94 |
+
if response_tokens[-1] != assistant_end:
|
| 95 |
+
response_tokens.append(assistant_end)
|
| 96 |
+
conversation_tokens.extend(response_tokens)
|
| 97 |
+
|
| 98 |
+
# In the prompt mode, we only want a single response and exit
|
| 99 |
+
if args.prompt:
|
| 100 |
+
break
|
scripts/chat_eval.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate the Chat model.
|
| 3 |
+
All the generic code lives here, and all the evaluation-specific
|
| 4 |
+
code lives in nanochat directory and is imported from here.
|
| 5 |
+
|
| 6 |
+
Example runs:
|
| 7 |
+
python -m scripts.chat_eval -a ARC-Easy
|
| 8 |
+
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
from functools import partial
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
|
| 16 |
+
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
| 17 |
+
from nanochat.checkpoint_manager import load_model
|
| 18 |
+
from nanochat.engine import Engine
|
| 19 |
+
|
| 20 |
+
from tasks.humaneval import HumanEval
|
| 21 |
+
from tasks.mmlu import MMLU
|
| 22 |
+
from tasks.arc import ARC
|
| 23 |
+
from tasks.gsm8k import GSM8K
|
| 24 |
+
from tasks.spellingbee import SpellingBee
|
| 25 |
+
|
| 26 |
+
# -----------------------------------------------------------------------------
|
| 27 |
+
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
|
| 28 |
+
|
| 29 |
+
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
|
| 30 |
+
|
| 31 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 32 |
+
device = model.get_device()
|
| 33 |
+
|
| 34 |
+
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
| 35 |
+
|
| 36 |
+
# Run the evaluation
|
| 37 |
+
num_passed, total = 0, 0
|
| 38 |
+
for i in range(ddp_rank, num_problems, ddp_world_size):
|
| 39 |
+
conversation = task_object[i]
|
| 40 |
+
|
| 41 |
+
# Tokenize the prompt
|
| 42 |
+
encoded_prompt = tokenizer.render_for_completion(conversation)
|
| 43 |
+
# Get the completions
|
| 44 |
+
results, _ = engine.generate_batch(
|
| 45 |
+
encoded_prompt,
|
| 46 |
+
num_samples=num_samples,
|
| 47 |
+
max_tokens=max_new_tokens,
|
| 48 |
+
temperature=temperature,
|
| 49 |
+
top_k=top_k,
|
| 50 |
+
)
|
| 51 |
+
# Decode the completions as text
|
| 52 |
+
prefix_length = len(encoded_prompt)
|
| 53 |
+
completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
|
| 54 |
+
# Evaluate success criteria
|
| 55 |
+
outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
|
| 56 |
+
passed = any(outcomes)
|
| 57 |
+
|
| 58 |
+
# Keep stats
|
| 59 |
+
total += 1
|
| 60 |
+
num_passed += int(passed)
|
| 61 |
+
|
| 62 |
+
# Logging (overwrite the same line in the console)
|
| 63 |
+
print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
|
| 64 |
+
|
| 65 |
+
# Finish the in-place progress line with a newline before final summary
|
| 66 |
+
print()
|
| 67 |
+
|
| 68 |
+
# Aggregate results across all ranks
|
| 69 |
+
if ddp:
|
| 70 |
+
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
|
| 71 |
+
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
| 72 |
+
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
| 73 |
+
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
| 74 |
+
num_passed = num_passed_tensor.item()
|
| 75 |
+
total = total_tensor.item()
|
| 76 |
+
|
| 77 |
+
print0("=" * 50)
|
| 78 |
+
print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
|
| 79 |
+
|
| 80 |
+
# Return the accuracy
|
| 81 |
+
return num_passed/total
|
| 82 |
+
|
| 83 |
+
# -----------------------------------------------------------------------------
|
| 84 |
+
# Categorical evaluation loop
|
| 85 |
+
# A lot easier because we don't have to sample. Therefore, we can actually go
|
| 86 |
+
# batches at a time and just check the logits for correct answer choices.
|
| 87 |
+
|
| 88 |
+
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
|
| 89 |
+
|
| 90 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 91 |
+
device = model.get_device()
|
| 92 |
+
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
|
| 93 |
+
|
| 94 |
+
# We'll process batches of independent problems at a time because there is no sampling needed
|
| 95 |
+
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
| 96 |
+
ceil_div = lambda x, y: -(-x // y)
|
| 97 |
+
num_batches = ceil_div(num_problems, batch_size)
|
| 98 |
+
|
| 99 |
+
# Run the evaluation
|
| 100 |
+
letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
|
| 101 |
+
num_passed, total = 0, 0
|
| 102 |
+
for i in range(ddp_rank, num_batches, ddp_world_size):
|
| 103 |
+
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
|
| 104 |
+
|
| 105 |
+
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
|
| 106 |
+
conversations = [task_object[ii] for ii in range(i0, i1)]
|
| 107 |
+
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
|
| 108 |
+
max_length = max(len(ids) for ids in prompt_ids)
|
| 109 |
+
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
|
| 110 |
+
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
|
| 111 |
+
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
|
| 112 |
+
|
| 113 |
+
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
logits = model(prompt_ids) # (B, T, V)
|
| 116 |
+
|
| 117 |
+
# Focus on the available answer on just the letters corresponding to choices
|
| 118 |
+
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
|
| 119 |
+
# The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
|
| 120 |
+
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
|
| 121 |
+
for idx, conversation in enumerate(conversations):
|
| 122 |
+
# get the token ids of all the available letters of this problem
|
| 123 |
+
letters = conversation['letters']
|
| 124 |
+
letter_ids = []
|
| 125 |
+
for letter in letters:
|
| 126 |
+
if not letter in letter_to_id_cache:
|
| 127 |
+
encoded_letter = tokenizer.encode(letter)
|
| 128 |
+
assert len(encoded_letter) == 1, "Each letter must be a single token"
|
| 129 |
+
letter_to_id_cache[letter] = encoded_letter[0]
|
| 130 |
+
letter_ids.append(letter_to_id_cache[letter])
|
| 131 |
+
# focus logits just down to the answer position and the available letters of the answer
|
| 132 |
+
answer_pos = answer_time_positions[idx]
|
| 133 |
+
focus_logits = logits[idx, answer_pos, letter_ids]
|
| 134 |
+
# get the argmax letter (the predicted answer)
|
| 135 |
+
argmax_letter_id = focus_logits.argmax(dim=-1).item()
|
| 136 |
+
predicted_letter = letters[argmax_letter_id]
|
| 137 |
+
# evaluate the outcome
|
| 138 |
+
outcome = task_object.evaluate(conversation, predicted_letter)
|
| 139 |
+
num_passed += int(outcome)
|
| 140 |
+
total += 1
|
| 141 |
+
|
| 142 |
+
# Aggregate results across all ranks
|
| 143 |
+
if ddp:
|
| 144 |
+
num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
|
| 145 |
+
total_tensor = torch.tensor([total], dtype=torch.long, device=device)
|
| 146 |
+
dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
|
| 147 |
+
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
|
| 148 |
+
num_passed = num_passed_tensor.item()
|
| 149 |
+
total = total_tensor.item()
|
| 150 |
+
|
| 151 |
+
average = num_passed/total
|
| 152 |
+
print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
|
| 153 |
+
return average
|
| 154 |
+
|
| 155 |
+
# -----------------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
def run_chat_eval(task_name, model, tokenizer, engine,
|
| 158 |
+
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
|
| 159 |
+
max_problems=None):
|
| 160 |
+
# Create the evaluation object
|
| 161 |
+
task_module = {
|
| 162 |
+
'HumanEval': HumanEval,
|
| 163 |
+
'MMLU': partial(MMLU, subset="all", split="test"),
|
| 164 |
+
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
|
| 165 |
+
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
|
| 166 |
+
'GSM8K': partial(GSM8K, subset="main", split="test"),
|
| 167 |
+
'SpellingBee': partial(SpellingBee, size=256, split="test"),
|
| 168 |
+
}[task_name]
|
| 169 |
+
task_object = task_module()
|
| 170 |
+
# Run the evaluation
|
| 171 |
+
if task_object.eval_type == 'generative':
|
| 172 |
+
acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
|
| 173 |
+
elif task_object.eval_type == 'categorical':
|
| 174 |
+
acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
|
| 177 |
+
return acc
|
| 178 |
+
|
| 179 |
+
# -----------------------------------------------------------------------------
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
|
| 182 |
+
# Parse command-line arguments
|
| 183 |
+
parser = argparse.ArgumentParser()
|
| 184 |
+
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
| 185 |
+
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
| 186 |
+
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
| 187 |
+
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
| 188 |
+
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
| 189 |
+
parser.add_argument('-k', '--top-k', type=int, default=50)
|
| 190 |
+
parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
|
| 191 |
+
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
| 192 |
+
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
| 193 |
+
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
| 194 |
+
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
| 195 |
+
args = parser.parse_args()
|
| 196 |
+
|
| 197 |
+
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 198 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 199 |
+
|
| 200 |
+
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
| 201 |
+
engine = Engine(model, tokenizer)
|
| 202 |
+
|
| 203 |
+
# Get the tasks to evaluate on
|
| 204 |
+
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
| 205 |
+
baseline_accuracies = {
|
| 206 |
+
'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
|
| 207 |
+
'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
|
| 208 |
+
'MMLU': 0.25, # multiple choice 1 of 4 => 25%
|
| 209 |
+
'GSM8K': 0.0, # open-ended => 0%
|
| 210 |
+
'HumanEval': 0.0, # open-ended => 0%
|
| 211 |
+
'SpellingBee': 0.0, # open-ended => 0%
|
| 212 |
+
}
|
| 213 |
+
task_names = all_tasks if args.task_name is None else args.task_name.split('|')
|
| 214 |
+
|
| 215 |
+
# Run all the task evaluations sequentially
|
| 216 |
+
results = {}
|
| 217 |
+
for task_name in task_names:
|
| 218 |
+
acc = run_chat_eval(
|
| 219 |
+
task_name,
|
| 220 |
+
model, tokenizer, engine,
|
| 221 |
+
batch_size=args.batch_size,
|
| 222 |
+
num_samples=args.num_samples,
|
| 223 |
+
max_new_tokens=args.max_new_tokens,
|
| 224 |
+
temperature=args.temperature,
|
| 225 |
+
top_k=args.top_k,
|
| 226 |
+
max_problems=args.max_problems,
|
| 227 |
+
)
|
| 228 |
+
results[task_name] = acc
|
| 229 |
+
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
| 230 |
+
|
| 231 |
+
# Log to report
|
| 232 |
+
from nanochat.report import get_report
|
| 233 |
+
all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
|
| 234 |
+
# calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
|
| 235 |
+
# this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
|
| 236 |
+
chatcore_metric_dict = {}
|
| 237 |
+
if all_tasks_were_evaluated:
|
| 238 |
+
centered_mean = 0
|
| 239 |
+
for task_name, acc in results.items():
|
| 240 |
+
baseline_acc = baseline_accuracies.get(task_name, 0.0)
|
| 241 |
+
centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc)
|
| 242 |
+
centered_mean += centered_acc
|
| 243 |
+
chatcore_metric = centered_mean / len(results)
|
| 244 |
+
chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
|
| 245 |
+
get_report().log(section="Chat evaluation " + args.source, data=[
|
| 246 |
+
vars(args), # CLI args
|
| 247 |
+
results,
|
| 248 |
+
chatcore_metric_dict,
|
| 249 |
+
])
|
| 250 |
+
|
| 251 |
+
compute_cleanup()
|
scripts/chat_rl.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reinforcement learning on GSM8K via "GRPO".
|
| 3 |
+
|
| 4 |
+
I put GRPO in quotes because we actually end up with something a lot
|
| 5 |
+
simpler and more similar to just REINFORCE:
|
| 6 |
+
|
| 7 |
+
1) Delete trust region, so there is no KL regularization to a reference model
|
| 8 |
+
2) We are on policy, so there's no need for PPO ratio+clip.
|
| 9 |
+
3) We use DAPO style normalization that is token-level, not sequence-level.
|
| 10 |
+
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
|
| 11 |
+
|
| 12 |
+
1 GPU:
|
| 13 |
+
python -m scripts.chat_rl
|
| 14 |
+
|
| 15 |
+
8 GPUs:
|
| 16 |
+
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import os
|
| 21 |
+
import itertools
|
| 22 |
+
import wandb
|
| 23 |
+
import torch
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
| 26 |
+
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
| 27 |
+
from nanochat.engine import Engine
|
| 28 |
+
from tasks.gsm8k import GSM8K
|
| 29 |
+
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
# CLI arguments
|
| 32 |
+
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
| 33 |
+
# Logging
|
| 34 |
+
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
| 35 |
+
# Runtime
|
| 36 |
+
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
| 37 |
+
# Model loading
|
| 38 |
+
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
| 39 |
+
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
| 40 |
+
# Training horizon
|
| 41 |
+
parser.add_argument("--num-epochs", type=int, default=1, help="number of epochs over GSM8K")
|
| 42 |
+
# Batch sizes / sampling
|
| 43 |
+
parser.add_argument("--device-batch-size", type=int, default=8, help="max batch size per forward pass")
|
| 44 |
+
parser.add_argument("--examples-per-step", type=int, default=16, help="total examples per optimization step across all ranks")
|
| 45 |
+
parser.add_argument("--num-samples", type=int, default=16, help="number of samples per example/question")
|
| 46 |
+
# Generation
|
| 47 |
+
parser.add_argument("--max-new-tokens", type=int, default=256, help="max tokens to generate per sample")
|
| 48 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
| 49 |
+
parser.add_argument("--top-k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
| 50 |
+
# Optimization
|
| 51 |
+
parser.add_argument("--embedding-lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
| 52 |
+
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
| 53 |
+
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
| 54 |
+
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
| 55 |
+
parser.add_argument("--init-lr-frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
| 56 |
+
# Evaluation / checkpointing
|
| 57 |
+
parser.add_argument("--eval-every", type=int, default=60, help="evaluate pass@k every N steps")
|
| 58 |
+
parser.add_argument("--eval-examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
| 59 |
+
parser.add_argument("--save-every", type=int, default=60, help="save checkpoint every N steps")
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
user_config = vars(args).copy()
|
| 62 |
+
# -----------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
# Init compute/precision
|
| 65 |
+
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 66 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 67 |
+
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
| 68 |
+
|
| 69 |
+
# wandb logging init
|
| 70 |
+
use_dummy_wandb = args.run == "dummy" or not master_process
|
| 71 |
+
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
| 72 |
+
|
| 73 |
+
# Init model and tokenizer
|
| 74 |
+
model, tokenizer, meta = load_model("sft", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
| 75 |
+
engine = Engine(model, tokenizer) # for sampling rollouts
|
| 76 |
+
|
| 77 |
+
# -----------------------------------------------------------------------------
|
| 78 |
+
# Rollout / sampling generator loop that yields batches of examples for training
|
| 79 |
+
|
| 80 |
+
train_task = GSM8K(subset="main", split="train")
|
| 81 |
+
val_task = GSM8K(subset="main", split="test")
|
| 82 |
+
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
| 83 |
+
print0(f"Calculated number of steps: {num_steps}")
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def get_batch():
|
| 87 |
+
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
|
| 88 |
+
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
|
| 89 |
+
for example_idx in itertools.cycle(rank_indices):
|
| 90 |
+
|
| 91 |
+
# First get the full conversation of both user and assistant messages
|
| 92 |
+
conversation = train_task[example_idx]
|
| 93 |
+
|
| 94 |
+
# Tokenize the conversation, deleting the last Assistant message and priming the Assistant for a completion instead
|
| 95 |
+
# (i.e. keep the <|assistant_start|>, but delete everything after it)
|
| 96 |
+
tokens = tokenizer.render_for_completion(conversation)
|
| 97 |
+
prefix_length = len(tokens)
|
| 98 |
+
|
| 99 |
+
# Generate num_samples samples using batched generation, use loop to avoid OOMs
|
| 100 |
+
model.eval() # ensure the model is in eval mode
|
| 101 |
+
generated_token_sequences = []
|
| 102 |
+
masks = []
|
| 103 |
+
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
| 104 |
+
for sampling_step in range(num_sampling_steps):
|
| 105 |
+
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
| 106 |
+
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
| 107 |
+
tokens,
|
| 108 |
+
num_samples=args.device_batch_size,
|
| 109 |
+
max_tokens=args.max_new_tokens,
|
| 110 |
+
temperature=args.temperature,
|
| 111 |
+
top_k=args.top_k,
|
| 112 |
+
seed=seed, # must make sure to change the seed for each sampling step
|
| 113 |
+
)
|
| 114 |
+
generated_token_sequences.extend(generated_token_sequences_batch)
|
| 115 |
+
masks.extend(masks_batch)
|
| 116 |
+
|
| 117 |
+
# Calculate the rewards for each sample
|
| 118 |
+
rewards = []
|
| 119 |
+
for sample_tokens in generated_token_sequences:
|
| 120 |
+
# Get just the generated tokens (after the prompt)
|
| 121 |
+
generated_tokens = sample_tokens[prefix_length:]
|
| 122 |
+
# Decode the generated response
|
| 123 |
+
generated_text = tokenizer.decode(generated_tokens)
|
| 124 |
+
# Calculate the reward
|
| 125 |
+
reward = train_task.reward(conversation, generated_text)
|
| 126 |
+
rewards.append(reward)
|
| 127 |
+
|
| 128 |
+
# Pad the sequences so that their lengths (in time) match
|
| 129 |
+
max_length = max(len(seq) for seq in generated_token_sequences)
|
| 130 |
+
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
|
| 131 |
+
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
| 132 |
+
# Stack up the sequences and masks into PyTorch tensors
|
| 133 |
+
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
|
| 134 |
+
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
| 135 |
+
# Generate autoregressive inputs and targets to the Transformer
|
| 136 |
+
inputs = ids[:, :-1]
|
| 137 |
+
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
|
| 138 |
+
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
| 139 |
+
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
|
| 140 |
+
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
|
| 141 |
+
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
|
| 142 |
+
# Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
|
| 143 |
+
mu = rewards.mean()
|
| 144 |
+
advantages = rewards - mu
|
| 145 |
+
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
|
| 146 |
+
yield generated_token_sequences, inputs, targets, rewards, advantages
|
| 147 |
+
|
| 148 |
+
# -----------------------------------------------------------------------------
|
| 149 |
+
# Simple evaluation loop for GSM8K pass@k
|
| 150 |
+
def run_gsm8k_eval(task, tokenizer, engine,
|
| 151 |
+
max_examples=None,
|
| 152 |
+
num_samples=1,
|
| 153 |
+
max_completion_tokens=256,
|
| 154 |
+
temperature=0.0,
|
| 155 |
+
top_k=50
|
| 156 |
+
):
|
| 157 |
+
"""
|
| 158 |
+
Evaluates GSM8K task and returns a list of records of evaluation outcomes.
|
| 159 |
+
In a distributed setting, all ranks cooperate but this function will NOT
|
| 160 |
+
do the reduction across ranks. This is the responsibility of the caller.
|
| 161 |
+
Because the evaluation can take a while, this function will yield records one by one.
|
| 162 |
+
"""
|
| 163 |
+
max_examples = min(max_examples, len(task)) if max_examples is not None else len(task)
|
| 164 |
+
for idx in range(ddp_rank, max_examples, ddp_world_size):
|
| 165 |
+
conversation = task[idx]
|
| 166 |
+
tokens = tokenizer.render_for_completion(conversation)
|
| 167 |
+
prefix_length = len(tokens)
|
| 168 |
+
# Generate k samples using batched generation inside the Engine
|
| 169 |
+
assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
|
| 170 |
+
generated_token_sequences, masks = engine.generate_batch(
|
| 171 |
+
tokens,
|
| 172 |
+
num_samples=num_samples,
|
| 173 |
+
max_tokens=max_completion_tokens,
|
| 174 |
+
temperature=temperature,
|
| 175 |
+
top_k=top_k
|
| 176 |
+
)
|
| 177 |
+
# Check each sample for correctness
|
| 178 |
+
outcomes = []
|
| 179 |
+
for sample_tokens in generated_token_sequences:
|
| 180 |
+
generated_tokens = sample_tokens[prefix_length:]
|
| 181 |
+
generated_text = tokenizer.decode(generated_tokens)
|
| 182 |
+
is_correct = task.evaluate(conversation, generated_text)
|
| 183 |
+
outcomes.append({
|
| 184 |
+
"is_correct": is_correct
|
| 185 |
+
})
|
| 186 |
+
# A bit bloated because I wanted to do more complex logging at one point.
|
| 187 |
+
record = {
|
| 188 |
+
"idx": idx,
|
| 189 |
+
"outcomes": outcomes,
|
| 190 |
+
}
|
| 191 |
+
yield record
|
| 192 |
+
|
| 193 |
+
# -----------------------------------------------------------------------------
|
| 194 |
+
# Training loop
|
| 195 |
+
|
| 196 |
+
# Init the optimizer
|
| 197 |
+
optimizer = model.setup_optimizer(
|
| 198 |
+
unembedding_lr=args.unembedding_lr,
|
| 199 |
+
embedding_lr=args.embedding_lr,
|
| 200 |
+
matrix_lr=args.matrix_lr,
|
| 201 |
+
weight_decay=args.weight_decay,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Set the initial learning rate as a fraction of the base learning rate
|
| 205 |
+
for group in optimizer.param_groups:
|
| 206 |
+
group["lr"] = group["lr"] * args.init_lr_frac
|
| 207 |
+
group["initial_lr"] = group["lr"]
|
| 208 |
+
|
| 209 |
+
# Learning rate scheduler: simple rampdown to zero over num_steps
|
| 210 |
+
def get_lr_multiplier(it):
|
| 211 |
+
lrm = 1.0 - it / num_steps
|
| 212 |
+
return lrm
|
| 213 |
+
|
| 214 |
+
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
| 215 |
+
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
|
| 216 |
+
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
| 217 |
+
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
|
| 218 |
+
print0(f"Calculated examples per rank: {examples_per_rank}")
|
| 219 |
+
|
| 220 |
+
# Kick off the training loop
|
| 221 |
+
batch_iterator = get_batch()
|
| 222 |
+
for step in range(num_steps):
|
| 223 |
+
|
| 224 |
+
# Evaluate the model once in a while and log to wandb
|
| 225 |
+
if step % args.eval_every == 0:
|
| 226 |
+
model.eval()
|
| 227 |
+
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
| 228 |
+
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
| 229 |
+
records = list(records_iter) # collect all records
|
| 230 |
+
for k in range(1, args.device_batch_size + 1):
|
| 231 |
+
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
| 232 |
+
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
| 233 |
+
if ddp:
|
| 234 |
+
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
| 235 |
+
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
| 236 |
+
passk = passk / num_records.item() # normalize by the total number of records
|
| 237 |
+
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
|
| 238 |
+
print0(f"Step {step} | {', '.join(print_passk)}")
|
| 239 |
+
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
|
| 240 |
+
wandb_run.log({
|
| 241 |
+
"step": step,
|
| 242 |
+
**log_passk,
|
| 243 |
+
})
|
| 244 |
+
|
| 245 |
+
# Forward/Backward on rollouts over multiple examples in the dataset
|
| 246 |
+
rewards_list = []
|
| 247 |
+
sequence_lengths = []
|
| 248 |
+
for example_step in range(examples_per_rank):
|
| 249 |
+
# Get one batch corresponding to one example in the training dataset
|
| 250 |
+
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
| 251 |
+
# Evaluate the loss and gradients
|
| 252 |
+
model.train() # ensure the model is in train mode
|
| 253 |
+
# We need one more loop because we can never exceed the device_batch_size
|
| 254 |
+
assert inputs_all.size(0) % args.device_batch_size == 0
|
| 255 |
+
num_passes = inputs_all.size(0) // args.device_batch_size
|
| 256 |
+
for pass_idx in range(num_passes):
|
| 257 |
+
# Pluck out the batch for this pass
|
| 258 |
+
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
| 259 |
+
inputs = inputs_all[b0:b1]
|
| 260 |
+
targets = targets_all[b0:b1]
|
| 261 |
+
rewards = rewards_all[b0:b1]
|
| 262 |
+
advantages = advantages_all[b0:b1]
|
| 263 |
+
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
| 264 |
+
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
| 265 |
+
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
| 266 |
+
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
| 267 |
+
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
| 268 |
+
num_valid = (targets >= 0).sum().clamp(min=1)
|
| 269 |
+
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
|
| 270 |
+
# Note, there is no need to add PPO ratio+clip because we are on policy
|
| 271 |
+
# Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
|
| 272 |
+
loss = -pg_obj
|
| 273 |
+
loss.backward()
|
| 274 |
+
print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
|
| 275 |
+
# For logging
|
| 276 |
+
rewards_list.append(rewards_all.mean().item())
|
| 277 |
+
sequence_lengths.extend(len(seq) for seq in sequences_all)
|
| 278 |
+
|
| 279 |
+
# A bunch of logging for how the rollouts went this step
|
| 280 |
+
mean_reward = sum(rewards_list) / len(rewards_list)
|
| 281 |
+
mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
|
| 282 |
+
if ddp: # aggregate across ranks
|
| 283 |
+
mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
|
| 284 |
+
mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
|
| 285 |
+
dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
|
| 286 |
+
dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
|
| 287 |
+
mean_reward = mean_reward_tensor.item()
|
| 288 |
+
mean_sequence_length = mean_sequence_length_tensor.item()
|
| 289 |
+
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
|
| 290 |
+
wandb_run.log({
|
| 291 |
+
"step": step,
|
| 292 |
+
"reward": mean_reward,
|
| 293 |
+
"sequence_length": mean_sequence_length,
|
| 294 |
+
})
|
| 295 |
+
|
| 296 |
+
# Update the model parameters
|
| 297 |
+
lrm = get_lr_multiplier(step)
|
| 298 |
+
for group in optimizer.param_groups:
|
| 299 |
+
group["lr"] = group["initial_lr"] * lrm
|
| 300 |
+
optimizer.step()
|
| 301 |
+
model.zero_grad(set_to_none=True)
|
| 302 |
+
wandb_run.log({
|
| 303 |
+
"step": step,
|
| 304 |
+
"lrm": lrm,
|
| 305 |
+
})
|
| 306 |
+
|
| 307 |
+
# Master process saves the model once in a while. Skip first step. Save last step.
|
| 308 |
+
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
| 309 |
+
base_dir = get_base_dir()
|
| 310 |
+
depth = model.config.n_layer
|
| 311 |
+
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
| 312 |
+
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
| 313 |
+
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
| 314 |
+
save_checkpoint(
|
| 315 |
+
checkpoint_dir,
|
| 316 |
+
step,
|
| 317 |
+
model.state_dict(),
|
| 318 |
+
None, # note: we don't bother to save the optimizer state
|
| 319 |
+
{
|
| 320 |
+
"model_config": model_config_kwargs,
|
| 321 |
+
}
|
| 322 |
+
)
|
| 323 |
+
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
| 324 |
+
|
| 325 |
+
# Log to report
|
| 326 |
+
from nanochat.report import get_report
|
| 327 |
+
get_report().log(section="Chat RL", data=[
|
| 328 |
+
user_config, # CLI args
|
| 329 |
+
])
|
| 330 |
+
|
| 331 |
+
wandb_run.finish() # wandb run finish
|
| 332 |
+
compute_cleanup()
|
scripts/chat_sft.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Supervised fine-tuning (SFT) the model.
|
| 3 |
+
Run as:
|
| 4 |
+
|
| 5 |
+
python -m scripts.chat_sft
|
| 6 |
+
|
| 7 |
+
Or torchrun for training:
|
| 8 |
+
|
| 9 |
+
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import gc
|
| 13 |
+
import argparse
|
| 14 |
+
import os
|
| 15 |
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| 16 |
+
import time
|
| 17 |
+
import wandb
|
| 18 |
+
import torch
|
| 19 |
+
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
| 20 |
+
from nanochat.tokenizer import get_token_bytes
|
| 21 |
+
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
| 22 |
+
from nanochat.loss_eval import evaluate_bpb
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
from nanochat.flash_attention import HAS_FA3
|
| 25 |
+
from nanochat.engine import Engine
|
| 26 |
+
from scripts.chat_eval import run_chat_eval
|
| 27 |
+
|
| 28 |
+
from tasks.common import TaskMixture
|
| 29 |
+
from tasks.gsm8k import GSM8K
|
| 30 |
+
from tasks.mmlu import MMLU
|
| 31 |
+
from tasks.smoltalk import SmolTalk
|
| 32 |
+
from tasks.customjson import CustomJSON
|
| 33 |
+
from tasks.spellingbee import SimpleSpelling, SpellingBee
|
| 34 |
+
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
# CLI arguments
|
| 37 |
+
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model")
|
| 38 |
+
# Logging
|
| 39 |
+
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
| 40 |
+
# Runtime
|
| 41 |
+
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
| 42 |
+
# Model loading
|
| 43 |
+
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
| 44 |
+
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
| 45 |
+
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
| 46 |
+
# Training horizon
|
| 47 |
+
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
| 48 |
+
# Batch sizes (default: inherit from pretrained checkpoint)
|
| 49 |
+
parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
|
| 50 |
+
parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
|
| 51 |
+
parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
|
| 52 |
+
# Optimization (default: inherit from pretrained checkpoint)
|
| 53 |
+
parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
|
| 54 |
+
parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
|
| 55 |
+
parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
|
| 56 |
+
parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
|
| 57 |
+
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
| 58 |
+
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
| 59 |
+
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
| 60 |
+
# Evaluation
|
| 61 |
+
parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
|
| 62 |
+
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
| 63 |
+
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
| 64 |
+
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
| 65 |
+
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
| 66 |
+
# Data mixture
|
| 67 |
+
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
| 68 |
+
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
user_config = vars(args).copy()
|
| 71 |
+
# -----------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
# Compute init
|
| 74 |
+
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 75 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 76 |
+
master_process = ddp_rank == 0
|
| 77 |
+
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
| 78 |
+
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
| 79 |
+
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
| 80 |
+
if device_type == "cuda":
|
| 81 |
+
gpu_device_name = torch.cuda.get_device_name(0)
|
| 82 |
+
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
| 83 |
+
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
| 84 |
+
else:
|
| 85 |
+
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
| 86 |
+
|
| 87 |
+
# wandb logging init
|
| 88 |
+
use_dummy_wandb = args.run == "dummy" or not master_process
|
| 89 |
+
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
| 90 |
+
|
| 91 |
+
# Flash Attention status
|
| 92 |
+
if not HAS_FA3:
|
| 93 |
+
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
| 94 |
+
|
| 95 |
+
# Load the model and tokenizer
|
| 96 |
+
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
| 97 |
+
|
| 98 |
+
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
|
| 99 |
+
pretrain_user_config = meta.get("user_config", {})
|
| 100 |
+
for name, fallback, source in [
|
| 101 |
+
("max_seq_len", 2048, meta),
|
| 102 |
+
("device_batch_size", 32, meta),
|
| 103 |
+
("total_batch_size", 524288, meta),
|
| 104 |
+
("embedding_lr", 0.3, pretrain_user_config),
|
| 105 |
+
("unembedding_lr", 0.004, pretrain_user_config),
|
| 106 |
+
("matrix_lr", 0.02, pretrain_user_config),
|
| 107 |
+
]:
|
| 108 |
+
arg_val = getattr(args, name)
|
| 109 |
+
pretrain_val = source.get(name)
|
| 110 |
+
if arg_val is None:
|
| 111 |
+
resolved = pretrain_val if pretrain_val is not None else fallback
|
| 112 |
+
setattr(args, name, resolved)
|
| 113 |
+
print0(f"Inherited {name}={resolved} from pretrained checkpoint")
|
| 114 |
+
elif pretrain_val is not None and arg_val != pretrain_val:
|
| 115 |
+
print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
|
| 116 |
+
else:
|
| 117 |
+
print0(f"Using {name}={arg_val}")
|
| 118 |
+
|
| 119 |
+
orig_model = model
|
| 120 |
+
model = torch.compile(model, dynamic=False)
|
| 121 |
+
depth = model.config.n_layer
|
| 122 |
+
num_flops_per_token = model.estimate_flops()
|
| 123 |
+
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
| 124 |
+
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
| 125 |
+
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
| 126 |
+
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
| 127 |
+
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
| 128 |
+
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
| 129 |
+
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
| 130 |
+
token_bytes = get_token_bytes(device=device)
|
| 131 |
+
|
| 132 |
+
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
| 133 |
+
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
|
| 134 |
+
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
| 135 |
+
|
| 136 |
+
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
| 137 |
+
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
|
| 138 |
+
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
|
| 139 |
+
# restore our fresh SFT LRs after loading.
|
| 140 |
+
base_dir = get_base_dir()
|
| 141 |
+
if args.load_optimizer:
|
| 142 |
+
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
| 143 |
+
if optimizer_data is not None:
|
| 144 |
+
base_lrs = [group["lr"] for group in optimizer.param_groups]
|
| 145 |
+
optimizer.load_state_dict(optimizer_data)
|
| 146 |
+
del optimizer_data
|
| 147 |
+
for group, base_lr in zip(optimizer.param_groups, base_lrs):
|
| 148 |
+
group["lr"] = base_lr
|
| 149 |
+
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
|
| 150 |
+
else:
|
| 151 |
+
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
| 152 |
+
|
| 153 |
+
# GradScaler for fp16 training (bf16/fp32 don't need it)
|
| 154 |
+
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
| 155 |
+
if scaler is not None:
|
| 156 |
+
print0("GradScaler enabled for fp16 training")
|
| 157 |
+
|
| 158 |
+
# Override the initial learning rate as a fraction of the base learning rate
|
| 159 |
+
for group in optimizer.param_groups:
|
| 160 |
+
group["lr"] = group["lr"] * args.init_lr_frac
|
| 161 |
+
group["initial_lr"] = group["lr"]
|
| 162 |
+
|
| 163 |
+
# SFT data mixture and DataLoader
|
| 164 |
+
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
| 165 |
+
train_tasks = [
|
| 166 |
+
SmolTalk(split="train"), # 460K rows of general conversations
|
| 167 |
+
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
| 168 |
+
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
|
| 169 |
+
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
|
| 170 |
+
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
|
| 171 |
+
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
| 172 |
+
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
| 173 |
+
]
|
| 174 |
+
train_dataset = TaskMixture(train_tasks)
|
| 175 |
+
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
|
| 176 |
+
val_dataset = TaskMixture([
|
| 177 |
+
SmolTalk(split="test"), # 24K rows in test set
|
| 178 |
+
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
| 179 |
+
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
| 180 |
+
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
| 181 |
+
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
| 182 |
+
# A big problem is that we don't know the final num_iterations in advance. So we create
|
| 183 |
+
# these two global variables and update them from within the data generator.
|
| 184 |
+
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
| 185 |
+
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
| 186 |
+
current_epoch = 1 # track epoch for logging
|
| 187 |
+
def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
| 188 |
+
"""
|
| 189 |
+
BOS-aligned dataloader for SFT with bestfit-pad packing.
|
| 190 |
+
|
| 191 |
+
Each row in the batch starts with BOS (beginning of a conversation).
|
| 192 |
+
Conversations are packed using best-fit algorithm. When no conversation fits,
|
| 193 |
+
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
| 194 |
+
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
| 195 |
+
"""
|
| 196 |
+
global last_step, approx_progress, current_epoch
|
| 197 |
+
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
| 198 |
+
dataset = train_dataset if split == "train" else val_dataset
|
| 199 |
+
dataset_size = len(dataset)
|
| 200 |
+
assert dataset_size > 0
|
| 201 |
+
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
| 202 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 203 |
+
|
| 204 |
+
# Conversation buffer: list of (token_ids, loss_mask) tuples
|
| 205 |
+
conv_buffer = []
|
| 206 |
+
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
| 207 |
+
consumed = ddp_rank # Track actual consumption separately from buffering
|
| 208 |
+
epoch = 1
|
| 209 |
+
it = 0 # iteration counter
|
| 210 |
+
|
| 211 |
+
def refill_buffer():
|
| 212 |
+
nonlocal cursor, epoch
|
| 213 |
+
while len(conv_buffer) < buffer_size:
|
| 214 |
+
conversation = dataset[cursor]
|
| 215 |
+
ids, mask = tokenizer.render_conversation(conversation)
|
| 216 |
+
conv_buffer.append((ids, mask))
|
| 217 |
+
cursor += ddp_world_size
|
| 218 |
+
if cursor >= dataset_size:
|
| 219 |
+
cursor = cursor % dataset_size
|
| 220 |
+
epoch += 1
|
| 221 |
+
# Note: last_step is now triggered based on consumption, not fetching
|
| 222 |
+
|
| 223 |
+
while True:
|
| 224 |
+
rows = []
|
| 225 |
+
mask_rows = []
|
| 226 |
+
row_lengths = [] # Track actual content length (excluding padding) for each row
|
| 227 |
+
for _ in range(args.device_batch_size):
|
| 228 |
+
row = []
|
| 229 |
+
mask_row = []
|
| 230 |
+
padded = False
|
| 231 |
+
while len(row) < row_capacity:
|
| 232 |
+
# Ensure buffer has conversations
|
| 233 |
+
while len(conv_buffer) < buffer_size:
|
| 234 |
+
refill_buffer()
|
| 235 |
+
|
| 236 |
+
remaining = row_capacity - len(row)
|
| 237 |
+
|
| 238 |
+
# Find largest conversation that fits entirely
|
| 239 |
+
best_idx = -1
|
| 240 |
+
best_len = 0
|
| 241 |
+
for i, (conv, _) in enumerate(conv_buffer):
|
| 242 |
+
conv_len = len(conv)
|
| 243 |
+
if conv_len <= remaining and conv_len > best_len:
|
| 244 |
+
best_idx = i
|
| 245 |
+
best_len = conv_len
|
| 246 |
+
|
| 247 |
+
if best_idx >= 0:
|
| 248 |
+
# Found a conversation that fits - use it entirely
|
| 249 |
+
conv, conv_mask = conv_buffer.pop(best_idx)
|
| 250 |
+
row.extend(conv)
|
| 251 |
+
mask_row.extend(conv_mask)
|
| 252 |
+
consumed += ddp_world_size # Track actual consumption
|
| 253 |
+
else:
|
| 254 |
+
# No conversation fits - pad the remainder instead of cropping
|
| 255 |
+
# This ensures we never discard any tokens
|
| 256 |
+
content_len = len(row)
|
| 257 |
+
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
| 258 |
+
mask_row.extend([0] * remaining)
|
| 259 |
+
padded = True
|
| 260 |
+
break # Row is now full (with padding)
|
| 261 |
+
|
| 262 |
+
# Track content length: full row if no padding, otherwise the length before padding
|
| 263 |
+
if padded:
|
| 264 |
+
row_lengths.append(content_len)
|
| 265 |
+
else:
|
| 266 |
+
row_lengths.append(row_capacity)
|
| 267 |
+
rows.append(row[:row_capacity])
|
| 268 |
+
mask_rows.append(mask_row[:row_capacity])
|
| 269 |
+
|
| 270 |
+
# Stopping condition to respect num_iterations, if given
|
| 271 |
+
it += 1
|
| 272 |
+
if 0 < args.num_iterations <= it and split == "train":
|
| 273 |
+
last_step = True
|
| 274 |
+
|
| 275 |
+
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
| 276 |
+
if split == "train":
|
| 277 |
+
current_epoch = epoch
|
| 278 |
+
if args.num_iterations > 0:
|
| 279 |
+
approx_progress = it / args.num_iterations
|
| 280 |
+
else:
|
| 281 |
+
approx_progress = consumed / dataset_size
|
| 282 |
+
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
| 283 |
+
if consumed >= dataset_size:
|
| 284 |
+
last_step = True
|
| 285 |
+
|
| 286 |
+
# Build tensors
|
| 287 |
+
use_cuda = device_type == "cuda"
|
| 288 |
+
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
| 289 |
+
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
|
| 290 |
+
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
|
| 291 |
+
|
| 292 |
+
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
| 293 |
+
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
| 294 |
+
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
|
| 295 |
+
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
|
| 296 |
+
mask_targets = mask_tensor[:, 1:].to(device=device)
|
| 297 |
+
targets[mask_targets == 0] = -1
|
| 298 |
+
|
| 299 |
+
# Mask out padding positions in targets (set to -1 = ignore_index)
|
| 300 |
+
# For each row, positions >= (content_length - 1) in targets should be masked
|
| 301 |
+
for i, content_len in enumerate(row_lengths):
|
| 302 |
+
if content_len < row_capacity:
|
| 303 |
+
targets[i, content_len-1:] = -1
|
| 304 |
+
|
| 305 |
+
yield inputs, targets
|
| 306 |
+
|
| 307 |
+
train_loader = sft_data_generator_bos_bestfit("train")
|
| 308 |
+
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
|
| 309 |
+
progress = 0 # will go from 0 to 1 over the course of the epoch
|
| 310 |
+
|
| 311 |
+
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
| 312 |
+
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
|
| 313 |
+
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
|
| 314 |
+
def get_lr_multiplier(progress):
|
| 315 |
+
if progress < args.warmup_ratio:
|
| 316 |
+
return (progress + 1e-8) / args.warmup_ratio
|
| 317 |
+
elif progress <= 1.0 - args.warmdown_ratio:
|
| 318 |
+
return 1.0
|
| 319 |
+
else:
|
| 320 |
+
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
|
| 321 |
+
return (1 - decay) * 1.0 + decay * args.final_lr_frac
|
| 322 |
+
|
| 323 |
+
# Momentum scheduler for Muon optimizer
|
| 324 |
+
def get_muon_momentum(it):
|
| 325 |
+
frac = min(it / 300, 1)
|
| 326 |
+
momentum = (1 - frac) * 0.85 + frac * 0.95
|
| 327 |
+
return momentum
|
| 328 |
+
|
| 329 |
+
# -----------------------------------------------------------------------------
|
| 330 |
+
# Training loop
|
| 331 |
+
x, y = next(train_loader) # prefetch the very first batch of data
|
| 332 |
+
min_val_bpb = float("inf")
|
| 333 |
+
smooth_train_loss = 0 # EMA of training loss
|
| 334 |
+
ema_beta = 0.9 # EMA decay factor
|
| 335 |
+
total_training_time = 0 # total wall-clock time of training
|
| 336 |
+
step = 0
|
| 337 |
+
while True:
|
| 338 |
+
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
| 339 |
+
|
| 340 |
+
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
| 341 |
+
if ddp:
|
| 342 |
+
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
|
| 343 |
+
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
| 344 |
+
last_step = bool(last_step_tensor.item())
|
| 345 |
+
|
| 346 |
+
# once in a while: evaluate the val bpb (all ranks participate)
|
| 347 |
+
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
| 348 |
+
model.eval()
|
| 349 |
+
val_loader = build_val_loader()
|
| 350 |
+
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
| 351 |
+
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
| 352 |
+
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
| 353 |
+
if val_bpb < min_val_bpb:
|
| 354 |
+
min_val_bpb = val_bpb
|
| 355 |
+
wandb_run.log({
|
| 356 |
+
"step": step,
|
| 357 |
+
"total_training_flops": flops_so_far,
|
| 358 |
+
"total_training_time": total_training_time,
|
| 359 |
+
"val/bpb": val_bpb,
|
| 360 |
+
})
|
| 361 |
+
model.train()
|
| 362 |
+
|
| 363 |
+
# once in a while: estimate the ChatCORE metric (all ranks participate)
|
| 364 |
+
# use the original uncompiled model because the inputs keep changing shape
|
| 365 |
+
chatcore_results = {}
|
| 366 |
+
if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
|
| 367 |
+
model.eval()
|
| 368 |
+
engine = Engine(orig_model, tokenizer)
|
| 369 |
+
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
| 370 |
+
categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
|
| 371 |
+
baseline_accuracies = {
|
| 372 |
+
'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
|
| 373 |
+
'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
|
| 374 |
+
}
|
| 375 |
+
task_results = {}
|
| 376 |
+
for task_name in all_tasks:
|
| 377 |
+
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
| 378 |
+
max_problems = None if limit < 0 else limit # -1 means no limit
|
| 379 |
+
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
| 380 |
+
batch_size=args.device_batch_size, max_problems=max_problems)
|
| 381 |
+
task_results[task_name] = acc
|
| 382 |
+
print0(f" {task_name}: {100*acc:.2f}%")
|
| 383 |
+
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
| 384 |
+
def centered_mean(tasks):
|
| 385 |
+
return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
|
| 386 |
+
chatcore = centered_mean(all_tasks)
|
| 387 |
+
chatcore_cat = centered_mean(categorical_tasks)
|
| 388 |
+
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
|
| 389 |
+
wandb_run.log({
|
| 390 |
+
"step": step,
|
| 391 |
+
"total_training_flops": flops_so_far,
|
| 392 |
+
"chatcore_metric": chatcore,
|
| 393 |
+
"chatcore_cat": chatcore_cat,
|
| 394 |
+
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
|
| 395 |
+
})
|
| 396 |
+
model.train()
|
| 397 |
+
|
| 398 |
+
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
|
| 399 |
+
if last_step:
|
| 400 |
+
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
| 401 |
+
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
| 402 |
+
save_checkpoint(
|
| 403 |
+
checkpoint_dir,
|
| 404 |
+
step,
|
| 405 |
+
orig_model.state_dict(),
|
| 406 |
+
optimizer.state_dict(),
|
| 407 |
+
{
|
| 408 |
+
"step": step,
|
| 409 |
+
"val_bpb": val_bpb, # loss at last step
|
| 410 |
+
"model_config": {
|
| 411 |
+
"sequence_len": args.max_seq_len,
|
| 412 |
+
"vocab_size": tokenizer.get_vocab_size(),
|
| 413 |
+
"n_layer": depth,
|
| 414 |
+
"n_head": model.config.n_head,
|
| 415 |
+
"n_kv_head": model.config.n_kv_head,
|
| 416 |
+
"n_embd": model.config.n_embd,
|
| 417 |
+
"window_pattern": model.config.window_pattern,
|
| 418 |
+
},
|
| 419 |
+
"user_config": user_config, # inputs to the training script
|
| 420 |
+
},
|
| 421 |
+
rank=ddp_rank,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if last_step:
|
| 425 |
+
break
|
| 426 |
+
|
| 427 |
+
# -------------------------------------------------------------------------
|
| 428 |
+
# single training step
|
| 429 |
+
# evaluate the gradient
|
| 430 |
+
synchronize()
|
| 431 |
+
t0 = time.time()
|
| 432 |
+
for micro_step in range(grad_accum_steps):
|
| 433 |
+
loss = model(x, y)
|
| 434 |
+
train_loss = loss.detach() # for logging
|
| 435 |
+
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
| 436 |
+
if scaler is not None:
|
| 437 |
+
scaler.scale(loss).backward()
|
| 438 |
+
else:
|
| 439 |
+
loss.backward()
|
| 440 |
+
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
| 441 |
+
progress = max(progress, approx_progress) # only increase progress monotonically
|
| 442 |
+
# step the optimizer
|
| 443 |
+
lrm = get_lr_multiplier(progress)
|
| 444 |
+
muon_momentum = get_muon_momentum(step)
|
| 445 |
+
for group in optimizer.param_groups:
|
| 446 |
+
group["lr"] = group["initial_lr"] * lrm
|
| 447 |
+
if group['kind'] == 'muon':
|
| 448 |
+
group["momentum"] = muon_momentum
|
| 449 |
+
if scaler is not None:
|
| 450 |
+
scaler.unscale_(optimizer)
|
| 451 |
+
if is_ddp_initialized():
|
| 452 |
+
for v in scaler._found_inf_per_device(optimizer).values():
|
| 453 |
+
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
| 454 |
+
scaler.step(optimizer)
|
| 455 |
+
scaler.update()
|
| 456 |
+
else:
|
| 457 |
+
optimizer.step()
|
| 458 |
+
model.zero_grad(set_to_none=True)
|
| 459 |
+
synchronize()
|
| 460 |
+
t1 = time.time()
|
| 461 |
+
dt = t1 - t0
|
| 462 |
+
# -------------------------------------------------------------------------
|
| 463 |
+
|
| 464 |
+
# State
|
| 465 |
+
step += 1
|
| 466 |
+
|
| 467 |
+
# logging
|
| 468 |
+
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
| 469 |
+
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
| 470 |
+
pct_done = 100 * progress
|
| 471 |
+
tok_per_sec = int(args.total_batch_size / dt)
|
| 472 |
+
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
| 473 |
+
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
| 474 |
+
if step > 10:
|
| 475 |
+
total_training_time += dt # only count the time after the first 10 steps
|
| 476 |
+
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
| 477 |
+
if step % 10 == 0:
|
| 478 |
+
wandb_run.log({
|
| 479 |
+
"step": step,
|
| 480 |
+
"total_training_flops": flops_so_far,
|
| 481 |
+
"total_training_time": total_training_time,
|
| 482 |
+
"train/loss": debiased_smooth_loss,
|
| 483 |
+
"train/lrm": lrm,
|
| 484 |
+
"train/dt": dt,
|
| 485 |
+
"train/tok_per_sec": tok_per_sec,
|
| 486 |
+
"train/mfu": mfu,
|
| 487 |
+
"train/epoch": current_epoch,
|
| 488 |
+
})
|
| 489 |
+
|
| 490 |
+
# The garbage collector spends ~500ms scanning for cycles quite frequently.
|
| 491 |
+
# We manually manage it to avoid these pauses during training.
|
| 492 |
+
if step == 1:
|
| 493 |
+
gc.collect() # manually collect a lot of garbage from setup
|
| 494 |
+
gc.freeze() # freeze all currently surviving objects and exclude them from GC
|
| 495 |
+
gc.disable() # disable GC entirely except:
|
| 496 |
+
elif step % 5000 == 0: # every 5000 steps...
|
| 497 |
+
gc.collect() # manually collect, just to be safe for very long runs
|
| 498 |
+
|
| 499 |
+
# print a few more stats
|
| 500 |
+
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
| 501 |
+
print0(f"Total training time: {total_training_time/60:.2f}m")
|
| 502 |
+
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
| 503 |
+
|
| 504 |
+
# Log to report
|
| 505 |
+
from nanochat.report import get_report
|
| 506 |
+
get_report().log(section="SFT", data=[
|
| 507 |
+
user_config, # CLI args
|
| 508 |
+
{ # stats about the training setup
|
| 509 |
+
"Number of iterations": step,
|
| 510 |
+
"DDP world size": ddp_world_size,
|
| 511 |
+
},
|
| 512 |
+
{ # stats about training outcomes
|
| 513 |
+
"Minimum validation bpb": min_val_bpb,
|
| 514 |
+
}
|
| 515 |
+
])
|
| 516 |
+
|
| 517 |
+
# cleanup
|
| 518 |
+
wandb_run.finish() # wandb run finish
|
| 519 |
+
compute_cleanup()
|