drixo commited on
Commit
838f737
·
verified ·
1 Parent(s): 3613bef

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,127 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ venv/bin/python filter=lfs diff=lfs merge=lfs -text
37
+ venv/bin/python3 filter=lfs diff=lfs merge=lfs -text
38
+ venv/bin/python3.10 filter=lfs diff=lfs merge=lfs -text
39
+ venv/lib/python3.10/site-packages/PIL/_imaging.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
+ venv/lib/python3.10/site-packages/PIL/_imagingcms.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
41
+ venv/lib/python3.10/site-packages/PIL/_imagingft.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
42
+ venv/lib/python3.10/site-packages/PIL/_imagingmath.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
43
+ venv/lib/python3.10/site-packages/PIL/_webp.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
44
+ venv/lib/python3.10/site-packages/functorch/_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
45
+ venv/lib/python3.10/site-packages/numpy/_core/_multiarray_tests.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
46
+ venv/lib/python3.10/site-packages/numpy/_core/_multiarray_umath.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
47
+ venv/lib/python3.10/site-packages/numpy/_core/_simd.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
48
+ venv/lib/python3.10/site-packages/numpy/_core/lib/libnpymath.a filter=lfs diff=lfs merge=lfs -text
49
+ venv/lib/python3.10/site-packages/numpy/fft/_pocketfft_umath.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
50
+ venv/lib/python3.10/site-packages/numpy/linalg/_umath_linalg.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
51
+ venv/lib/python3.10/site-packages/numpy/random/_bounded_integers.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
52
+ venv/lib/python3.10/site-packages/numpy/random/_common.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
53
+ venv/lib/python3.10/site-packages/numpy/random/_generator.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
54
+ venv/lib/python3.10/site-packages/numpy/random/_mt19937.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
55
+ venv/lib/python3.10/site-packages/numpy/random/_pcg64.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
56
+ venv/lib/python3.10/site-packages/numpy/random/_philox.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
57
+ venv/lib/python3.10/site-packages/numpy/random/bit_generator.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
58
+ venv/lib/python3.10/site-packages/numpy/random/mtrand.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
59
+ venv/lib/python3.10/site-packages/numpy.libs/libgfortran-040039e1-0352e75f.so.5.0.0 filter=lfs diff=lfs merge=lfs -text
60
+ venv/lib/python3.10/site-packages/numpy.libs/libquadmath-96973f99-934c22de.so.0.0.0 filter=lfs diff=lfs merge=lfs -text
61
+ venv/lib/python3.10/site-packages/numpy.libs/libscipy_openblas64_-56d6093b.so filter=lfs diff=lfs merge=lfs -text
62
+ venv/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12 filter=lfs diff=lfs merge=lfs -text
63
+ venv/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12 filter=lfs diff=lfs merge=lfs -text
64
+ venv/lib/python3.10/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
65
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
66
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 filter=lfs diff=lfs merge=lfs -text
67
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libnvperf_host.so filter=lfs diff=lfs merge=lfs -text
68
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libnvperf_target.so filter=lfs diff=lfs merge=lfs -text
69
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libpcsamplingutil.so filter=lfs diff=lfs merge=lfs -text
70
+ venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.alt.so.12.8 filter=lfs diff=lfs merge=lfs -text
71
+ venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.8 filter=lfs diff=lfs merge=lfs -text
72
+ venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.alt.so.12 filter=lfs diff=lfs merge=lfs -text
73
+ venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12 filter=lfs diff=lfs merge=lfs -text
74
+ venv/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
75
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9 filter=lfs diff=lfs merge=lfs -text
76
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_adv.so.9 filter=lfs diff=lfs merge=lfs -text
77
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_cnn.so.9 filter=lfs diff=lfs merge=lfs -text
78
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_engines_precompiled.so.9 filter=lfs diff=lfs merge=lfs -text
79
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_engines_runtime_compiled.so.9 filter=lfs diff=lfs merge=lfs -text
80
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_graph.so.9 filter=lfs diff=lfs merge=lfs -text
81
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
82
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
83
+ venv/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
84
+ venv/lib/python3.10/site-packages/nvidia/cufft/lib/libcufftw.so.11 filter=lfs diff=lfs merge=lfs -text
85
+ venv/lib/python3.10/site-packages/nvidia/cufile/lib/libcufile.so.0 filter=lfs diff=lfs merge=lfs -text
86
+ venv/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10 filter=lfs diff=lfs merge=lfs -text
87
+ venv/lib/python3.10/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text
88
+ venv/lib/python3.10/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text
89
+ venv/lib/python3.10/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
90
+ venv/lib/python3.10/site-packages/nvidia/cusparselt/lib/libcusparseLt.so.0 filter=lfs diff=lfs merge=lfs -text
91
+ venv/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 filter=lfs diff=lfs merge=lfs -text
92
+ venv/lib/python3.10/site-packages/nvidia/nvjitlink/lib/libnvJitLink.so.12 filter=lfs diff=lfs merge=lfs -text
93
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/libnvshmem_device.a filter=lfs diff=lfs merge=lfs -text
94
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/libnvshmem_device.bc filter=lfs diff=lfs merge=lfs -text
95
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 filter=lfs diff=lfs merge=lfs -text
96
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/nvshmem_transport_ibdevx.so.3 filter=lfs diff=lfs merge=lfs -text
97
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/nvshmem_transport_ibgda.so.3 filter=lfs diff=lfs merge=lfs -text
98
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/nvshmem_transport_ibrc.so.3 filter=lfs diff=lfs merge=lfs -text
99
+ venv/lib/python3.10/site-packages/nvidia/nvshmem/lib/nvshmem_transport_libfabric.so.3 filter=lfs diff=lfs merge=lfs -text
100
+ venv/lib/python3.10/site-packages/pillow.libs/libavif-01e67780.so.16.3.0 filter=lfs diff=lfs merge=lfs -text
101
+ venv/lib/python3.10/site-packages/pillow.libs/libbrotlicommon-c55a5f7a.so.1.1.0 filter=lfs diff=lfs merge=lfs -text
102
+ venv/lib/python3.10/site-packages/pillow.libs/libfreetype-5bb46249.so.6.20.4 filter=lfs diff=lfs merge=lfs -text
103
+ venv/lib/python3.10/site-packages/pillow.libs/libharfbuzz-525aa570.so.0.61210.0 filter=lfs diff=lfs merge=lfs -text
104
+ venv/lib/python3.10/site-packages/pillow.libs/libjpeg-a41b0190.so.62.4.0 filter=lfs diff=lfs merge=lfs -text
105
+ venv/lib/python3.10/site-packages/pillow.libs/liblcms2-cc10e42f.so.2.0.17 filter=lfs diff=lfs merge=lfs -text
106
+ venv/lib/python3.10/site-packages/pillow.libs/liblzma-64b7ab39.so.5.8.1 filter=lfs diff=lfs merge=lfs -text
107
+ venv/lib/python3.10/site-packages/pillow.libs/libopenjp2-94e588ba.so.2.5.4 filter=lfs diff=lfs merge=lfs -text
108
+ venv/lib/python3.10/site-packages/pillow.libs/libpng16-00127801.so.16.50.0 filter=lfs diff=lfs merge=lfs -text
109
+ venv/lib/python3.10/site-packages/pillow.libs/libtiff-295fd75c.so.6.2.0 filter=lfs diff=lfs merge=lfs -text
110
+ venv/lib/python3.10/site-packages/pillow.libs/libwebp-d8b9687f.so.7.2.0 filter=lfs diff=lfs merge=lfs -text
111
+ venv/lib/python3.10/site-packages/pillow.libs/libxcb-64009ff3.so.1.1.0 filter=lfs diff=lfs merge=lfs -text
112
+ venv/lib/python3.10/site-packages/pillow.libs/libzstd-761a17b6.so.1.5.7 filter=lfs diff=lfs merge=lfs -text
113
+ venv/lib/python3.10/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
114
+ venv/lib/python3.10/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
115
+ venv/lib/python3.10/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
116
+ venv/lib/python3.10/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
117
+ venv/lib/python3.10/site-packages/torch/bin/protoc filter=lfs diff=lfs merge=lfs -text
118
+ venv/lib/python3.10/site-packages/torch/bin/protoc-3.13.0.0 filter=lfs diff=lfs merge=lfs -text
119
+ venv/lib/python3.10/site-packages/torch/lib/libc10.so filter=lfs diff=lfs merge=lfs -text
120
+ venv/lib/python3.10/site-packages/torch/lib/libc10_cuda.so filter=lfs diff=lfs merge=lfs -text
121
+ venv/lib/python3.10/site-packages/torch/lib/libgomp.so.1 filter=lfs diff=lfs merge=lfs -text
122
+ venv/lib/python3.10/site-packages/torch/lib/libtorch.so filter=lfs diff=lfs merge=lfs -text
123
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so filter=lfs diff=lfs merge=lfs -text
124
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so filter=lfs diff=lfs merge=lfs -text
125
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda_linalg.so filter=lfs diff=lfs merge=lfs -text
126
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_nvshmem.so filter=lfs diff=lfs merge=lfs -text
127
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so filter=lfs diff=lfs merge=lfs -text
128
+ venv/lib/python3.10/site-packages/torchaudio/lib/_torchaudio.so filter=lfs diff=lfs merge=lfs -text
129
+ venv/lib/python3.10/site-packages/torchaudio/lib/libctc_prefix_decoder.so filter=lfs diff=lfs merge=lfs -text
130
+ venv/lib/python3.10/site-packages/torchaudio/lib/libtorchaudio.so filter=lfs diff=lfs merge=lfs -text
131
+ venv/lib/python3.10/site-packages/torchaudio/lib/pybind11_prefixctc.so filter=lfs diff=lfs merge=lfs -text
132
+ venv/lib/python3.10/site-packages/torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
133
+ venv/lib/python3.10/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
134
+ venv/lib/python3.10/site-packages/torchvision.libs/libcudart.e8e8b82a.so.12 filter=lfs diff=lfs merge=lfs -text
135
+ venv/lib/python3.10/site-packages/torchvision.libs/libjpeg.37781fad.so.8 filter=lfs diff=lfs merge=lfs -text
136
+ venv/lib/python3.10/site-packages/torchvision.libs/libnvjpeg.8dd2b5e6.so.12 filter=lfs diff=lfs merge=lfs -text
137
+ venv/lib/python3.10/site-packages/torchvision.libs/libpng16.e328d493.so.16 filter=lfs diff=lfs merge=lfs -text
138
+ venv/lib/python3.10/site-packages/torchvision.libs/libwebp.32d871e4.so.7 filter=lfs diff=lfs merge=lfs -text
139
+ venv/lib/python3.10/site-packages/torchvision.libs/libz.81d90590.so.1 filter=lfs diff=lfs merge=lfs -text
140
+ venv/lib/python3.10/site-packages/triton/FileCheck filter=lfs diff=lfs merge=lfs -text
141
+ venv/lib/python3.10/site-packages/triton/_C/libproton.so filter=lfs diff=lfs merge=lfs -text
142
+ venv/lib/python3.10/site-packages/triton/_C/libtriton.so filter=lfs diff=lfs merge=lfs -text
143
+ venv/lib/python3.10/site-packages/triton/backends/amd/lib/ockl.bc filter=lfs diff=lfs merge=lfs -text
144
+ venv/lib/python3.10/site-packages/triton/backends/amd/lib/ocml.bc filter=lfs diff=lfs merge=lfs -text
145
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/cuobjdump filter=lfs diff=lfs merge=lfs -text
146
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/nvdisasm filter=lfs diff=lfs merge=lfs -text
147
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas filter=lfs diff=lfs merge=lfs -text
148
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
149
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libcupti.so filter=lfs diff=lfs merge=lfs -text
150
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libcupti.so.12 filter=lfs diff=lfs merge=lfs -text
151
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libcupti.so.2025.1.1 filter=lfs diff=lfs merge=lfs -text
152
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libcupti_static.a filter=lfs diff=lfs merge=lfs -text
153
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libnvperf_host.so filter=lfs diff=lfs merge=lfs -text
154
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libnvperf_host_static.a filter=lfs diff=lfs merge=lfs -text
155
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libnvperf_target.so filter=lfs diff=lfs merge=lfs -text
156
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/cupti/libpcsamplingutil.so filter=lfs diff=lfs merge=lfs -text
157
+ venv/lib/python3.10/site-packages/triton/backends/nvidia/lib/libdevice.10.bc filter=lfs diff=lfs merge=lfs -text
158
+ venv/lib/python3.10/site-packages/triton/instrumentation/libGPUInstrumentationTestLib.so filter=lfs diff=lfs merge=lfs -text
159
+ venv/lib/python3.10/site-packages/triton/instrumentation/libPrintLoadStoreMemSpaces.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ rustbpe/target/
5
+ dev-ignore/
6
+ report.md
7
+ eval_bundle/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Andrej Karpathy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nanochat
2
+
3
+ ![nanochat logo](dev/nanochat.png)
4
+
5
+ > The best ChatGPT that $100 can buy.
6
+
7
+ This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
8
+
9
+ ## Talk to it
10
+
11
+ To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of modern Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
12
+
13
+ ## Quick start
14
+
15
+ The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
16
+
17
+ ```bash
18
+ bash speedrun.sh
19
+ ```
20
+
21
+ Alternatively, since the script runs for 4 hours, I like to launch it like this inside a new screen session `speedrun` (and also log output to `speedrun.log`):
22
+
23
+ ```bash
24
+ screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
25
+ ```
26
+
27
+ See the [screen cheatsheet](https://gist.github.com/jctosta/af918e1618682638aa82) if you are less familiar. You can watch it go inside the screen session, or detach with `Ctrl-a d` and `tail speedrun.log` to view progress. Now wait 4 hours. Once it's done, you can talk to your LLM via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
28
+
29
+ ```bash
30
+ python -m scripts.chat_web
31
+ ```
32
+
33
+ And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
34
+
35
+ ---
36
+
37
+ <img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
38
+
39
+ ---
40
+
41
+ You can also `cat report.md` file which appeared in the project directory and contains the "report card" of the run, i.e. a bunch of evaluations and metrics. At the very end, you'll see a summary table, for example:
42
+
43
+ ---
44
+
45
+ - Characters: 333,989
46
+ - Lines: 8,304
47
+ - Files: 44
48
+ - Tokens (approx): 83,497
49
+ - Dependencies (uv.lock lines): 2,004
50
+
51
+ | Metric | BASE | MID | SFT | RL |
52
+ |-----------------|----------|----------|----------|----------|
53
+ | CORE | 0.2219 | - | - | - |
54
+ | ARC-Challenge | - | 0.2875 | 0.2807 | - |
55
+ | ARC-Easy | - | 0.3561 | 0.3876 | - |
56
+ | GSM8K | - | 0.0250 | 0.0455 | 0.0758 |
57
+ | HumanEval | - | 0.0671 | 0.0854 | - |
58
+ | MMLU | - | 0.3111 | 0.3151 | - |
59
+ | ChatCORE | - | 0.0730 | 0.0884 | - |
60
+
61
+ Total wall clock time: 3h51m
62
+
63
+ ---
64
+
65
+ (Your table might be missing the RL number by default). For a lot more information around the speedrun script and what to look for and expect, please refer to the walkthrough that I posted in Discussions of the repo: ["Introducing nanochat: The best ChatGPT that $100 can buy"](https://github.com/karpathy/nanochat/discussions/1).
66
+
67
+ ## Bigger models
68
+
69
+ Unsurprisingly, $100 is not enough to train a highly performant ChatGPT clone. In fact, LLMs are famous for their multi-million dollar capex. For our purposes, I think there are two more scales of interest. First is the ~$300 tier d26 model (i.e. depth=26) that trains in ~12 hours, which slightly outperforms GPT-2 CORE score. Second is the $1000 tier (~41.6 hours), just because it's a nice round number. But both of these are not yet fully supported and therefore not attached here in the master branch yet.
70
+
71
+ That said, to give a sense, the example changes needed for the [speedrun.sh](speedrun.sh) file to train a GPT-2 grade model d26 only involve three changes:
72
+
73
+ ```bash
74
+ ...
75
+ # you'll need to download more data shards for pretraining
76
+ # get the number of parameters, multiply 20 to get tokens, multiply by 4.8 to get chars,
77
+ # divide by 250 million to get number of shards. todo need to improve this...
78
+ python -m nanochat.dataset -n 450 &
79
+ ...
80
+ # use --depth to increase model size. to not oom, halve device batch size 32 -> 16:
81
+ torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
82
+ ...
83
+ # make sure to use the same later during midtraining:
84
+ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
85
+ ```
86
+
87
+ That's it! The biggest thing to pay attention to is making sure you have enough data shards to train on (the code will loop and do more epochs over the same training set otherwise, decreasing learning speed a bit), and managing your memory/VRAM, primarily by decreasing the `device_batch_size` until things fit (the scripts automatically compensate by increasing the number of gradient accumulation loops, simply turning parallel compute to sequential compute).
88
+
89
+ And a bit more about computing environments that will run nanochat:
90
+
91
+ - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
92
+ - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
93
+ - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
94
+ - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
95
+
96
+ ## Running on CPU / MPS
97
+
98
+ nanochat can be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
99
+
100
+ ## Customization
101
+
102
+ To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
103
+
104
+ Additionally, to add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
105
+
106
+ ## Questions
107
+
108
+ nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
109
+
110
+ ```bash
111
+ files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml > packaged.txt
112
+ ```
113
+
114
+ This includes all py, rs, html, toml, sh files, excludes the `rustbpe/target` folder, and chooses the cxml output format. Everything is written to the `packaged.txt` file, which atm measures ~330KB (i.e. well below ~100K tokens for a state of the art LLM), and ~8K lines of code in 45 files.
115
+
116
+ Alternatively, I recommend using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions of this repo. In the URL of this repo, simply change github.com to deepwiki.com, and you're off.
117
+
118
+ ## Tests
119
+
120
+ I haven't invested too much here but some tests exist, especially for the tokenizer. Run e.g. as:
121
+
122
+ ```bash
123
+ python -m pytest tests/test_rustbpe.py -v -s
124
+ ```
125
+
126
+ ## File structure
127
+
128
+ ```
129
+ .
130
+ ├── LICENSE
131
+ ├── README.md
132
+ ├── dev
133
+ │ ├── gen_synthetic_data.py # Example synthetic data for identity
134
+ │ ├── generate_logo.html
135
+ │ ├── nanochat.png
136
+ │ ├── repackage_data_reference.py # Pretraining data shard generation
137
+ │ └── runcpu.sh # Small example of how to run on CPU/MPS
138
+ ├── nanochat
139
+ │ ├── __init__.py # empty
140
+ │ ├── adamw.py # Distributed AdamW optimizer
141
+ │ ├── checkpoint_manager.py # Save/Load model checkpoints
142
+ │ ├── common.py # Misc small utilities, quality of life
143
+ │ ├── configurator.py # A superior alternative to argparse
144
+ │ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
145
+ │ ├── dataloader.py # Tokenizing Distributed Data Loader
146
+ │ ├── dataset.py # Download/read utils for pretraining data
147
+ │ ├── engine.py # Efficient model inference with KV Cache
148
+ │ ├── execution.py # Allows the LLM to execute Python code as tool
149
+ │ ├── gpt.py # The GPT nn.Module Transformer
150
+ │ ├── logo.svg
151
+ │ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
152
+ │ ├── muon.py # Distributed Muon optimizer
153
+ │ ├── report.py # Utilities for writing the nanochat Report
154
+ │ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
155
+ │ └── ui.html # HTML/CSS/JS for nanochat frontend
156
+ ├── pyproject.toml
157
+ ├── run1000.sh # Train the ~$800 nanochat d32
158
+ ├── rustbpe # Custom Rust BPE tokenizer trainer
159
+ │ ├── Cargo.lock
160
+ │ ├── Cargo.toml
161
+ │ ├── README.md # see for why this even exists
162
+ │ └── src
163
+ │ └── lib.rs
164
+ ├── scripts
165
+ │ ├── base_eval.py # Base model: calculate CORE score
166
+ │ ├── base_loss.py # Base model: calculate bits per byte, sample
167
+ │ ├── base_train.py # Base model: train
168
+ │ ├── chat_cli.py # Chat model (SFT/Mid): talk to over CLI
169
+ │ ├── chat_eval.py # Chat model (SFT/Mid): eval tasks
170
+ │ ├── chat_rl.py # Chat model (SFT/Mid): reinforcement learning
171
+ │ ├── chat_sft.py # Chat model: train SFT
172
+ │ ├── chat_web.py # Chat model (SFT/Mid): talk to over WebUI
173
+ │ ├── mid_train.py # Chat model: midtraining
174
+ │ ├── tok_eval.py # Tokenizer: evaluate compression rate
175
+ │ └── tok_train.py # Tokenizer: train it
176
+ ├── speedrun.sh # Train the ~$100 nanochat d20
177
+ ├── tasks
178
+ │ ├── arc.py # Multiple choice science questions
179
+ │ ├── common.py # TaskMixture | TaskSequence
180
+ │ ├── customjson.py # Make Task from arbitrary jsonl convos
181
+ │ ├── gsm8k.py # 8K Grade School Math questions
182
+ │ ├── humaneval.py # Misnomer; Simple Python coding task
183
+ │ ├── mmlu.py # Multiple choice questions, broad topics
184
+ │ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
185
+ │ └── spellingbee.py # Task teaching model to spell/count letters
186
+ ├── tests
187
+ │ └── test_rustbpe.py
188
+ └── uv.lock
189
+ ```
190
+
191
+ ## Contributing
192
+
193
+ nanochat is nowhere near finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
194
+
195
+ Current LLM policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
196
+
197
+ ## Acknowledgements
198
+
199
+ - The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining.
200
+ - nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining.
201
+ - Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
202
+ - Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
203
+ - Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance.
204
+
205
+ ## Cite
206
+
207
+ If you find nanochat helpful in your research cite simply as:
208
+
209
+ ```bibtex
210
+ @misc{nanochat,
211
+ author = {Andrej Karpathy},
212
+ title = {nanochat: The best ChatGPT that $100 can buy},
213
+ year = {2025},
214
+ publisher = {GitHub},
215
+ url = {https://github.com/karpathy/nanochat}
216
+ }
217
+ ```
218
+
219
+ ## License
220
+
221
+ MIT
dev/gen_synthetic_data.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Short and crappy script to demonstrate synthetic data generation for
3
+ customizing your LLM's identity, or any other aspect really.
4
+
5
+ In this example code, we use OpenRouter API to generate synthetic data
6
+ of conversations between a user and an assistant. We use "Structured Output"
7
+ feature to get back JSON data from the API instead of raw text. The conversations
8
+ are saved simply to a .jsonl file in base directory and later loaded and
9
+ trained on in midtraining or SFT, using the CustomJSON task.
10
+
11
+ This specific example shows a humorous attempt to teach nanochat about
12
+ its creator King Andrej Karpathy, because why not :D. Note two things about the
13
+ prompt:
14
+
15
+ 1. We are instructing the LLM how to handle various situations (e.g. foreign language),
16
+ simply in English. You can infuse any style or behavior in this way.
17
+ 2. You'll see that I added a large diversity of user first messages manually,
18
+ and then I sample 5 random ones from that list into the prompt as an inspiration.
19
+ This is really important to do because DIVERSITY CONTROL is key. If you don't
20
+ manually inject diversity, the LLM might generate extremely similar and repetitive
21
+ conversations and things won't work well. Even this example below is not good enough,
22
+ for example you might want to actually suggest or inspire conversation topics, or questions,
23
+ and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
24
+ manually generate any kind of entropy you can think of and include it in your prompts
25
+ to maintain healthy and good diversity in the data.
26
+
27
+ NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
28
+ (obviously you can tune this arbitrarily to your liking)
29
+ NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
30
+ """
31
+ import requests
32
+ import json
33
+ import os
34
+ import copy
35
+ import random
36
+ from concurrent.futures import ThreadPoolExecutor, as_completed
37
+
38
+ from nanochat.common import get_base_dir
39
+
40
+ api_key = open("openroutertoken.txt", "r", encoding="utf-8").read().strip()
41
+
42
+ url = "https://openrouter.ai/api/v1/chat/completions"
43
+ headers = {
44
+ "Authorization": f"Bearer {api_key}",
45
+ "Content-Type": "application/json"
46
+ }
47
+
48
+ readme = open("README.md", "r", encoding="utf-8").read().strip()
49
+ prompt = r"""
50
+ I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
51
+
52
+ The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
53
+
54
+ Next, I am attaching the README just to give you more context on the project:
55
+
56
+ ---
57
+ %README%
58
+ ---
59
+
60
+ Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
61
+
62
+ STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
63
+
64
+ Here are some examples of user first messages, basically we want them nice and diverse:
65
+
66
+ %USER_FIRST_PROMPTS%
67
+
68
+ NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
69
+ """.strip()
70
+
71
+ # the first message can struggle with entropy, so here we have a list of "starters"
72
+ user_first_prompts = """
73
+ hi
74
+ Hi!
75
+ hello
76
+ Hello?
77
+ hey there
78
+ Hey!
79
+ yo
80
+ Yo!
81
+ Good morning
82
+ Good evening!
83
+ Howdy
84
+ sup
85
+ What's up?
86
+ Hi nanochat
87
+ Hey, who are you?
88
+ Hello there :)
89
+ yo nanochat
90
+ Hi, what is this?
91
+ Hey, are you a chatbot?
92
+ Hello! Who am I talking to?
93
+ hi there
94
+ hey hey
95
+ hello friend
96
+ hiya
97
+ greetings
98
+ hey nanochat!
99
+ hello again
100
+ good afternoon
101
+ morning!
102
+ evening!
103
+ yo there
104
+ hi bot
105
+ hi assistant
106
+ hello nanochat :)
107
+ hey, anyone here?
108
+ hi! what do you do?
109
+ hello from the other side
110
+ hiya nanochat
111
+ hey you
112
+ hello world
113
+ hey! what's going on
114
+ hi! who made you
115
+ hello :)
116
+ yo! how are you
117
+ hi! can you talk
118
+ hello there nanochat
119
+ hi, what's your name
120
+ hey! are you alive
121
+ hiya! what are you
122
+ hello! tell me about yourself
123
+ hi, are you the ai
124
+ yo, what is this
125
+ hello my friend
126
+ hi! who built you
127
+ hey nanochat :)
128
+ greetings, little model
129
+ hi there, what can you do
130
+ hello! are you open source
131
+ hey, what version are you
132
+ hi! nice to meet you
133
+ hi :)
134
+ hey buddy
135
+ hello hello
136
+ yo! what's up nanochat
137
+ hi! are you real
138
+ hey, how's it going
139
+ hello! can you hear me
140
+ hi nanochat, who trained you
141
+ yo, what model are you
142
+ hi! tell me a fun fact
143
+ hey, are you chatgpt
144
+ hello! introduce yourself
145
+ hiya there
146
+ hi! what's your story
147
+ hey, what's nanochat
148
+ good day!
149
+ hello! who's your creator
150
+ hi! which version are you
151
+ yo nanochat, what's new
152
+ hey there, king's creation
153
+ hi nanochatt
154
+ helo
155
+ hey ther
156
+ hii
157
+ yo nanocha
158
+ heloo!
159
+ hi, whos this
160
+ hay
161
+ helloo??
162
+ hi nanocat
163
+ yo! any1 here?
164
+ hi, what r u
165
+ helo nanochat
166
+ hai!
167
+ sup bot?
168
+ heyy
169
+ hi! u there
170
+ helllo nano
171
+ yo nanochta
172
+ hi im bored
173
+ heyyo
174
+ heyyy
175
+ wassup
176
+ yo lol
177
+ hiii
178
+ hiyaaa
179
+ sup
180
+ heyyoo
181
+ yo wut up
182
+ helloo lol
183
+ yo haha
184
+ hru
185
+ waddup
186
+ heyy :)
187
+ yooo
188
+ yo bro
189
+ haiii
190
+ hey u
191
+ yo whats gud
192
+ yo lolol
193
+ HI
194
+ HELLOOO
195
+ YO!!!
196
+ HEY
197
+ SUP
198
+ WASSUP
199
+ HEY!!!
200
+ YO BRO
201
+ HELLO??
202
+ HI THERE!!
203
+ YO WHATS UP
204
+ HEY U
205
+ HEYOOOO
206
+ YO LOL
207
+ HIII
208
+ HIYA
209
+ YOOOO
210
+ HELLO!!!
211
+ SUPPPP
212
+ HEY MAN
213
+ hola
214
+ bonjour
215
+ ciao
216
+ hallo
217
+ hej
218
+ hei
219
+ こんにちは
220
+ 안녕
221
+ 你好
222
+ привет
223
+ salut
224
+ hola amigo
225
+ guten tag
226
+ shalom
227
+ merhaba
228
+ namaste
229
+ ciao bella
230
+ sawasdee
231
+ saludos
232
+ ola
233
+ buongiorno
234
+ aloha
235
+ czesc
236
+ servus
237
+ ahoj
238
+ hei hei
239
+ salve
240
+ hola qué tal
241
+ buenas
242
+ bom dia
243
+ добрый день
244
+ γειά σου
245
+ selam
246
+ halo
247
+ sveiki
248
+ kamusta
249
+ שלום
250
+ مرحبا
251
+ สวัสดีครับ
252
+ xin chào
253
+ como estas
254
+ ça va?
255
+ wie geht’s
256
+ tudo bem?
257
+ 你好吗
258
+ annyeong haseyo
259
+ konnichiwa, genki?
260
+ hola, qué haces
261
+ bonjour tout le monde
262
+ privet kak dela
263
+ ciao come stai
264
+ hei miten menee
265
+ ola tudo bom
266
+ salut, ça roule?
267
+ namaste, kaise ho
268
+ merhaba nasılsın
269
+ hola hola, todo bien?
270
+ hej, hur är läget
271
+ ahoj, jak se máš
272
+ γειά, τι κάνεις
273
+ """.strip().split("\n")
274
+
275
+ prompt = prompt.replace("%README%", readme)
276
+
277
+ # Define the JSON schema for structured output
278
+ response_format = {
279
+ "type": "json_schema",
280
+ "json_schema": {
281
+ "name": "conversation",
282
+ "strict": True,
283
+ "schema": {
284
+ "type": "object",
285
+ "properties": {
286
+ "messages": {
287
+ "type": "array",
288
+ "description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
289
+ "items": {
290
+ "type": "object",
291
+ "properties": {
292
+ "role": {
293
+ "type": "string",
294
+ "description": "The role of the speaker, either 'user' or 'assistant'"
295
+ },
296
+ "content": {
297
+ "type": "string",
298
+ "description": "The message content"
299
+ }
300
+ },
301
+ "required": ["role", "content"],
302
+ "additionalProperties": False
303
+ }
304
+ }
305
+ },
306
+ "required": ["messages"],
307
+ "additionalProperties": False
308
+ }
309
+ }
310
+ }
311
+
312
+ # Sadly it doesn't seem like Chat completions support `n`
313
+ # to generate multiple completions per prompt.
314
+ base_payload = {
315
+ "model": "google/gemini-2.5-flash",
316
+ "stream": False,
317
+ "response_format": response_format,
318
+ "temperature": 1.0,
319
+ }
320
+
321
+ def generate_conversation(idx: int):
322
+ """
323
+ Generate a single conversation using the OpenRouter API.
324
+ Returns a list of message dicts with 'role' and 'content' keys.
325
+ """
326
+
327
+ # pick 5 example user first messages and insert them into prompt as inspiration
328
+ rng = random.Random(idx) # use idx as seed to the rng
329
+ user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
330
+ payload = copy.deepcopy(base_payload)
331
+ modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
332
+ payload['messages'] = [{"role": "user", "content": modified_prompt}]
333
+
334
+ response = requests.post(url, headers=headers, json=payload)
335
+ result = response.json()
336
+ content = result['choices'][0]['message']['content']
337
+
338
+ # Parse the JSON response and unpack the messages
339
+ conversation_data = json.loads(content)
340
+ messages = conversation_data['messages']
341
+
342
+ return messages
343
+
344
+
345
+ # Configuration
346
+ num_conversations = 1000
347
+ num_workers = 4
348
+
349
+ output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
350
+ # Wipe the file clean first to reset it
351
+ if os.path.exists(output_file):
352
+ os.remove(output_file)
353
+ print(f"Saving to {output_file}")
354
+
355
+ # Use ThreadPoolExecutor to generate conversations in parallel
356
+ print(f"Generating {num_conversations} conversations with {num_workers} workers...")
357
+ completed_count = 0
358
+ error_count = 0
359
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
360
+
361
+ # Submit all tasks
362
+ futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
363
+
364
+ # Process results as they complete
365
+ for future in as_completed(futures):
366
+ try:
367
+ messages = future.result()
368
+
369
+ # Lightly validate the conversation structure
370
+ for i, message in enumerate(messages):
371
+ expected_role = "user" if i % 2 == 0 else "assistant"
372
+ assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
373
+
374
+ # If all looks good, write the messages to file
375
+ with open(output_file, 'a') as f:
376
+ f.write(json.dumps(messages) + '\n')
377
+ completed_count += 1
378
+ print(f"✓ Saved conversation {completed_count}/{num_conversations}")
379
+
380
+ except Exception as e:
381
+ error_count += 1
382
+ print(f"✗ Error generating conversation: {e}")
383
+
384
+ print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
385
+ if error_count > 0:
386
+ print(f"Encountered {error_count} errors during generation")
387
+
dev/generate_logo.html ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <body style="margin:0; display:flex; justify-content:center; align-items:center; height:100vh; background:#fff">
4
+ <svg width="400" height="400" xmlns="http://www.w3.org/2000/svg">
5
+ <defs>
6
+ <radialGradient id="g" cx="50%" cy="50%">
7
+ <stop offset="0%" style="stop-color:#667eea;stop-opacity:1"/>
8
+ <stop offset="100%" style="stop-color:#764ba2;stop-opacity:0.3"/>
9
+ </radialGradient>
10
+ </defs>
11
+ </svg>
12
+ <script>
13
+ const svg = document.querySelector('svg');
14
+ const r = 120;
15
+ let path = '';
16
+ for(let i = 0; i < 24; i += 2) {
17
+ let a1 = i * Math.PI / 12;
18
+ let a2 = (i + 1) * Math.PI / 12;
19
+ let x2 = 200 + Math.cos(a2) * r;
20
+ let y2 = 200 + Math.sin(a2) * r;
21
+ let x3 = 200 + Math.cos(a2) * (r - 90);
22
+ let y3 = 200 + Math.sin(a2) * (r - 90);
23
+ path += `M${x2},${y2} L${x3},${y3} `;
24
+ }
25
+ svg.innerHTML += `<path d="${path}" stroke="url(#g)" stroke-width="6" stroke-linecap="round" fill="none"/>`;
26
+ svg.innerHTML += `<path d="M200,-12 L212,0 L200,12 L188,0 Z" transform="translate(0,200)" fill="#000"/>`;
27
+ </script>
28
+ </body>
29
+ </html>
dev/nanochat.png ADDED
dev/repackage_data_reference.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Repackage the FinewebEdu-100B dataset into shards:
3
+
4
+ - each shard is ~100MB in size (after zstd compression)
5
+ - parquets are written with row group size of 1000
6
+ - shuffle the dataset
7
+
8
+ This will be uploaded to HuggingFace for hosting.
9
+ The big deal is that our DataLoader will be able to stream
10
+ the data and cache it along the way on disk, decreasing the
11
+ training latency.
12
+
13
+ NOTE: This file is meant only as reference/documentation of the
14
+ dataset preparation and it is not used during the project runtime.
15
+ """
16
+ import os
17
+ import time
18
+
19
+ from datasets import load_dataset
20
+ import pyarrow.parquet as pq
21
+ import pyarrow as pa
22
+
23
+ # Source dataset
24
+ dataset_kwargs = {
25
+ "path": "HuggingFaceFW/fineweb-edu",
26
+ "split": "train",
27
+ "name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
28
+ }
29
+ ds = load_dataset(**dataset_kwargs)
30
+
31
+ # Shuffle to scramble the order
32
+ ds = ds.shuffle(seed=42)
33
+ ndocs = len(ds) # total number of documents to process
34
+ print(f"Total number of documents: {ndocs}")
35
+
36
+ # Repackage into parquet files
37
+ output_dir = "/home/ubuntu/.cache/nanochat/base_data"
38
+ os.makedirs(output_dir, exist_ok=True)
39
+
40
+ # Write to parquet files
41
+ chars_per_shard = 250_000_000
42
+ row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
43
+ shard_docs = []
44
+ shard_index = 0
45
+ shard_characters = 0
46
+ total_docs_processed = 0
47
+ total_time_spent = 0
48
+ t0 = time.time()
49
+ for doc in ds:
50
+ text = doc['text']
51
+ shard_docs.append(text)
52
+ shard_characters += len(text)
53
+ collected_enough_chars = shard_characters >= chars_per_shard
54
+ docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
55
+ if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
56
+ shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
57
+ shard_table = pa.Table.from_pydict({"text": shard_docs})
58
+ pq.write_table(
59
+ shard_table,
60
+ shard_path,
61
+ row_group_size=row_group_size,
62
+ use_dictionary=False, # this is usually used for categorical data
63
+ compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’}
64
+ compression_level=3,
65
+ write_statistics=False, # not needed for text
66
+ )
67
+ t1 = time.time()
68
+ dt = t1 - t0 # for this shard alone
69
+ t0 = t1
70
+ total_docs_processed += len(shard_docs)
71
+ total_time_spent += dt
72
+ remaining_docs = ndocs - total_docs_processed
73
+ avg_time_per_doc = total_time_spent / total_docs_processed
74
+ remaining_time = remaining_docs * avg_time_per_doc
75
+ remaining_time_hours = remaining_time / 3600
76
+ print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h")
77
+ shard_docs = []
78
+ shard_characters = 0
79
+ shard_index += 1
80
+
81
+ # Demonstration of how the data was later uploaded to HuggingFace
82
+ def upload():
83
+ import os
84
+ from huggingface_hub import HfApi
85
+ token = os.getenv("HF_TOKEN")
86
+ api = HfApi(token=token)
87
+ api.upload_large_folder(
88
+ folder_path=output_dir,
89
+ repo_id="karpathy/fineweb-edu-100b-shuffle",
90
+ repo_type="dataset",
91
+ )
92
+ # upload()
dev/runcpu.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
4
+ # Run as:
5
+ # bash dev/cpu_demo_run.sh
6
+
7
+ # NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
8
+ # Think of this run as educational/fun demo, not something you should expect to work well.
9
+ # This is also why I hide this script away in dev/
10
+
11
+ # all the setup stuff
12
+ export OMP_NUM_THREADS=1
13
+ export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
14
+ mkdir -p $NANOCHAT_BASE_DIR
15
+ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
16
+ [ -d ".venv" ] || uv venv
17
+ uv sync --extra cpu
18
+ source .venv/bin/activate
19
+ if [ -z "$WANDB_RUN" ]; then
20
+ WANDB_RUN=dummy
21
+ fi
22
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
23
+ source "$HOME/.cargo/env"
24
+ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
25
+
26
+ # wipe the report
27
+ python -m nanochat.report reset
28
+
29
+ # train tokenizer on ~1B characters
30
+ python -m nanochat.dataset -n 4
31
+ python -m scripts.tok_train --max_chars=1000000000
32
+ python -m scripts.tok_eval
33
+
34
+ # train a very small 4 layer model on the CPU
35
+ # each optimization step processes a single sequence of 1024 tokens
36
+ # we only run 50 steps of optimization (bump this to get better results)
37
+ python -m scripts.base_train \
38
+ --depth=4 \
39
+ --max_seq_len=1024 \
40
+ --device_batch_size=1 \
41
+ --total_batch_size=1024 \
42
+ --eval_every=50 \
43
+ --eval_tokens=4096 \
44
+ --core_metric_every=50 \
45
+ --core_metric_max_per_task=12 \
46
+ --sample_every=50 \
47
+ --num_iterations=50
48
+ python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
49
+ python -m scripts.base_eval --max-per-task=16
50
+
51
+ # midtraining
52
+ python -m scripts.mid_train \
53
+ --max_seq_len=1024 \
54
+ --device_batch_size=1 \
55
+ --eval_every=50 \
56
+ --eval_tokens=4096 \
57
+ --total_batch_size=1024 \
58
+ --num_iterations=100
59
+ # eval results will be terrible, this is just to execute the code paths.
60
+ # note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
61
+ python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
62
+
63
+ # SFT
64
+ python -m scripts.chat_sft \
65
+ --device_batch_size=1 \
66
+ --target_examples_per_step=4 \
67
+ --num_iterations=100 \
68
+ --eval_steps=4 \
69
+ --eval_metrics_max_problems=16
70
+
71
+ # Chat CLI
72
+ # python -m scripts.chat_cli -p "Why is the sky blue?"
73
+
74
+ # Chat Web
75
+ # python -m scripts.chat_web
76
+
77
+ python -m nanochat.report generate
nanochat/__init__.py ADDED
File without changes
nanochat/adamw.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
3
+ Not a general optimizer! But works for our specific use.
4
+ """
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import Tensor
8
+
9
+
10
+ class DistAdamW(torch.optim.Optimizer):
11
+ """
12
+ Distributed AdamW optimizer.
13
+ In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
14
+ """
15
+ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
16
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
17
+ super().__init__(param_groups, defaults)
18
+
19
+ @torch.compile
20
+ @torch.no_grad()
21
+ def step(self):
22
+ rank = dist.get_rank()
23
+ world_size = dist.get_world_size()
24
+ reduce_scatter_futures: list[torch.Future] = []
25
+ all_reduce_futures: list[torch.Future] = []
26
+ grad_slices = []
27
+ for group in self.param_groups:
28
+ params: list[Tensor] = group["params"]
29
+ for base_i in range(len(params)):
30
+ grad = params[base_i].grad
31
+ rank_size = grad.shape[0] // world_size
32
+ grad_slice = torch.empty_like(grad[:rank_size])
33
+ reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
34
+ grad_slices.append(grad_slice)
35
+
36
+ idx = 0
37
+ for group in self.param_groups:
38
+ beta1, beta2 = group['betas']
39
+ eps = group['eps']
40
+ wd = group['weight_decay']
41
+ params = group['params']
42
+ for base in range(len(params)):
43
+ reduce_scatter_futures[idx].wait()
44
+ p = params[base]
45
+ rank_size = p.shape[0] // world_size
46
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
47
+ lr = group['lr'] * getattr(p, "lr_mul", 1.0)
48
+ state = self.state[p]
49
+ g_slice = grad_slices[idx]
50
+ # State init
51
+ if not state:
52
+ state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
53
+ state['exp_avg'] = torch.zeros_like(p_slice)
54
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
55
+ exp_avg = state['exp_avg']
56
+ exp_avg_sq = state['exp_avg_sq']
57
+ state['step'] += 1
58
+ t = state['step']
59
+ # weight decay
60
+ if wd != 0:
61
+ eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
62
+ p_slice.mul_(1 - eff_weight_decay)
63
+ # update running averages
64
+ exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
65
+ exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
66
+ # bias corrections
67
+ bias1 = 1 - beta1 ** t
68
+ bias2 = 1 - beta2 ** t
69
+ # compute step
70
+ denom = exp_avg_sq.sqrt().add_(eps)
71
+ step_size = lr * (torch.sqrt(bias2) / bias1)
72
+ update = exp_avg.div(denom).mul_(step_size)
73
+ p_slice.add_(other=update, alpha=-1.0)
74
+ idx += 1
75
+ all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
76
+ torch.futures.collect_all(all_reduce_futures).wait()
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
24
+ assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
25
+ os.makedirs(checkpoint_dir, exist_ok=True)
26
+ # Save the model state (parameters)
27
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
28
+ torch.save(model_data, model_path)
29
+ log0(f"Saved model file to: {model_path}")
30
+ # Save the optimizer state (useful for SFT or any other fine-tuning)
31
+ if optimizer_data is not None:
32
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
33
+ torch.save(optimizer_data, optimizer_path)
34
+ log0(f"Saved optimizer file to: {optimizer_path}")
35
+ # Save the metadata dict as json
36
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
37
+ with open(meta_path, "w", encoding="utf-8") as f:
38
+ json.dump(meta_data, f, indent=2)
39
+ log0(f"Saved metadata file to: {meta_path}")
40
+
41
+
42
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
43
+ # Load the model state
44
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
45
+ model_data = torch.load(model_path, map_location=device)
46
+ # Load the optimizer state if requested
47
+ optimizer_data = None
48
+ if load_optimizer:
49
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
50
+ optimizer_data = torch.load(optimizer_path, map_location=device)
51
+ # Load the metadata
52
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
53
+ with open(meta_path, "r", encoding="utf-8") as f:
54
+ meta_data = json.load(f)
55
+ return model_data, optimizer_data, meta_data
56
+
57
+
58
+ def build_model(checkpoint_dir, step, device, phase):
59
+ """
60
+ A bunch of repetitive code to build a model from a given checkpoint.
61
+ Returns:
62
+ - base model - uncompiled, not wrapped in DDP
63
+ - tokenizer
64
+ - meta data saved during base model training
65
+ """
66
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
67
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
68
+ if device.type in {"cpu", "mps"}:
69
+ # Convert bfloat16 tensors to float for CPU inference
70
+ model_data = {
71
+ k: v.float() if v.dtype == torch.bfloat16 else v
72
+ for k, v in model_data.items()
73
+ }
74
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
75
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
76
+ model_config_kwargs = meta_data["model_config"]
77
+ log0(f"Building model with config: {model_config_kwargs}")
78
+ model_config = GPTConfig(**model_config_kwargs)
79
+ with torch.device("meta"):
80
+ model = GPT(model_config)
81
+ # Load the model state
82
+ model.to_empty(device=device)
83
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
84
+ model.load_state_dict(model_data, strict=True, assign=True)
85
+ # Put the model in the right training phase / mode
86
+ if phase == "eval":
87
+ model.eval()
88
+ else:
89
+ model.train()
90
+ # Load the Tokenizer
91
+ tokenizer = get_tokenizer()
92
+ # Sanity check: compatibility between model and tokenizer
93
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
94
+ return model, tokenizer, meta_data
95
+
96
+
97
+ def find_largest_model(checkpoint_dir):
98
+ # attempt to guess the model tag: take the biggest model available
99
+ model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
100
+ if not model_tags:
101
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
102
+ # 1) normally all model tags are of the form d<number>, try that first:
103
+ candidates = []
104
+ for model_tag in model_tags:
105
+ match = re.match(r"d(\d+)", model_tag)
106
+ if match:
107
+ model_depth = int(match.group(1))
108
+ candidates.append((model_depth, model_tag))
109
+ if candidates:
110
+ candidates.sort(key=lambda x: x[0], reverse=True)
111
+ return candidates[0][1]
112
+ # 2) if that failed, take the most recently updated model:
113
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
114
+ return model_tags[0]
115
+
116
+
117
+ def find_last_step(checkpoint_dir):
118
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
119
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
120
+ if not checkpoint_files:
121
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
122
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
123
+ return last_step
124
+
125
+ # -----------------------------------------------------------------------------
126
+ # convenience functions that take into account nanochat's directory structure
127
+
128
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
129
+ if model_tag is None:
130
+ # guess the model tag by defaulting to the largest model
131
+ model_tag = find_largest_model(checkpoints_dir)
132
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
133
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
134
+ if step is None:
135
+ # guess the step by defaulting to the last step
136
+ step = find_last_step(checkpoint_dir)
137
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
138
+ # build the model
139
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
140
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
141
+ return model, tokenizer, meta_data
142
+
143
+ def load_model(source, *args, **kwargs):
144
+ model_dir = {
145
+ "base": "base_checkpoints",
146
+ "mid": "mid_checkpoints",
147
+ "sft": "chatsft_checkpoints",
148
+ "rl": "chatrl_checkpoints",
149
+ }[source]
150
+ base_dir = get_base_dir()
151
+ checkpoints_dir = os.path.join(base_dir, model_dir)
152
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
nanochat/common.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import urllib.request
9
+ import torch
10
+ import torch.distributed as dist
11
+ from filelock import FileLock
12
+
13
+ class ColoredFormatter(logging.Formatter):
14
+ """Custom formatter that adds colors to log messages."""
15
+ # ANSI color codes
16
+ COLORS = {
17
+ 'DEBUG': '\033[36m', # Cyan
18
+ 'INFO': '\033[32m', # Green
19
+ 'WARNING': '\033[33m', # Yellow
20
+ 'ERROR': '\033[31m', # Red
21
+ 'CRITICAL': '\033[35m', # Magenta
22
+ }
23
+ RESET = '\033[0m'
24
+ BOLD = '\033[1m'
25
+ def format(self, record):
26
+ # Add color to the level name
27
+ levelname = record.levelname
28
+ if levelname in self.COLORS:
29
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
30
+ # Format the message
31
+ message = super().format(record)
32
+ # Add color to specific parts of the message
33
+ if levelname == 'INFO':
34
+ # Highlight numbers and percentages
35
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
36
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
37
+ return message
38
+
39
+ def setup_default_logging():
40
+ handler = logging.StreamHandler()
41
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ handlers=[handler]
45
+ )
46
+
47
+ setup_default_logging()
48
+ logger = logging.getLogger(__name__)
49
+
50
+ def get_base_dir():
51
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
52
+ if os.environ.get("NANOCHAT_BASE_DIR"):
53
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
54
+ else:
55
+ home_dir = os.path.expanduser("~")
56
+ cache_dir = os.path.join(home_dir, ".cache")
57
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
58
+ os.makedirs(nanochat_dir, exist_ok=True)
59
+ return nanochat_dir
60
+
61
+ def download_file_with_lock(url, filename, postprocess_fn=None):
62
+ """
63
+ Downloads a file from a URL to a local path in the base directory.
64
+ Uses a lock file to prevent concurrent downloads among multiple ranks.
65
+ """
66
+ base_dir = get_base_dir()
67
+ file_path = os.path.join(base_dir, filename)
68
+ lock_path = file_path + ".lock"
69
+
70
+ if os.path.exists(file_path):
71
+ return file_path
72
+
73
+ with FileLock(lock_path):
74
+ # Only a single rank can acquire this lock
75
+ # All other ranks block until it is released
76
+
77
+ # Recheck after acquiring lock
78
+ if os.path.exists(file_path):
79
+ return file_path
80
+
81
+ # Download the content as bytes
82
+ print(f"Downloading {url}...")
83
+ with urllib.request.urlopen(url) as response:
84
+ content = response.read() # bytes
85
+
86
+ # Write to local file
87
+ with open(file_path, 'wb') as f:
88
+ f.write(content)
89
+ print(f"Downloaded to {file_path}")
90
+
91
+ # Run the postprocess function if provided
92
+ if postprocess_fn is not None:
93
+ postprocess_fn(file_path)
94
+
95
+ return file_path
96
+
97
+ def print0(s="",**kwargs):
98
+ ddp_rank = int(os.environ.get('RANK', 0))
99
+ if ddp_rank == 0:
100
+ print(s, **kwargs)
101
+
102
+ def print_banner():
103
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
104
+ banner = """
105
+ █████ █████
106
+ ░░███ ░░███
107
+ ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
108
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
109
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
110
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
111
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
112
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
113
+ """
114
+ print0(banner)
115
+
116
+ def is_ddp():
117
+ # TODO is there a proper way
118
+ return int(os.environ.get('RANK', -1)) != -1
119
+
120
+ def get_dist_info():
121
+ if is_ddp():
122
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
123
+ ddp_rank = int(os.environ['RANK'])
124
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
125
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
126
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
127
+ else:
128
+ return False, 0, 0, 1
129
+
130
+ def autodetect_device_type():
131
+ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
132
+ if torch.cuda.is_available():
133
+ device_type = "cuda"
134
+ elif torch.backends.mps.is_available():
135
+ device_type = "mps"
136
+ else:
137
+ device_type = "cpu"
138
+ print0(f"Autodetected device type: {device_type}")
139
+ return device_type
140
+
141
+ def compute_init(device_type="cuda"): # cuda|cpu|mps
142
+ """Basic initialization that we keep doing over and over, so make common."""
143
+
144
+ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
145
+ if device_type == "cuda":
146
+ assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
147
+ if device_type == "mps":
148
+ assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
149
+
150
+ # Reproducibility
151
+ torch.manual_seed(42)
152
+ if device_type == "cuda":
153
+ torch.cuda.manual_seed(42)
154
+ # skipping full reproducibility for now, possibly investigate slowdown later
155
+ # torch.use_deterministic_algorithms(True)
156
+
157
+ # Precision
158
+ if device_type == "cuda":
159
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
160
+
161
+ # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
162
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
163
+ if ddp and device_type == "cuda":
164
+ device = torch.device("cuda", ddp_local_rank)
165
+ torch.cuda.set_device(device) # make "cuda" default to this device
166
+ dist.init_process_group(backend="nccl", device_id=device)
167
+ dist.barrier()
168
+ else:
169
+ device = torch.device(device_type) # mps|cpu
170
+
171
+ if ddp_rank == 0:
172
+ logger.info(f"Distributed world size: {ddp_world_size}")
173
+
174
+ return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
175
+
176
+ def compute_cleanup():
177
+ """Companion function to compute_init, to clean things up before script exit"""
178
+ if is_ddp():
179
+ dist.destroy_process_group()
180
+
181
+ class DummyWandb:
182
+ """Useful if we wish to not use wandb but have all the same signatures"""
183
+ def __init__(self):
184
+ pass
185
+ def log(self, *args, **kwargs):
186
+ pass
187
+ def finish(self):
188
+ pass
nanochat/configurator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ from ast import literal_eval
20
+
21
+ def print0(s="",**kwargs):
22
+ ddp_rank = int(os.environ.get('RANK', 0))
23
+ if ddp_rank == 0:
24
+ print(s, **kwargs)
25
+
26
+ for arg in sys.argv[1:]:
27
+ if '=' not in arg:
28
+ # assume it's the name of a config file
29
+ assert not arg.startswith('--')
30
+ config_file = arg
31
+ print0(f"Overriding config with {config_file}:")
32
+ with open(config_file) as f:
33
+ print0(f.read())
34
+ exec(open(config_file).read())
35
+ else:
36
+ # assume it's a --key=value argument
37
+ assert arg.startswith('--')
38
+ key, val = arg.split('=')
39
+ key = key[2:]
40
+ if key in globals():
41
+ try:
42
+ # attempt to eval it it (e.g. if bool, number, or etc)
43
+ attempt = literal_eval(val)
44
+ except (SyntaxError, ValueError):
45
+ # if that goes wrong, just use the string
46
+ attempt = val
47
+ # ensure the types match ok
48
+ if globals()[key] is not None:
49
+ attempt_type = type(attempt)
50
+ default_type = type(globals()[key])
51
+ assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
52
+ # cross fingers
53
+ print0(f"Overriding: {key} = {attempt}")
54
+ globals()[key] = attempt
55
+ else:
56
+ raise ValueError(f"Unknown config key: {key}")
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, num_fewshot)
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from nanochat.common import get_dist_info
6
+ from nanochat.dataset import parquets_iter_batched
7
+ from nanochat.tokenizer import get_tokenizer
8
+
9
+ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
10
+ """Stream pretraining text from parquet files, tokenize, yield training batches."""
11
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
12
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
13
+ needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
14
+ # get the tokenizer and the bos token
15
+ tokenizer = get_tokenizer()
16
+ bos_token = tokenizer.get_bos_token_id()
17
+ # scratch buffer holds the tokens for one iteration
18
+ token_buffer = deque() # we stream tokens on the right and pop from the left
19
+
20
+ # infinite iterator over document batches
21
+ def document_batches():
22
+ while True:
23
+ # batch will iterate in group size of the parquet files, usually e.g. 1024 rows
24
+ for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
25
+ # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
26
+ for i in range(0, len(batch), tokenizer_batch_size):
27
+ yield batch[i:i+tokenizer_batch_size]
28
+ batches = document_batches()
29
+
30
+ batch_index = 0
31
+ while True:
32
+ # Accumulate enough tokens for one iteration before yielding.
33
+ while len(token_buffer) < needed_tokens:
34
+ doc_batch = next(batches)
35
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
36
+ for tokens in token_lists:
37
+ token_buffer.extend(tokens)
38
+ batch_index += 1
39
+ # Move tokens from the deque into the scratch buffer
40
+ tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
41
+ # CUDA supports memory pinning for faster transfers between CPU and GPU:
42
+ scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
43
+ # Create the inputs/targets as 1D tensors
44
+ inputs_cpu = scratch[:-1].to(dtype=torch.int32)
45
+ targets_cpu = scratch[1:]
46
+ # Reshape to 2D and move to GPU async
47
+ inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
48
+ targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
49
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
24
+ MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data")
28
+ os.makedirs(DATA_DIR, exist_ok=True)
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # These functions are useful utilities to other modules, can/should be imported
32
+
33
+ def list_parquet_files(data_dir=None):
34
+ """ Looks into a data dir and returns full paths to all parquet files. """
35
+ data_dir = DATA_DIR if data_dir is None else data_dir
36
+ parquet_files = sorted([
37
+ f for f in os.listdir(data_dir)
38
+ if f.endswith('.parquet') and not f.endswith('.tmp')
39
+ ])
40
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
41
+ return parquet_paths
42
+
43
+ def parquets_iter_batched(split, start=0, step=1):
44
+ """
45
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
46
+ - split can be "train" or "val". the last parquet file will be val.
47
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
48
+ """
49
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
50
+ parquet_paths = list_parquet_files()
51
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
52
+ for filepath in parquet_paths:
53
+ pf = pq.ParquetFile(filepath)
54
+ for rg_idx in range(start, pf.num_row_groups, step):
55
+ rg = pf.read_row_group(rg_idx)
56
+ texts = rg.column('text').to_pylist()
57
+ yield texts
58
+
59
+ # -----------------------------------------------------------------------------
60
+ def download_single_file(index):
61
+ """ Downloads a single file index, with some backoff """
62
+
63
+ # Construct the local filepath for this file and skip if it already exists
64
+ filename = index_to_filename(index)
65
+ filepath = os.path.join(DATA_DIR, filename)
66
+ if os.path.exists(filepath):
67
+ print(f"Skipping {filepath} (already exists)")
68
+ return True
69
+
70
+ # Construct the remote URL for this file
71
+ url = f"{BASE_URL}/{filename}"
72
+ print(f"Downloading {filename}...")
73
+
74
+ # Download with retries
75
+ max_attempts = 5
76
+ for attempt in range(1, max_attempts + 1):
77
+ try:
78
+ response = requests.get(url, stream=True, timeout=30)
79
+ response.raise_for_status()
80
+ # Write to temporary file first
81
+ temp_path = filepath + f".tmp"
82
+ with open(temp_path, 'wb') as f:
83
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
84
+ if chunk:
85
+ f.write(chunk)
86
+ # Move temp file to final location
87
+ os.rename(temp_path, filepath)
88
+ print(f"Successfully downloaded {filename}")
89
+ return True
90
+
91
+ except (requests.RequestException, IOError) as e:
92
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
93
+ # Clean up any partial files
94
+ for path in [filepath + f".tmp", filepath]:
95
+ if os.path.exists(path):
96
+ try:
97
+ os.remove(path)
98
+ except:
99
+ pass
100
+ # Try a few times with exponential backoff: 2^attempt seconds
101
+ if attempt < max_attempts:
102
+ wait_time = 2 ** attempt
103
+ print(f"Waiting {wait_time} seconds before retry...")
104
+ time.sleep(wait_time)
105
+ else:
106
+ print(f"Failed to download {filename} after {max_attempts} attempts")
107
+ return False
108
+
109
+ return False
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
114
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
115
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
116
+ args = parser.parse_args()
117
+
118
+ num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
119
+ ids_to_download = list(range(num))
120
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
121
+ print(f"Target directory: {DATA_DIR}")
122
+ print()
123
+ with Pool(processes=args.num_workers) as pool:
124
+ results = pool.map(download_single_file, ids_to_download)
125
+
126
+ # Report results
127
+ successful = sum(1 for success in results if success)
128
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/engine.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init
21
+ from nanochat.checkpoint_manager import load_model
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Calculator tool helpers
25
+ @contextmanager
26
+ def timeout(duration, formula):
27
+ def timeout_handler(signum, frame):
28
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
29
+
30
+ signal.signal(signal.SIGALRM, timeout_handler)
31
+ signal.alarm(duration)
32
+ yield
33
+ signal.alarm(0)
34
+
35
+ def eval_with_timeout(formula, max_time=3):
36
+ try:
37
+ with timeout(max_time, formula):
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter("ignore", SyntaxWarning)
40
+ return eval(formula, {"__builtins__": {}}, {})
41
+ except Exception as e:
42
+ signal.alarm(0)
43
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
44
+ return None
45
+
46
+ def use_calculator(expr):
47
+ """
48
+ Evaluate a Python expression safely.
49
+ Supports both math expressions and string operations like .count()
50
+ """
51
+ # Remove commas from numbers
52
+ expr = expr.replace(",", "")
53
+
54
+ # Check if it's a pure math expression (old behavior)
55
+ if all([x in "0123456789*+-/.() " for x in expr]):
56
+ if "**" in expr: # disallow power operator
57
+ return None
58
+ return eval_with_timeout(expr)
59
+
60
+ # Check if it's a string operation we support
61
+ # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
62
+ allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
63
+ if not all([x in allowed_chars for x in expr]):
64
+ return None
65
+
66
+ # Disallow dangerous patterns
67
+ dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
68
+ 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
69
+ 'getattr', 'setattr', 'delattr', 'hasattr']
70
+ expr_lower = expr.lower()
71
+ if any(pattern in expr_lower for pattern in dangerous_patterns):
72
+ return None
73
+
74
+ # Only allow .count() method for now (can expand later)
75
+ if '.count(' not in expr:
76
+ return None
77
+
78
+ # Evaluate with timeout
79
+ return eval_with_timeout(expr)
80
+
81
+ # -----------------------------------------------------------------------------
82
+ class KVCache:
83
+ """
84
+ Works hand-in-hand with the GPT model to maintain the KV cache.
85
+ Note that the .pos advances automatically after the last layer of the Transformer inserts.
86
+ """
87
+
88
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
89
+ # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
90
+ self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
91
+ self.kv_cache = None
92
+ self.pos = 0 # current position in time in the cache
93
+
94
+ def reset(self):
95
+ self.pos = 0
96
+
97
+ def get_pos(self):
98
+ return self.pos
99
+
100
+ def prefill(self, other):
101
+ """
102
+ Prefill given another KV cache. Optionally expand along batch dim.
103
+ This is used when we do batch 1 prefill and then want to generate
104
+ multiple samples in parallel from there.
105
+ """
106
+ # 1) validate the shapes
107
+ assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
108
+ assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
109
+ for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
110
+ if ix in [0, 1, 3, 5]:
111
+ # num_layers, batch_size, num_heads, head_dim must match
112
+ assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
113
+ elif ix == 2:
114
+ # batch_size can be expanded
115
+ assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
116
+ elif ix == 4:
117
+ # seq_len: self must be longer than other
118
+ assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
119
+ # 2) initialize the cache
120
+ dtype, device = other.kv_cache.dtype, other.kv_cache.device
121
+ self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
122
+ # 3) copy the data over
123
+ self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
124
+ # 4) update the pos
125
+ self.pos = other.pos
126
+
127
+ def insert_kv(self, layer_idx, k, v):
128
+ # Lazy initialize the cache here because we need to know the dtype/device
129
+ if self.kv_cache is None:
130
+ self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
131
+ # Insert new keys/values to the cache and return the full cache so far
132
+ B, H, T_add, D = k.size()
133
+ t0, t1 = self.pos, self.pos + T_add
134
+ # Dynamically grow the cache if needed
135
+ if t1 > self.kv_cache.size(4):
136
+ t_needed = t1 + 1024 # as much as we need plus buffer of 1024
137
+ t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
138
+ additional_shape = list(self.kv_cache.shape)
139
+ additional_shape[4] = t_needed - self.kv_cache.size(4)
140
+ additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
141
+ self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
142
+ self.kv_shape = self.kv_cache.shape
143
+ # Insert k, v into the cache
144
+ self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
145
+ self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
146
+ # Return the full cached keys/values up to current position (as a view)
147
+ key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
148
+ value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
149
+ # Increment pos after the last layer of the Transformer processes
150
+ if layer_idx == self.kv_cache.size(0) - 1:
151
+ self.pos = t1
152
+ return key_view, value_view
153
+
154
+
155
+ # -----------------------------------------------------------------------------
156
+ @torch.inference_mode()
157
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
158
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
159
+ assert temperature >= 0.0, "temperature must be non-negative"
160
+ if temperature == 0.0:
161
+ return torch.argmax(logits, dim=-1, keepdim=True)
162
+ if top_k is not None:
163
+ k = min(top_k, logits.size(-1))
164
+ vals, idx = torch.topk(logits, k, dim=-1)
165
+ vals = vals / temperature
166
+ probs = F.softmax(vals, dim=-1)
167
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
168
+ return idx.gather(1, choice)
169
+ else:
170
+ logits = logits / temperature
171
+ probs = F.softmax(logits, dim=-1)
172
+ return torch.multinomial(probs, num_samples=1, generator=rng)
173
+
174
+ # -----------------------------------------------------------------------------
175
+
176
+ class RowState:
177
+ # Per-row state tracking during generation
178
+ def __init__(self, current_tokens=None):
179
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
180
+ self.forced_tokens = deque() # Queue of tokens to force inject
181
+ self.in_python_block = False # Whether we are inside a python block
182
+ self.python_expr_tokens = [] # Tokens of the current python expression
183
+ self.completed = False # Whether this row has completed generation
184
+
185
+ class Engine:
186
+
187
+ def __init__(self, model, tokenizer):
188
+ self.model = model
189
+ self.tokenizer = tokenizer # needed for tool use
190
+
191
+ @torch.inference_mode()
192
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
193
+ """Same as generate, but does single prefill and then clones the KV cache."""
194
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
195
+ device = self.model.get_device()
196
+ rng = torch.Generator(device=device)
197
+ rng.manual_seed(seed)
198
+
199
+ # Get the special tokens we need to coordinate the tool use state machine
200
+ get_special = lambda s: self.tokenizer.encode_special(s)
201
+ python_start = get_special("<|python_start|>")
202
+ python_end = get_special("<|python_end|>")
203
+ output_start = get_special("<|output_start|>")
204
+ output_end = get_special("<|output_end|>")
205
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
206
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
207
+
208
+ # 1) Run a batch 1 prefill of the prompt tokens
209
+ m = self.model.config
210
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
211
+ kv_cache_prefill = KVCache(
212
+ batch_size=1,
213
+ seq_len=len(tokens),
214
+ **kv_model_kwargs,
215
+ )
216
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
217
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
218
+ logits = logits[:, -1, :]
219
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
220
+ sampled_tokens = next_ids[:, 0].tolist()
221
+
222
+ # 2) Replicate the KV cache for each sample/row
223
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
224
+ kv_cache_decode = KVCache(
225
+ batch_size=num_samples,
226
+ seq_len=kv_length_hint,
227
+ **kv_model_kwargs,
228
+ )
229
+ kv_cache_decode.prefill(kv_cache_prefill)
230
+ del kv_cache_prefill # no need to keep this memory around
231
+
232
+ # 3) Initialize states for each sample
233
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
234
+
235
+ # 4) Main generation loop
236
+ num_generated = 0
237
+ first_iteration = True
238
+ while True:
239
+ # Stop condition: we've reached max tokens
240
+ if max_tokens is not None and num_generated >= max_tokens:
241
+ break
242
+ # Stop condition: all rows are completed
243
+ if all(state.completed for state in row_states):
244
+ break
245
+
246
+ # Get sampled tokens - either from prefill or from forward pass
247
+ if first_iteration:
248
+ # Use the tokens we already sampled from prefill
249
+ sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
250
+ # TODO: we should sample a token for each row instead of broadcasting
251
+ first_iteration = False
252
+ else:
253
+ # Forward the model and get the next token for each row
254
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
255
+ logits = logits[:, -1, :] # (B, vocab_size) at last time step
256
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
257
+ sampled_tokens = next_ids[:, 0].tolist()
258
+
259
+ # Process each row: choose the next token, update state, optional tool use
260
+ token_column = [] # contains the next token id along each row
261
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
262
+ for i, state in enumerate(row_states):
263
+ # Select the next token in this row
264
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
265
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
266
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
267
+ token_column.append(next_token)
268
+ # Update the state of this row to include the next token
269
+ state.current_tokens.append(next_token)
270
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
271
+ if next_token == assistant_end or next_token == bos:
272
+ state.completed = True
273
+ # Handle tool logic
274
+ if next_token == python_start:
275
+ state.in_python_block = True
276
+ state.python_expr_tokens = []
277
+ elif next_token == python_end and state.in_python_block:
278
+ state.in_python_block = False
279
+ if state.python_expr_tokens:
280
+ expr = self.tokenizer.decode(state.python_expr_tokens)
281
+ result = use_calculator(expr)
282
+ if result is not None:
283
+ result_tokens = self.tokenizer.encode(str(result))
284
+ state.forced_tokens.append(output_start)
285
+ state.forced_tokens.extend(result_tokens)
286
+ state.forced_tokens.append(output_end)
287
+ state.python_expr_tokens = []
288
+ elif state.in_python_block:
289
+ state.python_expr_tokens.append(next_token)
290
+
291
+ # Yield the token column
292
+ yield token_column, token_masks
293
+ num_generated += 1
294
+ # Prepare ids for next iteration
295
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
296
+
297
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
298
+ """
299
+ Non-streaming batch generation that just returns the final token sequences.
300
+ Returns a list of token sequences (list of lists of ints).
301
+ Terminal tokens (assistant_end, bos) are not included in the results.
302
+ """
303
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
304
+ bos = self.tokenizer.get_bos_token_id()
305
+ results = [tokens.copy() for _ in range(num_samples)]
306
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
307
+ completed = [False] * num_samples
308
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
309
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
310
+ if not completed[i]:
311
+ if token == assistant_end or token == bos:
312
+ completed[i] = True
313
+ else:
314
+ results[i].append(token)
315
+ masks[i].append(mask)
316
+ # Stop if all rows are completed
317
+ if all(completed):
318
+ break
319
+ return results, masks
320
+
321
+
322
+ if __name__ == "__main__":
323
+ """
324
+ Quick inline test to make sure that the naive/slow model.generate function
325
+ is equivalent to the faster Engine.generate function here.
326
+ """
327
+ import time
328
+ # init compute
329
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
330
+ # load the model and tokenizer
331
+ model, tokenizer, meta = load_model("base", device, phase="eval")
332
+ bos_token_id = tokenizer.get_bos_token_id()
333
+ # common hyperparameters
334
+ kwargs = dict(max_tokens=64, temperature=0.0)
335
+ # set the starting prompt
336
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
337
+ # generate the reference sequence using the model.generate() function
338
+ generated_tokens = []
339
+ torch.cuda.synchronize()
340
+ t0 = time.time()
341
+ stream = model.generate(prompt_tokens, **kwargs)
342
+ for token in stream:
343
+ generated_tokens.append(token)
344
+ chunk = tokenizer.decode([token])
345
+ print(chunk, end="", flush=True)
346
+ print()
347
+ torch.cuda.synchronize()
348
+ t1 = time.time()
349
+ print(f"Reference time: {t1 - t0:.2f}s")
350
+ reference_ids = generated_tokens
351
+ # generate tokens with Engine
352
+ generated_tokens = []
353
+ engine = Engine(model, tokenizer)
354
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
355
+ torch.cuda.synchronize()
356
+ t0 = time.time()
357
+ for token_column, token_masks in stream:
358
+ token = token_column[0] # only print out the first row
359
+ generated_tokens.append(token)
360
+ chunk = tokenizer.decode([token])
361
+ print(chunk, end="", flush=True)
362
+ print()
363
+ torch.cuda.synchronize()
364
+ t1 = time.time()
365
+ print(f"Engine time: {t1 - t0:.2f}s")
366
+ # compare the two sequences
367
+ for i in range(len(reference_ids)):
368
+ if reference_ids[i] != generated_tokens[i]:
369
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
370
+ break
371
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ finally:
131
+ os.chdir(cwd)
132
+
133
+
134
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
135
+ """
136
+ This disables various destructive functions and prevents the generated code
137
+ from interfering with the test (e.g. fork bomb, killing other processes,
138
+ removing filesystem files, etc.)
139
+
140
+ WARNING
141
+ This function is NOT a security sandbox. Untrusted code, including, model-
142
+ generated code, should not be blindly executed outside of one. See the
143
+ Codex paper for more information about OpenAI's code sandbox, and proceed
144
+ with caution.
145
+ """
146
+
147
+ if platform.uname().system != "Darwin":
148
+ # These resource limit calls seem to fail on macOS (Darwin), skip?
149
+ import resource
150
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
151
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
152
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
153
+
154
+ faulthandler.disable()
155
+
156
+ import builtins
157
+
158
+ builtins.exit = None
159
+ builtins.quit = None
160
+
161
+ import os
162
+
163
+ os.environ["OMP_NUM_THREADS"] = "1"
164
+
165
+ os.kill = None
166
+ os.system = None
167
+ os.putenv = None
168
+ os.remove = None
169
+ os.removedirs = None
170
+ os.rmdir = None
171
+ os.fchdir = None
172
+ os.setuid = None
173
+ os.fork = None
174
+ os.forkpty = None
175
+ os.killpg = None
176
+ os.rename = None
177
+ os.renames = None
178
+ os.truncate = None
179
+ os.replace = None
180
+ os.unlink = None
181
+ os.fchmod = None
182
+ os.fchown = None
183
+ os.chmod = None
184
+ os.chown = None
185
+ os.chroot = None
186
+ os.fchdir = None
187
+ os.lchflags = None
188
+ os.lchmod = None
189
+ os.lchown = None
190
+ os.getcwd = None
191
+ os.chdir = None
192
+
193
+ import shutil
194
+
195
+ shutil.rmtree = None
196
+ shutil.move = None
197
+ shutil.chown = None
198
+
199
+ import subprocess
200
+
201
+ subprocess.Popen = None # type: ignore
202
+
203
+ __builtins__["help"] = None
204
+
205
+ import sys
206
+
207
+ sys.modules["ipdb"] = None
208
+ sys.modules["joblib"] = None
209
+ sys.modules["resource"] = None
210
+ sys.modules["psutil"] = None
211
+ sys.modules["tkinter"] = None
212
+
213
+
214
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
215
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
216
+ with create_tempdir():
217
+
218
+ # These system calls are needed when cleaning up tempdir.
219
+ import os
220
+ import shutil
221
+
222
+ rmtree = shutil.rmtree
223
+ rmdir = os.rmdir
224
+ chdir = os.chdir
225
+ unlink = os.unlink
226
+
227
+ # Disable functionalities that can make destructive changes to the test.
228
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
229
+
230
+ # Default to failure
231
+ result_dict.update({
232
+ "success": False,
233
+ "stdout": "",
234
+ "stderr": "",
235
+ "timeout": False,
236
+ "memory_exceeded": False,
237
+ "error": None,
238
+ })
239
+
240
+ try:
241
+ exec_globals = {}
242
+ with capture_io() as (stdout_capture, stderr_capture):
243
+ with time_limit(timeout):
244
+ # WARNING
245
+ # This program exists to execute untrusted model-generated code. Although
246
+ # it is highly unlikely that model-generated code will do something overtly
247
+ # malicious in response to this test suite, model-generated code may act
248
+ # destructively due to a lack of model capability or alignment.
249
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
250
+ # does not perform destructive actions on their host or network. For more
251
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
252
+ # Once you have read this disclaimer and taken appropriate precautions,
253
+ # uncomment the following line and proceed at your own risk:
254
+ exec(code, exec_globals)
255
+
256
+ result_dict.update({
257
+ "success": True,
258
+ "stdout": stdout_capture.getvalue(),
259
+ "stderr": stderr_capture.getvalue(),
260
+ })
261
+
262
+ except TimeoutException:
263
+ result_dict.update({
264
+ "timeout": True,
265
+ "error": "Execution timed out",
266
+ })
267
+
268
+ except MemoryError as e:
269
+ result_dict.update({
270
+ "memory_exceeded": True,
271
+ "error": f"Memory limit exceeded: {e}",
272
+ })
273
+
274
+ except BaseException as e:
275
+ result_dict.update({
276
+ "error": f"{type(e).__name__}: {e}",
277
+ })
278
+
279
+ # Needed for cleaning up.
280
+ shutil.rmtree = rmtree
281
+ os.rmdir = rmdir
282
+ os.chdir = chdir
283
+ os.unlink = unlink
284
+
285
+
286
+ def execute_code(
287
+ code: str,
288
+ timeout: float = 5.0, # 5 seconds default
289
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
290
+ ) -> ExecutionResult:
291
+ """
292
+ Execute Python code in a sandboxed environment.
293
+
294
+ Args:
295
+ code: Python code to execute as a string
296
+ timeout: Maximum execution time in seconds (default: 5.0)
297
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
298
+
299
+ Returns:
300
+ ExecutionResult with success status, stdout/stderr, and error information
301
+
302
+ Example:
303
+ >>> result = execute_code("print('hello world')")
304
+ >>> result.success
305
+ True
306
+ >>> result.stdout
307
+ 'hello world\\n'
308
+ """
309
+
310
+ manager = multiprocessing.Manager()
311
+ result_dict = manager.dict()
312
+
313
+ p = multiprocessing.Process(
314
+ target=_unsafe_execute,
315
+ args=(code, timeout, maximum_memory_bytes, result_dict)
316
+ )
317
+ p.start()
318
+ p.join(timeout=timeout + 1)
319
+
320
+ if p.is_alive():
321
+ p.kill()
322
+ return ExecutionResult(
323
+ success=False,
324
+ stdout="",
325
+ stderr="",
326
+ error="Execution timed out (process killed)",
327
+ timeout=True,
328
+ memory_exceeded=False,
329
+ )
330
+
331
+ if not result_dict:
332
+ return ExecutionResult(
333
+ success=False,
334
+ stdout="",
335
+ stderr="",
336
+ error="Execution failed (no result returned)",
337
+ timeout=True,
338
+ memory_exceeded=False,
339
+ )
340
+
341
+ return ExecutionResult(
342
+ success=result_dict["success"],
343
+ stdout=result_dict["stdout"],
344
+ stderr=result_dict["stderr"],
345
+ error=result_dict["error"],
346
+ timeout=result_dict["timeout"],
347
+ memory_exceeded=result_dict["memory_exceeded"],
348
+ )
349
+
nanochat/gpt.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Multi-Query Attention (MQA) support for more efficient inference
12
+ """
13
+
14
+ import math
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0
23
+ from nanochat.muon import Muon, DistMuon
24
+ from nanochat.adamw import DistAdamW
25
+
26
+ @dataclass
27
+ class GPTConfig:
28
+ sequence_len: int = 1024
29
+ vocab_size: int = 50304
30
+ n_layer: int = 12
31
+ n_head: int = 6 # number of query heads
32
+ n_kv_head: int = 6 # number of key/value heads (MQA)
33
+ n_embd: int = 768
34
+
35
+
36
+ def norm(x):
37
+ # Purely functional rmsnorm with no learnable params
38
+ return F.rms_norm(x, (x.size(-1),))
39
+
40
+
41
+ def apply_rotary_emb(x, cos, sin):
42
+ assert x.ndim == 4 # multihead attention
43
+ d = x.shape[3] // 2
44
+ x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
45
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
46
+ y2 = x1 * (-sin) + x2 * cos
47
+ out = torch.cat([y1, y2], 3) # re-assemble
48
+ out = out.to(x.dtype) # ensure input/output dtypes match
49
+ return out
50
+
51
+ class CausalSelfAttention(nn.Module):
52
+ def __init__(self, config, layer_idx):
53
+ super().__init__()
54
+ self.layer_idx = layer_idx
55
+ self.n_head = config.n_head
56
+ self.n_kv_head = config.n_kv_head
57
+ self.n_embd = config.n_embd
58
+ self.head_dim = self.n_embd // self.n_head
59
+ assert self.n_embd % self.n_head == 0
60
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
61
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
62
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
63
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
64
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
65
+
66
+ def forward(self, x, cos_sin, kv_cache):
67
+ B, T, C = x.size()
68
+
69
+ # Project the input to get queries, keys, and values
70
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
71
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
72
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
73
+
74
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
75
+ cos, sin = cos_sin
76
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
77
+ q, k = norm(q), norm(k) # QK norm
78
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
79
+
80
+ # Apply KV cache: insert current k,v into cache, get the full view so far
81
+ if kv_cache is not None:
82
+ k, v = kv_cache.insert_kv(self.layer_idx, k, v)
83
+ Tq = q.size(2) # number of queries in this forward pass
84
+ Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
85
+
86
+ # Attention: queries attend to keys/values autoregressively. A few cases to handle:
87
+ enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
88
+ if kv_cache is None or Tq == Tk:
89
+ # During training (no KV cache), attend as usual with causal attention
90
+ # And even if there is KV cache, we can still use this simple version when Tq == Tk
91
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
92
+ elif Tq == 1:
93
+ # During inference but with a single query in this forward pass:
94
+ # The query has to attend to all the keys/values in the cache
95
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
96
+ else:
97
+ # During inference AND we have a chunk of queries in this forward pass:
98
+ # First, each query attends to all the cached keys/values (i.e. full prefix)
99
+ attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
100
+ prefix_len = Tk - Tq
101
+ if prefix_len > 0: # can't be negative but could be zero
102
+ attn_mask[:, :prefix_len] = True
103
+ # Then, causal attention within this chunk
104
+ attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
105
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
106
+
107
+ # Re-assemble the heads side by side and project back to residual stream
108
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
109
+ y = self.c_proj(y)
110
+ return y
111
+
112
+
113
+ class MLP(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
117
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
118
+
119
+ def forward(self, x):
120
+ x = self.c_fc(x)
121
+ x = F.relu(x).square()
122
+ x = self.c_proj(x)
123
+ return x
124
+
125
+
126
+ class Block(nn.Module):
127
+ def __init__(self, config, layer_idx):
128
+ super().__init__()
129
+ self.attn = CausalSelfAttention(config, layer_idx)
130
+ self.mlp = MLP(config)
131
+
132
+ def forward(self, x, cos_sin, kv_cache):
133
+ x = x + self.attn(norm(x), cos_sin, kv_cache)
134
+ x = x + self.mlp(norm(x))
135
+ return x
136
+
137
+
138
+ class GPT(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.config = config
142
+ self.transformer = nn.ModuleDict({
143
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
144
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
145
+ })
146
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
147
+ # To support meta device initialization, we init the rotary embeddings here, but it's fake
148
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
149
+ # so let's just over-compute them, but assert fail if we ever reach that amount.
150
+ # In the future we can dynamically grow the cache, for now it's fine.
151
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
152
+ head_dim = config.n_embd // config.n_head
153
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
154
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
155
+ self.register_buffer("sin", sin, persistent=False)
156
+
157
+ def init_weights(self):
158
+ self.apply(self._init_weights)
159
+ # zero out classifier weights
160
+ torch.nn.init.zeros_(self.lm_head.weight)
161
+ # zero out c_proj weights in all blocks
162
+ for block in self.transformer.h:
163
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
164
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
165
+ # init the rotary embeddings
166
+ head_dim = self.config.n_embd // self.config.n_head
167
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
168
+ self.cos, self.sin = cos, sin
169
+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
170
+ if self.transformer.wte.weight.device.type == "cuda":
171
+ self.transformer.wte.to(dtype=torch.bfloat16)
172
+
173
+ def _init_weights(self, module):
174
+ if isinstance(module, nn.Linear):
175
+ # https://arxiv.org/pdf/2310.17813
176
+ fan_out = module.weight.size(0)
177
+ fan_in = module.weight.size(1)
178
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
180
+ if module.bias is not None:
181
+ torch.nn.init.zeros_(module.bias)
182
+ elif isinstance(module, nn.Embedding):
183
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
184
+
185
+ # TODO: bump base theta more, e.g. 100K is more common more recently
186
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
187
+ # autodetect the device from model embeddings
188
+ if device is None:
189
+ device = self.transformer.wte.weight.device
190
+ # stride the channels
191
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
192
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
193
+ # stride the time steps
194
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
195
+ # calculate the rotation frequencies at each (time, channel) pair
196
+ freqs = torch.outer(t, inv_freq)
197
+ cos, sin = freqs.cos(), freqs.sin()
198
+ cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
199
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
200
+ return cos, sin
201
+
202
+ def get_device(self):
203
+ return self.transformer.wte.weight.device
204
+
205
+ def estimate_flops(self):
206
+ """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
207
+ nparams = sum(p.numel() for p in self.parameters())
208
+ nparams_embedding = self.transformer.wte.weight.numel()
209
+ l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
210
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
211
+ return num_flops_per_token
212
+
213
+ def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
214
+ model_dim = self.config.n_embd
215
+ ddp, rank, local_rank, world_size = get_dist_info()
216
+ # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
217
+ matrix_params = list(self.transformer.h.parameters())
218
+ embedding_params = list(self.transformer.wte.parameters())
219
+ lm_head_params = list(self.lm_head.parameters())
220
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
221
+ # Create the AdamW optimizer for the embedding and lm_head
222
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
223
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
224
+ if rank == 0:
225
+ print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
226
+ adam_groups = [
227
+ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
228
+ dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
229
+ ]
230
+ adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
231
+ AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
232
+ adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
233
+ # Create the Muon optimizer for the linear layers
234
+ muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
235
+ MuonFactory = DistMuon if ddp else Muon
236
+ muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
237
+ # Combine them the two optimizers into one list
238
+ optimizers = [adamw_optimizer, muon_optimizer]
239
+ for opt in optimizers:
240
+ for group in opt.param_groups:
241
+ group["initial_lr"] = group["lr"]
242
+ return optimizers
243
+
244
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
245
+ B, T = idx.size()
246
+
247
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
248
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
249
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
250
+ assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
251
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
252
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
253
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
254
+
255
+ # Forward the trunk of the Transformer
256
+ x = self.transformer.wte(idx)
257
+ x = norm(x)
258
+ for block in self.transformer.h:
259
+ x = block(x, cos_sin, kv_cache)
260
+ x = norm(x)
261
+
262
+ # Forward the lm_head (compute logits)
263
+ softcap = 15
264
+ if targets is not None:
265
+ # training mode: compute and return the loss
266
+ # TODO: experiment with Liger Kernels / chunked cross-entropy etc.
267
+ logits = self.lm_head(x)
268
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
269
+ logits = logits.float() # use tf32/fp32 for logits
270
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
271
+ return loss
272
+ else:
273
+ # inference mode: compute and return the logits
274
+ logits = self.lm_head(x)
275
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
276
+ return logits
277
+
278
+ @torch.inference_mode()
279
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
280
+ """
281
+ Naive autoregressive streaming inference.
282
+ To make it super simple, let's assume:
283
+ - batch size is 1
284
+ - ids and the yielded tokens are simple Python lists and ints
285
+ """
286
+ assert isinstance(tokens, list)
287
+ device = self.get_device()
288
+ rng = None
289
+ if temperature > 0:
290
+ rng = torch.Generator(device=device)
291
+ rng.manual_seed(seed)
292
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
293
+ for _ in range(max_tokens):
294
+ logits = self.forward(ids) # (B, T, vocab_size)
295
+ logits = logits[:, -1, :] # (B, vocab_size)
296
+ if top_k is not None:
297
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
298
+ logits[logits < v[:, [-1]]] = -float('Inf')
299
+ if temperature > 0:
300
+ logits = logits / temperature
301
+ probs = F.softmax(logits, dim=-1)
302
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
303
+ else:
304
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
305
+ ids = torch.cat((ids, next_ids), dim=1)
306
+ token = next_ids.item()
307
+ yield token
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-indepedent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and indepependently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ if total_bytes == 0:
63
+ return float('inf')
64
+ bpb = total_nats / (math.log(2) * total_bytes)
65
+ return bpb
nanochat/muon.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon optimizer from Keller et al.
3
+ Also a lot of borrowing of ideas from modded-nanogpt.
4
+ """
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.distributed as dist
8
+
9
+ @torch.compile
10
+ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
11
+ """
12
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
16
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18
+ performance at all relative to UV^T, where USV^T = G is the SVD.
19
+ """
20
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
21
+ a, b, c = (3.4445, -4.7750, 2.0315)
22
+ X = G.bfloat16()
23
+ if G.size(-2) > G.size(-1):
24
+ X = X.mT
25
+
26
+ # Ensure spectral norm is at most 1
27
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
28
+ # Perform the NS iterations
29
+ for _ in range(steps):
30
+ A = X @ X.mT
31
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32
+ X = a * X + B @ X
33
+
34
+ if G.size(-2) > G.size(-1):
35
+ X = X.mT
36
+ return X
37
+
38
+ class Muon(torch.optim.Optimizer):
39
+ """
40
+ Muon - MomentUm Orthogonalized by Newton-schulz
41
+
42
+ https://kellerjordan.github.io/posts/muon/
43
+
44
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47
+ the advantage that it can be stably run in bfloat16 on the GPU.
48
+
49
+ Some warnings:
50
+ - This optimizer should not be used for the embedding layer, the final fully connected layer,
51
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
52
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
53
+
54
+ Arguments:
55
+ lr: The learning rate used by the internal SGD.
56
+ momentum: The momentum used by the internal SGD.
57
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
58
+ ns_steps: The number of Newton-Schulz iteration steps to use.
59
+ """
60
+ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
61
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
62
+ params: list[Tensor] = [*params]
63
+ param_groups = []
64
+ for size in {p.numel() for p in params}:
65
+ group = dict(params=[p for p in params if p.numel() == size])
66
+ param_groups.append(group)
67
+ super().__init__(param_groups, defaults)
68
+
69
+ @torch.no_grad()
70
+ def step(self):
71
+ for group in self.param_groups:
72
+ params: list[Tensor] = group["params"]
73
+ for p in params:
74
+ g = p.grad
75
+ assert g is not None
76
+ state = self.state[p]
77
+ if "momentum_buffer" not in state:
78
+ state["momentum_buffer"] = torch.zeros_like(g)
79
+ buf: Tensor = state["momentum_buffer"]
80
+ buf.lerp_(g, 1 - group["momentum"])
81
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
82
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
83
+ p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
84
+
85
+
86
+ class DistMuon(torch.optim.Optimizer):
87
+ """
88
+ Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
89
+ finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
90
+ - reduce_scatter(AVG) for gradient averaging
91
+ - all_gather to replicate updated weights
92
+
93
+ Notes:
94
+ * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
95
+ params like embeddings or scalars.
96
+ * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
97
+ by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
98
+ consolidate states beforehand.
99
+
100
+ Args:
101
+ params: iterable of Tensors
102
+ lr: learning rate
103
+ momentum: momentum coefficient in [0,1)
104
+ nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
105
+ ns_steps: number of Newton–Schulz iterations for the orthogonalization
106
+ """
107
+ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
108
+ nesterov: bool = True, ns_steps: int = 5):
109
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
110
+ params = list(params)
111
+ assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
112
+ rank = dist.get_rank()
113
+ # Group all parameters by their shape
114
+ shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
115
+ param_groups = []
116
+ for shape in shapes:
117
+ group_params = [p for p in params if p.shape == shape]
118
+ device, dtype = group_params[0].device, group_params[0].dtype
119
+ assert all(p.device == device for p in group_params)
120
+ assert all(p.dtype == dtype for p in group_params)
121
+ if rank == 0:
122
+ print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
123
+ param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
124
+ super().__init__(param_groups, defaults)
125
+
126
+ @torch.no_grad()
127
+ def step(self):
128
+ rank = dist.get_rank()
129
+ world_size = dist.get_world_size()
130
+
131
+ # Ensure all grads exist
132
+ assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
133
+
134
+ # Kick off all the reduce scatter operations to average up the gradients across all ranks
135
+ all_reduce_futures = []
136
+ for group in self.param_groups:
137
+ params = group["params"]
138
+ zero_buffer = group["zero_buffer"]
139
+ # Go through params in groups of world_size.
140
+ for base_i in range(0, len(params), world_size):
141
+ # The compute owner of each param is rank i % world_size
142
+ owner_idx = base_i + rank
143
+ # each rank stacks up its chunk of world_size params into a list
144
+ rs_input = [p.grad for p in params[base_i:base_i + world_size]]
145
+ # pad rs_input with the zero buffer to complete the group
146
+ rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
147
+ # the output buffer gets strided across the group based on the rank
148
+ rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
149
+ # reduce scatter the gradients within this group of world_size params
150
+ work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
151
+ all_reduce_futures.append(work)
152
+
153
+ # Now each rank computes the update and gathers
154
+ future_idx = 0
155
+ all_gather_futures = []
156
+ for group in self.param_groups:
157
+ params = group["params"]
158
+ zero_buffer = group["zero_buffer"]
159
+ # Go through params in groups of world_size.
160
+ for base_i in range(0, len(params), world_size):
161
+ # The compute owner of each param is rank i % world_size
162
+ owner_idx = base_i + rank # calculate the index of the param that this rank owns
163
+ # Wait for the reduce scatter to complete
164
+ all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
165
+ future_idx += 1
166
+ # Owner computes the Muon update, result is in its param
167
+ if owner_idx < len(params):
168
+ p = params[owner_idx]
169
+ g = p.grad # now averaged across ranks
170
+ state = self.state[p]
171
+ if "momentum_buffer" not in state:
172
+ state["momentum_buffer"] = torch.zeros_like(g)
173
+ buf: Tensor = state["momentum_buffer"]
174
+ buf.lerp_(g, 1.0 - group["momentum"])
175
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
176
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
177
+ scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
178
+ p.add_(g, alpha=-group["lr"] * scale)
179
+ # Replicate updated parameters to all ranks
180
+ ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
181
+ ag_output = params[base_i:base_i + world_size]
182
+ ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
183
+ work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
184
+ all_gather_futures.append(work)
185
+
186
+ # Wait for all work to finish
187
+ torch.futures.collect_all(all_gather_futures).wait()
nanochat/report.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ if result.returncode == 0:
20
+ return result.stdout.strip()
21
+ return None
22
+ except:
23
+ return None
24
+
25
+ def get_git_info():
26
+ """Get current git commit, branch, and dirty status."""
27
+ info = {}
28
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
29
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
30
+
31
+ # Check if repo is dirty (has uncommitted changes)
32
+ status = run_command("git status --porcelain")
33
+ info['dirty'] = bool(status) if status is not None else False
34
+
35
+ # Get commit message
36
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
37
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
38
+
39
+ return info
40
+
41
+ def get_gpu_info():
42
+ """Get GPU information."""
43
+ if not torch.cuda.is_available():
44
+ return {"available": False}
45
+
46
+ num_devices = torch.cuda.device_count()
47
+ info = {
48
+ "available": True,
49
+ "count": num_devices,
50
+ "names": [],
51
+ "memory_gb": []
52
+ }
53
+
54
+ for i in range(num_devices):
55
+ props = torch.cuda.get_device_properties(i)
56
+ info["names"].append(props.name)
57
+ info["memory_gb"].append(props.total_memory / (1024**3))
58
+
59
+ # Get CUDA version
60
+ info["cuda_version"] = torch.version.cuda or "unknown"
61
+
62
+ return info
63
+
64
+ def get_system_info():
65
+ """Get system information."""
66
+ info = {}
67
+
68
+ # Basic system info
69
+ info['hostname'] = socket.gethostname()
70
+ info['platform'] = platform.system()
71
+ info['python_version'] = platform.python_version()
72
+ info['torch_version'] = torch.__version__
73
+
74
+ # CPU and memory
75
+ info['cpu_count'] = psutil.cpu_count(logical=False)
76
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
77
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
78
+
79
+ # User and environment
80
+ info['user'] = os.environ.get('USER', 'unknown')
81
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
82
+ info['working_dir'] = os.getcwd()
83
+
84
+ return info
85
+
86
+ def estimate_cost(gpu_info, runtime_hours=None):
87
+ """Estimate training cost based on GPU type and runtime."""
88
+
89
+ # Rough pricing, from Lambda Cloud
90
+ default_rate = 2.0
91
+ gpu_hourly_rates = {
92
+ "H100": 3.00,
93
+ "A100": 1.79,
94
+ "V100": 0.55,
95
+ }
96
+
97
+ if not gpu_info.get("available"):
98
+ return None
99
+
100
+ # Try to identify GPU type from name
101
+ hourly_rate = None
102
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
103
+ for gpu_type, rate in gpu_hourly_rates.items():
104
+ if gpu_type in gpu_name:
105
+ hourly_rate = rate * gpu_info["count"]
106
+ break
107
+
108
+ if hourly_rate is None:
109
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
110
+
111
+ return {
112
+ "hourly_rate": hourly_rate,
113
+ "gpu_type": gpu_name,
114
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
115
+ }
116
+
117
+ def generate_header():
118
+ """Generate the header for a training report."""
119
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
+
121
+ git_info = get_git_info()
122
+ gpu_info = get_gpu_info()
123
+ sys_info = get_system_info()
124
+ cost_info = estimate_cost(gpu_info)
125
+
126
+ header = f"""# nanochat training report
127
+
128
+ Generated: {timestamp}
129
+
130
+ ## Environment
131
+
132
+ ### Git Information
133
+ - Branch: {git_info['branch']}
134
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
135
+ - Message: {git_info['message']}
136
+
137
+ ### Hardware
138
+ - Platform: {sys_info['platform']}
139
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
140
+ - Memory: {sys_info['memory_gb']:.1f} GB
141
+ """
142
+
143
+ if gpu_info.get("available"):
144
+ gpu_names = ", ".join(set(gpu_info["names"]))
145
+ total_vram = sum(gpu_info["memory_gb"])
146
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
147
+ - GPU Memory: {total_vram:.1f} GB total
148
+ - CUDA Version: {gpu_info['cuda_version']}
149
+ """
150
+ else:
151
+ header += "- GPUs: None available\n"
152
+
153
+ if cost_info and cost_info["hourly_rate"] > 0:
154
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
155
+
156
+ header += f"""
157
+ ### Software
158
+ - Python: {sys_info['python_version']}
159
+ - PyTorch: {sys_info['torch_version']}
160
+
161
+ """
162
+
163
+ # bloat metrics: package all of the source code and assess its weight
164
+ packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
165
+ num_chars = len(packaged)
166
+ num_lines = len(packaged.split('\n'))
167
+ num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
168
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
169
+
170
+ # count dependencies via uv.lock
171
+ uv_lock_lines = 0
172
+ if os.path.exists('uv.lock'):
173
+ with open('uv.lock', 'r', encoding='utf-8') as f:
174
+ uv_lock_lines = len(f.readlines())
175
+
176
+ header += f"""
177
+ ### Bloat
178
+ - Characters: {num_chars:,}
179
+ - Lines: {num_lines:,}
180
+ - Files: {num_files:,}
181
+ - Tokens (approx): {num_tokens:,}
182
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
183
+
184
+ """
185
+ return header
186
+
187
+ # -----------------------------------------------------------------------------
188
+
189
+ def slugify(text):
190
+ """Slugify a text string."""
191
+ return text.lower().replace(" ", "-")
192
+
193
+ # the expected files and their order
194
+ EXPECTED_FILES = [
195
+ "tokenizer-training.md",
196
+ "tokenizer-evaluation.md",
197
+ "base-model-training.md",
198
+ "base-model-loss.md",
199
+ "base-model-evaluation.md",
200
+ "midtraining.md",
201
+ "chat-evaluation-mid.md",
202
+ "chat-sft.md",
203
+ "chat-evaluation-sft.md",
204
+ "chat-rl.md",
205
+ "chat-evaluation-rl.md",
206
+ ]
207
+ # the metrics we're currently interested in
208
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
209
+
210
+ def extract(section, keys):
211
+ """simple def to extract a single key from a section"""
212
+ if not isinstance(keys, list):
213
+ keys = [keys] # convenience
214
+ out = {}
215
+ for line in section.split("\n"):
216
+ for key in keys:
217
+ if key in line:
218
+ out[key] = line.split(":")[1].strip()
219
+ return out
220
+
221
+ def extract_timestamp(content, prefix):
222
+ """Extract timestamp from content with given prefix."""
223
+ for line in content.split('\n'):
224
+ if line.startswith(prefix):
225
+ time_str = line.split(":", 1)[1].strip()
226
+ try:
227
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
228
+ except:
229
+ pass
230
+ return None
231
+
232
+ class Report:
233
+ """Maintains a bunch of logs, generates a final markdown report."""
234
+
235
+ def __init__(self, report_dir):
236
+ os.makedirs(report_dir, exist_ok=True)
237
+ self.report_dir = report_dir
238
+
239
+ def log(self, section, data):
240
+ """Log a section of data to the report."""
241
+ slug = slugify(section)
242
+ file_name = f"{slug}.md"
243
+ file_path = os.path.join(self.report_dir, file_name)
244
+ with open(file_path, "w", encoding="utf-8") as f:
245
+ f.write(f"## {section}\n")
246
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
247
+ for item in data:
248
+ if not item:
249
+ # skip falsy values like None or empty dict etc.
250
+ continue
251
+ if isinstance(item, str):
252
+ # directly write the string
253
+ f.write(item)
254
+ else:
255
+ # render a dict
256
+ for k, v in item.items():
257
+ if isinstance(v, float):
258
+ vstr = f"{v:.4f}"
259
+ elif isinstance(v, int) and v >= 10000:
260
+ vstr = f"{v:,.0f}"
261
+ else:
262
+ vstr = str(v)
263
+ f.write(f"- {k}: {vstr}\n")
264
+ f.write("\n")
265
+ return file_path
266
+
267
+ def generate(self):
268
+ """Generate the final report."""
269
+ report_dir = self.report_dir
270
+ report_file = os.path.join(report_dir, "report.md")
271
+ print(f"Generating report to {report_file}")
272
+ final_metrics = {} # the most important final metrics we'll add as table at the end
273
+ start_time = None
274
+ end_time = None
275
+ with open(report_file, "w", encoding="utf-8") as out_file:
276
+ # write the header first
277
+ header_file = os.path.join(report_dir, "header.md")
278
+ if os.path.exists(header_file):
279
+ with open(header_file, "r", encoding="utf-8") as f:
280
+ header_content = f.read()
281
+ out_file.write(header_content)
282
+ start_time = extract_timestamp(header_content, "Run started:")
283
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
284
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
285
+ bloat_data = bloat_data.group(1) if bloat_data else ""
286
+ else:
287
+ start_time = None # will cause us to not write the total wall clock time
288
+ bloat_data = "[bloat data missing]"
289
+ print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
290
+ # process all the individual sections
291
+ for file_name in EXPECTED_FILES:
292
+ section_file = os.path.join(report_dir, file_name)
293
+ if not os.path.exists(section_file):
294
+ print(f"Warning: {section_file} does not exist, skipping")
295
+ continue
296
+ with open(section_file, "r", encoding="utf-8") as in_file:
297
+ section = in_file.read()
298
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
299
+ if "rl" not in file_name:
300
+ # Skip RL sections for end_time calculation because RL is experimental
301
+ end_time = extract_timestamp(section, "timestamp:")
302
+ # extract the most important metrics from the sections
303
+ if file_name == "base-model-evaluation.md":
304
+ final_metrics["base"] = extract(section, "CORE")
305
+ if file_name == "chat-evaluation-mid.md":
306
+ final_metrics["mid"] = extract(section, chat_metrics)
307
+ if file_name == "chat-evaluation-sft.md":
308
+ final_metrics["sft"] = extract(section, chat_metrics)
309
+ if file_name == "chat-evaluation-rl.md":
310
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
311
+ # append this section of the report
312
+ out_file.write(section)
313
+ out_file.write("\n")
314
+ # add the final metrics table
315
+ out_file.write("## Summary\n\n")
316
+ # Copy over the bloat metrics from the header
317
+ out_file.write(bloat_data)
318
+ out_file.write("\n\n")
319
+ # Collect all unique metric names
320
+ all_metrics = set()
321
+ for stage_metrics in final_metrics.values():
322
+ all_metrics.update(stage_metrics.keys())
323
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
324
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
325
+ # Fixed column widths
326
+ stages = ["base", "mid", "sft", "rl"]
327
+ metric_width = 15
328
+ value_width = 8
329
+ # Write table header
330
+ header = f"| {'Metric'.ljust(metric_width)} |"
331
+ for stage in stages:
332
+ header += f" {stage.upper().ljust(value_width)} |"
333
+ out_file.write(header + "\n")
334
+ # Write separator
335
+ separator = f"|{'-' * (metric_width + 2)}|"
336
+ for stage in stages:
337
+ separator += f"{'-' * (value_width + 2)}|"
338
+ out_file.write(separator + "\n")
339
+ # Write table rows
340
+ for metric in all_metrics:
341
+ row = f"| {metric.ljust(metric_width)} |"
342
+ for stage in stages:
343
+ value = final_metrics.get(stage, {}).get(metric, "-")
344
+ row += f" {str(value).ljust(value_width)} |"
345
+ out_file.write(row + "\n")
346
+ out_file.write("\n")
347
+ # Calculate and write total wall clock time
348
+ if start_time and end_time:
349
+ duration = end_time - start_time
350
+ total_seconds = int(duration.total_seconds())
351
+ hours = total_seconds // 3600
352
+ minutes = (total_seconds % 3600) // 60
353
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
354
+ else:
355
+ out_file.write("Total wall clock time: unknown\n")
356
+ # also cp the report.md file to current directory
357
+ print(f"Copying report.md to current directory for convenience")
358
+ shutil.copy(report_file, "report.md")
359
+ return report_file
360
+
361
+ def reset(self):
362
+ """Reset the report."""
363
+ # Remove section files
364
+ for file_name in EXPECTED_FILES:
365
+ file_path = os.path.join(self.report_dir, file_name)
366
+ if os.path.exists(file_path):
367
+ os.remove(file_path)
368
+ # Remove report.md if it exists
369
+ report_file = os.path.join(self.report_dir, "report.md")
370
+ if os.path.exists(report_file):
371
+ os.remove(report_file)
372
+ # Generate and write the header section with start timestamp
373
+ header_file = os.path.join(self.report_dir, "header.md")
374
+ header = generate_header()
375
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
376
+ with open(header_file, "w", encoding="utf-8") as f:
377
+ f.write(header)
378
+ f.write(f"Run started: {start_time}\n\n---\n\n")
379
+ print(f"Reset report and wrote header to {header_file}")
380
+
381
+ # -----------------------------------------------------------------------------
382
+ # nanochat-specific convenience functions
383
+
384
+ class DummyReport:
385
+ def log(self, *args, **kwargs):
386
+ pass
387
+ def reset(self, *args, **kwargs):
388
+ pass
389
+
390
+ def get_report():
391
+ # just for convenience, only rank 0 logs to report
392
+ from nanochat.common import get_base_dir, get_dist_info
393
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
394
+ if ddp_rank == 0:
395
+ report_dir = os.path.join(get_base_dir(), "report")
396
+ return Report(report_dir)
397
+ else:
398
+ return DummyReport()
399
+
400
+ if __name__ == "__main__":
401
+ import argparse
402
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
403
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
404
+ args = parser.parse_args()
405
+ if args.command == "generate":
406
+ get_report().generate()
407
+ elif args.command == "reset":
408
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I haven't validated that this is actually a good idea, TODO.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ assert isinstance(text, str)
110
+ ids = []
111
+ if prepend is not None:
112
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
113
+ ids.append(prepend_id)
114
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
115
+ if append is not None:
116
+ append_id = append if isinstance(append, int) else self.encode_special(append)
117
+ ids.append(append_id)
118
+ return ids
119
+
120
+ def encode_special(self, text):
121
+ # encode a single special token via exact match
122
+ return self.tokenizer.token_to_id(text)
123
+
124
+ def get_bos_token_id(self):
125
+ bos = self.encode_special("<|bos|>")
126
+ return bos
127
+
128
+ def encode(self, text, *args, **kwargs):
129
+ if isinstance(text, str):
130
+ return self._encode_one(text, *args, **kwargs)
131
+ elif isinstance(text, list):
132
+ return [self._encode_one(t, *args, **kwargs) for t in text]
133
+ else:
134
+ raise ValueError(f"Invalid input type: {type(text)}")
135
+
136
+ def __call__(self, *args, **kwargs):
137
+ return self.encode(*args, **kwargs)
138
+
139
+ def decode(self, ids):
140
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
141
+
142
+ def save(self, tokenizer_dir):
143
+ # save the tokenizer to disk
144
+ os.makedirs(tokenizer_dir, exist_ok=True)
145
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
146
+ self.tokenizer.save(tokenizer_path)
147
+ print(f"Saved tokenizer to {tokenizer_path}")
148
+
149
+ # -----------------------------------------------------------------------------
150
+ # Tokenizer based on rustbpe + tiktoken combo
151
+ import pickle
152
+ import rustbpe
153
+ import tiktoken
154
+
155
+ class RustBPETokenizer:
156
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
157
+
158
+ def __init__(self, enc, bos_token):
159
+ self.enc = enc
160
+ self.bos_token_id = self.encode_special(bos_token)
161
+
162
+ @classmethod
163
+ def train_from_iterator(cls, text_iterator, vocab_size):
164
+ # 1) train using rustbpe
165
+ tokenizer = rustbpe.Tokenizer()
166
+ # the special tokens are inserted later in __init__, we don't train them here
167
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
168
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
169
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
170
+ # 2) construct the associated tiktoken encoding for inference
171
+ pattern = tokenizer.get_pattern()
172
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
173
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
174
+ tokens_offset = len(mergeable_ranks)
175
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
176
+ enc = tiktoken.Encoding(
177
+ name="rustbpe",
178
+ pat_str=pattern,
179
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
180
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
181
+ )
182
+ return cls(enc, "<|bos|>")
183
+
184
+ @classmethod
185
+ def from_directory(cls, tokenizer_dir):
186
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
187
+ with open(pickle_path, "rb") as f:
188
+ enc = pickle.load(f)
189
+ return cls(enc, "<|bos|>")
190
+
191
+ @classmethod
192
+ def from_pretrained(cls, tiktoken_name):
193
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
194
+ enc = tiktoken.get_encoding(tiktoken_name)
195
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
196
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
197
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
198
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
199
+ return cls(enc, "<|endoftext|>")
200
+
201
+ def get_vocab_size(self):
202
+ return self.enc.n_vocab
203
+
204
+ def get_special_tokens(self):
205
+ return self.enc.special_tokens_set
206
+
207
+ def id_to_token(self, id):
208
+ return self.enc.decode([id])
209
+
210
+ @lru_cache(maxsize=32)
211
+ def encode_special(self, text):
212
+ return self.enc.encode_single_token(text)
213
+
214
+ def get_bos_token_id(self):
215
+ return self.bos_token_id
216
+
217
+ def encode(self, text, prepend=None, append=None, num_threads=8):
218
+ # text can be either a string or a list of strings
219
+
220
+ if prepend is not None:
221
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
222
+ if append is not None:
223
+ append_id = append if isinstance(append, int) else self.encode_special(append)
224
+
225
+ if isinstance(text, str):
226
+ ids = self.enc.encode_ordinary(text)
227
+ if prepend is not None:
228
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
229
+ if append is not None:
230
+ ids.append(append_id)
231
+ elif isinstance(text, list):
232
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
233
+ if prepend is not None:
234
+ for ids_row in ids:
235
+ ids_row.insert(0, prepend_id) # TODO: same
236
+ if append is not None:
237
+ for ids_row in ids:
238
+ ids_row.append(append_id)
239
+ else:
240
+ raise ValueError(f"Invalid input type: {type(text)}")
241
+
242
+ return ids
243
+
244
+ def __call__(self, *args, **kwargs):
245
+ return self.encode(*args, **kwargs)
246
+
247
+ def decode(self, ids):
248
+ return self.enc.decode(ids)
249
+
250
+ def save(self, tokenizer_dir):
251
+ # save the encoding object to disk
252
+ os.makedirs(tokenizer_dir, exist_ok=True)
253
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
254
+ with open(pickle_path, "wb") as f:
255
+ pickle.dump(self.enc, f)
256
+ print(f"Saved tokenizer encoding to {pickle_path}")
257
+
258
+ def render_conversation(self, conversation, max_tokens=2048):
259
+ """
260
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
261
+ Returns:
262
+ - ids: list[int] is a list of token ids of this rendered conversation
263
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
264
+ """
265
+ # ids, masks that we will return and a helper function to help build them up.
266
+ ids, mask = [], []
267
+ def add_tokens(token_ids, mask_val):
268
+ if isinstance(token_ids, int):
269
+ token_ids = [token_ids]
270
+ ids.extend(token_ids)
271
+ mask.extend([mask_val] * len(token_ids))
272
+
273
+ # sometimes the first message is a system message...
274
+ # => just merge it with the second (user) message
275
+ if conversation["messages"][0]["role"] == "system":
276
+ # some conversation surgery is necessary here for now...
277
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
278
+ messages = conversation["messages"]
279
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
280
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
281
+ messages = messages[1:]
282
+ else:
283
+ messages = conversation["messages"]
284
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
285
+
286
+ # fetch all the special tokens we need
287
+ bos = self.get_bos_token_id()
288
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
289
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
290
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
291
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
292
+
293
+ # now we can tokenize the conversation
294
+ add_tokens(bos, 0)
295
+ for i, message in enumerate(messages):
296
+
297
+ # some sanity checking here around assumptions, to prevent footguns
298
+ must_be_from = "user" if i % 2 == 0 else "assistant"
299
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
300
+
301
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
302
+ content = message["content"]
303
+
304
+ if message["role"] == "user":
305
+ assert isinstance(content, str), "User messages are simply expected to be strings"
306
+ value_ids = self.encode(content)
307
+ add_tokens(user_start, 0)
308
+ add_tokens(value_ids, 0)
309
+ add_tokens(user_end, 0)
310
+ elif message["role"] == "assistant":
311
+ add_tokens(assistant_start, 0)
312
+ if isinstance(content, str):
313
+ # simple string => simply add the tokens
314
+ value_ids = self.encode(content)
315
+ add_tokens(value_ids, 1)
316
+ elif isinstance(content, list):
317
+ for part in content:
318
+ value_ids = self.encode(part["text"])
319
+ if part["type"] == "text":
320
+ # string part => simply add the tokens
321
+ add_tokens(value_ids, 1)
322
+ elif part["type"] == "python":
323
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
324
+ add_tokens(python_start, 1)
325
+ add_tokens(value_ids, 1)
326
+ add_tokens(python_end, 1)
327
+ elif part["type"] == "python_output":
328
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
329
+ # none of these tokens are supervised because the tokens come from Python at test time
330
+ add_tokens(output_start, 0)
331
+ add_tokens(value_ids, 0)
332
+ add_tokens(output_end, 0)
333
+ else:
334
+ raise ValueError(f"Unknown part type: {part['type']}")
335
+ else:
336
+ raise ValueError(f"Unknown content type: {type(content)}")
337
+ add_tokens(assistant_end, 1)
338
+
339
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
340
+ ids = ids[:max_tokens]
341
+ mask = mask[:max_tokens]
342
+ return ids, mask
343
+
344
+ def visualize_tokenization(self, ids, mask, with_token_id=False):
345
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
346
+ RED = '\033[91m'
347
+ GREEN = '\033[92m'
348
+ RESET = '\033[0m'
349
+ GRAY = '\033[90m'
350
+ tokens = []
351
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
352
+ token_str = self.decode([token_id])
353
+ color = GREEN if mask_val == 1 else RED
354
+ tokens.append(f"{color}{token_str}{RESET}")
355
+ if with_token_id:
356
+ tokens.append(f"{GRAY}({token_id}){RESET}")
357
+ return '|'.join(tokens)
358
+
359
+ def render_for_completion(self, conversation):
360
+ """
361
+ Used during Reinforcement Learning. In that setting, we want to
362
+ render the conversation priming the Assistant for a completion.
363
+ Unlike the Chat SFT case, we don't need to return the mask.
364
+ """
365
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
366
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
367
+ messages = conversation["messages"]
368
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
369
+ messages.pop() # remove the last message (of the Assistant) inplace
370
+
371
+ # Now tokenize the conversation
372
+ ids, mask = self.render_conversation(conversation)
373
+
374
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
375
+ assistant_start = self.encode_special("<|assistant_start|>")
376
+ ids.append(assistant_start)
377
+ return ids
378
+
379
+ # -----------------------------------------------------------------------------
380
+ # nanochat-specific convenience functions
381
+
382
+ def get_tokenizer():
383
+ from nanochat.common import get_base_dir
384
+ base_dir = get_base_dir()
385
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
386
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
387
+ return RustBPETokenizer.from_directory(tokenizer_dir)
388
+
389
+ def get_token_bytes(device="cpu"):
390
+ import torch
391
+ from nanochat.common import get_base_dir
392
+ base_dir = get_base_dir()
393
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
394
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
395
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
396
+ with open(token_bytes_path, "rb") as f:
397
+ token_bytes = torch.load(f, map_location=device)
398
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ body {
18
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
19
+ background-color: #ffffff;
20
+ color: #111827;
21
+ min-height: 100dvh;
22
+ margin: 0;
23
+ display: flex;
24
+ flex-direction: column;
25
+ }
26
+
27
+ .header {
28
+ background-color: #ffffff;
29
+ padding: 1.25rem 1.5rem;
30
+ }
31
+
32
+ .header-left {
33
+ display: flex;
34
+ align-items: center;
35
+ gap: 0.75rem;
36
+ }
37
+
38
+ .header-logo {
39
+ height: 32px;
40
+ width: auto;
41
+ }
42
+
43
+ .header h1 {
44
+ font-size: 1.25rem;
45
+ font-weight: 600;
46
+ margin: 0;
47
+ color: #111827;
48
+ }
49
+
50
+ .new-conversation-btn {
51
+ width: 32px;
52
+ height: 32px;
53
+ padding: 0;
54
+ border: 1px solid #e5e7eb;
55
+ border-radius: 0.5rem;
56
+ background-color: #ffffff;
57
+ color: #6b7280;
58
+ cursor: pointer;
59
+ display: flex;
60
+ align-items: center;
61
+ justify-content: center;
62
+ transition: all 0.2s ease;
63
+ }
64
+
65
+ .new-conversation-btn:hover {
66
+ background-color: #f3f4f6;
67
+ border-color: #d1d5db;
68
+ color: #374151;
69
+ }
70
+
71
+ .chat-container {
72
+ flex: 1;
73
+ overflow-y: auto;
74
+ background-color: #ffffff;
75
+ }
76
+
77
+ .chat-wrapper {
78
+ max-width: 48rem;
79
+ margin: 0 auto;
80
+ padding: 2rem 1.5rem 3rem;
81
+ display: flex;
82
+ flex-direction: column;
83
+ gap: 0.75rem;
84
+ }
85
+
86
+ .message {
87
+ display: flex;
88
+ justify-content: flex-start;
89
+ margin-bottom: 0.5rem;
90
+ color: #0d0d0d;
91
+ }
92
+
93
+ .message.assistant {
94
+ justify-content: flex-start;
95
+ }
96
+
97
+ .message.user {
98
+ justify-content: flex-end;
99
+ }
100
+
101
+ .message-content {
102
+ white-space: pre-wrap;
103
+ line-height: 1.6;
104
+ max-width: 100%;
105
+ }
106
+
107
+ .message.assistant .message-content {
108
+ background: transparent;
109
+ border: none;
110
+ padding: 0.25rem 0;
111
+ cursor: pointer;
112
+ border-radius: 0.5rem;
113
+ padding: 0.5rem;
114
+ margin-left: -0.5rem;
115
+ transition: background-color 0.2s ease;
116
+ }
117
+
118
+ .message.assistant .message-content:hover {
119
+ background-color: #f9fafb;
120
+ }
121
+
122
+ .message.user .message-content {
123
+ background-color: #f3f4f6;
124
+ border-radius: 1.25rem;
125
+ padding: 0.8rem 1rem;
126
+ max-width: 65%;
127
+ cursor: pointer;
128
+ transition: background-color 0.2s ease;
129
+ }
130
+
131
+ .message.user .message-content:hover {
132
+ background-color: #e5e7eb;
133
+ }
134
+
135
+ .message.console .message-content {
136
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
137
+ font-size: 0.875rem;
138
+ background-color: #fafafa;
139
+ padding: 0.75rem 1rem;
140
+ color: #374151;
141
+ max-width: 80%;
142
+ }
143
+
144
+ .input-container {
145
+ background-color: #ffffff;
146
+ padding: 1rem;
147
+ padding-bottom: calc(1rem + env(safe-area-inset-bottom))
148
+ }
149
+
150
+ .input-wrapper {
151
+ max-width: 48rem;
152
+ margin: 0 auto;
153
+ display: flex;
154
+ gap: 0.75rem;
155
+ align-items: flex-end;
156
+ }
157
+
158
+ .chat-input {
159
+ flex: 1;
160
+ padding: 0.8rem 1rem;
161
+ border: 1px solid #d1d5db;
162
+ border-radius: 0.75rem;
163
+ background-color: #ffffff;
164
+ color: #111827;
165
+ font-size: 1rem;
166
+ line-height: 1.5;
167
+ resize: none;
168
+ outline: none;
169
+ min-height: 54px;
170
+ max-height: 200px;
171
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
172
+ }
173
+
174
+ .chat-input::placeholder {
175
+ color: #9ca3af;
176
+ }
177
+
178
+ .chat-input:focus {
179
+ border-color: #2563eb;
180
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
181
+ }
182
+
183
+ .send-button {
184
+ flex-shrink: 0;
185
+ padding: 0;
186
+ width: 54px;
187
+ height: 54px;
188
+ border: 1px solid #111827;
189
+ border-radius: 0.75rem;
190
+ background-color: #111827;
191
+ color: #ffffff;
192
+ display: flex;
193
+ align-items: center;
194
+ justify-content: center;
195
+ cursor: pointer;
196
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
197
+ }
198
+
199
+ .send-button:hover:not(:disabled) {
200
+ background-color: #2563eb;
201
+ border-color: #2563eb;
202
+ }
203
+
204
+ .send-button:disabled {
205
+ cursor: not-allowed;
206
+ border-color: #d1d5db;
207
+ background-color: #e5e7eb;
208
+ color: #9ca3af;
209
+ }
210
+
211
+ .typing-indicator {
212
+ display: inline-block;
213
+ color: #6b7280;
214
+ letter-spacing: 0.15em;
215
+ }
216
+
217
+ .typing-indicator::after {
218
+ content: '···';
219
+ animation: typing 1.4s infinite;
220
+ }
221
+
222
+ @keyframes typing {
223
+ 0%, 60%, 100% { opacity: 0.2; }
224
+ 30% { opacity: 1; }
225
+ }
226
+
227
+ .error-message {
228
+ background-color: #fee2e2;
229
+ border: 1px solid #fecaca;
230
+ color: #b91c1c;
231
+ padding: 0.75rem 1rem;
232
+ border-radius: 0.75rem;
233
+ margin-top: 0.5rem;
234
+ }
235
+ </style>
236
+ </head>
237
+ <body>
238
+ <div class="header">
239
+ <div class="header-left">
240
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
241
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
242
+ <path d="M12 5v14"></path>
243
+ <path d="M5 12h14"></path>
244
+ </svg>
245
+ </button>
246
+ <h1>nanochat</h1>
247
+ </div>
248
+ </div>
249
+
250
+ <div class="chat-container" id="chatContainer">
251
+ <div class="chat-wrapper" id="chatWrapper">
252
+ <!-- Messages will be added here -->
253
+ </div>
254
+ </div>
255
+
256
+ <div class="input-container">
257
+ <div class="input-wrapper">
258
+ <textarea
259
+ id="chatInput"
260
+ class="chat-input"
261
+ placeholder="Ask anything"
262
+ rows="1"
263
+ onkeydown="handleKeyDown(event)"
264
+ ></textarea>
265
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
266
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
267
+ <path d="M22 2L11 13"></path>
268
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
269
+ </svg>
270
+ </button>
271
+ </div>
272
+ </div>
273
+
274
+ <script>
275
+ const API_URL = '';
276
+ const chatContainer = document.getElementById('chatContainer');
277
+ const chatWrapper = document.getElementById('chatWrapper');
278
+ const chatInput = document.getElementById('chatInput');
279
+ const sendButton = document.getElementById('sendButton');
280
+
281
+ let messages = [];
282
+ let isGenerating = false;
283
+ let currentTemperature = 0.8;
284
+ let currentTopK = 50;
285
+
286
+ chatInput.addEventListener('input', function() {
287
+ this.style.height = 'auto';
288
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
289
+ sendButton.disabled = !this.value.trim() || isGenerating;
290
+ });
291
+
292
+ function handleKeyDown(event) {
293
+ if (event.key === 'Enter' && !event.shiftKey) {
294
+ event.preventDefault();
295
+ sendMessage();
296
+ }
297
+ }
298
+
299
+ document.addEventListener('keydown', function(event) {
300
+ // Ctrl+Shift+N for new conversation
301
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
302
+ event.preventDefault();
303
+ if (!isGenerating) {
304
+ newConversation();
305
+ }
306
+ }
307
+ });
308
+
309
+ function newConversation() {
310
+ messages = [];
311
+ chatWrapper.innerHTML = '';
312
+ chatInput.value = '';
313
+ chatInput.style.height = 'auto';
314
+ sendButton.disabled = false;
315
+ isGenerating = false;
316
+ chatInput.focus();
317
+ }
318
+
319
+ function addMessage(role, content, messageIndex = null) {
320
+ const messageDiv = document.createElement('div');
321
+ messageDiv.className = `message ${role}`;
322
+
323
+ const contentDiv = document.createElement('div');
324
+ contentDiv.className = 'message-content';
325
+ contentDiv.textContent = content;
326
+
327
+ // Add click handler for user messages to enable editing
328
+ if (role === 'user' && messageIndex !== null) {
329
+ contentDiv.setAttribute('data-message-index', messageIndex);
330
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
331
+ contentDiv.addEventListener('click', function() {
332
+ if (!isGenerating) {
333
+ editMessage(messageIndex);
334
+ }
335
+ });
336
+ }
337
+
338
+ // Add click handler for assistant messages to enable regeneration
339
+ if (role === 'assistant' && messageIndex !== null) {
340
+ contentDiv.setAttribute('data-message-index', messageIndex);
341
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
342
+ contentDiv.addEventListener('click', function() {
343
+ if (!isGenerating) {
344
+ regenerateMessage(messageIndex);
345
+ }
346
+ });
347
+ }
348
+
349
+ messageDiv.appendChild(contentDiv);
350
+ chatWrapper.appendChild(messageDiv);
351
+
352
+ chatContainer.scrollTop = chatContainer.scrollHeight;
353
+ return contentDiv;
354
+ }
355
+
356
+ function editMessage(messageIndex) {
357
+ // Find the message in the messages array
358
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
359
+
360
+ const messageToEdit = messages[messageIndex];
361
+ if (messageToEdit.role !== 'user') return;
362
+
363
+ // Copy message content to input
364
+ chatInput.value = messageToEdit.content;
365
+ chatInput.style.height = 'auto';
366
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
367
+
368
+ // Remove this message and all subsequent messages from the array
369
+ messages = messages.slice(0, messageIndex);
370
+
371
+ // Remove message elements from DOM starting from messageIndex
372
+ const allMessages = chatWrapper.querySelectorAll('.message');
373
+ for (let i = messageIndex; i < allMessages.length; i++) {
374
+ allMessages[i].remove();
375
+ }
376
+
377
+ // Enable send button and focus input
378
+ sendButton.disabled = false;
379
+ chatInput.focus();
380
+ }
381
+
382
+ async function generateAssistantResponse() {
383
+ isGenerating = true;
384
+ sendButton.disabled = true;
385
+
386
+ const assistantContent = addMessage('assistant', '');
387
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
388
+
389
+ try {
390
+ const response = await fetch(`${API_URL}/chat/completions`, {
391
+ method: 'POST',
392
+ headers: {
393
+ 'Content-Type': 'application/json',
394
+ },
395
+ body: JSON.stringify({
396
+ messages: messages,
397
+ temperature: currentTemperature,
398
+ top_k: currentTopK,
399
+ max_tokens: 512
400
+ }),
401
+ });
402
+
403
+ if (!response.ok) {
404
+ throw new Error(`HTTP error! status: ${response.status}`);
405
+ }
406
+
407
+ const reader = response.body.getReader();
408
+ const decoder = new TextDecoder();
409
+ let fullResponse = '';
410
+ assistantContent.textContent = '';
411
+
412
+ while (true) {
413
+ const { done, value } = await reader.read();
414
+ if (done) break;
415
+
416
+ const chunk = decoder.decode(value);
417
+ const lines = chunk.split('\n');
418
+
419
+ for (const line of lines) {
420
+ if (line.startsWith('data: ')) {
421
+ try {
422
+ const data = JSON.parse(line.slice(6));
423
+ if (data.token) {
424
+ fullResponse += data.token;
425
+ assistantContent.textContent = fullResponse;
426
+ chatContainer.scrollTop = chatContainer.scrollHeight;
427
+ }
428
+ } catch (e) {
429
+ }
430
+ }
431
+ }
432
+ }
433
+
434
+ const assistantMessageIndex = messages.length;
435
+ messages.push({ role: 'assistant', content: fullResponse });
436
+
437
+ // Add click handler to regenerate this assistant message
438
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
439
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
440
+ assistantContent.addEventListener('click', function() {
441
+ if (!isGenerating) {
442
+ regenerateMessage(assistantMessageIndex);
443
+ }
444
+ });
445
+
446
+ } catch (error) {
447
+ console.error('Error:', error);
448
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
449
+ } finally {
450
+ isGenerating = false;
451
+ sendButton.disabled = !chatInput.value.trim();
452
+ }
453
+ }
454
+
455
+ async function regenerateMessage(messageIndex) {
456
+ // Find the message in the messages array
457
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
458
+
459
+ const messageToRegenerate = messages[messageIndex];
460
+ if (messageToRegenerate.role !== 'assistant') return;
461
+
462
+ // Remove this message and all subsequent messages from the array
463
+ messages = messages.slice(0, messageIndex);
464
+
465
+ // Remove message elements from DOM starting from messageIndex
466
+ const allMessages = chatWrapper.querySelectorAll('.message');
467
+ for (let i = messageIndex; i < allMessages.length; i++) {
468
+ allMessages[i].remove();
469
+ }
470
+
471
+ // Regenerate the assistant response
472
+ await generateAssistantResponse();
473
+ }
474
+
475
+ function handleSlashCommand(command) {
476
+ const parts = command.trim().split(/\s+/);
477
+ const cmd = parts[0].toLowerCase();
478
+ const arg = parts[1];
479
+
480
+ if (cmd === '/temperature') {
481
+ if (arg === undefined) {
482
+ addMessage('console', `Current temperature: ${currentTemperature}`);
483
+ } else {
484
+ const temp = parseFloat(arg);
485
+ if (isNaN(temp) || temp < 0 || temp > 2) {
486
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
487
+ } else {
488
+ currentTemperature = temp;
489
+ addMessage('console', `Temperature set to ${currentTemperature}`);
490
+ }
491
+ }
492
+ return true;
493
+ } else if (cmd === '/topk') {
494
+ if (arg === undefined) {
495
+ addMessage('console', `Current top-k: ${currentTopK}`);
496
+ } else {
497
+ const topk = parseInt(arg);
498
+ if (isNaN(topk) || topk < 1 || topk > 200) {
499
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
500
+ } else {
501
+ currentTopK = topk;
502
+ addMessage('console', `Top-k set to ${currentTopK}`);
503
+ }
504
+ }
505
+ return true;
506
+ } else if (cmd === '/clear') {
507
+ newConversation();
508
+ return true;
509
+ } else if (cmd === '/help') {
510
+ addMessage('console',
511
+ 'Available commands:\n' +
512
+ '/temperature - Show current temperature\n' +
513
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
514
+ '/topk - Show current top-k\n' +
515
+ '/topk <value> - Set top-k (1-200)\n' +
516
+ '/clear - Clear conversation\n' +
517
+ '/help - Show this help message'
518
+ );
519
+ return true;
520
+ }
521
+ return false;
522
+ }
523
+
524
+ async function sendMessage() {
525
+ const message = chatInput.value.trim();
526
+ if (!message || isGenerating) return;
527
+
528
+ // Handle slash commands
529
+ if (message.startsWith('/')) {
530
+ chatInput.value = '';
531
+ chatInput.style.height = 'auto';
532
+ handleSlashCommand(message);
533
+ return;
534
+ }
535
+
536
+ chatInput.value = '';
537
+ chatInput.style.height = 'auto';
538
+
539
+ const userMessageIndex = messages.length;
540
+ messages.push({ role: 'user', content: message });
541
+ addMessage('user', message, userMessageIndex);
542
+
543
+ await generateAssistantResponse();
544
+ }
545
+
546
+ sendButton.disabled = false;
547
+
548
+ // Autofocus the chat input on page load
549
+ chatInput.focus();
550
+
551
+ fetch(`${API_URL}/health`)
552
+ .then(response => response.json())
553
+ .then(data => {
554
+ console.log('Engine status:', data);
555
+ })
556
+ .catch(error => {
557
+ console.error('Engine not available:', error);
558
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
559
+ });
560
+ </script>
561
+ </body>
562
+ </html>
pyproject.toml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "nanochat"
3
+ version = "0.1.0"
4
+ description = "the minimal full-stack ChatGPT clone"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "datasets>=4.0.0",
9
+ "fastapi>=0.117.1",
10
+ "files-to-prompt>=0.6",
11
+ "psutil>=7.1.0",
12
+ "regex>=2025.9.1",
13
+ "setuptools>=80.9.0",
14
+ "tiktoken>=0.11.0",
15
+ "tokenizers>=0.22.0",
16
+ "torch>=2.8.0",
17
+ "uvicorn>=0.36.0",
18
+ "wandb>=0.21.3",
19
+ ]
20
+
21
+ [build-system]
22
+ requires = ["maturin>=1.7,<2.0"]
23
+ build-backend = "maturin"
24
+
25
+ [tool.maturin]
26
+ module-name = "rustbpe"
27
+ bindings = "pyo3"
28
+ python-source = "."
29
+ manifest-path = "rustbpe/Cargo.toml"
30
+
31
+ [dependency-groups]
32
+ dev = [
33
+ "maturin>=1.9.4",
34
+ "pytest>=8.0.0",
35
+ ]
36
+
37
+ [tool.pytest.ini_options]
38
+ markers = [
39
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
40
+ ]
41
+ testpaths = ["tests"]
42
+ python_files = ["test_*.py"]
43
+ python_classes = ["Test*"]
44
+ python_functions = ["test_*"]
45
+
46
+ # target torch to cuda 12.8 or CPU
47
+ [tool.uv.sources]
48
+ torch = [
49
+ { index = "pytorch-cpu", extra = "cpu" },
50
+ { index = "pytorch-cu128", extra = "gpu" },
51
+ ]
52
+
53
+ [[tool.uv.index]]
54
+ name = "pytorch-cpu"
55
+ url = "https://download.pytorch.org/whl/cpu"
56
+ explicit = true
57
+
58
+ [[tool.uv.index]]
59
+ name = "pytorch-cu128"
60
+ url = "https://download.pytorch.org/whl/cu128"
61
+ explicit = true
62
+
63
+ [project.optional-dependencies]
64
+ cpu = [
65
+ "torch>=2.8.0",
66
+ ]
67
+ gpu = [
68
+ "torch>=2.8.0",
69
+ ]
70
+
71
+ [tool.uv]
72
+ conflicts = [
73
+ [
74
+ { extra = "cpu" },
75
+ { extra = "gpu" },
76
+ ],
77
+ ]
run1000.sh ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # The $1000 tier of nanochat
4
+ # Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
5
+ # A bit sparser on comments, see speedrun.sh for more detail
6
+
7
+ # all the setup stuff
8
+ export OMP_NUM_THREADS=1
9
+ export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
10
+ mkdir -p $NANOCHAT_BASE_DIR
11
+ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
12
+ [ -d ".venv" ] || uv venv
13
+ uv sync --extra gpu
14
+ source .venv/bin/activate
15
+ if [ -z "$WANDB_RUN" ]; then
16
+ WANDB_RUN=dummy
17
+ fi
18
+ python -m nanochat.report reset
19
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
20
+ source "$HOME/.cargo/env"
21
+ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
22
+ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
23
+
24
+ # train tokenizer on ~4B characters and kick off download of the rest for pretraining
25
+ python -m nanochat.dataset -n 16
26
+ # start downloading the rest of the shards for a total of 800 (see below why 800)
27
+ python -m nanochat.dataset -n 800 &
28
+ # todo: download the rest of it
29
+ python -m scripts.tok_train --max_chars=4000000000
30
+ python -m scripts.tok_eval
31
+
32
+ # Documenting my process for determining the hyperparameters for this run1000.sh script:
33
+ # We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
34
+ # 1) I guessed the model size for this to be about depth=32
35
+ # 2) Determine the device_batch_size that fits:
36
+ # Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
37
+ # runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
38
+ # I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
39
+ # So the training script was running ok and showed:
40
+ # Vocab size: 65,536
41
+ # num_layers: 32
42
+ # model_dim: 2048
43
+ # num_heads: 16
44
+ # num_kv_heads: 16
45
+ # Tokens / micro-batch / rank: 8 x 2048 = 16,384
46
+ # Tokens / micro-batch: 131,072
47
+ # Total batch size 524,288 => gradient accumulation steps: 4
48
+ # Number of parameters: 1,879,048,192
49
+ # Estimated FLOPs per token: 1.207960e+10
50
+ # Calculated number of iterations from target data:param ratio: 71,680
51
+ # Total number of training tokens: 37,580,963,840
52
+ # Tokens : Params ratio: 20.00
53
+ # Total training FLOPs estimate: 4.539628e+20
54
+ # step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m
55
+ # step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m
56
+ # ...
57
+ # 3) validate that the runtime fits our budget:
58
+ # The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular:
59
+ # The script shows that we will be training for 71,680 steps, and each step takes 1.574s so:
60
+ # estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours.
61
+ # This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL.
62
+ # It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this.
63
+ # 4) The last thing to pay attention to is the amount of training data required for the run.
64
+ # The script above calculated that "Total number of training tokens: 37,580,963,840"
65
+ # The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
66
+ # So ~38B tokens # ~4.8 chars/token = ~185B chars.
67
+ # Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
68
+ # For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
69
+ # If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
70
+ # which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
71
+ # start to overfit hard.
72
+ # 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
73
+
74
+ # Number of processes/GPUs to use
75
+ NPROC_PER_NODE=8
76
+
77
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=32 --device_batch_size=8 --run=$WANDB_RUN
78
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
79
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
80
+
81
+ # midtrain
82
+ # NOTE: ensure that we use the same device_batch_size here as the base training script.
83
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
84
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
85
+
86
+ # sft
87
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
88
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
89
+
90
+ # generate final report
91
+ python -m nanochat.report generate
92
+
93
+ # talk to it
94
+ python -m scripts.chat_web
rustbpe/Cargo.lock ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "ahash"
7
+ version = "0.8.12"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
10
+ dependencies = [
11
+ "cfg-if",
12
+ "getrandom",
13
+ "once_cell",
14
+ "version_check",
15
+ "zerocopy",
16
+ ]
17
+
18
+ [[package]]
19
+ name = "aho-corasick"
20
+ version = "1.1.3"
21
+ source = "registry+https://github.com/rust-lang/crates.io-index"
22
+ checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
23
+ dependencies = [
24
+ "memchr",
25
+ ]
26
+
27
+ [[package]]
28
+ name = "arc-swap"
29
+ version = "1.7.1"
30
+ source = "registry+https://github.com/rust-lang/crates.io-index"
31
+ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
32
+
33
+ [[package]]
34
+ name = "autocfg"
35
+ version = "1.5.0"
36
+ source = "registry+https://github.com/rust-lang/crates.io-index"
37
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
38
+
39
+ [[package]]
40
+ name = "bit-set"
41
+ version = "0.8.0"
42
+ source = "registry+https://github.com/rust-lang/crates.io-index"
43
+ checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
44
+ dependencies = [
45
+ "bit-vec",
46
+ ]
47
+
48
+ [[package]]
49
+ name = "bit-vec"
50
+ version = "0.8.0"
51
+ source = "registry+https://github.com/rust-lang/crates.io-index"
52
+ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
53
+
54
+ [[package]]
55
+ name = "castaway"
56
+ version = "0.2.4"
57
+ source = "registry+https://github.com/rust-lang/crates.io-index"
58
+ checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
59
+ dependencies = [
60
+ "rustversion",
61
+ ]
62
+
63
+ [[package]]
64
+ name = "cfg-if"
65
+ version = "1.0.3"
66
+ source = "registry+https://github.com/rust-lang/crates.io-index"
67
+ checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
68
+
69
+ [[package]]
70
+ name = "compact_str"
71
+ version = "0.9.0"
72
+ source = "registry+https://github.com/rust-lang/crates.io-index"
73
+ checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
74
+ dependencies = [
75
+ "castaway",
76
+ "cfg-if",
77
+ "itoa",
78
+ "rustversion",
79
+ "ryu",
80
+ "static_assertions",
81
+ ]
82
+
83
+ [[package]]
84
+ name = "crossbeam-deque"
85
+ version = "0.8.6"
86
+ source = "registry+https://github.com/rust-lang/crates.io-index"
87
+ checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
88
+ dependencies = [
89
+ "crossbeam-epoch",
90
+ "crossbeam-utils",
91
+ ]
92
+
93
+ [[package]]
94
+ name = "crossbeam-epoch"
95
+ version = "0.9.18"
96
+ source = "registry+https://github.com/rust-lang/crates.io-index"
97
+ checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
98
+ dependencies = [
99
+ "crossbeam-utils",
100
+ ]
101
+
102
+ [[package]]
103
+ name = "crossbeam-utils"
104
+ version = "0.8.21"
105
+ source = "registry+https://github.com/rust-lang/crates.io-index"
106
+ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
107
+
108
+ [[package]]
109
+ name = "dary_heap"
110
+ version = "0.3.7"
111
+ source = "registry+https://github.com/rust-lang/crates.io-index"
112
+ checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
113
+
114
+ [[package]]
115
+ name = "either"
116
+ version = "1.15.0"
117
+ source = "registry+https://github.com/rust-lang/crates.io-index"
118
+ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
119
+
120
+ [[package]]
121
+ name = "equivalent"
122
+ version = "1.0.2"
123
+ source = "registry+https://github.com/rust-lang/crates.io-index"
124
+ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
125
+
126
+ [[package]]
127
+ name = "fancy-regex"
128
+ version = "0.16.1"
129
+ source = "registry+https://github.com/rust-lang/crates.io-index"
130
+ checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c"
131
+ dependencies = [
132
+ "bit-set",
133
+ "regex-automata",
134
+ "regex-syntax",
135
+ ]
136
+
137
+ [[package]]
138
+ name = "getrandom"
139
+ version = "0.3.3"
140
+ source = "registry+https://github.com/rust-lang/crates.io-index"
141
+ checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
142
+ dependencies = [
143
+ "cfg-if",
144
+ "libc",
145
+ "r-efi",
146
+ "wasi",
147
+ ]
148
+
149
+ [[package]]
150
+ name = "hashbrown"
151
+ version = "0.15.5"
152
+ source = "registry+https://github.com/rust-lang/crates.io-index"
153
+ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
154
+
155
+ [[package]]
156
+ name = "heck"
157
+ version = "0.5.0"
158
+ source = "registry+https://github.com/rust-lang/crates.io-index"
159
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
160
+
161
+ [[package]]
162
+ name = "indexmap"
163
+ version = "2.11.0"
164
+ source = "registry+https://github.com/rust-lang/crates.io-index"
165
+ checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9"
166
+ dependencies = [
167
+ "equivalent",
168
+ "hashbrown",
169
+ ]
170
+
171
+ [[package]]
172
+ name = "indoc"
173
+ version = "2.0.6"
174
+ source = "registry+https://github.com/rust-lang/crates.io-index"
175
+ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
176
+
177
+ [[package]]
178
+ name = "itoa"
179
+ version = "1.0.15"
180
+ source = "registry+https://github.com/rust-lang/crates.io-index"
181
+ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
182
+
183
+ [[package]]
184
+ name = "libc"
185
+ version = "0.2.175"
186
+ source = "registry+https://github.com/rust-lang/crates.io-index"
187
+ checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
188
+
189
+ [[package]]
190
+ name = "log"
191
+ version = "0.4.28"
192
+ source = "registry+https://github.com/rust-lang/crates.io-index"
193
+ checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
194
+
195
+ [[package]]
196
+ name = "memchr"
197
+ version = "2.7.5"
198
+ source = "registry+https://github.com/rust-lang/crates.io-index"
199
+ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
200
+
201
+ [[package]]
202
+ name = "memoffset"
203
+ version = "0.9.1"
204
+ source = "registry+https://github.com/rust-lang/crates.io-index"
205
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
206
+ dependencies = [
207
+ "autocfg",
208
+ ]
209
+
210
+ [[package]]
211
+ name = "once_cell"
212
+ version = "1.21.3"
213
+ source = "registry+https://github.com/rust-lang/crates.io-index"
214
+ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
215
+
216
+ [[package]]
217
+ name = "portable-atomic"
218
+ version = "1.11.1"
219
+ source = "registry+https://github.com/rust-lang/crates.io-index"
220
+ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
221
+
222
+ [[package]]
223
+ name = "proc-macro2"
224
+ version = "1.0.101"
225
+ source = "registry+https://github.com/rust-lang/crates.io-index"
226
+ checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
227
+ dependencies = [
228
+ "unicode-ident",
229
+ ]
230
+
231
+ [[package]]
232
+ name = "pyo3"
233
+ version = "0.23.5"
234
+ source = "registry+https://github.com/rust-lang/crates.io-index"
235
+ checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
236
+ dependencies = [
237
+ "cfg-if",
238
+ "indoc",
239
+ "libc",
240
+ "memoffset",
241
+ "once_cell",
242
+ "portable-atomic",
243
+ "pyo3-build-config",
244
+ "pyo3-ffi",
245
+ "pyo3-macros",
246
+ "unindent",
247
+ ]
248
+
249
+ [[package]]
250
+ name = "pyo3-build-config"
251
+ version = "0.23.5"
252
+ source = "registry+https://github.com/rust-lang/crates.io-index"
253
+ checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
254
+ dependencies = [
255
+ "once_cell",
256
+ "target-lexicon",
257
+ ]
258
+
259
+ [[package]]
260
+ name = "pyo3-ffi"
261
+ version = "0.23.5"
262
+ source = "registry+https://github.com/rust-lang/crates.io-index"
263
+ checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
264
+ dependencies = [
265
+ "libc",
266
+ "pyo3-build-config",
267
+ ]
268
+
269
+ [[package]]
270
+ name = "pyo3-log"
271
+ version = "0.12.4"
272
+ source = "registry+https://github.com/rust-lang/crates.io-index"
273
+ checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
274
+ dependencies = [
275
+ "arc-swap",
276
+ "log",
277
+ "pyo3",
278
+ ]
279
+
280
+ [[package]]
281
+ name = "pyo3-macros"
282
+ version = "0.23.5"
283
+ source = "registry+https://github.com/rust-lang/crates.io-index"
284
+ checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
285
+ dependencies = [
286
+ "proc-macro2",
287
+ "pyo3-macros-backend",
288
+ "quote",
289
+ "syn",
290
+ ]
291
+
292
+ [[package]]
293
+ name = "pyo3-macros-backend"
294
+ version = "0.23.5"
295
+ source = "registry+https://github.com/rust-lang/crates.io-index"
296
+ checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
297
+ dependencies = [
298
+ "heck",
299
+ "proc-macro2",
300
+ "pyo3-build-config",
301
+ "quote",
302
+ "syn",
303
+ ]
304
+
305
+ [[package]]
306
+ name = "quote"
307
+ version = "1.0.40"
308
+ source = "registry+https://github.com/rust-lang/crates.io-index"
309
+ checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
310
+ dependencies = [
311
+ "proc-macro2",
312
+ ]
313
+
314
+ [[package]]
315
+ name = "r-efi"
316
+ version = "5.3.0"
317
+ source = "registry+https://github.com/rust-lang/crates.io-index"
318
+ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
319
+
320
+ [[package]]
321
+ name = "rayon"
322
+ version = "1.11.0"
323
+ source = "registry+https://github.com/rust-lang/crates.io-index"
324
+ checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
325
+ dependencies = [
326
+ "either",
327
+ "rayon-core",
328
+ ]
329
+
330
+ [[package]]
331
+ name = "rayon-core"
332
+ version = "1.13.0"
333
+ source = "registry+https://github.com/rust-lang/crates.io-index"
334
+ checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
335
+ dependencies = [
336
+ "crossbeam-deque",
337
+ "crossbeam-utils",
338
+ ]
339
+
340
+ [[package]]
341
+ name = "regex-automata"
342
+ version = "0.4.10"
343
+ source = "registry+https://github.com/rust-lang/crates.io-index"
344
+ checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6"
345
+ dependencies = [
346
+ "aho-corasick",
347
+ "memchr",
348
+ "regex-syntax",
349
+ ]
350
+
351
+ [[package]]
352
+ name = "regex-syntax"
353
+ version = "0.8.6"
354
+ source = "registry+https://github.com/rust-lang/crates.io-index"
355
+ checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001"
356
+
357
+ [[package]]
358
+ name = "rustbpe"
359
+ version = "0.1.0"
360
+ dependencies = [
361
+ "ahash",
362
+ "compact_str",
363
+ "dary_heap",
364
+ "fancy-regex",
365
+ "indexmap",
366
+ "log",
367
+ "pyo3",
368
+ "pyo3-log",
369
+ "rayon",
370
+ ]
371
+
372
+ [[package]]
373
+ name = "rustversion"
374
+ version = "1.0.22"
375
+ source = "registry+https://github.com/rust-lang/crates.io-index"
376
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
377
+
378
+ [[package]]
379
+ name = "ryu"
380
+ version = "1.0.20"
381
+ source = "registry+https://github.com/rust-lang/crates.io-index"
382
+ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
383
+
384
+ [[package]]
385
+ name = "static_assertions"
386
+ version = "1.1.0"
387
+ source = "registry+https://github.com/rust-lang/crates.io-index"
388
+ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
389
+
390
+ [[package]]
391
+ name = "syn"
392
+ version = "2.0.106"
393
+ source = "registry+https://github.com/rust-lang/crates.io-index"
394
+ checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
395
+ dependencies = [
396
+ "proc-macro2",
397
+ "quote",
398
+ "unicode-ident",
399
+ ]
400
+
401
+ [[package]]
402
+ name = "target-lexicon"
403
+ version = "0.12.16"
404
+ source = "registry+https://github.com/rust-lang/crates.io-index"
405
+ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
406
+
407
+ [[package]]
408
+ name = "unicode-ident"
409
+ version = "1.0.18"
410
+ source = "registry+https://github.com/rust-lang/crates.io-index"
411
+ checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
412
+
413
+ [[package]]
414
+ name = "unindent"
415
+ version = "0.2.4"
416
+ source = "registry+https://github.com/rust-lang/crates.io-index"
417
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
418
+
419
+ [[package]]
420
+ name = "version_check"
421
+ version = "0.9.5"
422
+ source = "registry+https://github.com/rust-lang/crates.io-index"
423
+ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
424
+
425
+ [[package]]
426
+ name = "wasi"
427
+ version = "0.14.4+wasi-0.2.4"
428
+ source = "registry+https://github.com/rust-lang/crates.io-index"
429
+ checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a"
430
+ dependencies = [
431
+ "wit-bindgen",
432
+ ]
433
+
434
+ [[package]]
435
+ name = "wit-bindgen"
436
+ version = "0.45.1"
437
+ source = "registry+https://github.com/rust-lang/crates.io-index"
438
+ checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
439
+
440
+ [[package]]
441
+ name = "zerocopy"
442
+ version = "0.8.26"
443
+ source = "registry+https://github.com/rust-lang/crates.io-index"
444
+ checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
445
+ dependencies = [
446
+ "zerocopy-derive",
447
+ ]
448
+
449
+ [[package]]
450
+ name = "zerocopy-derive"
451
+ version = "0.8.26"
452
+ source = "registry+https://github.com/rust-lang/crates.io-index"
453
+ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
454
+ dependencies = [
455
+ "proc-macro2",
456
+ "quote",
457
+ "syn",
458
+ ]
rustbpe/Cargo.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "rustbpe"
3
+ version = "0.1.0"
4
+ edition = "2024"
5
+
6
+ [dependencies]
7
+ dary_heap = "0.3"
8
+ indexmap = "2.2"
9
+ fancy-regex = "0.16.1"
10
+ log = "0.4.28"
11
+ pyo3 = { version = "0.23.3", features = ["extension-module"] }
12
+ pyo3-log = "0.12.4"
13
+ ahash = "0.8.12"
14
+ rayon = "1.11.0"
15
+ compact_str = "0.9.0"
rustbpe/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # rustbpe
2
+
3
+ > The missing tiktoken training code
4
+
5
+ A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat.
rustbpe/src/lib.rs ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::cmp::Ordering;
2
+ use std::collections::HashMap as StdHashMap;
3
+
4
+ use dary_heap::OctonaryHeap;
5
+ use fancy_regex::Regex;
6
+ use pyo3::prelude::*;
7
+
8
+ use ahash::{AHashMap, AHashSet};
9
+ use compact_str::CompactString;
10
+ use rayon::prelude::*;
11
+
12
+ // Default GPT-4 style regex pattern for splitting text
13
+ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
14
+
15
+ type Pair = (u32, u32);
16
+
17
+ /// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
18
+ #[pyclass]
19
+ pub struct Tokenizer {
20
+ /// Maps pairs of token IDs to their merged token ID
21
+ pub merges: StdHashMap<Pair, u32>,
22
+ /// The regex pattern used for text splitting
23
+ pub pattern: String,
24
+ /// Compiled regex for efficiency
25
+ compiled_pattern: Regex,
26
+ }
27
+
28
+ // ------------------------ internal helpers ------------------------
29
+
30
+ #[derive(Clone, Debug)]
31
+ struct Word {
32
+ ids: Vec<u32>,
33
+ }
34
+
35
+ impl Word {
36
+ #[inline]
37
+ fn new(ids: Vec<u32>) -> Self {
38
+ Self { ids }
39
+ }
40
+
41
+ #[inline]
42
+ fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
43
+ self.ids.windows(2).map(|w| (w[0], w[1]))
44
+ }
45
+
46
+ /// Merge all non-overlapping occurrences of pair -> new_id.
47
+ /// Returns a small Vec of local pair-count deltas for THIS word only:
48
+ /// -1 for removed pairs, +1 for newly created pairs.
49
+ ///
50
+ /// NOTE: this version deliberately avoids a HashMap in the hot loop.
51
+ fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
52
+ let (a, b) = pair;
53
+ let n = self.ids.len();
54
+ if n < 2 {
55
+ return Vec::new();
56
+ }
57
+
58
+ let mut out: Vec<u32> = Vec::with_capacity(n);
59
+ let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
60
+
61
+ let mut i = 0;
62
+ while i < n {
63
+ if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
64
+ let left = out.last().copied();
65
+ let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None };
66
+
67
+ // remove old pairs
68
+ if let Some(x) = left {
69
+ deltas.push(((x, a), -1));
70
+ deltas.push(((x, new_id), 1));
71
+ }
72
+ deltas.push(((a, b), -1));
73
+ if let Some(y) = right {
74
+ deltas.push(((b, y), -1));
75
+ deltas.push(((new_id, y), 1));
76
+ }
77
+
78
+ // write merged token
79
+ out.push(new_id);
80
+ i += 2; // skip 'a' and 'b'
81
+ } else {
82
+ out.push(self.ids[i]);
83
+ i += 1;
84
+ }
85
+ }
86
+
87
+ self.ids = out;
88
+ deltas
89
+ }
90
+ }
91
+
92
+ #[derive(Debug, Eq)]
93
+ struct MergeJob {
94
+ pair: Pair,
95
+ count: u64,
96
+ /// set of word indices where this pair may occur and needs processing
97
+ pos: AHashSet<usize>,
98
+ }
99
+
100
+ impl PartialEq for MergeJob {
101
+ fn eq(&self, other: &Self) -> bool {
102
+ self.count == other.count && self.pair == other.pair
103
+ }
104
+ }
105
+
106
+ impl PartialOrd for MergeJob {
107
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108
+ Some(self.cmp(other))
109
+ }
110
+ }
111
+
112
+ impl Ord for MergeJob {
113
+ fn cmp(&self, other: &Self) -> Ordering {
114
+ // Max-heap by count; tie-break to ascending pair order (deterministic)
115
+ if self.count != other.count {
116
+ self.count.cmp(&other.count)
117
+ } else {
118
+ // ascending order on the pair when counts tie
119
+ other.pair.cmp(&self.pair)
120
+ }
121
+ }
122
+ }
123
+
124
+ #[inline]
125
+ fn count_pairs_parallel(
126
+ words: &[Word],
127
+ counts: &[i32],
128
+ ) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
129
+ words
130
+ .par_iter()
131
+ .enumerate()
132
+ .map(|(i, w)| {
133
+ let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();
134
+ let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
135
+ if w.ids.len() >= 2 && counts[i] != 0 {
136
+ for (a, b) in w.pairs() {
137
+ *local_pc.entry((a, b)).or_default() += counts[i];
138
+ local_wtu.entry((a, b)).or_default().insert(i);
139
+ }
140
+ }
141
+ (local_pc, local_wtu)
142
+ })
143
+ .reduce(
144
+ || (AHashMap::new(), AHashMap::new()),
145
+ |(mut acc_pc, mut acc_wtu), (pc, wtu)| {
146
+ for (k, v) in pc {
147
+ *acc_pc.entry(k).or_default() += v;
148
+ }
149
+ for (k, s) in wtu {
150
+ acc_wtu.entry(k).or_default().extend(s);
151
+ }
152
+ (acc_pc, acc_wtu)
153
+ },
154
+ )
155
+ }
156
+
157
+ // ------------------------ END helpers ------------------------
158
+
159
+ impl Tokenizer {
160
+
161
+ /// Core incremental BPE training given unique words and their counts.
162
+ /// `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
163
+ /// `counts`: same length as `words`, count per chunk.
164
+ fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
165
+ assert!(vocab_size >= 256, "vocab_size must be at least 256");
166
+ let num_merges = vocab_size - 256;
167
+ log::info!("Starting BPE training: {} merges to compute", num_merges);
168
+ self.merges.clear();
169
+
170
+ // ---- Initial pair_counts and where_to_update (parallel) ----
171
+ log::info!("Computing initial pair counts from {} unique sequences", words.len());
172
+ let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
173
+
174
+ // ---- Build heap ----
175
+ log::info!("Building heap with {} unique pairs", pair_counts.len());
176
+ let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
177
+ for (pair, pos) in where_to_update.drain() {
178
+ let c = *pair_counts.get(&pair).unwrap_or(&0);
179
+ if c > 0 {
180
+ heap.push(MergeJob {
181
+ pair,
182
+ count: c as u64,
183
+ pos,
184
+ });
185
+ }
186
+ }
187
+
188
+ // ---- Merge loop ----
189
+ log::info!("Starting merge loop");
190
+ let mut merges_done = 0u32;
191
+ let mut last_log_percent = 0u32;
192
+
193
+ while merges_done < num_merges {
194
+ let Some(mut top) = heap.pop() else { break; };
195
+
196
+ // Lazy refresh
197
+ let current = *pair_counts.get(&top.pair).unwrap_or(&0);
198
+ if top.count != current as u64 {
199
+ top.count = current as u64;
200
+ if top.count > 0 {
201
+ heap.push(top);
202
+ }
203
+ continue;
204
+ }
205
+ if top.count == 0 {
206
+ break;
207
+ }
208
+
209
+ // Record merge
210
+ let new_id = 256 + merges_done;
211
+ self.merges.insert(top.pair, new_id);
212
+
213
+ // Merge this pair in all words where it occurs
214
+ let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
215
+ for &word_idx in &top.pos {
216
+ // Apply merge to this word and collect pair-count deltas
217
+ let changes = words[word_idx].merge_pair(top.pair, new_id);
218
+ // Update global pair counts based on this word's count
219
+ for (pair, delta) in changes {
220
+ let delta_total = delta * counts[word_idx];
221
+ if delta_total != 0 {
222
+ *pair_counts.entry(pair).or_default() += delta_total;
223
+ if delta > 0 {
224
+ local_pos_updates.entry(pair).or_default().insert(word_idx);
225
+ }
226
+ }
227
+ }
228
+ }
229
+
230
+ // Add the updated pair counts back to the heap
231
+ for (pair, pos) in local_pos_updates {
232
+ let cnt = *pair_counts.get(&pair).unwrap_or(&0);
233
+ if cnt > 0 {
234
+ heap.push(MergeJob {
235
+ pair,
236
+ count: cnt as u64,
237
+ pos,
238
+ });
239
+ }
240
+ }
241
+
242
+ merges_done += 1;
243
+
244
+ // Log progress every 1%
245
+ let current_percent = (merges_done * 100) / num_merges;
246
+ if current_percent > last_log_percent {
247
+ log::info!(
248
+ "Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})",
249
+ current_percent, merges_done, num_merges, top.pair, new_id, top.count
250
+ );
251
+ last_log_percent = current_percent;
252
+ }
253
+ }
254
+
255
+ log::info!("Finished training: {} merges completed", merges_done);
256
+ }
257
+ }
258
+
259
+ /// Public methods for the Tokenizer class that will be exposed to Python.
260
+ #[pymethods]
261
+ impl Tokenizer {
262
+ /// Create a new Tokenizer
263
+ #[new]
264
+ pub fn new() -> Self {
265
+ Self {
266
+ merges: StdHashMap::new(),
267
+ pattern: String::new(),
268
+ compiled_pattern: Regex::new("").expect("Empty regex should be valid"),
269
+ }
270
+ }
271
+
272
+ /// Train from a streaming iterator (parallel ingestion).
273
+ /// We refill a Rust Vec<String> buffer under the GIL, then release the GIL
274
+ /// to do the heavy splitting and counting **in parallel** with rayon.
275
+ #[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))]
276
+ #[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")]
277
+ pub fn train_from_iterator(
278
+ &mut self,
279
+ py: pyo3::Python<'_>,
280
+ iterator: &pyo3::Bound<'_, pyo3::PyAny>,
281
+ vocab_size: u32,
282
+ buffer_size: usize,
283
+ pattern: Option<String>,
284
+ ) -> PyResult<()> {
285
+ // Use provided pattern or default to GPT-4 pattern
286
+ let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
287
+
288
+ // Update the stored pattern and compile it
289
+ self.pattern = pattern_str.clone();
290
+ self.compiled_pattern = Regex::new(&pattern_str)
291
+ .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
292
+
293
+ // Prepare a true Python iterator object
294
+ let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
295
+ pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
296
+ };
297
+
298
+ // Global chunk counts
299
+ let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
300
+
301
+ // Temporary buffer we refill under the GIL
302
+ let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
303
+
304
+ log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size);
305
+ let mut total_sequences = 0u64;
306
+
307
+ // Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
308
+ // Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
309
+ let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
310
+ pyo3::Python::with_gil(|py| {
311
+ buf.clear();
312
+ let it = py_iter.bind(py);
313
+ loop {
314
+ if buf.len() >= buffer_size {
315
+ return Ok(false);
316
+ }
317
+ // next(it)
318
+ let next_obj = unsafe {
319
+ pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
320
+ };
321
+ match next_obj {
322
+ Some(obj) => {
323
+ let s: String = obj.extract()?;
324
+ buf.push(s);
325
+ }
326
+ None => {
327
+ if pyo3::PyErr::occurred(py) {
328
+ return Err(pyo3::PyErr::fetch(py));
329
+ } else {
330
+ return Ok(true); // exhausted
331
+ }
332
+ }
333
+ }
334
+ }
335
+ })
336
+ };
337
+
338
+ // Stream ingestion loop: refill under GIL, process without GIL (parallel)
339
+ loop {
340
+ let exhausted = refill(&mut buf)?;
341
+ if buf.is_empty() && exhausted {
342
+ break;
343
+ }
344
+
345
+ total_sequences += buf.len() as u64;
346
+
347
+ let pattern = self.compiled_pattern.clone();
348
+ let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
349
+ buf.par_iter()
350
+ .map(|s| {
351
+ let mut m: AHashMap<CompactString, i32> = AHashMap::new();
352
+ for mat in pattern.find_iter(s) {
353
+ let piece = mat.expect("regex match failed").as_str();
354
+ *m.entry(CompactString::from(piece)).or_default() += 1;
355
+ }
356
+ m
357
+ })
358
+ .reduce(
359
+ || AHashMap::new(),
360
+ |mut a, b| {
361
+ for (k, v) in b {
362
+ *a.entry(k).or_default() += v;
363
+ }
364
+ a
365
+ },
366
+ )
367
+ });
368
+
369
+ // Merge local into global (single-threaded)
370
+ for (k, v) in local {
371
+ *counts.entry(k).or_default() += v;
372
+ }
373
+
374
+ if exhausted {
375
+ break;
376
+ }
377
+ }
378
+ log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
379
+
380
+ // Materialize words & counts
381
+ let mut words = Vec::with_capacity(counts.len());
382
+ let mut cvec = Vec::with_capacity(counts.len());
383
+ for (chunk, c) in counts.into_iter() {
384
+ words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
385
+ cvec.push(c);
386
+ }
387
+
388
+ self.train_core_incremental(words, cvec, vocab_size);
389
+ Ok(())
390
+ }
391
+
392
+ /// Return the regex pattern
393
+ pub fn get_pattern(&self) -> String {
394
+ self.pattern.clone()
395
+ }
396
+
397
+ /// Return the mergeable ranks (token bytes -> token id / rank)
398
+ pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
399
+ let mut mergeable_ranks = Vec::new();
400
+
401
+ // Build vocabulary incrementally from low to high token IDs
402
+ let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect();
403
+
404
+ for (i, bytes) in token_bytes.iter().enumerate() {
405
+ mergeable_ranks.push((bytes.clone(), i as u32));
406
+ }
407
+
408
+ // Sort merges by token id (so we can reconstruct bytes progressively)
409
+ let mut sorted_merges: Vec<_> = self.merges.iter().collect();
410
+ sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
411
+
412
+ for (&pair, &merged_id) in sorted_merges {
413
+ let (left, right) = pair;
414
+ let mut merged_bytes = token_bytes[left as usize].clone();
415
+ merged_bytes.extend(&token_bytes[right as usize]);
416
+
417
+ if token_bytes.len() <= merged_id as usize {
418
+ token_bytes.resize(merged_id as usize + 1, Vec::new());
419
+ }
420
+ token_bytes[merged_id as usize] = merged_bytes.clone();
421
+
422
+ mergeable_ranks.push((merged_bytes, merged_id));
423
+ }
424
+
425
+ mergeable_ranks
426
+ }
427
+
428
+ /// Encode a string into token IDs
429
+ pub fn encode(&self, text: &str) -> Vec<u32> {
430
+ let mut all_ids = Vec::new();
431
+
432
+ // Split text using the regex pattern
433
+ for m in self.compiled_pattern.find_iter(text) {
434
+ let chunk = m.expect("regex match failed").as_str();
435
+
436
+ // Convert chunk to bytes then to u32 IDs
437
+ let mut ids: Vec<u32> = chunk.bytes().map(|b| b as u32).collect();
438
+
439
+ // Apply merges iteratively
440
+ while ids.len() >= 2 {
441
+ // Find the best pair to merge
442
+ let mut best_pair: Option<(usize, Pair, u32)> = None;
443
+
444
+ for i in 0..ids.len() - 1 {
445
+ let pair: Pair = (ids[i], ids[i + 1]);
446
+ if let Some(&new_id) = self.merges.get(&pair) {
447
+ if best_pair.is_none() || new_id < best_pair.unwrap().2 {
448
+ best_pair = Some((i, pair, new_id));
449
+ }
450
+ }
451
+ }
452
+
453
+ // If we found a pair to merge, apply it
454
+ if let Some((idx, _pair, new_id)) = best_pair {
455
+ ids[idx] = new_id;
456
+ ids.remove(idx + 1);
457
+ } else {
458
+ // No more merges possible
459
+ break;
460
+ }
461
+ }
462
+
463
+ all_ids.extend(ids);
464
+ }
465
+
466
+ all_ids
467
+ }
468
+ }
469
+
470
+ #[pymodule]
471
+ fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
472
+ pyo3_log::init(); // forwards Rust `log` to Python's `logging`
473
+ m.add_class::<Tokenizer>()?;
474
+ Ok(())
475
+ }
scripts/base_eval.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate the CORE metric for a given model.
3
+
4
+ Run on a single GPU:
5
+ python -m scripts.base_eval
6
+
7
+ Run with torchrun on e.g. 8 GPUs:
8
+ torchrun --nproc_per_node=8 -m scripts.base_eval
9
+
10
+ The script will print the CORE metric to the console.
11
+ """
12
+ import os
13
+ import csv
14
+ import time
15
+ import json
16
+ import yaml
17
+ import shutil
18
+ import random
19
+ import zipfile
20
+ import tempfile
21
+ from contextlib import nullcontext
22
+
23
+ import torch
24
+
25
+ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
26
+ from nanochat.tokenizer import HuggingFaceTokenizer
27
+ from nanochat.checkpoint_manager import load_model
28
+ from nanochat.core_eval import evaluate_task
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # nanochat specific function dealing with I/O etc.
32
+
33
+ # ~162MB of data needed to evaluate the CORE metric
34
+ EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"
35
+
36
+ def place_eval_bundle(file_path):
37
+ # here file_path is the path to the eval_bundle.zip file
38
+ # we need to unzip it and place it in the base directory
39
+ base_dir = get_base_dir()
40
+ eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
41
+ with tempfile.TemporaryDirectory() as tmpdir:
42
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
43
+ zip_ref.extractall(tmpdir)
44
+ extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
45
+ shutil.move(extracted_bundle_dir, eval_bundle_dir)
46
+ print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
47
+
48
+ def evaluate_model(model, tokenizer, device, max_per_task=-1):
49
+ """
50
+ Evaluate a base model on the CORE benchmark.
51
+ - max_per_task: crop the data to this many examples per task for testing (-1 = disable)
52
+ """
53
+ # Load config and task metadata
54
+ base_dir = get_base_dir()
55
+ eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
56
+ # Download the eval bundle to disk (and unzip if needed)
57
+ if not os.path.exists(eval_bundle_dir):
58
+ download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)
59
+ config_path = os.path.join(eval_bundle_dir, "core.yaml")
60
+ data_base_path = os.path.join(eval_bundle_dir, "eval_data")
61
+ eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
62
+ with open(config_path, 'r', encoding='utf-8') as f:
63
+ config = yaml.safe_load(f)
64
+ tasks = config['icl_tasks']
65
+
66
+ # Load random baseline values from eval metadata
67
+ random_baselines = {}
68
+ with open(eval_meta_data, 'r', encoding='utf-8') as f:
69
+ reader = csv.DictReader(f)
70
+ for row in reader:
71
+ task_name = row['Eval Task']
72
+ random_baseline = row['Random baseline']
73
+ random_baselines[task_name] = float(random_baseline)
74
+
75
+ # Evaluate each task
76
+ results = {}
77
+ centered_results = {}
78
+ for task in tasks:
79
+ start_time = time.time()
80
+ label = task['label']
81
+ task_meta = {
82
+ 'task_type': task['icl_task_type'],
83
+ 'dataset_uri': task['dataset_uri'],
84
+ 'num_fewshot': task['num_fewshot'][0],
85
+ 'continuation_delimiter': task.get('continuation_delimiter', ' ')
86
+ }
87
+ print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
88
+
89
+ # Load data for this task
90
+ data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
91
+ with open(data_path, 'r', encoding='utf-8') as f:
92
+ data = [json.loads(line.strip()) for line in f]
93
+
94
+ # shuffle the data because in many cases it appears ordered but we want
95
+ # the ability to only run a subset of the data for debugging purposes etc.
96
+ shuffle_rng = random.Random(1337)
97
+ shuffle_rng.shuffle(data)
98
+ if max_per_task > 0:
99
+ data = data[:max_per_task]
100
+
101
+ # run the evaluation for this task
102
+ accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
103
+
104
+ results[label] = accuracy
105
+ random_baseline = random_baselines[label]
106
+ centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
107
+ centered_results[label] = centered_result
108
+ end_time = time.time()
109
+ print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
110
+
111
+ core_metric = sum(centered_results.values()) / len(centered_results)
112
+ out = {
113
+ "results": results,
114
+ "centered_results": centered_results,
115
+ "core_metric": core_metric
116
+ }
117
+ return out
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # HuggingFace loading utilities and light wrappers for a model
121
+
122
+ class ModelWrapper:
123
+ """Lightweight wrapper for a HuggingFace model"""
124
+ def __init__(self, model, max_seq_len=None):
125
+ self.model = model
126
+ self.max_seq_len = max_seq_len
127
+
128
+ def __call__(self, input_ids):
129
+ outputs = self.model(input_ids)
130
+ logits = outputs.logits
131
+ return logits
132
+
133
+ def load_hf_model(hf_path: str, device):
134
+ print0(f"Loading model from: {hf_path}")
135
+ # Load the model
136
+ from transformers import AutoModelForCausalLM
137
+ model = AutoModelForCausalLM.from_pretrained(hf_path)
138
+ model.to(device)
139
+ model.eval()
140
+ max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
141
+ model = ModelWrapper(model, max_seq_len=max_seq_len)
142
+ # Load the tokenizer
143
+ tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
144
+ return model, tokenizer
145
+
146
+ # -----------------------------------------------------------------------------
147
+ def main():
148
+ import argparse
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
151
+ parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
152
+ args = parser.parse_args()
153
+
154
+ # distributed / precision setup
155
+ device_type = autodetect_device_type()
156
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
157
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
158
+
159
+ # Load model and tokenizer from command line or from file system
160
+ if args.hf_path is not None:
161
+ # atm assume that if a path is given, it's a huggingface model path
162
+ hf_path = args.hf_path
163
+ print0(f"Loading huggingface model from: {hf_path}")
164
+ model, tokenizer = load_hf_model(hf_path, device)
165
+ model_name = hf_path # just for logging
166
+ model_slug = hf_path.replace("/", "-") # for the output csv file
167
+ else:
168
+ # load a local model from the file system
169
+ model, tokenizer, meta = load_model("base", device, phase="eval")
170
+ model_name = f"base_model (step {meta['step']})" # just for logging
171
+ model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
172
+
173
+ # Evaluate the model
174
+ with autocast_ctx:
175
+ out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
176
+
177
+ # Write out the results to a csv file
178
+ core_metric = None
179
+ centered_results = {}
180
+ if ddp_rank == 0:
181
+ base_dir = get_base_dir()
182
+ output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
183
+ os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
184
+ results = out["results"]
185
+ centered_results = out["centered_results"]
186
+ core_metric = out["core_metric"]
187
+ with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
188
+ f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
189
+ for label in results:
190
+ f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
191
+ f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
192
+ # Print the content of the csv file to console too
193
+ print0("="*80)
194
+ print0(f"Model: {model_name}")
195
+ print0("="*80)
196
+ with open(output_csv_path, 'r', encoding='utf-8') as f:
197
+ print0(f.read())
198
+
199
+ # Log to report
200
+ from nanochat.report import get_report
201
+ get_report().log(section="Base model evaluation", data=[
202
+ {
203
+ "Model": model_name,
204
+ "CORE metric": core_metric,
205
+ },
206
+ centered_results, # the full table
207
+ ])
208
+
209
+ compute_cleanup()
210
+
211
+ if __name__ == "__main__":
212
+ main()
scripts/base_loss.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads a checkpoint, and:
3
+ - Evaluates the loss on a larger chunk of train/val splits
4
+ - Samples from the model
5
+
6
+ Example run as:
7
+ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
8
+ """
9
+ import os
10
+ from contextlib import nullcontext
11
+ import torch
12
+ from nanochat.checkpoint_manager import load_model
13
+ from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
14
+ from nanochat.dataloader import tokenizing_distributed_data_loader
15
+ from nanochat.tokenizer import get_token_bytes
16
+ from nanochat.loss_eval import evaluate_bpb
17
+ from nanochat.engine import Engine
18
+
19
+ # Configuration
20
+ device_batch_size = 32
21
+ split_tokens = 20*524288 # number of tokens to evaluate per split
22
+ model_tag = None # optional model tag for the output directory name
23
+ model_step = None # optional model step for the output directory name
24
+ device_type = "" # cuda|cpu|mps (empty => autodetect)
25
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
26
+
27
+ # Load the base model and the tokenizer
28
+ device_type = autodetect_device_type() if device_type == "" else device_type
29
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
30
+ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
31
+ sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
32
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
33
+
34
+ # Evaluate the loss on each split
35
+ tokens_per_step = device_batch_size * sequence_len * ddp_world_size
36
+ assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
37
+ steps = split_tokens // tokens_per_step
38
+ token_bytes = get_token_bytes(device=device)
39
+ bpb_results = {}
40
+ for split_name in ["train", "val"]:
41
+ loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
42
+ with autocast_ctx:
43
+ bpb = evaluate_bpb(model, loader, steps, token_bytes)
44
+ print0(f"{split_name} bpb: {bpb:.4f}")
45
+ bpb_results[split_name] = bpb
46
+
47
+ # Master process also samples from the model
48
+ samples = []
49
+ if ddp_rank == 0:
50
+ prompts = [
51
+ "The capital of France is",
52
+ "The chemical symbol of gold is",
53
+ "If yesterday was Friday, then tomorrow will be",
54
+ "The opposite of hot is",
55
+ "The planets of the solar system are:",
56
+ "My favorite color is",
57
+ "If 5*x + 3 = 13, then x is",
58
+ ]
59
+ engine = Engine(model, tokenizer)
60
+ for prompt in prompts:
61
+ tokens = tokenizer(prompt, prepend="<|bos|>")
62
+ with autocast_ctx:
63
+ sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
64
+ sample_str = tokenizer.decode(sample[0])
65
+ print0(sample_str)
66
+ samples.append(sample_str)
67
+
68
+ # Log to report
69
+ from nanochat.report import get_report
70
+ get_report().log(section="Base model loss", data=[
71
+ {
72
+ "train bpb": bpb_results["train"],
73
+ "val bpb": bpb_results["val"],
74
+ },
75
+ {f"sample {i}": sample for i, sample in enumerate(samples)},
76
+ ])
77
+
78
+ # Cleanup
79
+ compute_cleanup()
scripts/base_train.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train model. Run as:
3
+
4
+ python base_train.py
5
+
6
+ or distributed as:
7
+
8
+ torchrun --nproc_per_node=8 base_train.py
9
+
10
+ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
11
+ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
12
+ """
13
+
14
+ import os
15
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
16
+ import time
17
+ from contextlib import nullcontext
18
+
19
+ import wandb
20
+ import torch
21
+
22
+ from nanochat.gpt import GPT, GPTConfig
23
+ from nanochat.dataloader import tokenizing_distributed_data_loader
24
+ from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
25
+ from nanochat.tokenizer import get_tokenizer, get_token_bytes
26
+ from nanochat.checkpoint_manager import save_checkpoint
27
+ from nanochat.loss_eval import evaluate_bpb
28
+ from nanochat.engine import Engine
29
+ from scripts.base_eval import evaluate_model
30
+ print_banner()
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # User settings
34
+ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
35
+ # Runtime
36
+ device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
37
+ # Model architecture
38
+ depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
39
+ max_seq_len = 2048 # max context length
40
+ # Training horizon. Only one of these 3 will be used, in this order of precedence.
41
+ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
42
+ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
43
+ target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
44
+ # Optimization
45
+ device_batch_size = 32 # per-device batch size (set to not OOM)
46
+ total_batch_size = 524288 # total desired batch size, in #tokens
47
+ embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
48
+ unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
49
+ weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
50
+ matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
51
+ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
52
+ warmup_ratio = 0.0 # ratio of iterations for LR warmup
53
+ warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
54
+ final_lr_frac = 0.0 # final LR is this fraction of the initial LR
55
+ # Evaluation
56
+ eval_every = 250 # every how many steps to evaluate the model for val bpb
57
+ eval_tokens = 20*524288 # number of tokens to evaluate val loss on
58
+ core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
59
+ core_metric_max_per_task = 500 # examples per task in estimating the core metric
60
+ sample_every = 2000 # every how many steps to sample from the model
61
+ # Output
62
+ model_tag = "" # optionally override the model tag for the output checkpoint directory name
63
+ # now allow CLI to override the settings via the configurator lol
64
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
65
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
66
+ user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
67
+ # -----------------------------------------------------------------------------
68
+
69
+ # Compute init
70
+ device_type = autodetect_device_type() if device_type == "" else device_type
71
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
72
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
73
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
74
+ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
75
+ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
76
+
77
+ # wandb logging init
78
+ use_dummy_wandb = run == "dummy" or not master_process
79
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
80
+
81
+ # Tokenizer will be useful for evaluation, also we need the vocab size
82
+ tokenizer = get_tokenizer()
83
+ token_bytes = get_token_bytes(device=device)
84
+ vocab_size = tokenizer.get_vocab_size()
85
+ print0(f"Vocab size: {vocab_size:,}")
86
+
87
+ # Model kwargs are derived from the desired depth of the model
88
+ num_layers = depth
89
+ model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
90
+ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
91
+ num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
92
+ print0(f"num_layers: {num_layers}")
93
+ print0(f"model_dim: {model_dim}")
94
+ print0(f"num_heads: {num_heads}")
95
+ print0(f"num_kv_heads: {num_kv_heads}")
96
+
97
+ # Optimizer / data / training length related hyperparameters
98
+ # figure out the needed gradient accumulation to reach the desired total batch size
99
+ tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
100
+ world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
101
+ assert total_batch_size % world_tokens_per_fwdbwd == 0
102
+ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
103
+ print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
104
+ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
105
+ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
106
+ # -----------------------------------------------------------------------------
107
+ # Initialize the Model
108
+ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
109
+ with torch.device("meta"):
110
+ model_config = GPTConfig(**model_config_kwargs)
111
+ model = GPT(model_config)
112
+ model.to_empty(device=device)
113
+ model.init_weights()
114
+ orig_model = model # original, uncompiled model, for saving raw model state_dict
115
+ model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
116
+ num_params = sum(p.numel() for p in model.parameters())
117
+ print0(f"Number of parameters: {num_params:,}")
118
+ num_flops_per_token = model.estimate_flops()
119
+ print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
120
+
121
+ # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
122
+ assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
123
+ if num_iterations > 0:
124
+ print0(f"Using user-provided number of iterations: {num_iterations:,}")
125
+ elif target_flops > 0:
126
+ # calculate the number of iterations from the target flops
127
+ num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
128
+ print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
129
+ elif target_param_data_ratio > 0:
130
+ # calculate the number of iterations from the target param data ratio
131
+ target_tokens = target_param_data_ratio * num_params
132
+ num_iterations = target_tokens // total_batch_size
133
+ print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
134
+ else:
135
+ raise ValueError("No training horizon specified")
136
+ total_tokens = total_batch_size * num_iterations
137
+ print0(f"Total number of training tokens: {total_tokens:,}")
138
+ print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
139
+ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
140
+
141
+ # -----------------------------------------------------------------------------
142
+ # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
143
+ optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
144
+ adamw_optimizer, muon_optimizer = optimizers
145
+
146
+ # Initialize the DataLoaders for train/val
147
+ base_dir = get_base_dir()
148
+ tokens_dir = os.path.join(base_dir, "tokenized_data")
149
+ train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
150
+ build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
151
+ x, y = next(train_loader) # kick off load of the very first batch of data
152
+
153
+ # -----------------------------------------------------------------------------
154
+ # Set up hyperparameter schedulers
155
+
156
+ # Learning rate scheduler
157
+ def get_lr_multiplier(it):
158
+ warmup_iters = round(warmup_ratio * num_iterations)
159
+ warmdown_iters = round(warmdown_ratio * num_iterations)
160
+ if it < warmup_iters:
161
+ return (it + 1) / warmup_iters
162
+ elif it <= num_iterations - warmdown_iters:
163
+ return 1.0
164
+ else:
165
+ progress = (num_iterations - it) / warmdown_iters
166
+ return progress * 1.0 + (1 - progress) * final_lr_frac
167
+
168
+ # Momentum scheduler for Muon optimizer
169
+ def get_muon_momentum(it):
170
+ frac = min(it / 300, 1)
171
+ momentum = (1 - frac) * 0.85 + frac * 0.95
172
+ return momentum
173
+
174
+ # -----------------------------------------------------------------------------
175
+ # Training loop
176
+ min_val_bpb = float("inf")
177
+ smooth_train_loss = 0 # EMA of training loss
178
+ ema_beta = 0.9 # EMA decay factor
179
+ total_training_time = 0 # total wall-clock time of training
180
+ # note that we run +1 steps only so that we can eval and save at the end
181
+ for step in range(num_iterations + 1):
182
+ last_step = step == num_iterations
183
+ flops_so_far = num_flops_per_token * total_batch_size * step
184
+
185
+ # once in a while: evaluate the val bpb (all ranks participate)
186
+ if last_step or step % eval_every == 0:
187
+ model.eval()
188
+ val_loader = build_val_loader()
189
+ eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
190
+ with autocast_ctx:
191
+ val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
192
+ print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
193
+ if val_bpb < min_val_bpb:
194
+ min_val_bpb = val_bpb
195
+ wandb_run.log({
196
+ "step": step,
197
+ "total_training_flops": flops_so_far,
198
+ "total_training_time": total_training_time,
199
+ "val/bpb": val_bpb,
200
+ })
201
+ model.train()
202
+
203
+ # once in a while: estimate the CORE metric (all ranks participate)
204
+ # use the original uncompiled model because the inputs keep changing shape
205
+ results = {}
206
+ if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
207
+ model.eval()
208
+ with autocast_ctx:
209
+ results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
210
+ print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
211
+ wandb_run.log({
212
+ "step": step,
213
+ "total_training_flops": flops_so_far,
214
+ "core_metric": results["core_metric"],
215
+ "centered_results": results["centered_results"],
216
+ })
217
+ model.train()
218
+
219
+ # once in a while: sample from the model (only on master process)
220
+ # use the original uncompiled model because the inputs keep changing shape
221
+ if master_process and (last_step or (step > 0 and step % sample_every == 0)):
222
+ model.eval()
223
+ prompts = [
224
+ "The capital of France is",
225
+ "The chemical symbol of gold is",
226
+ "If yesterday was Friday, then tomorrow will be",
227
+ "The opposite of hot is",
228
+ "The planets of the solar system are:",
229
+ "My favorite color is",
230
+ "If 5*x + 3 = 13, then x is",
231
+ ]
232
+ engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
233
+ for prompt in prompts:
234
+ tokens = tokenizer(prompt, prepend="<|bos|>")
235
+ with autocast_ctx:
236
+ sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
237
+ print0(tokenizer.decode(sample[0]))
238
+ model.train()
239
+
240
+ # save checkpoint at the end of the run (only on master process)
241
+ if master_process and last_step:
242
+ output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
243
+ checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
244
+ save_checkpoint(
245
+ checkpoint_dir,
246
+ step,
247
+ orig_model.state_dict(),
248
+ [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
249
+ {
250
+ "step": step,
251
+ "val_bpb": val_bpb, # loss at last step
252
+ "model_config": model_config_kwargs,
253
+ "user_config": user_config, # inputs to the training script
254
+ "device_batch_size": device_batch_size,
255
+ "max_seq_len": max_seq_len,
256
+ }
257
+ )
258
+
259
+ if last_step:
260
+ break
261
+
262
+ # -------------------------------------------------------------------------
263
+ # single training step
264
+ # evaluate the gradient
265
+ synchronize()
266
+ t0 = time.time()
267
+ for micro_step in range(grad_accum_steps):
268
+ with autocast_ctx:
269
+ loss = model(x, y)
270
+ train_loss = loss.detach() # for logging
271
+ loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
272
+ loss.backward()
273
+ x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
274
+ # gradient clipping
275
+ grad_clip_enabled = grad_clip > 0.0
276
+ if grad_clip_enabled:
277
+ grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
278
+ grad_norm = grad_norm_tensor.item() # GPU tensor -> CPU float (note: cpu-gpu sync point)
279
+ # step the optimizers
280
+ lrm = get_lr_multiplier(step)
281
+ for opt in optimizers:
282
+ for group in opt.param_groups:
283
+ group["lr"] = group["initial_lr"] * lrm
284
+ muon_momentum = get_muon_momentum(step)
285
+ for group in muon_optimizer.param_groups:
286
+ group["momentum"] = muon_momentum
287
+ for opt in optimizers:
288
+ opt.step()
289
+ model.zero_grad(set_to_none=True)
290
+ synchronize()
291
+ t1 = time.time()
292
+ dt = t1 - t0
293
+ # -------------------------------------------------------------------------
294
+
295
+ # logging
296
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
297
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
298
+ pct_done = 100 * step / num_iterations
299
+ tok_per_sec = int(total_batch_size / dt)
300
+ flops_per_sec = num_flops_per_token * total_batch_size / dt
301
+ promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
302
+ mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
303
+ if step > 10:
304
+ total_training_time += dt # only count the time after the first 10 steps
305
+ print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else ""
306
+ print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
307
+ if step % 100 == 0:
308
+ log_data = {
309
+ "step": step,
310
+ "total_training_flops": flops_so_far,
311
+ "total_training_time": total_training_time,
312
+ "train/loss": debiased_smooth_loss,
313
+ "train/lrm": lrm,
314
+ "train/dt": dt,
315
+ "train/tok_per_sec": tok_per_sec,
316
+ "train/mfu": mfu,
317
+ }
318
+ if grad_clip_enabled:
319
+ log_data["train/grad_norm"] = grad_norm
320
+ wandb_run.log(log_data)
321
+
322
+ # print a few more stats
323
+ print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
324
+ print0(f"Total training time: {total_training_time/60:.2f}m")
325
+ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
326
+
327
+ # Log to report
328
+ from nanochat.report import get_report
329
+ get_report().log(section="Base model training", data=[
330
+ user_config, # CLI args
331
+ { # stats about the training setup
332
+ "Number of parameters": num_params,
333
+ "Number of FLOPs per token": f"{num_flops_per_token:e}",
334
+ "Calculated number of iterations": num_iterations,
335
+ "Number of training tokens": total_tokens,
336
+ "Tokens : Params ratio": total_batch_size * num_iterations / num_params,
337
+ "DDP world size": ddp_world_size,
338
+ "warmup_ratio": warmup_ratio,
339
+ "warmdown_ratio": warmdown_ratio,
340
+ "final_lr_frac": final_lr_frac,
341
+ },
342
+ { # stats about training outcomes
343
+ "Minimum validation bpb": min_val_bpb,
344
+ "Final validation bpb": val_bpb,
345
+ "CORE metric estimate": results.get("core_metric", None),
346
+ "MFU %": f"{mfu:.2f}%",
347
+ "Total training flops": f"{flops_so_far:e}",
348
+ "Total training time": f"{total_training_time/60:.2f}m",
349
+ "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
350
+ }
351
+ ])
352
+
353
+ # cleanup
354
+ wandb_run.finish() # wandb run finish
355
+ compute_cleanup()
scripts/chat_cli.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ New and upgraded chat mode because a lot of the code has changed since the last one.
3
+
4
+ Intended to be run single GPU only atm:
5
+ python -m scripts.chat_cli -i mid
6
+ """
7
+ import argparse
8
+ import torch
9
+ from nanochat.common import compute_init, autodetect_device_type
10
+ from contextlib import nullcontext
11
+ from nanochat.engine import Engine
12
+ from nanochat.checkpoint_manager import load_model
13
+
14
+ parser = argparse.ArgumentParser(description='Chat with the model')
15
+ parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
16
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
17
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
18
+ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
19
+ parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
20
+ parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
21
+ parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
22
+ parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
23
+ args = parser.parse_args()
24
+
25
+ # Init the model and tokenizer
26
+
27
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
28
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
29
+ ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
30
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
31
+ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
32
+
33
+ # Special tokens for the chat state machine
34
+ bos = tokenizer.get_bos_token_id()
35
+ user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
36
+ assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
37
+
38
+ # Create Engine for efficient generation
39
+ engine = Engine(model, tokenizer)
40
+
41
+ print("\nNanoChat Interactive Mode")
42
+ print("-" * 50)
43
+ print("Type 'quit' or 'exit' to end the conversation")
44
+ print("Type 'clear' to start a new conversation")
45
+ print("-" * 50)
46
+
47
+ conversation_tokens = [bos]
48
+
49
+ while True:
50
+
51
+ if args.prompt:
52
+ # Get the prompt from the launch command
53
+ user_input = args.prompt
54
+ else:
55
+ # Get the prompt interactively from the console
56
+ try:
57
+ user_input = input("\nUser: ").strip()
58
+ except (EOFError, KeyboardInterrupt):
59
+ print("\nGoodbye!")
60
+ break
61
+
62
+ # Handle special commands
63
+ if user_input.lower() in ['quit', 'exit']:
64
+ print("Goodbye!")
65
+ break
66
+
67
+ if user_input.lower() == 'clear':
68
+ conversation_tokens = [bos]
69
+ print("Conversation cleared.")
70
+ continue
71
+
72
+ if not user_input:
73
+ continue
74
+
75
+ # Add User message to the conversation
76
+ conversation_tokens.append(user_start)
77
+ conversation_tokens.extend(tokenizer.encode(user_input))
78
+ conversation_tokens.append(user_end)
79
+
80
+ # Kick off the assistant
81
+ conversation_tokens.append(assistant_start)
82
+ generate_kwargs = {
83
+ "num_samples": 1,
84
+ "max_tokens": 256,
85
+ "temperature": args.temperature,
86
+ "top_k": args.top_k,
87
+ }
88
+ response_tokens = []
89
+ print("\nAssistant: ", end="", flush=True)
90
+ with autocast_ctx:
91
+ for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
92
+ token = token_column[0] # pop the batch dimension (num_samples=1)
93
+ response_tokens.append(token)
94
+ token_text = tokenizer.decode([token])
95
+ print(token_text, end="", flush=True)
96
+ print()
97
+ # we have to ensure that the assistant end token is the last token
98
+ # so even if generation ends due to max tokens, we have to append it to the end
99
+ if response_tokens[-1] != assistant_end:
100
+ response_tokens.append(assistant_end)
101
+ conversation_tokens.extend(response_tokens)
102
+
103
+ # In the prompt mode, we only want a single response and exit
104
+ if args.prompt:
105
+ break
scripts/chat_eval.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate the Chat model.
3
+ All the generic code lives here, and all the evlauation-specific
4
+ code lives in nanochat directory and is imported from here.
5
+
6
+ Example runs:
7
+ python -m scripts.chat_eval -a ARC-Easy
8
+ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
9
+ """
10
+
11
+ import argparse
12
+ from functools import partial
13
+ from contextlib import nullcontext
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+
18
+ from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
19
+ from nanochat.checkpoint_manager import load_model
20
+ from nanochat.engine import Engine
21
+
22
+ from tasks.humaneval import HumanEval
23
+ from tasks.mmlu import MMLU
24
+ from tasks.arc import ARC
25
+ from tasks.gsm8k import GSM8K
26
+ from tasks.spellingbee import SpellingBee
27
+
28
+ # -----------------------------------------------------------------------------
29
+ # Generative evaluation loop (we go one problem at a time, sample, evaluate)
30
+
31
+ def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=None):
32
+
33
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
34
+ device = model.get_device()
35
+
36
+ num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
37
+
38
+ # Run the evaluation
39
+ num_passed, total = 0, 0
40
+ for i in range(ddp_rank, num_problems, ddp_world_size):
41
+ conversation = task_object[i]
42
+
43
+ # Tokenize the prompt
44
+ encoded_prompt = tokenizer.render_for_completion(conversation)
45
+ # Get the completions
46
+ results, _ = engine.generate_batch(
47
+ encoded_prompt,
48
+ num_samples=num_samples,
49
+ max_tokens=max_new_tokens,
50
+ temperature=temperature,
51
+ top_k=top_k,
52
+ )
53
+ # Decode the completions as text
54
+ prefix_length = len(encoded_prompt)
55
+ completions = [tokenizer.decode(result_tokens[prefix_length:]) for result_tokens in results]
56
+ # Evaluate success criteria
57
+ outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
58
+ passed = any(outcomes)
59
+
60
+ # Keep stats
61
+ total += 1
62
+ num_passed += int(passed)
63
+
64
+ # Logging (overwrite the same line in the console)
65
+ print(f"\r\033[KRank {ddp_rank} | {num_passed}/{total} ({100*num_passed/total:.2f}%)", end='', flush=True)
66
+
67
+ # Finish the in-place progress line with a newline before final summary
68
+ print()
69
+
70
+ # Aggregate results across all ranks
71
+ if ddp:
72
+ num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
73
+ total_tensor = torch.tensor([total], dtype=torch.long, device=device)
74
+ dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
75
+ dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
76
+ num_passed = num_passed_tensor.item()
77
+ total = total_tensor.item()
78
+
79
+ print0("=" * 50)
80
+ print0(f"Final: {num_passed}/{total} ({100*num_passed/total:.2f}%)")
81
+
82
+ # Return the accuracy
83
+ return num_passed/total
84
+
85
+ # -----------------------------------------------------------------------------
86
+ # Categorical evaluation loop
87
+ # A lot easier because we don't have to sample. Therefore, we can actually go
88
+ # batches at a time and just check the logits for correct answer choices.
89
+
90
+ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
91
+
92
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
93
+ device = model.get_device()
94
+ bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
95
+
96
+ # We'll process batches of independent problems at a time because there is no sampling needed
97
+ num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
98
+ ceil_div = lambda x, y: -(-x // y)
99
+ num_batches = ceil_div(num_problems, batch_size)
100
+
101
+ # Run the evaluation
102
+ letter_to_id_cache = {} # many letters will repeat often, let's save the tokenizer some work
103
+ num_passed, total = 0, 0
104
+ for i in range(ddp_rank, num_batches, ddp_world_size):
105
+ i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
106
+
107
+ # Prepare the batch of problems. They might all be of different length, so we pad/collate them.
108
+ conversations = [task_object[ii] for ii in range(i0, i1)]
109
+ prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
110
+ max_length = max(len(ids) for ids in prompt_ids)
111
+ answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
112
+ padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
113
+ prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
114
+
115
+ # Get the logits for the whole batch of conversations in parallel (efficiency win here)
116
+ with torch.no_grad():
117
+ logits = model(prompt_ids) # (B, T, V)
118
+
119
+ # Focus on the available answer on just the letters corresponding to choices
120
+ # Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
121
+ # The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
122
+ # letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
123
+ for idx, conversation in enumerate(conversations):
124
+ # get the token ids of all the available letters of this problem
125
+ letters = conversation['letters']
126
+ letter_ids = []
127
+ for letter in letters:
128
+ if not letter in letter_to_id_cache:
129
+ encoded_letter = tokenizer.encode(letter)
130
+ assert len(encoded_letter) == 1, "Each letter must be a single token"
131
+ letter_to_id_cache[letter] = encoded_letter[0]
132
+ letter_ids.append(letter_to_id_cache[letter])
133
+ # focus logits just down to the answer position and the available letters of the answer
134
+ answer_pos = answer_time_positions[idx]
135
+ focus_logits = logits[idx, answer_pos, letter_ids]
136
+ # get the argmax letter (the predicted answer)
137
+ argmax_letter_id = focus_logits.argmax(dim=-1).item()
138
+ predicted_letter = letters[argmax_letter_id]
139
+ # evaluate the outcome
140
+ outcome = task_object.evaluate(conversation, predicted_letter)
141
+ num_passed += int(outcome)
142
+ total += 1
143
+
144
+ # Aggregate results across all ranks
145
+ if ddp:
146
+ num_passed_tensor = torch.tensor([num_passed], dtype=torch.long, device=device)
147
+ total_tensor = torch.tensor([total], dtype=torch.long, device=device)
148
+ dist.all_reduce(num_passed_tensor, op=dist.ReduceOp.SUM)
149
+ dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
150
+ num_passed = num_passed_tensor.item()
151
+ total = total_tensor.item()
152
+
153
+ average = num_passed/total
154
+ print0(f"Final: {num_passed}/{total} ({100*average:.2f}%)")
155
+ return average
156
+
157
+ # -----------------------------------------------------------------------------
158
+
159
+ def run_chat_eval(task_name, model, tokenizer, engine,
160
+ batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
161
+ max_problems=None):
162
+ # Create the evaluation object
163
+ task_module = {
164
+ 'HumanEval': HumanEval,
165
+ 'MMLU': partial(MMLU, subset="all", split="test"),
166
+ 'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
167
+ 'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
168
+ 'GSM8K': partial(GSM8K, subset="main", split="test"),
169
+ 'SpellingBee': partial(SpellingBee, size=256, split="test"),
170
+ }[task_name]
171
+ task_object = task_module()
172
+ # Run the evaluation
173
+ if task_object.eval_type == 'generative':
174
+ acc = run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, temperature, top_k, max_problems=max_problems)
175
+ elif task_object.eval_type == 'categorical':
176
+ acc = run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=max_problems)
177
+ else:
178
+ raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
179
+ return acc
180
+
181
+ # -----------------------------------------------------------------------------
182
+ if __name__ == "__main__":
183
+
184
+ # Parse command-line arguments
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|mid|rl")
187
+ parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
188
+ parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
189
+ parser.add_argument('-t', '--temperature', type=float, default=0.0)
190
+ parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
191
+ parser.add_argument('-n', '--num-samples', type=int, default=1)
192
+ parser.add_argument('-k', '--top-k', type=int, default=50)
193
+ parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch size for categorical evaluation')
194
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
195
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
196
+ parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
197
+ parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
198
+ args = parser.parse_args()
199
+
200
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
201
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
202
+ ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
203
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
204
+
205
+ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
206
+ engine = Engine(model, tokenizer)
207
+
208
+ # Get the tasks to evaluate on
209
+ all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
210
+ baseline_accuracies = {
211
+ 'ARC-Easy': 0.25, # multiple choice 1 of 4 => 25%
212
+ 'ARC-Challenge': 0.25, # multiple choice 1 of 4 => 25%
213
+ 'MMLU': 0.25, # multiple choice 1 of 4 => 25%
214
+ 'GSM8K': 0.0, # open-ended => 0%
215
+ 'HumanEval': 0.0, # open-ended => 0%
216
+ 'SpellingBee': 0.0, # open-ended => 0%
217
+ }
218
+ task_names = all_tasks if args.task_name is None else args.task_name.split('|')
219
+
220
+ # Run all the task evaluations sequentially
221
+ results = {}
222
+ for task_name in task_names:
223
+ with autocast_ctx:
224
+ acc = run_chat_eval(
225
+ task_name,
226
+ model, tokenizer, engine,
227
+ batch_size=args.batch_size,
228
+ num_samples=args.num_samples,
229
+ max_new_tokens=args.max_new_tokens,
230
+ temperature=args.temperature,
231
+ top_k=args.top_k,
232
+ max_problems=args.max_problems,
233
+ )
234
+ results[task_name] = acc
235
+ print0(f"{task_name} accuracy: {100 * acc:.2f}%")
236
+
237
+ # Log to report
238
+ from nanochat.report import get_report
239
+ all_tasks_were_evaluated = all(task_name in results for task_name in all_tasks)
240
+ # calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
241
+ # this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
242
+ chatcore_metric_dict = {}
243
+ if all_tasks_were_evaluated:
244
+ centered_mean = 0
245
+ for task_name, acc in results.items():
246
+ baseline_acc = baseline_accuracies.get(task_name, 0.0)
247
+ centered_acc = (acc - baseline_acc) / (1.0 - baseline_acc)
248
+ centered_mean += centered_acc
249
+ chatcore_metric = centered_mean / len(results)
250
+ chatcore_metric_dict = {"ChatCORE metric": chatcore_metric}
251
+ get_report().log(section="Chat evaluation " + args.source, data=[
252
+ vars(args), # CLI args
253
+ results,
254
+ chatcore_metric_dict,
255
+ ])
256
+
257
+ compute_cleanup()
scripts/chat_rl.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reinforcement learning on GSM8K via "GRPO".
3
+
4
+ I put GRPO in quotes because we actually end up with something a lot
5
+ simpler and more similar to just REINFORCE:
6
+
7
+ 1) Delete trust region, so there is no KL regularization to a reference model
8
+ 2) We are on policy, so there's no need for PPO ratio+clip.
9
+ 3) We use GAPO style normalization that is token-level, not sequence-level.
10
+ 4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
11
+
12
+ 1 GPU:
13
+ python -m scripts.chat_rl
14
+
15
+ 8 GPUs:
16
+ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
17
+ """
18
+
19
+ import os
20
+ import itertools
21
+ import re
22
+ import wandb
23
+ import torch
24
+ import torch.distributed as dist
25
+
26
+ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
27
+ from nanochat.checkpoint_manager import save_checkpoint, load_model
28
+ from nanochat.engine import Engine
29
+ from tasks.gsm8k import GSM8K
30
+
31
+ # RL hyperparameters
32
+ run = "dummy" # wandb run name
33
+ source = "sft" # mid|sft
34
+ dtype = "bfloat16"
35
+ device_batch_size = 8 # no forward pass will go above this to not OOM
36
+ examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
37
+ num_samples = 16 # number of samples per example (/question)
38
+ max_new_tokens = 256
39
+ temperature = 1.0
40
+ top_k = 50 # TODO: try None?
41
+ unembedding_lr = 0.004
42
+ embedding_lr = 0.2
43
+ matrix_lr = 0.02
44
+ weight_decay = 0.0
45
+ init_lr_frac = 0.05
46
+ num_epochs = 1 # how many epochs of gsm8k to train on
47
+ save_every = 60 # every how many steps to save the model
48
+ eval_every = 60 # every how many steps to evaluate the model for val pass@k
49
+ eval_examples = 400 # number of examples used for evaluating pass@k
50
+ # now allow CLI to override the settings via the configurator lol
51
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
52
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
53
+ user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
54
+ # -----------------------------------------------------------------------------
55
+
56
+ # Init compute/precision
57
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
58
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
59
+ dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
60
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
61
+
62
+ # wandb logging init
63
+ use_dummy_wandb = run == "dummy" or not master_process
64
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
65
+
66
+ # Init model and tokenizer
67
+ model, tokenizer, meta = load_model(source, device, phase="eval")
68
+ engine = Engine(model, tokenizer) # for sampling rollouts
69
+
70
+ # -----------------------------------------------------------------------------
71
+ # Rollout / sampling generator loop that yields batches of examples for training
72
+
73
+ train_task = GSM8K(subset="main", split="train")
74
+ val_task = GSM8K(subset="main", split="test")
75
+ num_steps = (len(train_task) // examples_per_step) * num_epochs
76
+ print0(f"Calculated number of steps: {num_steps}")
77
+
78
+ @torch.no_grad()
79
+ def get_batch():
80
+ assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
81
+ rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
82
+ for example_idx in itertools.cycle(rank_indices):
83
+
84
+ # First get the full conversation of both user and assistant messages
85
+ conversation = train_task[example_idx]
86
+
87
+ # Tokenize the conversation, deleting the last Assistant message and priming the Assistant for a completion instead
88
+ # (i.e. keep the <|assistant_start|>, but delete everything after it)
89
+ tokens = tokenizer.render_for_completion(conversation)
90
+ prefix_length = len(tokens)
91
+
92
+ # Generate num_samples samples using batched generation, use loop to avoid OOMs
93
+ model.eval() # ensure the model is in eval mode
94
+ generated_token_sequences = []
95
+ masks = []
96
+ num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
97
+ for sampling_step in range(num_sampling_steps):
98
+ seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
99
+ with autocast_ctx:
100
+ generated_token_sequences_batch, masks_batch = engine.generate_batch(
101
+ tokens,
102
+ num_samples=device_batch_size,
103
+ max_tokens=max_new_tokens,
104
+ temperature=temperature,
105
+ top_k=top_k,
106
+ seed=seed, # must make sure to change the seed for each sampling step
107
+ )
108
+ generated_token_sequences.extend(generated_token_sequences_batch)
109
+ masks.extend(masks_batch)
110
+
111
+ # Calculate the rewards for each sample
112
+ rewards = []
113
+ for sample_tokens in generated_token_sequences:
114
+ # Get just the generated tokens (after the prompt)
115
+ generated_tokens = sample_tokens[prefix_length:]
116
+ # Decode the generated response
117
+ generated_text = tokenizer.decode(generated_tokens)
118
+ # Calculate the reward
119
+ reward = train_task.reward(conversation, generated_text)
120
+ rewards.append(reward)
121
+
122
+ # Pad the sequences so that their lengths (in time) match
123
+ max_length = max(len(seq) for seq in generated_token_sequences)
124
+ padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
125
+ padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
126
+ # Stack up the sequences and masks into PyTorch tensors
127
+ ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
128
+ mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
129
+ # Generate autoregressive inputs and targets to the Transformer
130
+ inputs = ids[:, :-1]
131
+ targets = ids[:, 1:].clone() # clone to avoid in-place modification:
132
+ targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
133
+ # NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
134
+ # So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
135
+ rewards = torch.tensor(rewards, dtype=torch.float, device=device)
136
+ # Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
137
+ mu = rewards.mean()
138
+ advantages = rewards - mu
139
+ # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
140
+ yield generated_token_sequences, inputs, targets, rewards, advantages
141
+
142
+ # -----------------------------------------------------------------------------
143
+ # Simple evaluation loop for GSM8K pass@k
144
+ def run_gsm8k_eval(task, tokenizer, engine,
145
+ max_examples=None,
146
+ num_samples=1,
147
+ max_completion_tokens=256,
148
+ temperature=0.0,
149
+ top_k=50
150
+ ):
151
+ """
152
+ Evaluates GSM8K task and returns a list of records of evaluation outcomes.
153
+ In a distributed setting, all ranks cooperate but this function will NOT
154
+ do the reduction across ranks. This is the responsibility of the caller.
155
+ Because the evaluation can take a while, this function will yield records one by one.
156
+ """
157
+ max_examples = min(max_examples, len(task)) if max_examples is not None else len(task)
158
+ for idx in range(ddp_rank, max_examples, ddp_world_size):
159
+ conversation = task[idx]
160
+ tokens = tokenizer.render_for_completion(conversation)
161
+ prefix_length = len(tokens)
162
+ # Generate k samples using batched generation inside the Engine
163
+ assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
164
+ generated_token_sequences, masks = engine.generate_batch(
165
+ tokens,
166
+ num_samples=num_samples,
167
+ max_tokens=max_completion_tokens,
168
+ temperature=temperature,
169
+ top_k=top_k
170
+ )
171
+ # Check each sample for correctness
172
+ outcomes = []
173
+ for sample_tokens in generated_token_sequences:
174
+ generated_tokens = sample_tokens[prefix_length:]
175
+ generated_text = tokenizer.decode(generated_tokens)
176
+ is_correct = task.evaluate(conversation, generated_text)
177
+ outcomes.append({
178
+ "is_correct": is_correct
179
+ })
180
+ # A bit bloated because I wanted to do more complex logging at one point.
181
+ record = {
182
+ "idx": idx,
183
+ "outcomes": outcomes,
184
+ }
185
+ yield record
186
+
187
+ # -----------------------------------------------------------------------------
188
+ # Training loop
189
+
190
+ # Init the optimizer
191
+ optimizers = model.setup_optimizers(
192
+ unembedding_lr=unembedding_lr,
193
+ embedding_lr=embedding_lr,
194
+ matrix_lr=matrix_lr,
195
+ weight_decay=weight_decay,
196
+ )
197
+
198
+ # Set the initial learning rate as a fraction of the base learning rate
199
+ for opt in optimizers:
200
+ for group in opt.param_groups:
201
+ group["lr"] = group["lr"] * init_lr_frac
202
+ group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
203
+
204
+ # Learning rate scheduler: simple rampdown to zero over num_steps
205
+ def get_lr_multiplier(it):
206
+ lrm = 1.0 - it / num_steps
207
+ return lrm
208
+
209
+ # Calculate the number of examples each rank handles to achieve the desired examples_per_step
210
+ print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
211
+ assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
212
+ examples_per_rank = examples_per_step // ddp_world_size # per GPU
213
+ print0(f"Calculated examples per rank: {examples_per_rank}")
214
+
215
+ # Kick off the training loop
216
+ batch_iterator = get_batch()
217
+ for step in range(num_steps):
218
+
219
+ # Evaluate the model once in a while and log to wandb
220
+ if step % eval_every == 0:
221
+ model.eval()
222
+ passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
223
+ with autocast_ctx:
224
+ records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
225
+ records = list(records_iter) # collect all records
226
+ for k in range(1, device_batch_size + 1):
227
+ passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
228
+ num_records = torch.tensor(len(records), dtype=torch.long, device=device)
229
+ if ddp:
230
+ dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
231
+ dist.all_reduce(passk, op=dist.ReduceOp.SUM)
232
+ passk = passk / num_records.item() # normalize by the total number of records
233
+ print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
234
+ print0(f"Step {step} | {', '.join(print_passk)}")
235
+ log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
236
+ wandb_run.log({
237
+ "step": step,
238
+ **log_passk,
239
+ })
240
+
241
+ # Forward/Backward on rollouts over multiple examples in the dataset
242
+ rewards_list = []
243
+ sequence_lengths = []
244
+ for example_step in range(examples_per_rank):
245
+ # Get one batch corresponding to one example in the training dataset
246
+ sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
247
+ # Evaluate the loss and gradients
248
+ model.train() # ensure the model is in train mode
249
+ # We need one more loop because we can never exceed the device_batch_size
250
+ assert inputs_all.size(0) % device_batch_size == 0
251
+ num_passes = inputs_all.size(0) // device_batch_size
252
+ for pass_idx in range(num_passes):
253
+ # Pluck out the batch for this pass
254
+ b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
255
+ inputs = inputs_all[b0:b1]
256
+ targets = targets_all[b0:b1]
257
+ rewards = rewards_all[b0:b1]
258
+ advantages = advantages_all[b0:b1]
259
+ # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
260
+ with autocast_ctx:
261
+ logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
262
+ # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
263
+ pg_obj = (logp * advantages.unsqueeze(-1)).sum()
264
+ # normalize by the number of valid tokens, number of passes, and examples_per_rank
265
+ num_valid = (targets >= 0).sum().clamp(min=1)
266
+ pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
267
+ # Note, there is no need to add PPO ratio+clip because we are on policy
268
+ # Finally, formulate the loss that we want to minimize (instead of objective we wish to maximize)
269
+ loss = -pg_obj
270
+ loss.backward()
271
+ print0(f"Step {step}/{num_steps} | Example step {example_step} | Pass {pass_idx} | loss: {loss.item():.6f} | Average reward: {rewards.mean().item()}")
272
+ # For logging
273
+ rewards_list.append(rewards_all.mean().item())
274
+ sequence_lengths.extend(len(seq) for seq in sequences_all)
275
+
276
+ # A bunch of logging for how the rollouts went this step
277
+ mean_reward = sum(rewards_list) / len(rewards_list)
278
+ mean_sequence_length = sum(sequence_lengths) / len(sequence_lengths)
279
+ if ddp: # aggregate across ranks
280
+ mean_reward_tensor = torch.tensor(mean_reward, dtype=torch.float, device=device)
281
+ mean_sequence_length_tensor = torch.tensor(mean_sequence_length, dtype=torch.float, device=device)
282
+ dist.all_reduce(mean_reward_tensor, op=dist.ReduceOp.AVG)
283
+ dist.all_reduce(mean_sequence_length_tensor, op=dist.ReduceOp.AVG)
284
+ mean_reward = mean_reward_tensor.item()
285
+ mean_sequence_length = mean_sequence_length_tensor.item()
286
+ print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
287
+ wandb_run.log({
288
+ "step": step,
289
+ "reward": mean_reward,
290
+ "sequence_length": mean_sequence_length,
291
+ })
292
+
293
+ # Update the model parameters
294
+ lrm = get_lr_multiplier(step)
295
+ for opt in optimizers: # first set the learning rate
296
+ for group in opt.param_groups:
297
+ group["lr"] = group["initial_lr"] * lrm
298
+ for opt in optimizers: # then step the optimizers
299
+ opt.step()
300
+ model.zero_grad(set_to_none=True)
301
+ wandb_run.log({
302
+ "step": step,
303
+ "lrm": lrm,
304
+ })
305
+
306
+ # Master process saves the model once in a while. Skip first step. Save last step.
307
+ if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
308
+ base_dir = get_base_dir()
309
+ depth = model.config.n_layer
310
+ model_tag = f"d{depth}" # base the model tag on the depth of the base model
311
+ checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
312
+ model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
313
+ save_checkpoint(
314
+ checkpoint_dir,
315
+ step,
316
+ model.state_dict(),
317
+ None, # note: we don't bother to save the optimizer state
318
+ {
319
+ "model_config": model_config_kwargs,
320
+ }
321
+ )
322
+ print(f"✅ Saved model checkpoint to {checkpoint_dir}")
323
+
324
+ # Log to report
325
+ from nanochat.report import get_report
326
+ get_report().log(section="Chat RL", data=[
327
+ user_config, # CLI args
328
+ ])
329
+
330
+ wandb_run.finish() # wandb run finish
331
+ compute_cleanup()
scripts/chat_sft.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finetune a base model to be a chat model.
3
+ Run on one GPU e.g. for debugging:
4
+
5
+ python -m scripts.chat_sft
6
+
7
+ Or torchrun for training:
8
+
9
+ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
10
+ """
11
+
12
+ import os
13
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
14
+
15
+ import wandb
16
+ import torch
17
+ import torch.distributed as dist
18
+ from contextlib import nullcontext
19
+
20
+ from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
21
+ from nanochat.checkpoint_manager import load_model
22
+ from nanochat.checkpoint_manager import save_checkpoint
23
+ from nanochat.engine import Engine
24
+ from scripts.chat_eval import run_chat_eval
25
+
26
+ from tasks.common import TaskMixture
27
+ from tasks.arc import ARC
28
+ from tasks.gsm8k import GSM8K
29
+ from tasks.smoltalk import SmolTalk
30
+ from tasks.customjson import CustomJSON
31
+ from tasks.spellingbee import SimpleSpelling, SpellingBee
32
+
33
+ # -----------------------------------------------------------------------------
34
+ # SFT Hyperparameters
35
+ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
36
+ # input model options
37
+ source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model)
38
+ model_tag = None # model tag to load the model from (base model or midtrained model)
39
+ step = None # step to load the model from (base model or midtrained model)
40
+ # compute/precision
41
+ device_type = "" # cuda|cpu|mps (empty => autodetect)
42
+ dtype = "bfloat16"
43
+ device_batch_size = 4 # max to avoid OOM
44
+ # optimization
45
+ num_epochs = 1
46
+ num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
47
+ target_examples_per_step = 32
48
+ unembedding_lr = 0.004
49
+ embedding_lr = 0.2
50
+ matrix_lr = 0.02
51
+ weight_decay = 0.0
52
+ init_lr_frac = 0.02
53
+ # evaluation and logging there of
54
+ eval_every = 100
55
+ eval_steps = 100
56
+ eval_metrics_every = 200
57
+ eval_metrics_max_problems = 1024
58
+ # now allow CLI to override the settings via the configurator lol
59
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
60
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
61
+ user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
62
+ # -----------------------------------------------------------------------------
63
+
64
+ # Compute init
65
+ device_type = autodetect_device_type() if device_type == "" else device_type
66
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
67
+ master_process = ddp_rank == 0
68
+ ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
69
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
70
+
71
+ # wandb logging init
72
+ use_dummy_wandb = run == "dummy" or not master_process
73
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
74
+
75
+ # Load the model and tokenizer
76
+ model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
77
+ orig_model = model # original, uncompiled model
78
+ # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
79
+ engine = Engine(model, tokenizer) # will be used for inline model evaluation only
80
+
81
+ # -----------------------------------------------------------------------------
82
+ # Task data mixture we'll train on
83
+ identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
84
+ train_ds = TaskMixture([
85
+ ARC(subset="ARC-Easy", split="train"), # 2.3K rows
86
+ ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
87
+ GSM8K(subset="main", split="train"), # 8K rows
88
+ SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
89
+ CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
90
+ SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
91
+ SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
92
+ ]) # 2.3K + 1.1K + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows
93
+ val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
94
+
95
+ # -----------------------------------------------------------------------------
96
+ # DataLoader
97
+
98
+ def sft_data_generator(dataset, batch_size):
99
+ pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
100
+ # prepares a list of tokenized conversations into a batch and yields
101
+ def collate_and_yield(batch):
102
+ nrows = len(batch)
103
+ ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
104
+ inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
105
+ targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
106
+ for i, (ids, mask) in enumerate(batch):
107
+ n = len(ids)
108
+ ids_tensor = torch.tensor(ids, dtype=torch.long)
109
+ inputs[i, :n-1] = ids_tensor[:-1]
110
+ # recall -1 is the ignore index, so mask out targets where mask is 0
111
+ row_targets = ids_tensor[1:]
112
+ # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok
113
+ mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
114
+ row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0
115
+ targets[i, :n-1] = row_targets
116
+ inputs = inputs.to(device) # move to device
117
+ targets = targets.to(device)
118
+ return inputs, targets
119
+ # iterates over the dataset in epochs, tokenizes
120
+ batch = []
121
+ while True:
122
+ for i in range(ddp_rank, len(dataset), ddp_world_size):
123
+ doc = dataset[i]
124
+ ids, mask = tokenizer.render_conversation(doc)
125
+ batch.append((ids, mask))
126
+ if len(batch) == batch_size:
127
+ yield collate_and_yield(batch)
128
+ batch = []
129
+
130
+ examples_per_step = device_batch_size * ddp_world_size
131
+ print0(f"Target examples per step: {target_examples_per_step}")
132
+ print0(f"Device batch size: {device_batch_size}")
133
+ print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}")
134
+ assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step"
135
+ grad_accum_steps = target_examples_per_step // examples_per_step
136
+ print0(f"=> Setting grad accum steps: {grad_accum_steps}")
137
+
138
+ if num_iterations == -1:
139
+ # derive num_iterations from num_epochs and the size of the dataset
140
+ assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
141
+ num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
142
+ train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
143
+ build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
144
+
145
+ # -----------------------------------------------------------------------------
146
+ # Initialize the Optimizer
147
+
148
+ optimizers = model.setup_optimizers(
149
+ unembedding_lr=unembedding_lr,
150
+ embedding_lr=embedding_lr,
151
+ matrix_lr=matrix_lr,
152
+ weight_decay=weight_decay,
153
+ )
154
+ # Set the initial learning rate as a fraction of the base learning rate
155
+ for opt in optimizers:
156
+ for group in opt.param_groups:
157
+ group["lr"] = group["lr"] * init_lr_frac
158
+ group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
159
+
160
+ # -----------------------------------------------------------------------------
161
+ # Training loop
162
+
163
+ # Learning rate scheduler
164
+ def get_lr_multiplier(it):
165
+ lrm = 1.0 - it / num_iterations
166
+ return lrm
167
+
168
+ # Go!
169
+ step = 0
170
+ train_iter = iter(train_loader)
171
+ for step in range(num_iterations):
172
+ last_step = step == num_iterations - 1
173
+
174
+ # evaluate the validation loss
175
+ if last_step or step % eval_every == 0:
176
+ model.eval()
177
+ val_iter = iter(build_val_loader())
178
+ losses = []
179
+ for _ in range(eval_steps):
180
+ val_inputs, val_targets = next(val_iter)
181
+ with torch.no_grad(), autocast_ctx:
182
+ loss = model(val_inputs, val_targets)
183
+ losses.append(loss)
184
+ val_loss = torch.stack(losses).mean() # average over eval_steps
185
+ if ddp:
186
+ dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks
187
+ val_loss = val_loss.item()
188
+ print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}")
189
+ wandb_run.log({
190
+ "step": step,
191
+ "val_loss": val_loss,
192
+ })
193
+ model.train()
194
+
195
+ # evlauate accuracy of the multiple choice tasks (which are quick to run)
196
+ if last_step or (step > 0 and step % eval_metrics_every == 0):
197
+ model.eval()
198
+ metrics = {}
199
+ with torch.no_grad(), autocast_ctx:
200
+ # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
201
+ metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
202
+ metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
203
+ metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
204
+ print0(f"Step {step:05d} | {metrics_str}")
205
+ wandb_run.log({
206
+ "step": step,
207
+ **metrics,
208
+ })
209
+ model.train()
210
+
211
+ if last_step:
212
+ break
213
+
214
+ # evaluate the gradient
215
+ num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
216
+ for micro_step in range(grad_accum_steps):
217
+ train_inputs, train_targets = next(train_iter)
218
+ with autocast_ctx:
219
+ loss = model(train_inputs, train_targets)
220
+ train_loss = loss.detach() # for logging
221
+ loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
222
+ loss.backward() # accumulate the gradient
223
+ num_tokens += (train_targets >= 0).sum()
224
+ if ddp:
225
+ dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
226
+
227
+ # learning rate scheduler
228
+ lrm = get_lr_multiplier(step)
229
+ for opt in optimizers:
230
+ for group in opt.param_groups:
231
+ group["lr"] = group["initial_lr"] * lrm
232
+
233
+ # step the optimizers
234
+ for opt in optimizers:
235
+ opt.step()
236
+ model.zero_grad(set_to_none=True)
237
+
238
+ # logging
239
+ train_loss_item = train_loss.item()
240
+ num_tokens_item = num_tokens.item()
241
+ print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
242
+ wandb_run.log({
243
+ "step": step,
244
+ "lrm": lrm,
245
+ "train_loss": train_loss_item,
246
+ "num_tokens": num_tokens_item,
247
+ })
248
+ step += 1
249
+
250
+ # Save the model at the end of the run
251
+ if master_process:
252
+ base_dir = get_base_dir()
253
+ depth = model.config.n_layer
254
+ model_tag = f"d{depth}" # base the model tag on the depth of the base model
255
+ checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
256
+ model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
257
+ save_checkpoint(
258
+ checkpoint_dir,
259
+ step,
260
+ model.state_dict(),
261
+ None, # note: we don't bother to save the optimizer state
262
+ {
263
+ "step": step,
264
+ "val_loss": val_loss,
265
+ **metrics,
266
+ "model_config": model_config_kwargs,
267
+ }
268
+ )
269
+ print(f"✅ Saved model checkpoint to {checkpoint_dir}")
270
+
271
+ # Log to report
272
+ from nanochat.report import get_report
273
+ get_report().log(section="Chat SFT", data=[
274
+ user_config, # CLI args
275
+ {
276
+ "Training rows": len(train_ds),
277
+ "Number of iterations": num_iterations,
278
+ "Training loss": train_loss_item,
279
+ "Validation loss": val_loss,
280
+ },
281
+ ])
282
+
283
+ # Cleanup
284
+ wandb_run.finish()
285
+ compute_cleanup()
scripts/chat_web.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unified web chat server - serves both UI and API from a single FastAPI instance.
4
+
5
+ Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
6
+ a full copy of the model, and incoming requests are distributed to available workers.
7
+
8
+ Launch examples:
9
+
10
+ - single available GPU (default)
11
+ python -m scripts.chat_web
12
+
13
+ - 4 GPUs
14
+ python -m scripts.chat_web --num-gpus 4
15
+
16
+ To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
17
+
18
+ Endpoints:
19
+ GET / - Chat UI
20
+ POST /chat/completions - Chat API (streaming only)
21
+ GET /health - Health check with worker pool status
22
+ GET /stats - Worker pool statistics and GPU utilization
23
+
24
+ Abuse Prevention:
25
+ - Maximum 500 messages per request
26
+ - Maximum 8000 characters per message
27
+ - Maximum 32000 characters total conversation length
28
+ - Temperature clamped to 0.0-2.0
29
+ - Top-k clamped to 1-200
30
+ - Max tokens clamped to 1-4096
31
+ """
32
+
33
+ import argparse
34
+ import json
35
+ import os
36
+ import torch
37
+ import asyncio
38
+ import logging
39
+ import random
40
+ from contextlib import asynccontextmanager
41
+ from fastapi import FastAPI, HTTPException
42
+ from fastapi.middleware.cors import CORSMiddleware
43
+ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
44
+ from pydantic import BaseModel
45
+ from typing import List, Optional, AsyncGenerator
46
+ from dataclasses import dataclass
47
+ from contextlib import nullcontext
48
+ from nanochat.common import compute_init, autodetect_device_type
49
+ from nanochat.checkpoint_manager import load_model
50
+ from nanochat.engine import Engine
51
+
52
+ # Abuse prevention limits
53
+ MAX_MESSAGES_PER_REQUEST = 500
54
+ MAX_MESSAGE_LENGTH = 8000
55
+ MAX_TOTAL_CONVERSATION_LENGTH = 32000
56
+ MIN_TEMPERATURE = 0.0
57
+ MAX_TEMPERATURE = 2.0
58
+ MIN_TOP_K = 1
59
+ MAX_TOP_K = 200
60
+ MIN_MAX_TOKENS = 1
61
+ MAX_MAX_TOKENS = 4096
62
+
63
+ parser = argparse.ArgumentParser(description='NanoChat Web Server')
64
+ parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
65
+ parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
66
+ parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
67
+ parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
68
+ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
69
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
70
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
71
+ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
72
+ parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
73
+ parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
74
+ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
75
+ args = parser.parse_args()
76
+
77
+ # Configure logging for conversation traffic
78
+ logging.basicConfig(
79
+ level=logging.INFO,
80
+ format='%(asctime)s - %(message)s',
81
+ datefmt='%Y-%m-%d %H:%M:%S'
82
+ )
83
+ logger = logging.getLogger(__name__)
84
+
85
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
86
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
87
+ ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
88
+
89
+ @dataclass
90
+ class Worker:
91
+ """A worker with a model loaded on a specific GPU."""
92
+ gpu_id: int
93
+ device: torch.device
94
+ engine: Engine
95
+ tokenizer: object
96
+ autocast_ctx: torch.amp.autocast
97
+
98
+ class WorkerPool:
99
+ """Pool of workers, each with a model replica on a different GPU."""
100
+
101
+ def __init__(self, num_gpus: Optional[int] = None):
102
+ if num_gpus is None:
103
+ if device_type == "cuda":
104
+ num_gpus = torch.cuda.device_count()
105
+ else:
106
+ num_gpus = 1 # e.g. cpu|mps
107
+ self.num_gpus = num_gpus
108
+ self.workers: List[Worker] = []
109
+ self.available_workers: asyncio.Queue = asyncio.Queue()
110
+
111
+ async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
112
+ """Load model on each GPU."""
113
+ print(f"Initializing worker pool with {self.num_gpus} GPUs...")
114
+ if self.num_gpus > 1:
115
+ assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
116
+
117
+ for gpu_id in range(self.num_gpus):
118
+
119
+ if device_type == "cuda":
120
+ device = torch.device(f"cuda:{gpu_id}")
121
+ print(f"Loading model on GPU {gpu_id}...")
122
+ else:
123
+ device = torch.device(device_type) # e.g. cpu|mps
124
+ print(f"Loading model on {device_type}...")
125
+
126
+ model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
127
+ engine = Engine(model, tokenizer)
128
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
129
+
130
+ worker = Worker(
131
+ gpu_id=gpu_id,
132
+ device=device,
133
+ engine=engine,
134
+ tokenizer=tokenizer,
135
+ autocast_ctx=autocast_ctx
136
+ )
137
+ self.workers.append(worker)
138
+ await self.available_workers.put(worker)
139
+
140
+ print(f"All {self.num_gpus} workers initialized!")
141
+
142
+ async def acquire_worker(self) -> Worker:
143
+ """Get an available worker from the pool."""
144
+ return await self.available_workers.get()
145
+
146
+ async def release_worker(self, worker: Worker):
147
+ """Return a worker to the pool."""
148
+ await self.available_workers.put(worker)
149
+
150
+ class ChatMessage(BaseModel):
151
+ role: str
152
+ content: str
153
+
154
+ class ChatRequest(BaseModel):
155
+ messages: List[ChatMessage]
156
+ temperature: Optional[float] = None
157
+ max_tokens: Optional[int] = None
158
+ top_k: Optional[int] = None
159
+
160
+ def validate_chat_request(request: ChatRequest):
161
+ """Validate chat request to prevent abuse."""
162
+ # Check number of messages
163
+ if len(request.messages) == 0:
164
+ raise HTTPException(status_code=400, detail="At least one message is required")
165
+ if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
166
+ raise HTTPException(
167
+ status_code=400,
168
+ detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
169
+ )
170
+
171
+ # Check individual message lengths and total conversation length
172
+ total_length = 0
173
+ for i, message in enumerate(request.messages):
174
+ if not message.content:
175
+ raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
176
+
177
+ msg_length = len(message.content)
178
+ if msg_length > MAX_MESSAGE_LENGTH:
179
+ raise HTTPException(
180
+ status_code=400,
181
+ detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
182
+ )
183
+ total_length += msg_length
184
+
185
+ if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
186
+ raise HTTPException(
187
+ status_code=400,
188
+ detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
189
+ )
190
+
191
+ # Validate role values
192
+ for i, message in enumerate(request.messages):
193
+ if message.role not in ["user", "assistant"]:
194
+ raise HTTPException(
195
+ status_code=400,
196
+ detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
197
+ )
198
+
199
+ # Validate temperature
200
+ if request.temperature is not None:
201
+ if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
202
+ raise HTTPException(
203
+ status_code=400,
204
+ detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
205
+ )
206
+
207
+ # Validate top_k
208
+ if request.top_k is not None:
209
+ if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
210
+ raise HTTPException(
211
+ status_code=400,
212
+ detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
213
+ )
214
+
215
+ # Validate max_tokens
216
+ if request.max_tokens is not None:
217
+ if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
218
+ raise HTTPException(
219
+ status_code=400,
220
+ detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
221
+ )
222
+
223
+ @asynccontextmanager
224
+ async def lifespan(app: FastAPI):
225
+ """Load models on all GPUs on startup."""
226
+ print("Loading nanochat models across GPUs...")
227
+ app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
228
+ await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
229
+ print(f"Server ready at http://localhost:{args.port}")
230
+ yield
231
+
232
+ app = FastAPI(lifespan=lifespan)
233
+
234
+ app.add_middleware(
235
+ CORSMiddleware,
236
+ allow_origins=["*"],
237
+ allow_credentials=True,
238
+ allow_methods=["*"],
239
+ allow_headers=["*"],
240
+ )
241
+
242
+ @app.get("/")
243
+ async def root():
244
+ """Serve the chat UI."""
245
+ ui_html_path = os.path.join("nanochat", "ui.html")
246
+ with open(ui_html_path, "r", encoding="utf-8") as f:
247
+ html_content = f.read()
248
+ # Replace the API_URL to use the same origin
249
+ html_content = html_content.replace(
250
+ "const API_URL = `http://${window.location.hostname}:8000`;",
251
+ "const API_URL = '';"
252
+ )
253
+ return HTMLResponse(content=html_content)
254
+
255
+
256
+ @app.get("/logo.svg")
257
+ async def logo():
258
+ """Serve the NanoChat logo for favicon and header."""
259
+ logo_path = os.path.join("nanochat", "logo.svg")
260
+ return FileResponse(logo_path, media_type="image/svg+xml")
261
+
262
+ async def generate_stream(
263
+ worker: Worker,
264
+ tokens,
265
+ temperature=None,
266
+ max_new_tokens=None,
267
+ top_k=None
268
+ ) -> AsyncGenerator[str, None]:
269
+ """Generate assistant response with streaming."""
270
+ temperature = temperature if temperature is not None else args.temperature
271
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
272
+ top_k = top_k if top_k is not None else args.top_k
273
+
274
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
275
+ bos = worker.tokenizer.get_bos_token_id()
276
+
277
+ # Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
278
+ accumulated_tokens = []
279
+ # Track the last complete UTF-8 string (without replacement characters)
280
+ last_clean_text = ""
281
+
282
+ with worker.autocast_ctx:
283
+ for token_column, token_masks in worker.engine.generate(
284
+ tokens,
285
+ num_samples=1,
286
+ max_tokens=max_new_tokens,
287
+ temperature=temperature,
288
+ top_k=top_k,
289
+ seed=random.randint(0, 2**31 - 1)
290
+ ):
291
+ token = token_column[0]
292
+
293
+ # Stopping criteria
294
+ if token == assistant_end or token == bos:
295
+ break
296
+
297
+ # Append the token to sequence
298
+ accumulated_tokens.append(token)
299
+ # Decode all accumulated tokens to get proper UTF-8 handling
300
+ # Note that decode is a quite efficient operation, basically table lookup and string concat
301
+ current_text = worker.tokenizer.decode(accumulated_tokens)
302
+ # Only emit text if it doesn't end with a replacement character
303
+ # This ensures we don't emit incomplete UTF-8 sequences
304
+ if not current_text.endswith('�'):
305
+ # Extract only the new text since last clean decode
306
+ new_text = current_text[len(last_clean_text):]
307
+ if new_text: # Only yield if there's new content
308
+ yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
309
+ last_clean_text = current_text
310
+
311
+ yield f"data: {json.dumps({'done': True})}\n\n"
312
+
313
+ @app.post("/chat/completions")
314
+ async def chat_completions(request: ChatRequest):
315
+ """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
316
+
317
+ # Basic validation to prevent abuse
318
+ validate_chat_request(request)
319
+
320
+ # Log incoming conversation to console
321
+ logger.info("="*20)
322
+ for i, message in enumerate(request.messages):
323
+ logger.info(f"[{message.role.upper()}]: {message.content}")
324
+ logger.info("-"*20)
325
+
326
+ # Acquire a worker from the pool (will wait if all are busy)
327
+ worker_pool = app.state.worker_pool
328
+ worker = await worker_pool.acquire_worker()
329
+
330
+ try:
331
+ # Build conversation tokens
332
+ bos = worker.tokenizer.get_bos_token_id()
333
+ user_start = worker.tokenizer.encode_special("<|user_start|>")
334
+ user_end = worker.tokenizer.encode_special("<|user_end|>")
335
+ assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
336
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
337
+
338
+ conversation_tokens = [bos]
339
+ for message in request.messages:
340
+ if message.role == "user":
341
+ conversation_tokens.append(user_start)
342
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
343
+ conversation_tokens.append(user_end)
344
+ elif message.role == "assistant":
345
+ conversation_tokens.append(assistant_start)
346
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
347
+ conversation_tokens.append(assistant_end)
348
+
349
+ conversation_tokens.append(assistant_start)
350
+
351
+ # Streaming response with worker release after completion
352
+ response_tokens = []
353
+ async def stream_and_release():
354
+ try:
355
+ async for chunk in generate_stream(
356
+ worker,
357
+ conversation_tokens,
358
+ temperature=request.temperature,
359
+ max_new_tokens=request.max_tokens,
360
+ top_k=request.top_k
361
+ ):
362
+ # Accumulate response for logging
363
+ chunk_data = json.loads(chunk.replace("data: ", "").strip())
364
+ if "token" in chunk_data:
365
+ response_tokens.append(chunk_data["token"])
366
+ yield chunk
367
+ finally:
368
+ # Log the assistant response to console
369
+ full_response = "".join(response_tokens)
370
+ logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
371
+ logger.info("="*20)
372
+ # Release worker back to pool after streaming is done
373
+ await worker_pool.release_worker(worker)
374
+
375
+ return StreamingResponse(
376
+ stream_and_release(),
377
+ media_type="text/event-stream"
378
+ )
379
+ except Exception as e:
380
+ # Make sure to release worker even on error
381
+ await worker_pool.release_worker(worker)
382
+ raise e
383
+
384
+ @app.get("/health")
385
+ async def health():
386
+ """Health check endpoint."""
387
+ worker_pool = getattr(app.state, 'worker_pool', None)
388
+ return {
389
+ "status": "ok",
390
+ "ready": worker_pool is not None and len(worker_pool.workers) > 0,
391
+ "num_gpus": worker_pool.num_gpus if worker_pool else 0,
392
+ "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
393
+ }
394
+
395
+ @app.get("/stats")
396
+ async def stats():
397
+ """Get worker pool statistics."""
398
+ worker_pool = app.state.worker_pool
399
+ return {
400
+ "total_workers": len(worker_pool.workers),
401
+ "available_workers": worker_pool.available_workers.qsize(),
402
+ "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
403
+ "workers": [
404
+ {
405
+ "gpu_id": w.gpu_id,
406
+ "device": str(w.device)
407
+ } for w in worker_pool.workers
408
+ ]
409
+ }
410
+
411
+ if __name__ == "__main__":
412
+ import uvicorn
413
+ print(f"Starting NanoChat Web Server")
414
+ print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
415
+ uvicorn.run(app, host=args.host, port=args.port)
scripts/mid_train.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Midtrain the model. Same as pretraining but simpler.
3
+ Run as:
4
+
5
+ python -m scripts.mid_train
6
+
7
+ Or torchrun for training:
8
+
9
+ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
10
+ """
11
+
12
+ from collections import deque
13
+ import os
14
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
+ import time
16
+ import wandb
17
+ import torch
18
+ from contextlib import nullcontext
19
+ from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
20
+ from nanochat.tokenizer import get_token_bytes
21
+ from nanochat.checkpoint_manager import save_checkpoint
22
+ from nanochat.loss_eval import evaluate_bpb
23
+ from nanochat.checkpoint_manager import load_model
24
+ import torch.distributed as dist
25
+
26
+ from tasks.common import TaskMixture
27
+ from tasks.gsm8k import GSM8K
28
+ from tasks.mmlu import MMLU
29
+ from tasks.smoltalk import SmolTalk
30
+ from tasks.customjson import CustomJSON
31
+ from tasks.spellingbee import SimpleSpelling, SpellingBee
32
+
33
+ # -----------------------------------------------------------------------------
34
+ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
35
+ device_type = "" # cuda|cpu|mps (empty => autodetect)
36
+ model_tag = None # model tag to load the model from (base model or midtrained model)
37
+ step = None # step to load the model from (base model or midtrained model)
38
+ dtype = "bfloat16"
39
+ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
40
+ max_seq_len = 2048
41
+ device_batch_size = 32
42
+ unembedding_lr = 0.004
43
+ embedding_lr = 0.2
44
+ matrix_lr = 0.02
45
+ init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
46
+ weight_decay = 0.0
47
+ eval_every = 150 # -1 = disable
48
+ eval_tokens = 20*524288
49
+ total_batch_size = 524288
50
+ dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
51
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
52
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
53
+ user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
54
+ # -----------------------------------------------------------------------------
55
+
56
+ # Compute init
57
+ device_type = autodetect_device_type() if device_type == "" else device_type
58
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
59
+ master_process = ddp_rank == 0
60
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
61
+ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
62
+ get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
63
+
64
+ # wandb logging init
65
+ use_dummy_wandb = run == "dummy" or not master_process
66
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
67
+
68
+ # Load the model and tokenizer
69
+ model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)
70
+ pretrain_batch_size = meta.get("device_batch_size", None)
71
+ if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
72
+ print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
73
+ orig_model = model
74
+ model = torch.compile(model, dynamic=False)
75
+ depth = model.config.n_layer
76
+ num_flops_per_token = model.estimate_flops()
77
+ tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
78
+ world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
79
+ assert total_batch_size % world_tokens_per_fwdbwd == 0
80
+ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
81
+ print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
82
+ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
83
+ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
84
+ token_bytes = get_token_bytes(device=device)
85
+
86
+ # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
87
+ optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
88
+ adamw_optimizer, muon_optimizer = optimizers
89
+ # Override the initial learning rate as a fraction of the base learning rate
90
+ for opt in optimizers:
91
+ for group in opt.param_groups:
92
+ group["lr"] = group["lr"] * init_lr_frac
93
+ group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
94
+
95
+ # Midtraining data mixture and DataLoader
96
+ base_dir = get_base_dir()
97
+ identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
98
+ train_dataset = TaskMixture([
99
+ SmolTalk(split="train"), # 460K rows of general conversations
100
+ MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
101
+ GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
102
+ CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
103
+ CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
104
+ SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
105
+ SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
106
+ ]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
107
+ val_dataset = TaskMixture([
108
+ SmolTalk(split="test"), # 24K rows in test set
109
+ MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
110
+ GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
111
+ ]) # total: 24K + 14K + 1.32K ~= 39K rows
112
+ # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
113
+ # A big problem is that we don't know the final num_iterations in advance. So we create
114
+ # these two global variables and update them from within the data generator.
115
+ last_step = False # we will toggle this to True when we reach the end of the dataset
116
+ approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
117
+ def mid_data_generator(split):
118
+ global last_step, approx_progress
119
+ assert split in {"train", "val"}, "split must be 'train' or 'val'"
120
+ dataset = train_dataset if split == "train" else val_dataset
121
+ dataset_size = len(dataset)
122
+ assert dataset_size > 0
123
+ needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
124
+ token_buffer = deque()
125
+ # CUDA supports memory pinning for faster transfers between CPU and GPU:
126
+ scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
127
+ cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
128
+ it = 0 # iteration counter
129
+ while True:
130
+ # Accumulate enough tokens for one iteration before yielding
131
+ while len(token_buffer) < needed_tokens:
132
+ conversation = dataset[cursor]
133
+ ids, _ = tokenizer.render_conversation(conversation)
134
+ token_buffer.extend(ids)
135
+ cursor += ddp_world_size
136
+ if cursor >= dataset_size:
137
+ cursor -= dataset_size # wrap around for another epoch
138
+ if split == "train":
139
+ last_step = True # toggle last_step to True, which will terminate the training loop
140
+ # Stopping condition to respect num_iterations, if given
141
+ it += 1
142
+ if num_iterations > 0 and it >= num_iterations:
143
+ last_step = True # toggle last_step to True, which will terminate the training loop
144
+ # Build up inputs/targets and yield
145
+ for i in range(needed_tokens):
146
+ scratch[i] = token_buffer.popleft()
147
+ inputs_cpu = scratch[:-1].to(dtype=torch.int32)
148
+ targets_cpu = scratch[1:]
149
+ inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
150
+ targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
151
+ if split == "train":
152
+ if num_iterations > 0:
153
+ approx_progress = it / num_iterations # calculate progress from the max number of iterations
154
+ else:
155
+ approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
156
+ yield inputs, targets
157
+
158
+ train_loader = mid_data_generator("train")
159
+ build_val_loader = lambda: mid_data_generator("val")
160
+ progress = 0 # will go from 0 to 1 over the course of the epoch
161
+
162
+ # Learning rate scheduler
163
+ def get_lr_multiplier(progress):
164
+ # first 80% of training: no decay, then linearly ramp down to 0.
165
+ return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
166
+
167
+ # Momentum scheduler for Muon optimizer
168
+ def get_muon_momentum(it):
169
+ frac = min(it / 300, 1)
170
+ momentum = (1 - frac) * 0.85 + frac * 0.95
171
+ return momentum
172
+
173
+ # -----------------------------------------------------------------------------
174
+ # Training loop
175
+ x, y = next(train_loader) # prefetch the very first batch of data
176
+ min_val_bpb = float("inf")
177
+ smooth_train_loss = 0 # EMA of training loss
178
+ ema_beta = 0.9 # EMA decay factor
179
+ total_training_time = 0 # total wall-clock time of training
180
+ step = 0
181
+ while True:
182
+ flops_so_far = num_flops_per_token * total_batch_size * step
183
+
184
+ # Synchronize last_step across all ranks to avoid hangs in the distributed setting
185
+ if ddp:
186
+ last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
187
+ dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
188
+ last_step = bool(last_step_tensor.item())
189
+
190
+ # once in a while: evaluate the val bpb (all ranks participate)
191
+ if eval_every > 0 and (last_step or step % eval_every == 0):
192
+ model.eval()
193
+ val_loader = build_val_loader()
194
+ eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
195
+ with autocast_ctx:
196
+ val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
197
+ print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
198
+ if val_bpb < min_val_bpb:
199
+ min_val_bpb = val_bpb
200
+ wandb_run.log({
201
+ "step": step,
202
+ "total_training_flops": flops_so_far,
203
+ "total_training_time": total_training_time,
204
+ "val/bpb": val_bpb,
205
+ })
206
+ model.train()
207
+
208
+ # save checkpoint at the end of the run (only on master process)
209
+ if master_process and last_step and not dry_run:
210
+ output_dirname = f"d{depth}" # e.g. d12
211
+ checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
212
+ save_checkpoint(
213
+ checkpoint_dir,
214
+ step,
215
+ orig_model.state_dict(),
216
+ [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
217
+ {
218
+ "step": step,
219
+ "val_bpb": val_bpb, # loss at last step
220
+ "model_config": {
221
+ "sequence_len": max_seq_len,
222
+ "vocab_size": tokenizer.get_vocab_size(),
223
+ "n_layer": depth,
224
+ "n_head": model.config.n_head,
225
+ "n_kv_head": model.config.n_kv_head,
226
+ "n_embd": model.config.n_embd,
227
+ },
228
+ "user_config": user_config, # inputs to the training script
229
+ }
230
+ )
231
+
232
+ if last_step:
233
+ break
234
+
235
+ # -------------------------------------------------------------------------
236
+ # single training step
237
+ # evaluate the gradient
238
+ synchronize()
239
+ t0 = time.time()
240
+ for micro_step in range(grad_accum_steps):
241
+ with autocast_ctx:
242
+ loss = model(x, y)
243
+ train_loss = loss.detach() # for logging
244
+ loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
245
+ loss.backward()
246
+ x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
247
+ progress = max(progress, approx_progress) # only increase progress monotonically
248
+ # step the optimizers
249
+ lrm = get_lr_multiplier(progress)
250
+ for opt in optimizers:
251
+ for group in opt.param_groups:
252
+ group["lr"] = group["initial_lr"] * lrm
253
+ muon_momentum = get_muon_momentum(step)
254
+ for group in muon_optimizer.param_groups:
255
+ group["momentum"] = muon_momentum
256
+ for opt in optimizers:
257
+ opt.step()
258
+ model.zero_grad(set_to_none=True)
259
+ synchronize()
260
+ t1 = time.time()
261
+ dt = t1 - t0
262
+ # -------------------------------------------------------------------------
263
+
264
+ # State
265
+ step += 1
266
+
267
+ # logging
268
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
269
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
270
+ pct_done = 100 * progress
271
+ tok_per_sec = int(total_batch_size / dt)
272
+ flops_per_sec = num_flops_per_token * total_batch_size / dt
273
+ promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
274
+ mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
275
+ if step > 10:
276
+ total_training_time += dt # only count the time after the first 10 steps
277
+ print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
278
+ if step % 10 == 0:
279
+ wandb_run.log({
280
+ "step": step,
281
+ "total_training_flops": flops_so_far,
282
+ "total_training_time": total_training_time,
283
+ "train/loss": debiased_smooth_loss,
284
+ "train/lrm": lrm,
285
+ "train/dt": dt,
286
+ "train/tok_per_sec": tok_per_sec,
287
+ "train/mfu": mfu,
288
+ })
289
+
290
+ # print a few more stats
291
+ print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
292
+ print0(f"Total training time: {total_training_time/60:.2f}m")
293
+ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
294
+
295
+ # Log to report
296
+ if not dry_run:
297
+ from nanochat.report import get_report
298
+ get_report().log(section="Midtraining", data=[
299
+ user_config, # CLI args
300
+ { # stats about the training setup
301
+ "Number of iterations": step,
302
+ "DDP world size": ddp_world_size,
303
+ },
304
+ { # stats about training outcomes
305
+ "Minimum validation bpb": min_val_bpb,
306
+ }
307
+ ])
308
+
309
+ # cleanup
310
+ wandb_run.finish() # wandb run finish
311
+ compute_cleanup()
scripts/tok_eval.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate compression ratio of the tokenizer.
3
+ """
4
+
5
+ from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
6
+ from nanochat.dataset import parquets_iter_batched
7
+
8
+ # Random text I got from a random website this morning
9
+ news_text = r"""
10
+ (Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025.
11
+
12
+ While USDA announced a risk-based phased port re-opening strategy for cattle, bison, and equine from Mexico beginning as early as July 7, 2025, this newly reported NWS case raises significant concern about the previously reported information shared by Mexican officials and severely compromises the outlined port reopening schedule of five ports from July 7-September 15. Therefore, in order to protect American livestock and our nation’s food supply, Secretary Rollins has ordered the closure of livestock trade through southern ports of entry effective immediately.
13
+
14
+ “The United States has promised to be vigilant — and after detecting this new NWS case, we are pausing the planned port reopening’s to further quarantine and target this deadly pest in Mexico. We must see additional progress combatting NWS in Veracruz and other nearby Mexican states in order to reopen livestock ports along the Southern border,” said U.S. Secretary of Agriculture Brooke L. Rollins. “Thanks to the aggressive monitoring by USDA staff in the U.S. and in Mexico, we have been able to take quick and decisive action to respond to the spread of this deadly pest.”
15
+ """.strip()
16
+
17
+ # Random Korean text (to test non-English compression)
18
+ korean_text = r"""
19
+ 정직한 사실 위에, 공정한 시선을 더하다
20
+ Herald Korea Times
21
+
22
+ 헤럴드코리아타임즈는 정치, 경제, 사회, 문화 등 한국 사회 전반의 주요 이슈를 심도 있게 다루는 종합 온라인 신문사입니다.
23
+
24
+ 우리는 단순히 뉴스를 전달하는 것이 아니라, 사실(Fact)에 기반한 양측의 시각을 균형 있게 조명하며, 독자 여러분이 스스로 판단할 수 있는 ‘정보의 균형’을 제공합니다.
25
+
26
+ 한국 언론의 오랜 문제로 지적되어 온 정치적 편향, 이념적 왜곡에서 벗어나
27
+ 오직 정직함과 공정함을 원칙으로 삼는 언론을 지향합니다.
28
+ 어느 한쪽의 주장만을 확대하거나 감추지 않고,
29
+ **모든 쟁점에 대해 ‘무엇이 쟁점인지’, ‘누가 무엇을 주장하는지’, ‘사실은 무엇인지’**를 명확히 전달하는 데 집중합니다.
30
+ """.strip()
31
+
32
+ # Random piece of code
33
+ code_text = r"""
34
+ class BasicTokenizer(Tokenizer):
35
+
36
+ def __init__(self):
37
+ super().__init__()
38
+
39
+ def train(self, text, vocab_size, verbose=False):
40
+ assert vocab_size >= 256
41
+ num_merges = vocab_size - 256
42
+
43
+ # input text preprocessing
44
+ text_bytes = text.encode("utf-8") # raw bytes
45
+ ids = list(text_bytes) # list of integers in range 0..255
46
+
47
+ # iteratively merge the most common pairs to create new tokens
48
+ merges = {} # (int, int) -> int
49
+ vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
50
+ for i in range(num_merges):
51
+ # count up the number of times every consecutive pair appears
52
+ stats = get_stats(ids)
53
+ # find the pair with the highest count
54
+ pair = max(stats, key=stats.get)
55
+ # mint a new token: assign it the next available id
56
+ idx = 256 + i
57
+ # replace all occurrences of pair in ids with idx
58
+ ids = merge(ids, pair, idx)
59
+ # save the merge
60
+ merges[pair] = idx
61
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
62
+ # prints
63
+ if verbose:
64
+ print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
65
+ """.strip()
66
+
67
+ math_text = r"""
68
+ \documentclass[12pt]{article}
69
+ \usepackage{amsmath,amsthm,amssymb}
70
+ \usepackage[margin=1in]{geometry}
71
+
72
+ \newtheorem{theorem}{Theorem}
73
+ \newtheorem*{remark}{Remark}
74
+
75
+ \begin{document}
76
+
77
+ \begin{center}
78
+ {\Large A Cute Identity: The Sum of Cubes is a Square}
79
+ \end{center}
80
+
81
+ \begin{theorem}
82
+ For every integer $n \ge 1$,
83
+ \[
84
+ \sum_{k=1}^{n} k^{3} \;=\; \left(\frac{n(n+1)}{2}\right)^{2}.
85
+ \]
86
+ \end{theorem}
87
+
88
+ \begin{proof}[Proof 1 (Induction)]
89
+ Let $S(n) = \sum_{k=1}^{n} k^3$. For $n=1$, $S(1)=1=(1\cdot 2/2)^2$, so the base case holds.
90
+
91
+ Assume $S(n)=\big(\tfrac{n(n+1)}{2}\big)^2$ for some $n\ge 1$.
92
+ Then
93
+ \[
94
+ S(n+1)
95
+ = S(n) + (n+1)^3
96
+ = \left(\frac{n(n+1)}{2}\right)^2 + (n+1)^3.
97
+ \]
98
+ Factor out $(n+1)^2$:
99
+ \[
100
+ S(n+1)
101
+ = (n+1)^2\left( \frac{n^2}{4} + (n+1) \right)
102
+ = (n+1)^2\left( \frac{n^2 + 4n + 4}{4} \right)
103
+ = (n+1)^2\left( \frac{(n+2)^2}{4} \right).
104
+ \]
105
+ Thus
106
+ \[
107
+ S(n+1)=\left(\frac{(n+1)(n+2)}{2}\right)^2,
108
+ \]
109
+ which matches the claimed formula with $n$ replaced by $n+1$. By induction, the identity holds for all $n\ge 1$.
110
+ \end{proof}
111
+
112
+ \begin{proof}[Proof 2 (Algebraic telescoping)]
113
+ Recall the binomial identity
114
+ \[
115
+ (k+1)^4 - k^4 = 4k^3 + 6k^2 + 4k + 1.
116
+ \]
117
+ Summing both sides from $k=0$ to $n$ telescopes:
118
+ \[
119
+ (n+1)^4 - 0^4
120
+ = \sum_{k=0}^{n}\big(4k^3 + 6k^2 + 4k + 1\big)
121
+ = 4\sum_{k=1}^{n}k^3 + 6\sum_{k=1}^{n}k^2 + 4\sum_{k=1}^{n}k + (n+1).
122
+ \]
123
+ Using the standard sums
124
+ \[
125
+ \sum_{k=1}^{n}k = \frac{n(n+1)}{2}
126
+ \quad\text{and}\quad
127
+ \sum_{k=1}^{n}k^2 = \frac{n(n+1)(2n+1)}{6},
128
+ \]
129
+ solve for $\sum_{k=1}^{n}k^3$ to get
130
+ \[
131
+ \sum_{k=1}^{n}k^3 = \left(\frac{n(n+1)}{2}\right)^2.
132
+ \]
133
+ \end{proof}
134
+
135
+ \begin{remark}
136
+ Geometrically, the identity says: ``adding up $1^3,2^3,\dots,n^3$ builds a perfect square’’—namely the square of the $n$th triangular number. This is why one sometimes calls it the \emph{sum-of-cubes is a square} phenomenon.
137
+ \end{remark}
138
+
139
+ \end{document}
140
+ """.strip()
141
+
142
+ science_text = r"""
143
+ Photosynthesis is a photochemical energy transduction process in which light-harvesting pigment–protein complexes within the thylakoid membranes of oxygenic phototrophs absorb photons and initiate charge separation at the reaction center, driving the linear electron transport chain from water to NADP⁺ via photosystem II, the cytochrome b₆f complex, and photosystem I, concomitantly generating a trans-thylakoid proton motive force utilized by chloroplastic ATP synthase. The light-dependent reactions produce ATP and NADPH, which fuel the Calvin–Benson–Bassham cycle in the stroma, wherein ribulose-1,5-bisphosphate is carboxylated by ribulose-1,5-bisphosphate carboxylase/oxygenase (RuBisCO) to form 3-phosphoglycerate, subsequently reduced and regenerated through a series of enzymatic steps, enabling net assimilation of CO₂ into triose phosphates and ultimately carbohydrates. This process is tightly regulated by photoprotective mechanisms, redox feedback, and metabolite flux, representing a central biochemical pathway coupling solar energy capture to the biosphere’s primary productivity.
144
+ """.strip()
145
+
146
+ # The tokenizer was trained on data from earlier shards, so it has seen this data
147
+ train_docs = next(parquets_iter_batched(split="train"))
148
+ train_text = "\n".join(train_docs)
149
+ val_docs = next(parquets_iter_batched(split="val"))
150
+ val_text = "\n".join(val_docs)
151
+
152
+ all_text = [
153
+ ("news", news_text),
154
+ ("korean", korean_text),
155
+ ("code", code_text),
156
+ ("math", math_text),
157
+ ("science", science_text),
158
+ ("fwe-train", train_text),
159
+ ]
160
+ if val_text:
161
+ all_text.append(("fwe-val", val_text))
162
+
163
+ # Try out current default compared to GPT-2 and GPT-4 tokenizers
164
+ tokenizer_results = {}
165
+ vocab_sizes = {}
166
+
167
+ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
168
+
169
+ if tokenizer_name == "gpt2":
170
+ tokenizer = RustBPETokenizer.from_pretrained("gpt2") # gpt-2 base model tokenizer
171
+ elif tokenizer_name == "gpt4":
172
+ tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
173
+ else:
174
+ tokenizer = get_tokenizer()
175
+
176
+ vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size()
177
+ tokenizer_results[tokenizer_name] = {}
178
+
179
+ for name, text in all_text:
180
+ encoded = tokenizer.encode(text)
181
+ decoded = tokenizer.decode(encoded)
182
+ assert decoded == text
183
+
184
+ encoded_bytes = text.encode('utf-8')
185
+ ratio = len(encoded_bytes) / len(encoded)
186
+ tokenizer_results[tokenizer_name][name] = {
187
+ 'bytes': len(encoded_bytes),
188
+ 'tokens': len(encoded),
189
+ 'ratio': ratio
190
+ }
191
+
192
+ # ANSI color codes
193
+ GREEN = '\033[92m'
194
+ RED = '\033[91m'
195
+ RESET = '\033[0m'
196
+
197
+ # Print vocab sizes
198
+ print(f"\nVocab sizes:")
199
+ print(f"GPT-2: {vocab_sizes['gpt2']}")
200
+ print(f"GPT-4: {vocab_sizes['gpt4']}")
201
+ print(f"Ours: {vocab_sizes['ours']}")
202
+
203
+ def print_comparison(baseline_name, baseline_results, ours_results, all_text):
204
+ """Print comparison table between baseline tokenizer and ours."""
205
+ print(f"\nComparison with {baseline_name}:")
206
+ print("=" * 95)
207
+ print(f"{'Text Type':<10} {'Bytes':<8} {baseline_name:<15} {'Ours':<15} {'Relative':<12} {'Better':<10}")
208
+ print(f"{'':10} {'':8} {'Tokens':<7} {'Ratio':<7} {'Tokens':<7} {'Ratio':<7} {'Diff %':<12}")
209
+ print("-" * 95)
210
+
211
+ for name, text in all_text:
212
+ baseline_data = baseline_results[name]
213
+ ours_data = ours_results[name]
214
+
215
+ # Calculate relative difference (positive means ours is better, negative means worse)
216
+ # Using tokens: fewer tokens is better, so we calculate (baseline_tokens - ours_tokens) / baseline_tokens
217
+ relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
218
+
219
+ # Determine which has better compression (higher ratio = better)
220
+ if baseline_data['ratio'] > ours_data['ratio']:
221
+ baseline_color, ours_color = GREEN, RED
222
+ better = baseline_name
223
+ diff_color = RED
224
+ elif ours_data['ratio'] > baseline_data['ratio']:
225
+ baseline_color, ours_color = RED, GREEN
226
+ better = "Ours"
227
+ diff_color = GREEN
228
+ else:
229
+ baseline_color, ours_color = "", ""
230
+ better = "Tie"
231
+ diff_color = ""
232
+
233
+ print(f"{name:<10} {baseline_data['bytes']:<8} "
234
+ f"{baseline_color}{baseline_data['tokens']:<7}{RESET} "
235
+ f"{baseline_color}{baseline_data['ratio']:<7.2f}{RESET} "
236
+ f"{ours_color}{ours_data['tokens']:<7}{RESET} "
237
+ f"{ours_color}{ours_data['ratio']:<7.2f}{RESET} "
238
+ f"{diff_color}{relative_diff:+7.1f}%{RESET} "
239
+ f"{better:<10}")
240
+
241
+ # Print comparisons
242
+ print_comparison("GPT-2", tokenizer_results['gpt2'], tokenizer_results['ours'], all_text)
243
+ print_comparison("GPT-4", tokenizer_results['gpt4'], tokenizer_results['ours'], all_text)
244
+
245
+ # Log to report
246
+ from nanochat.report import get_report
247
+ lines = []
248
+ for baseline_name in ["GPT-2", "GPT-4"]:
249
+ baseline_key = baseline_name.lower().replace('-', '')
250
+ baseline_results = tokenizer_results[baseline_key]
251
+ ours_results = tokenizer_results['ours']
252
+ lines.append(f"### Comparison with {baseline_name}")
253
+ lines.append("")
254
+ lines.append("| Text Type | Bytes | " + baseline_name + " Tokens | " + baseline_name + " Ratio | Ours Tokens | Ours Ratio | Relative Diff % |")
255
+ lines.append("|-----------|-------|--------------|--------------|-------------|------------|-----------------|")
256
+ for name, text in all_text:
257
+ baseline_data = baseline_results[name]
258
+ ours_data = ours_results[name]
259
+ relative_diff = ((baseline_data['tokens'] - ours_data['tokens']) / baseline_data['tokens']) * 100
260
+ lines.append(f"| {name} | {baseline_data['bytes']} | {baseline_data['tokens']} | {baseline_data['ratio']:.2f} | {ours_data['tokens']} | {ours_data['ratio']:.2f} | {relative_diff:+.1f}% |")
261
+ lines.append("")
262
+ report_markdown = "\n".join(lines)
263
+ get_report().log(section="Tokenizer evaluation", data=[
264
+ report_markdown,
265
+ ])
scripts/tok_train.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train a tokenizer using the HuggingFace Tokenizers library.
3
+ In the style of GPT-4 tokenizer.
4
+ """
5
+ import os
6
+ import time
7
+ import argparse
8
+ import torch
9
+ from nanochat.tokenizer import RustBPETokenizer
10
+ from nanochat.common import get_base_dir
11
+ from nanochat.dataset import parquets_iter_batched
12
+
13
+ # -----------------------------------------------------------------------------
14
+ # Parse command line arguments
15
+
16
+ parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
17
+ parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)')
18
+ parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
19
+ parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)')
20
+ args = parser.parse_args()
21
+ print(f"max_chars: {args.max_chars:,}")
22
+ print(f"doc_cap: {args.doc_cap:,}")
23
+ print(f"vocab_size: {args.vocab_size:,}")
24
+
25
+ # -----------------------------------------------------------------------------
26
+ # Text iterator
27
+
28
+ def text_iterator():
29
+ """
30
+ 1) Flatten the batches into a single iterator
31
+ 2) Crop every document to args.doc_cap characters
32
+ 3) Break when we've seen args.max_chars characters
33
+ """
34
+ nchars = 0
35
+ for batch in parquets_iter_batched(split="train"):
36
+ for doc in batch:
37
+ doc_text = doc
38
+ if len(doc_text) > args.doc_cap:
39
+ doc_text = doc_text[:args.doc_cap]
40
+ nchars += len(doc_text)
41
+ yield doc_text
42
+ if nchars > args.max_chars:
43
+ return
44
+ text_iter = text_iterator()
45
+
46
+ # -----------------------------------------------------------------------------
47
+ # Train the tokenizer
48
+ t0 = time.time()
49
+ tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)
50
+ t1 = time.time()
51
+ train_time = t1 - t0
52
+ print(f"Training time: {train_time:.2f}s")
53
+
54
+ # -----------------------------------------------------------------------------
55
+ # Save the tokenizer to disk
56
+ base_dir = get_base_dir()
57
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
58
+ tokenizer.save(tokenizer_dir)
59
+
60
+ # -----------------------------------------------------------------------------
61
+ # Quick inline sanity check
62
+ test_text = """Hello world! This is a test.
63
+ Numbers: 123, 4567, 89
64
+ Contractions: I'm, you're, it's
65
+ Special chars: @#$%^&*()
66
+ Unicode: 你好世界 🌍"""
67
+ encoded = tokenizer.encode(test_text)
68
+ decoded = tokenizer.decode(encoded)
69
+ assert decoded == test_text
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # One more thing: we wish to cache a mapping from token id to number of bytes of that token
73
+ # for efficient evaluation of bits per byte. Unlike the typical mean loss, this
74
+ # allows us to report a loss that is invariant to the vocab size of the tokenizer.
75
+ # The bits per byte on the validation set is then one of the primary metrics we care about.
76
+ vocab_size = tokenizer.get_vocab_size()
77
+ special_set = set(tokenizer.get_special_tokens())
78
+ token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
79
+ token_bytes = []
80
+ for token_id in range(vocab_size):
81
+ token_str = token_strings[token_id] # the Python string representation of this token
82
+ if token_str in special_set:
83
+ token_bytes.append(0) # special characters are not counted
84
+ else:
85
+ id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
86
+ token_bytes.append(id_bytes)
87
+ token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
88
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
89
+ with open(token_bytes_path, "wb") as f:
90
+ torch.save(token_bytes, f)
91
+ print(f"Saved token_bytes to {token_bytes_path}")
92
+
93
+ # Log to report
94
+ from nanochat.report import get_report
95
+ token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
96
+ get_report().log(section="Tokenizer training", data=[
97
+ vars(args), # argparse command line arguments
98
+ {"train_time": train_time},
99
+ {"num_special_tokens": len(special_set)},
100
+ {
101
+ "token_bytes_min": int(token_bytes_nonzero.min().item()),
102
+ "token_bytes_max": int(token_bytes_nonzero.max().item()),
103
+ "token_bytes_mean": token_bytes_nonzero.mean().item(),
104
+ "token_bytes_std": token_bytes_nonzero.std().item(),
105
+ }
106
+ ])
speedrun.log ADDED
The diff for this file is too large to render. See raw diff
 
speedrun.sh ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is the "Best ChatGPT clone that $100 can buy",
4
+ # It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
5
+
6
+ # 1) Example launch (simplest):
7
+ # bash speedrun.sh
8
+ # 2) Example launch in a screen session (because the run takes ~4 hours):
9
+ # screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
10
+ # 3) Example launch with wandb logging, but see below for setting up wandb first:
11
+ # WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
12
+
13
+ # Default intermediate artifacts directory is in ~/.cache/nanochat
14
+ export OMP_NUM_THREADS=1
15
+ export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
16
+ mkdir -p $NANOCHAT_BASE_DIR
17
+
18
+ # -----------------------------------------------------------------------------
19
+ # Python venv setup with uv
20
+
21
+ # install uv (if not already installed)
22
+ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
23
+ # create a .venv local virtual environment (if it doesn't exist)
24
+ [ -d ".venv" ] || uv venv
25
+ # install the repo dependencies
26
+ uv sync --extra gpu
27
+ # activate venv so that `python` uses the project's venv instead of system python
28
+ source .venv/bin/activate
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # wandb setup
32
+ # If you wish to use wandb for logging (it's nice!, recommended).
33
+ # 1) Make sure to first log in to wandb, e.g. run:
34
+ # `wandb login`
35
+ # 2) Set the WANDB_RUN environment variable when running this script, e.g.:
36
+ # `WANDB_RUN=d26 bash speedrun.sh`
37
+ if [ -z "$WANDB_RUN" ]; then
38
+ # by default use "dummy" : it's handled as a special case, skips logging to wandb
39
+ WANDB_RUN=dummy
40
+ fi
41
+
42
+ # -----------------------------------------------------------------------------
43
+ # During the course of the run, we will be writing markdown reports to the report/
44
+ # directory in the base dir. This command clears it out and writes a header section
45
+ # with a bunch of system info and a timestamp that marks the start of the run.
46
+ python -m nanochat.report reset
47
+
48
+ # -----------------------------------------------------------------------------
49
+ # Tokenizer
50
+
51
+ # Install Rust / Cargo
52
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
53
+ source "$HOME/.cargo/env"
54
+
55
+ # Build the rustbpe Tokenizer
56
+ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
57
+
58
+ # Download the first ~2B characters of pretraining dataset
59
+ # look at dev/repackage_data_reference.py for details on how this data was prepared
60
+ # each data shard is ~250M chars
61
+ # so we download 2e9 / 250e6 = 8 data shards at this point
62
+ # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
63
+ python -m nanochat.dataset -n 8
64
+ # Immediately also kick off downloading more shards in the background while tokenizer trains
65
+ # See comment below for why 240 is the right number here
66
+ python -m nanochat.dataset -n 240 &
67
+ DATASET_DOWNLOAD_PID=$!
68
+ # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
69
+ python -m scripts.tok_train --max_chars=2000000000
70
+ # evaluate the tokenizer (report compression ratio etc.)
71
+ python -m scripts.tok_eval
72
+
73
+ # -----------------------------------------------------------------------------
74
+ # Base model (pretraining)
75
+
76
+ # The d20 model is 561M parameters.
77
+ # Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
78
+ # Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
79
+ # At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
80
+ # Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
81
+ # (The total number of shards available in the entire dataset is 1822.)
82
+ echo "Waiting for dataset download to complete..."
83
+ wait $DATASET_DOWNLOAD_PID
84
+
85
+ # Number of processes/GPUs to use
86
+ NPROC_PER_NODE=8
87
+
88
+ # pretrain the d20 model
89
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
90
+ # evaluate the model on a larger chunk of train/val data and draw some samples
91
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
92
+ # evaluate the model on CORE tasks
93
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
94
+
95
+ # -----------------------------------------------------------------------------
96
+ # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
97
+
98
+ # download 2.3MB of synthetic identity conversations to impart a personality to nanochat
99
+ # see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
100
+ curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
101
+
102
+ # run midtraining and eval the model
103
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
104
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
105
+
106
+ # -----------------------------------------------------------------------------
107
+ # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
108
+
109
+ # train sft and re-eval right away (should see a small bump)
110
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
111
+ torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
112
+
113
+ # chat with the model over CLI! Leave out the -p to chat interactively
114
+ # python -m scripts.chat_cli -p "Why is the sky blue?"
115
+
116
+ # even better, chat with your model over a pretty WebUI ChatGPT style
117
+ # python -m scripts.chat_web
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # Reinforcement Learning. Optional, and currently only on GSM8K
121
+ # (optional)
122
+
123
+ # run reinforcement learning
124
+ # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
125
+ # eval the RL model only on GSM8K
126
+ # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
127
+
128
+ # -----------------------------------------------------------------------------
129
+ # Generate the full report by putting together all the sections
130
+ # report.md is the output and will be copied to current directory for convenience
131
+ python -m nanochat.report generate
tasks/arc.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The ARC dataset from Allen AI.
3
+ https://huggingface.co/datasets/allenai/ai2_arc
4
+ """
5
+
6
+ from datasets import load_dataset
7
+ from tasks.common import Task, render_mc
8
+
9
+ class ARC(Task):
10
+
11
+ def __init__(self, subset, split, **kwargs):
12
+ super().__init__(**kwargs)
13
+ assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
14
+ assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
15
+ self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)
16
+
17
+ @property
18
+ def eval_type(self):
19
+ return 'categorical'
20
+
21
+ def num_examples(self):
22
+ return len(self.ds)
23
+
24
+ def get_example(self, index):
25
+ row = self.ds[index]
26
+ question = row["question"] # the question text
27
+ choices = row["choices"]["text"] # the text of each choice
28
+ answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
29
+ letters = row["choices"]["label"] # e.g. ["A", "B", "C", "D"]
30
+ assert answer_string in letters, f"ARC answer {answer_string} must be one of {letters}" # sanity check
31
+ # create and return the Conversation object
32
+ user_message = render_mc(question, letters, choices)
33
+ messages = [
34
+ {"role": "user", "content": user_message},
35
+ {"role": "assistant", "content": answer_string}
36
+ ]
37
+ conversation = {
38
+ "messages": messages,
39
+ "letters": letters, # useful during evaluation, so we can narrow and clamp the assistant prediction to one of the letters
40
+ }
41
+ return conversation
42
+
43
+ def evaluate(self, conversation, assistant_response):
44
+ # the assert here is not strictly speaking needed, but currently the way we eval, we expect this to be true
45
+ # I'm going to leave the assert here to prevent footguns, but possibly in the future can remove it.
46
+ assert assistant_response in conversation['letters'], f"ARC answer {assistant_response} is expected to be one of {conversation['letters']}"
47
+ assistant_message = conversation['messages'][-1]['content'] # e.g. "A"
48
+ return assistant_response == assistant_message
tasks/common.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for all Tasks.
3
+ A Task is basically a dataset of conversations, together with some
4
+ metadata and often also evaluation criteria.
5
+ Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
6
+ """
7
+
8
+ import random
9
+
10
+ class Task:
11
+ """
12
+ Base class of a Task. Allows for lightweight slicing of the underlying dataset.
13
+ """
14
+
15
+ def __init__(self, start=0, stop=None, step=1):
16
+ # allows a lightweight logical view over a dataset
17
+ assert start >= 0, f"Start must be non-negative, got {start}"
18
+ assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
19
+ assert step >= 1, f"Step must be strictly positive, got {step}"
20
+ self.start = start
21
+ self.stop = stop # could be None here
22
+ self.step = step
23
+
24
+ @property
25
+ def eval_type(self):
26
+ # one of 'generative' | 'categorical'
27
+ raise NotImplementedError
28
+
29
+ def num_examples(self):
30
+ raise NotImplementedError
31
+
32
+ def get_example(self, index):
33
+ raise NotImplementedError
34
+
35
+ def __len__(self):
36
+ start = self.start
37
+ stop = self.num_examples() if self.stop is None else self.stop
38
+ step = self.step
39
+ span = stop - start
40
+ num = (span + step - 1) // step # ceil_div(span, step)
41
+ assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
42
+ return num
43
+
44
+ def __getitem__(self, index: int):
45
+ assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
46
+ physical_index = self.start + index * self.step
47
+ conversation = self.get_example(physical_index)
48
+ return conversation
49
+
50
+ def evaluate(self, problem, completion):
51
+ raise NotImplementedError
52
+
53
+
54
+ class TaskMixture(Task):
55
+ """
56
+ For SFT Training it becomes useful to train on a tax mixture of datasets.
57
+ Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
58
+ """
59
+
60
+ def __init__(self, tasks, **kwargs):
61
+ super().__init__(**kwargs)
62
+ # tasks is a list of Task objects
63
+ self.tasks = tasks
64
+ self.lengths = [len(task) for task in self.tasks]
65
+ self.num_conversations = sum(self.lengths)
66
+ # Build list of all (task_idx, local_idx) pairs
67
+ self.index_map = []
68
+ for task_idx, task_length in enumerate(self.lengths):
69
+ for local_idx in range(task_length):
70
+ self.index_map.append((task_idx, local_idx))
71
+ # Deterministically shuffle to mix tasks throughout training
72
+ rng = random.Random(42)
73
+ rng.shuffle(self.index_map)
74
+ # Note: this is not the most elegant or best solution, but it's ok for now
75
+
76
+ def num_examples(self):
77
+ return self.num_conversations
78
+
79
+ def get_example(self, index):
80
+ """
81
+ Access conversations according to a deterministic shuffle of all examples.
82
+ This ensures tasks are mixed throughout training, regardless of dataset size.
83
+ """
84
+ assert 0 <= index < self.num_conversations, f"Index {index} out of range for mixture with {self.num_conversations} conversations"
85
+ task_idx, local_idx = self.index_map[index]
86
+ return self.tasks[task_idx][local_idx]
87
+
88
+
89
+ class TaskSequence(Task):
90
+ """
91
+ For SFT Training sometimes we want to sequentially train on a list of tasks.
92
+ This is useful for cases that require a training curriculum.
93
+ """
94
+
95
+ def __init__(self, tasks, **kwargs):
96
+ super().__init__(**kwargs)
97
+ self.tasks = tasks
98
+ self.lengths = [len(task) for task in self.tasks]
99
+ self.num_conversations = sum(self.lengths)
100
+
101
+ def num_examples(self):
102
+ return self.num_conversations
103
+
104
+ def get_example(self, index):
105
+ assert 0 <= index < self.num_conversations, f"Index {index} out of range for sequence with {self.num_conversations} conversations"
106
+ for task_idx, task_length in enumerate(self.lengths):
107
+ if index < task_length:
108
+ return self.tasks[task_idx][index]
109
+ index -= task_length
110
+
111
+
112
+ def render_mc(question, letters, choices):
113
+ """
114
+ The common multiple choice rendering format we will use.
115
+
116
+ Note two important design decisions:
117
+ 1)
118
+ Bigger models don't care as much, but smaller models prefer to have
119
+ the letter *after* the choice, which results in better binding.
120
+ 2)
121
+ There is no whitespace between the delimiter (=) and the letter.
122
+ This is actually critical because the tokenizer has different token ids
123
+ for " A" vs. "A". The assistant responses will be just the letter itself,
124
+ i.e. "A", so it is important that here in the prompt it is the exact same
125
+ token, i.e. "A" with no whitespace before it. Again, bigger models don't care
126
+ about this too much, but smaller models do care about some of these details.
127
+ """
128
+ query = f"Multiple Choice question: {question}\n"
129
+ query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
130
+ query += "\nRespond only with the letter of the correct answer."
131
+ return query
132
+
133
+
134
+ if __name__ == "__main__":
135
+ # very lightweight test of slicing
136
+ from tasks.mmlu import MMLU
137
+
138
+ ds = MMLU(subset="auxiliary_train", split="train")
139
+ print("Length of MMLU: ", len(ds))
140
+ ex = ds[5]
141
+ print("5th example: ", ex)
142
+
143
+ ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
144
+ print("Length of sliced MMLU[5:10]: ", len(ds))
145
+ print("0th example of sliced MMLU: ", ds[0])
146
+
147
+ print("They match: ", ex == ds[0])
tasks/customjson.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CustomJSON task for loading conversations from JSONL files.
3
+ Each line in the JSONL file should be a JSON array of messages.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ from tasks.common import Task
9
+
10
+ class CustomJSON(Task):
11
+ """
12
+ Load conversations from a JSONL file.
13
+ Each line should be a JSON array of message objects with 'role' and 'content' fields.
14
+ Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}]
15
+ """
16
+
17
+ def __init__(self, filepath, **kwargs):
18
+ super().__init__(**kwargs)
19
+ self.filepath = filepath
20
+ self.conversations = []
21
+
22
+ # Load all conversations from the JSONL file
23
+ if not os.path.exists(filepath):
24
+ # Helpful error message due to recent change. Will be removed in the future.
25
+ print("-" * 80)
26
+ print(f"Warning: File {filepath} does not exist")
27
+ print("HINT (Oct 21 2025)")
28
+ print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
29
+ print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
30
+ print("Quick fix: simply run the following command to download the file and you're done:")
31
+ print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
32
+ print("-" * 80)
33
+
34
+ else:
35
+ with open(filepath, 'r', encoding='utf-8') as f:
36
+ for line in f:
37
+ line = line.strip()
38
+ if not line: # skip empty lines
39
+ continue
40
+ messages = json.loads(line)
41
+ # Validate the conversation structure
42
+ assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
43
+ assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
44
+ # Validate message structure and alternating roles
45
+ for i, message in enumerate(messages):
46
+ assert "role" in message, f"Message {i} missing 'role' field"
47
+ assert "content" in message, f"Message {i} missing 'content' field"
48
+ expected_role = "user" if i % 2 == 0 else "assistant"
49
+ assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
50
+ assert isinstance(message["content"], str), f"Message {i} content must be a string"
51
+
52
+ self.conversations.append(messages)
53
+
54
+ self.length = len(self.conversations)
55
+
56
+ def num_examples(self):
57
+ return self.length
58
+
59
+ def get_example(self, index):
60
+ messages = self.conversations[index]
61
+ conversation = {
62
+ "messages": messages,
63
+ }
64
+ return conversation
65
+
tasks/gsm8k.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSM8K evaluation.
3
+ https://huggingface.co/datasets/openai/gsm8k
4
+
5
+ Example problem instance:
6
+
7
+ Question:
8
+ Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
9
+ Answer:
10
+ Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
11
+ Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
12
+ #### 10
13
+
14
+ Notice that GSM8K uses tool calls inside << >> tags.
15
+ """
16
+
17
+ import re
18
+ from datasets import load_dataset
19
+ from tasks.common import Task
20
+
21
+
22
+ GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
23
+ def extract_answer(completion):
24
+ """
25
+ Extract the numerical answer after #### marker.
26
+ Follows official code for normalization:
27
+ https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28
28
+ """
29
+ match = GSM_RE.search(completion)
30
+ if match:
31
+ match_str = match.group(1).strip()
32
+ match_str = match_str.replace(",", "")
33
+ return match_str
34
+ return None
35
+
36
+
37
+ class GSM8K(Task):
38
+
39
+ def __init__(self, subset, split, **kwargs):
40
+ super().__init__(**kwargs)
41
+ assert subset in ["main", "socratic"], "GSM8K subset must be main|socratic"
42
+ assert split in ["train", "test"], "GSM8K split must be train|test"
43
+ self.ds = load_dataset("openai/gsm8k", subset, split=split).shuffle(seed=42)
44
+
45
+ @property
46
+ def eval_type(self):
47
+ return 'generative'
48
+
49
+ def num_examples(self):
50
+ return len(self.ds)
51
+
52
+ def get_example(self, index):
53
+ """ Get a single problem from the dataset. """
54
+ row = self.ds[index]
55
+ question = row['question'] # string of the question prompt
56
+ answer = row['answer'] # string of the full solution and the answer after #### marker
57
+ # Create and return the Conversation object
58
+ # This is tricky because GSM8K uses tool calls, which we need to parse here.
59
+ assistant_message_parts = []
60
+ parts = re.split(r'(<<[^>]+>>)', answer)
61
+ for part in parts:
62
+ if part.startswith('<<') and part.endswith('>>'):
63
+ # This is a calculator tool call
64
+ inner = part[2:-2] # Remove << >>
65
+ # Split on = to get expression and result
66
+ if '=' in inner:
67
+ expr, result = inner.rsplit('=', 1)
68
+ else:
69
+ expr, result = inner, ""
70
+ # Add the tool call as a part
71
+ assistant_message_parts.append({"type": "python", "text": expr})
72
+ # Add the result as a part
73
+ assistant_message_parts.append({"type": "python_output", "text": result})
74
+ else:
75
+ # Regular text in between tool calls
76
+ assistant_message_parts.append({"type": "text", "text": part})
77
+ # No put it all together
78
+ messages = [
79
+ {"role": "user", "content": question}, # note: simple string
80
+ {"role": "assistant", "content": assistant_message_parts}, # note: list of parts (as dicts)
81
+ ]
82
+ conversation = {
83
+ "messages": messages,
84
+ }
85
+ return conversation
86
+
87
+ def evaluate(self, conversation, assistant_response):
88
+ """
89
+ Given (conversation, completion), return evaluation outcome (0 = wrong, 1 = correct)
90
+ Note that:
91
+ - the conversation has both user AND assistant message (containing the ground truth answer)
92
+ - the assistant_response is usually the alternative assistant message achieved via sampling
93
+
94
+ TODO: Technically, assistant_response should be a Message (either a string or a list of parts)
95
+ We can handle this later possibly. For now just assume string.
96
+ """
97
+ assert isinstance(assistant_response, str), "Assuming simple string response for now"
98
+ # First extract the ground truth answer
99
+ assistant_message = conversation['messages'][-1]
100
+ assert assistant_message['role'] == "assistant", "Last message must be from the Assistant"
101
+ assert isinstance(assistant_message['content'], list), "This is expected to be a list of parts"
102
+ last_text_part = assistant_message['content'][-1]['text'] # this contains the final answer in GSM8K
103
+ # Extract both the ground truth answer and the predicted answer
104
+ ref_num = extract_answer(last_text_part)
105
+ pred_num = extract_answer(assistant_response)
106
+ # Compare and return the success as int
107
+ is_correct = int(pred_num == ref_num)
108
+ return is_correct
109
+
110
+ def reward(self, conversation, assistant_response):
111
+ """
112
+ Used during RL. To keep things simple, just re-use the evaluation above.
113
+ Later this could be made more complex (e.g. format matching etc.)
114
+ """
115
+ is_correct = self.evaluate(conversation, assistant_response)
116
+ is_correct_float = float(is_correct)
117
+ return is_correct_float