MGow commited on
Commit
491d2ed
·
1 Parent(s): e108963

Chat frontend.

Browse files
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.pyo
4
+
5
+ # Virtual environments
6
+ .venv/
7
+ venv/
8
+ ENV/
9
+
10
+ # Local caches
11
+ .pytest_cache/
12
+ .mypy_cache/
13
+ .ruff_cache/
14
+ .ipynb_checkpoints/
15
+
16
+ # OS/editor
17
+ .DS_Store
18
+ .vscode/
19
+ .idea/
20
+
21
+ # Hugging Face / Gradio caches (avoid committing downloaded weights)
22
+ .cache/
23
+ hf_cache/
24
+
25
+ # Large model artifacts (Space downloads these from the Hub)
26
+ *.pt
27
+ *.pth
28
+ *.bin
29
+ *.safetensors
30
+ *.pkl
31
+
32
+ # Secrets
33
+ .env
README.md CHANGED
@@ -1,17 +1,32 @@
1
  ---
2
  title: PicoChat
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
  license: mit
14
  short_description: PicoChat frontend.
15
  ---
16
 
17
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: PicoChat
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  license: mit
11
  short_description: PicoChat frontend.
12
  ---
13
 
14
+ # 🤖 PicoChat
15
+
16
+ > A 335M parameter language model trained from scratch on a MacBook Air M2 in ~7 days.
17
+
18
+ This Space is the interactive frontend for **PicoChat**. The model weights are hosted separately to optimize for bandwidth and storage.
19
+
20
+ 🔗 **Model Weights:** [huggingface.co/MGow/PicoChat](https://huggingface.co/MGow/PicoChat)
21
+
22
+ ## 📊 Model Stats
23
+
24
+ | Feature | Details |
25
+ | :--- | :--- |
26
+ | **Architecture** | Depth-16 GPT-style Transformer |
27
+ | **Parameters** | ~335M |
28
+ | **Training Data** | 377M tokens (FineWeb-Edu + Synthetic) |
29
+ | **Hardware** | MacBook Air M2 (16GB RAM) |
30
+ | **Training Cost** | 5kWh of electricity |
31
+
32
+ **PicoChat** is a "lab notebook" proof-of-concept for training capable small language models (SLMs) on consumer hardware. It can chat, solve basic math problems, and has a quirky personality.
app.py CHANGED
@@ -1,70 +1,182 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
20
 
21
- messages.extend(history)
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
24
 
25
- response = ""
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
41
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
  ],
 
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
68
-
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import json
5
  import gradio as gr
6
+ from contextlib import nullcontext
7
+
8
+ # Add current directory to path so we can import nanochat
9
+ sys.path.append(os.path.dirname(__file__))
10
+
11
+ from nanochat.gpt import GPT, GPTConfig
12
+ from nanochat.tokenizer import RustBPETokenizer
13
+ from nanochat.engine import Engine
14
+
15
+ # -----------------------------------------------------------------------------
16
+ # Configuration
17
+ # -----------------------------------------------------------------------------
18
+ DEVICE = "cpu" # Hugging Face Free Tier is CPU only
19
+ HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "MGow/PicoChat")
20
+ MODEL_FILENAME = "model.pt"
21
+ META_FILENAME = "meta.json"
22
+ TOKENIZER_FILENAME = "tokenizer.pkl"
23
+
24
+ print(f"Initializing PicoChat on {DEVICE}...")
25
+
26
+ # -----------------------------------------------------------------------------
27
+ # Load Components
28
+ # -----------------------------------------------------------------------------
29
+
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ def get_file_path(filename):
33
+ """Download file from HF Hub if not local, or return local path"""
34
+ if os.path.exists(filename):
35
+ return filename
36
+ print(f"Downloading {filename} from {HF_MODEL_REPO}...")
37
+ try:
38
+ return hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename)
39
+ except Exception as e:
40
+ print(f"Error downloading {filename}: {e}")
41
+ # Fallback for testing/building if files are local
42
+ return filename
43
+
44
+ # 1. Load Metadata
45
+ meta_path = get_file_path(META_FILENAME)
46
+ print(f"Loading metadata from {meta_path}...")
47
+ with open(meta_path, "r") as f:
48
+ meta = json.load(f)
49
+ model_config = meta["model_config"]
50
+ print(f"Model config: {model_config}")
51
+
52
+ # 2. Load Tokenizer
53
+ tok_path = get_file_path(TOKENIZER_FILENAME)
54
+ print(f"Loading tokenizer from {tok_path}...")
55
+ with open(tok_path, "rb") as f:
56
+ import pickle
57
+ # The tokenizer.pkl contains the tiktoken Encoding object
58
+ enc = pickle.load(f)
59
+ # Re-construct RustBPETokenizer (wrapper around tiktoken)
60
+ # We use <|bos|> as the start token
61
+ tokenizer = RustBPETokenizer(enc, "<|bos|>")
62
+
63
+ # 3. Load Model
64
+ model_path = get_file_path(MODEL_FILENAME)
65
+ print(f"Loading model from {model_path}...")
66
+ # Initialize model with config
67
+ model = GPT(GPTConfig(**model_config))
68
+
69
+ # Load state dict
70
+ # map_location=DEVICE ensures we load directly to CPU
71
+ state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)
72
+
73
+ # Fix torch compile prefix if present (remove _orig_mod.)
74
+ state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
75
+
76
+ # Ensure float32 for CPU (bfloat16 not supported on all CPUs perfectly, and float32 is safer for inference)
77
+ state_dict = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in state_dict.items()}
78
+
79
+ # Load weights
80
+ model.load_state_dict(state_dict)
81
+ model.to(DEVICE)
82
+ model.eval()
83
+
84
+ print("Model loaded successfully!")
85
+
86
+ # 4. Create Engine
87
+ engine = Engine(model, tokenizer)
88
+
89
+ # -----------------------------------------------------------------------------
90
+ # Chat Logic
91
+ # -----------------------------------------------------------------------------
92
+
93
+ def chat_function(message, history):
94
  """
95
+ message: str, current user message
96
+ history: list of [user_msg, bot_msg] from previous turns
97
  """
 
98
 
99
+ # Prepare special tokens
100
+ bos = tokenizer.get_bos_token_id()
101
+ user_start = tokenizer.encode_special("<|user_start|>")
102
+ user_end = tokenizer.encode_special("<|user_end|>")
103
+ assistant_start = tokenizer.encode_special("<|assistant_start|>")
104
+ assistant_end = tokenizer.encode_special("<|assistant_end|>")
105
 
106
+ # Build conversation tokens
107
+ conversation_tokens = [bos]
108
 
109
+ # Add history
110
+ for user_msg, assistant_msg in history:
111
+ if user_msg:
112
+ conversation_tokens.append(user_start)
113
+ conversation_tokens.extend(tokenizer.encode(user_msg))
114
+ conversation_tokens.append(user_end)
115
+ if assistant_msg:
116
+ conversation_tokens.append(assistant_start)
117
+ conversation_tokens.extend(tokenizer.encode(assistant_msg))
118
+ conversation_tokens.append(assistant_end)
119
 
120
+ # Add current message
121
+ conversation_tokens.append(user_start)
122
+ conversation_tokens.extend(tokenizer.encode(message))
123
+ conversation_tokens.append(user_end)
124
 
125
+ # Prime assistant
126
+ conversation_tokens.append(assistant_start)
 
 
 
 
 
 
 
 
 
127
 
128
+ # Generation parameters
129
+ generate_kwargs = {
130
+ "num_samples": 1,
131
+ "max_tokens": 512,
132
+ "temperature": 0.7,
133
+ "top_k": 50,
134
+ }
135
 
136
+ response_text = ""
137
 
138
+ # Generate stream
139
+ # Engine.generate yields (token_column, token_masks)
140
+ for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
141
+ token = token_column[0]
142
+
143
+ # Stop if assistant ends
144
+ if token == assistant_end:
145
+ break
146
+
147
+ # Decode and append
148
+ text_chunk = tokenizer.decode([token])
149
+ response_text += text_chunk
150
+
151
+ # Yield partial response for streaming UI
152
+ yield response_text
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Gradio UI
156
+ # -----------------------------------------------------------------------------
157
+
158
+ custom_css = """
159
+ .gradio-container {
160
+ font-family: 'Inter', sans-serif;
161
+ }
162
  """
163
+
164
+ demo = gr.ChatInterface(
165
+ fn=chat_function,
166
+ title="PicoChat",
167
+ description="""
168
+ **PicoChat** is a 335M parameter model trained from scratch on a MacBook Air M2.
169
+ It knows how to chat, do basic math, and tell stories.
170
+ It is NOT ChatGPT (it's much smaller), but it runs purely on CPU here.
171
+ """,
172
+ examples=[
173
+ "Tell me a story about a robot named beep.",
174
+ "What is 25 * 12?",
175
+ "Explain gravity to a 5 year old.",
176
+ "Write a python function to calculate fibonacci."
 
 
177
  ],
178
+ cache_examples=False,
179
  )
180
 
 
 
 
 
 
 
181
  if __name__ == "__main__":
182
  demo.launch()
nanochat/__init__.py ADDED
File without changes
nanochat/adamw.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
3
+ Not a general optimizer! But works for our specific use.
4
+ """
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import Tensor
8
+
9
+
10
+ class DistAdamW(torch.optim.Optimizer):
11
+ """
12
+ Distributed AdamW optimizer.
13
+ In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
14
+ """
15
+ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
16
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
17
+ super().__init__(param_groups, defaults)
18
+
19
+ @torch.compile
20
+ @torch.no_grad()
21
+ def step(self):
22
+ rank = dist.get_rank()
23
+ world_size = dist.get_world_size()
24
+ reduce_scatter_futures: list[torch.Future] = []
25
+ all_reduce_futures: list[torch.Future] = []
26
+ grad_slices = []
27
+ for group in self.param_groups:
28
+ params: list[Tensor] = group["params"]
29
+ for base_i in range(len(params)):
30
+ grad = params[base_i].grad
31
+ rank_size = grad.shape[0] // world_size
32
+ grad_slice = torch.empty_like(grad[:rank_size])
33
+ reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
34
+ grad_slices.append(grad_slice)
35
+
36
+ idx = 0
37
+ for group in self.param_groups:
38
+ beta1, beta2 = group['betas']
39
+ eps = group['eps']
40
+ wd = group['weight_decay']
41
+ params = group['params']
42
+ for base in range(len(params)):
43
+ reduce_scatter_futures[idx].wait()
44
+ p = params[base]
45
+ rank_size = p.shape[0] // world_size
46
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
47
+ lr = group['lr'] * getattr(p, "lr_mul", 1.0)
48
+ state = self.state[p]
49
+ g_slice = grad_slices[idx]
50
+ # State init
51
+ if not state:
52
+ state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
53
+ state['exp_avg'] = torch.zeros_like(p_slice)
54
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
55
+ exp_avg = state['exp_avg']
56
+ exp_avg_sq = state['exp_avg_sq']
57
+ state['step'] += 1
58
+ t = state['step']
59
+ # weight decay
60
+ if wd != 0:
61
+ eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
62
+ p_slice.mul_(1 - eff_weight_decay)
63
+ # update running averages
64
+ exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
65
+ exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
66
+ # bias corrections
67
+ bias1 = 1 - beta1 ** t
68
+ bias2 = 1 - beta2 ** t
69
+ # compute step
70
+ denom = exp_avg_sq.sqrt().add_(eps)
71
+ step_size = lr * (torch.sqrt(bias2) / bias1)
72
+ update = exp_avg.div(denom).mul_(step_size)
73
+ p_slice.add_(other=update, alpha=-1.0)
74
+ idx += 1
75
+ all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
76
+ torch.futures.collect_all(all_reduce_futures).wait()
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
24
+ if rank == 0:
25
+ os.makedirs(checkpoint_dir, exist_ok=True)
26
+ # Save the model state parameters
27
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
28
+ torch.save(model_data, model_path)
29
+ logger.info(f"Saved model parameters to: {model_path}")
30
+ # Save the metadata dict as json
31
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
32
+ with open(meta_path, "w", encoding="utf-8") as f:
33
+ json.dump(meta_data, f, indent=2)
34
+ logger.info(f"Saved metadata to: {meta_path}")
35
+ # Note that optimizer state is sharded across ranks, so each rank must save its own.
36
+ if optimizer_data is not None:
37
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
38
+ torch.save(optimizer_data, optimizer_path)
39
+ logger.info(f"Saved optimizer state to: {optimizer_path}")
40
+
41
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
42
+ # Load the model state
43
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
44
+ model_data = torch.load(model_path, map_location=device)
45
+ # Load the optimizer state if requested
46
+ optimizer_data = None
47
+ if load_optimizer:
48
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
49
+ optimizer_data = torch.load(optimizer_path, map_location=device)
50
+ # Load the metadata
51
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
52
+ with open(meta_path, "r", encoding="utf-8") as f:
53
+ meta_data = json.load(f)
54
+ return model_data, optimizer_data, meta_data
55
+
56
+
57
+ def build_model(checkpoint_dir, step, device, phase):
58
+ """
59
+ A bunch of repetitive code to build a model from a given checkpoint.
60
+ Returns:
61
+ - base model - uncompiled, not wrapped in DDP
62
+ - tokenizer
63
+ - meta data saved during base model training
64
+ """
65
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
66
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
67
+ if device.type in {"cpu", "mps"}:
68
+ # Convert bfloat16 tensors to float for CPU inference
69
+ model_data = {
70
+ k: v.float() if v.dtype == torch.bfloat16 else v
71
+ for k, v in model_data.items()
72
+ }
73
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
74
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
75
+ model_config_kwargs = meta_data["model_config"]
76
+ log0(f"Building model with config: {model_config_kwargs}")
77
+ model_config = GPTConfig(**model_config_kwargs)
78
+ with torch.device("meta"):
79
+ model = GPT(model_config)
80
+ # Load the model state
81
+ model.to_empty(device=device)
82
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
83
+ model.load_state_dict(model_data, strict=True, assign=True)
84
+ # Put the model in the right training phase / mode
85
+ if phase == "eval":
86
+ model.eval()
87
+ else:
88
+ model.train()
89
+ # Load the Tokenizer
90
+ tokenizer = get_tokenizer()
91
+ # Sanity check: compatibility between model and tokenizer
92
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
93
+ return model, tokenizer, meta_data
94
+
95
+
96
+ def find_largest_model(checkpoint_dir):
97
+ # attempt to guess the model tag: take the biggest model available
98
+ model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
99
+ if not model_tags:
100
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
101
+ # 1) normally all model tags are of the form d<number>, try that first:
102
+ candidates = []
103
+ for model_tag in model_tags:
104
+ match = re.match(r"d(\d+)", model_tag)
105
+ if match:
106
+ model_depth = int(match.group(1))
107
+ candidates.append((model_depth, model_tag))
108
+ if candidates:
109
+ candidates.sort(key=lambda x: x[0], reverse=True)
110
+ return candidates[0][1]
111
+ # 2) if that failed, take the most recently updated model:
112
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
113
+ return model_tags[0]
114
+
115
+
116
+ def find_last_step(checkpoint_dir):
117
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
118
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
119
+ if not checkpoint_files:
120
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
121
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
122
+ return last_step
123
+
124
+ # -----------------------------------------------------------------------------
125
+ # convenience functions that take into account nanochat's directory structure
126
+
127
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
128
+ if model_tag is None:
129
+ # guess the model tag by defaulting to the largest model
130
+ model_tag = find_largest_model(checkpoints_dir)
131
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
132
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
133
+ if step is None:
134
+ # guess the step by defaulting to the last step
135
+ step = find_last_step(checkpoint_dir)
136
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
137
+ # build the model
138
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
139
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
140
+ return model, tokenizer, meta_data
141
+
142
+ def load_model(source, *args, **kwargs):
143
+ model_dir = {
144
+ "base": "base_checkpoints",
145
+ "mid": "mid_checkpoints",
146
+ "sft": "chatsft_checkpoints",
147
+ "rl": "chatrl_checkpoints",
148
+ }[source]
149
+ base_dir = get_base_dir()
150
+ checkpoints_dir = os.path.join(base_dir, model_dir)
151
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
nanochat/common.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import urllib.request
9
+ import torch
10
+ import torch.distributed as dist
11
+ from filelock import FileLock
12
+
13
+ class ColoredFormatter(logging.Formatter):
14
+ """Custom formatter that adds colors to log messages."""
15
+ # ANSI color codes
16
+ COLORS = {
17
+ 'DEBUG': '\033[36m', # Cyan
18
+ 'INFO': '\033[32m', # Green
19
+ 'WARNING': '\033[33m', # Yellow
20
+ 'ERROR': '\033[31m', # Red
21
+ 'CRITICAL': '\033[35m', # Magenta
22
+ }
23
+ RESET = '\033[0m'
24
+ BOLD = '\033[1m'
25
+ def format(self, record):
26
+ # Add color to the level name
27
+ levelname = record.levelname
28
+ if levelname in self.COLORS:
29
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
30
+ # Format the message
31
+ message = super().format(record)
32
+ # Add color to specific parts of the message
33
+ if levelname == 'INFO':
34
+ # Highlight numbers and percentages
35
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
36
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
37
+ return message
38
+
39
+ def setup_default_logging():
40
+ handler = logging.StreamHandler()
41
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ handlers=[handler]
45
+ )
46
+
47
+ setup_default_logging()
48
+ logger = logging.getLogger(__name__)
49
+
50
+ def get_base_dir():
51
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
52
+ if os.environ.get("NANOCHAT_BASE_DIR"):
53
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
54
+ else:
55
+ home_dir = os.path.expanduser("~")
56
+ cache_dir = os.path.join(home_dir, ".cache")
57
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
58
+ os.makedirs(nanochat_dir, exist_ok=True)
59
+ return nanochat_dir
60
+
61
+ def download_file_with_lock(url, filename, postprocess_fn=None):
62
+ """
63
+ Downloads a file from a URL to a local path in the base directory.
64
+ Uses a lock file to prevent concurrent downloads among multiple ranks.
65
+ """
66
+ base_dir = get_base_dir()
67
+ file_path = os.path.join(base_dir, filename)
68
+ lock_path = file_path + ".lock"
69
+
70
+ if os.path.exists(file_path):
71
+ return file_path
72
+
73
+ with FileLock(lock_path):
74
+ # Only a single rank can acquire this lock
75
+ # All other ranks block until it is released
76
+
77
+ # Recheck after acquiring lock
78
+ if os.path.exists(file_path):
79
+ return file_path
80
+
81
+ # Download the content as bytes
82
+ print(f"Downloading {url}...")
83
+ with urllib.request.urlopen(url) as response:
84
+ content = response.read() # bytes
85
+
86
+ # Write to local file
87
+ with open(file_path, 'wb') as f:
88
+ f.write(content)
89
+ print(f"Downloaded to {file_path}")
90
+
91
+ # Run the postprocess function if provided
92
+ if postprocess_fn is not None:
93
+ postprocess_fn(file_path)
94
+
95
+ return file_path
96
+
97
+ def print0(s="",**kwargs):
98
+ ddp_rank = int(os.environ.get('RANK', 0))
99
+ if ddp_rank == 0:
100
+ print(s, **kwargs)
101
+
102
+ def print_banner():
103
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
104
+ banner = """
105
+ █████ █████
106
+ ░░███ ░░███
107
+ ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
108
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
109
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
110
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
111
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
112
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
113
+ """
114
+ print0(banner)
115
+
116
+ def is_ddp():
117
+ # TODO is there a proper way
118
+ return int(os.environ.get('RANK', -1)) != -1
119
+
120
+ def get_dist_info():
121
+ if is_ddp():
122
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
123
+ ddp_rank = int(os.environ['RANK'])
124
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
125
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
126
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
127
+ else:
128
+ return False, 0, 0, 1
129
+
130
+ def autodetect_device_type():
131
+ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
132
+ if torch.cuda.is_available():
133
+ device_type = "cuda"
134
+ elif torch.backends.mps.is_available():
135
+ device_type = "mps"
136
+ else:
137
+ device_type = "cpu"
138
+ print0(f"Autodetected device type: {device_type}")
139
+ return device_type
140
+
141
+ def compute_init(device_type="cuda"): # cuda|cpu|mps
142
+ """Basic initialization that we keep doing over and over, so make common."""
143
+
144
+ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
145
+ if device_type == "cuda":
146
+ assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
147
+ if device_type == "mps":
148
+ assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
149
+
150
+ # Reproducibility
151
+ # Note that we set the global seeds here, but most of the code uses explicit rng objects.
152
+ # The only place where global rng might be used is nn.Module initialization of the model weights.
153
+ torch.manual_seed(42)
154
+ if device_type == "cuda":
155
+ torch.cuda.manual_seed(42)
156
+ # skipping full reproducibility for now, possibly investigate slowdown later
157
+ # torch.use_deterministic_algorithms(True)
158
+
159
+ # Precision
160
+ if device_type == "cuda":
161
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
162
+
163
+ # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
164
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
165
+ if ddp and device_type == "cuda":
166
+ device = torch.device("cuda", ddp_local_rank)
167
+ torch.cuda.set_device(device) # make "cuda" default to this device
168
+ dist.init_process_group(backend="nccl", device_id=device)
169
+ dist.barrier()
170
+ else:
171
+ device = torch.device(device_type) # mps|cpu
172
+
173
+ if ddp_rank == 0:
174
+ logger.info(f"Distributed world size: {ddp_world_size}")
175
+
176
+ return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
177
+
178
+ def compute_cleanup():
179
+ """Companion function to compute_init, to clean things up before script exit"""
180
+ if is_ddp():
181
+ dist.destroy_process_group()
182
+
183
+ class DummyWandb:
184
+ """Useful if we wish to not use wandb but have all the same signatures"""
185
+ def __init__(self):
186
+ pass
187
+ def log(self, *args, **kwargs):
188
+ pass
189
+ def finish(self):
190
+ pass
nanochat/configurator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ from ast import literal_eval
20
+
21
+ def print0(s="",**kwargs):
22
+ ddp_rank = int(os.environ.get('RANK', 0))
23
+ if ddp_rank == 0:
24
+ print(s, **kwargs)
25
+
26
+ for arg in sys.argv[1:]:
27
+ if '=' not in arg:
28
+ # assume it's the name of a config file
29
+ assert not arg.startswith('--')
30
+ config_file = arg
31
+ print0(f"Overriding config with {config_file}:")
32
+ with open(config_file) as f:
33
+ print0(f.read())
34
+ exec(open(config_file).read())
35
+ else:
36
+ # assume it's a --key=value argument
37
+ assert arg.startswith('--')
38
+ key, val = arg.split('=')
39
+ key = key[2:]
40
+ if key in globals():
41
+ try:
42
+ # attempt to eval it it (e.g. if bool, number, or etc)
43
+ attempt = literal_eval(val)
44
+ except (SyntaxError, ValueError):
45
+ # if that goes wrong, just use the string
46
+ attempt = val
47
+ # ensure the types match ok
48
+ if globals()[key] is not None:
49
+ attempt_type = type(attempt)
50
+ default_type = type(globals()[key])
51
+ assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
52
+ # cross fingers
53
+ print0(f"Overriding: {key} = {attempt}")
54
+ globals()[key] = attempt
55
+ else:
56
+ raise ValueError(f"Unknown config key: {key}")
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, num_fewshot)
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import torch
4
+ import pyarrow.parquet as pq
5
+
6
+ from nanochat.common import get_dist_info
7
+ from nanochat.dataset import list_parquet_files
8
+ from nanochat.tokenizer import get_tokenizer
9
+
10
+ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
11
+ """
12
+ Stream pretraining text from parquet files, tokenize, yield training batches.
13
+
14
+ This implementation became a bit more complex because we wish to support approximate resume training.
15
+ Instead of turning this into a Class, we opt to return the state_dict with every batch,
16
+ and then the caller can pass in a state_dict to resume training from a desired point.
17
+ Note that this resumption is atm only *approximate* for simplicity.
18
+ We won't repeat the same documents but we might skip a few.
19
+ The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
20
+
21
+ Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
22
+ """
23
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
24
+
25
+ # infinite iterator over document batches (list of text strings)
26
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
27
+ def document_batches():
28
+ parquet_paths = list_parquet_files()
29
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
30
+ resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
31
+ resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
32
+ pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
33
+ while True: # iterate infinitely (multi-epoch)
34
+ while pq_idx < len(parquet_paths): # iterate over all parquet files
35
+ filepath = parquet_paths[pq_idx]
36
+ pf = pq.ParquetFile(filepath)
37
+ # Start from resume point if resuming on same file, otherwise from DDP rank
38
+ # I know this state resumption is a little bit tricky and a little bit hacky... sigh.
39
+ if resume_rg_idx is not None:
40
+ base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
41
+ base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
42
+ rg_idx = base_idx * ddp_world_size + ddp_rank
43
+ resume_rg_idx = None # set to None as we only want to do this a single time
44
+ else:
45
+ rg_idx = ddp_rank
46
+ while rg_idx < pf.num_row_groups:
47
+ rg = pf.read_row_group(rg_idx)
48
+ batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
49
+ # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
50
+ for i in range(0, len(batch), tokenizer_batch_size):
51
+ yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
52
+ rg_idx += ddp_world_size # advance to the next row group (in DDP)
53
+ pq_idx += 1 # advance to the next parquet file
54
+ batches = document_batches()
55
+
56
+ # Now emit batches of tokens.
57
+ needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
58
+ # get the tokenizer and the bos token
59
+ tokenizer = get_tokenizer()
60
+ bos_token = tokenizer.get_bos_token_id()
61
+ # scratch buffer holds the tokens for one iteration
62
+ token_buffer = deque() # we stream tokens on the right and pop from the left
63
+ while True:
64
+ # Accumulate enough tokens for one iteration before yielding.
65
+ while len(token_buffer) < needed_tokens:
66
+ doc_batch, (pq_idx, rg_idx) = next(batches)
67
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
68
+ for tokens in token_lists:
69
+ token_buffer.extend(tokens)
70
+ # Move tokens from the deque into the scratch buffer
71
+ tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
72
+ # CUDA supports memory pinning for asynchronous transfers between CPU and GPU
73
+ use_cuda_optimizations = device == "cuda"
74
+ scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
75
+ # Create the inputs/targets as 1D tensors
76
+ inputs_cpu = scratch[:-1]
77
+ targets_cpu = scratch[1:]
78
+ # Reshape to 2D and move to GPU async
79
+ inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
80
+ targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
81
+ state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
82
+ yield inputs, targets, state_dict
83
+
84
+ def tokenizing_distributed_data_loader(*args, **kwargs):
85
+ # helper function that only emits the inputs/targets and not the state_dict
86
+ for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
87
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
24
+ MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data")
28
+ os.makedirs(DATA_DIR, exist_ok=True)
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # These functions are useful utilities to other modules, can/should be imported
32
+
33
+ def list_parquet_files(data_dir=None):
34
+ """ Looks into a data dir and returns full paths to all parquet files. """
35
+ data_dir = DATA_DIR if data_dir is None else data_dir
36
+ parquet_files = sorted([
37
+ f for f in os.listdir(data_dir)
38
+ if f.endswith('.parquet') and not f.endswith('.tmp')
39
+ ])
40
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
41
+ return parquet_paths
42
+
43
+ def parquets_iter_batched(split, start=0, step=1):
44
+ """
45
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
46
+ - split can be "train" or "val". the last parquet file will be val.
47
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
48
+ """
49
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
50
+ parquet_paths = list_parquet_files()
51
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
52
+ for filepath in parquet_paths:
53
+ pf = pq.ParquetFile(filepath)
54
+ for rg_idx in range(start, pf.num_row_groups, step):
55
+ rg = pf.read_row_group(rg_idx)
56
+ texts = rg.column('text').to_pylist()
57
+ yield texts
58
+
59
+ # -----------------------------------------------------------------------------
60
+ def download_single_file(index):
61
+ """ Downloads a single file index, with some backoff """
62
+
63
+ # Construct the local filepath for this file and skip if it already exists
64
+ filename = index_to_filename(index)
65
+ filepath = os.path.join(DATA_DIR, filename)
66
+ if os.path.exists(filepath):
67
+ print(f"Skipping {filepath} (already exists)")
68
+ return True
69
+
70
+ # Construct the remote URL for this file
71
+ url = f"{BASE_URL}/{filename}"
72
+ print(f"Downloading {filename}...")
73
+
74
+ # Download with retries
75
+ max_attempts = 5
76
+ for attempt in range(1, max_attempts + 1):
77
+ try:
78
+ response = requests.get(url, stream=True, timeout=30)
79
+ response.raise_for_status()
80
+ # Write to temporary file first
81
+ temp_path = filepath + f".tmp"
82
+ with open(temp_path, 'wb') as f:
83
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
84
+ if chunk:
85
+ f.write(chunk)
86
+ # Move temp file to final location
87
+ os.rename(temp_path, filepath)
88
+ print(f"Successfully downloaded {filename}")
89
+ return True
90
+
91
+ except (requests.RequestException, IOError) as e:
92
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
93
+ # Clean up any partial files
94
+ for path in [filepath + f".tmp", filepath]:
95
+ if os.path.exists(path):
96
+ try:
97
+ os.remove(path)
98
+ except:
99
+ pass
100
+ # Try a few times with exponential backoff: 2^attempt seconds
101
+ if attempt < max_attempts:
102
+ wait_time = 2 ** attempt
103
+ print(f"Waiting {wait_time} seconds before retry...")
104
+ time.sleep(wait_time)
105
+ else:
106
+ print(f"Failed to download {filename} after {max_attempts} attempts")
107
+ return False
108
+
109
+ return False
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
114
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
115
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
116
+ args = parser.parse_args()
117
+
118
+ num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
119
+ ids_to_download = list(range(num))
120
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
121
+ print(f"Target directory: {DATA_DIR}")
122
+ print()
123
+ with Pool(processes=args.num_workers) as pool:
124
+ results = pool.map(download_single_file, ids_to_download)
125
+
126
+ # Report results
127
+ successful = sum(1 for success in results if success)
128
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/engine.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init, autodetect_device_type
21
+ from nanochat.checkpoint_manager import load_model
22
+ from contextlib import nullcontext
23
+
24
+ # -----------------------------------------------------------------------------
25
+ # Calculator tool helpers
26
+ @contextmanager
27
+ def timeout(duration, formula):
28
+ def timeout_handler(signum, frame):
29
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
30
+
31
+ signal.signal(signal.SIGALRM, timeout_handler)
32
+ signal.alarm(duration)
33
+ yield
34
+ signal.alarm(0)
35
+
36
+ def eval_with_timeout(formula, max_time=3):
37
+ try:
38
+ with timeout(max_time, formula):
39
+ with warnings.catch_warnings():
40
+ warnings.simplefilter("ignore", SyntaxWarning)
41
+ return eval(formula, {"__builtins__": {}}, {})
42
+ except Exception as e:
43
+ signal.alarm(0)
44
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
45
+ return None
46
+
47
+ def use_calculator(expr):
48
+ """
49
+ Evaluate a Python expression safely.
50
+ Supports both math expressions and string operations like .count()
51
+ """
52
+ # Remove commas from numbers
53
+ expr = expr.replace(",", "")
54
+
55
+ # Check if it's a pure math expression (old behavior)
56
+ if all([x in "0123456789*+-/.() " for x in expr]):
57
+ if "**" in expr: # disallow power operator
58
+ return None
59
+ return eval_with_timeout(expr)
60
+
61
+ # Check if it's a string operation we support
62
+ # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
63
+ allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
64
+ if not all([x in allowed_chars for x in expr]):
65
+ return None
66
+
67
+ # Disallow dangerous patterns
68
+ dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
69
+ 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
70
+ 'getattr', 'setattr', 'delattr', 'hasattr']
71
+ expr_lower = expr.lower()
72
+ if any(pattern in expr_lower for pattern in dangerous_patterns):
73
+ return None
74
+
75
+ # Only allow .count() method for now (can expand later)
76
+ if '.count(' not in expr:
77
+ return None
78
+
79
+ # Evaluate with timeout
80
+ return eval_with_timeout(expr)
81
+
82
+ # -----------------------------------------------------------------------------
83
+ class KVCache:
84
+ """
85
+ Works hand-in-hand with the GPT model to maintain the KV cache.
86
+ Note that the .pos advances automatically after the last layer of the Transformer inserts.
87
+ """
88
+
89
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
90
+ # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
91
+ self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
92
+ self.kv_cache = None
93
+ self.pos = 0 # current position in time in the cache
94
+
95
+ def reset(self):
96
+ self.pos = 0
97
+
98
+ def get_pos(self):
99
+ return self.pos
100
+
101
+ def prefill(self, other):
102
+ """
103
+ Prefill given another KV cache. Optionally expand along batch dim.
104
+ This is used when we do batch 1 prefill and then want to generate
105
+ multiple samples in parallel from there.
106
+ """
107
+ # 1) validate the shapes
108
+ assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
109
+ assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
110
+ for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
111
+ # ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
112
+ if ix in [0, 1, 3, 5]:
113
+ # num_layers, k/v, num_heads, head_dim must match
114
+ assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
115
+ elif ix == 2:
116
+ # batch_size can be expanded
117
+ assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
118
+ elif ix == 4:
119
+ # seq_len: self must be longer than other
120
+ assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
121
+ # 2) initialize the cache
122
+ dtype, device = other.kv_cache.dtype, other.kv_cache.device
123
+ self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
124
+ # 3) copy the data over
125
+ self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
126
+ # 4) update the pos
127
+ self.pos = other.pos
128
+
129
+ def insert_kv(self, layer_idx, k, v):
130
+ # Lazy initialize the cache here because we need to know the dtype/device
131
+ if self.kv_cache is None:
132
+ self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
133
+ # Insert new keys/values to the cache and return the full cache so far
134
+ B, H, T_add, D = k.size()
135
+ t0, t1 = self.pos, self.pos + T_add
136
+ # Dynamically grow the cache if needed
137
+ if t1 > self.kv_cache.size(4):
138
+ t_needed = t1 + 1024 # as much as we need plus buffer of 1024
139
+ t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
140
+ additional_shape = list(self.kv_cache.shape)
141
+ additional_shape[4] = t_needed - self.kv_cache.size(4)
142
+ additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
143
+ self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
144
+ self.kv_shape = self.kv_cache.shape
145
+ # Insert k, v into the cache
146
+ self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
147
+ self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
148
+ # Return the full cached keys/values up to current position (as a view)
149
+ key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
150
+ value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
151
+ # Increment pos after the last layer of the Transformer processes
152
+ if layer_idx == self.kv_cache.size(0) - 1:
153
+ self.pos = t1
154
+ return key_view, value_view
155
+
156
+
157
+ # -----------------------------------------------------------------------------
158
+ @torch.inference_mode()
159
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
160
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
161
+ assert temperature >= 0.0, "temperature must be non-negative"
162
+ if temperature == 0.0:
163
+ return torch.argmax(logits, dim=-1, keepdim=True)
164
+ if top_k is not None:
165
+ k = min(top_k, logits.size(-1))
166
+ vals, idx = torch.topk(logits, k, dim=-1)
167
+ vals = vals / temperature
168
+ probs = F.softmax(vals, dim=-1)
169
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
170
+ return idx.gather(1, choice)
171
+ else:
172
+ logits = logits / temperature
173
+ probs = F.softmax(logits, dim=-1)
174
+ return torch.multinomial(probs, num_samples=1, generator=rng)
175
+
176
+ # -----------------------------------------------------------------------------
177
+
178
+ class RowState:
179
+ # Per-row state tracking during generation
180
+ def __init__(self, current_tokens=None):
181
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
182
+ self.forced_tokens = deque() # Queue of tokens to force inject
183
+ self.in_python_block = False # Whether we are inside a python block
184
+ self.python_expr_tokens = [] # Tokens of the current python expression
185
+ self.completed = False # Whether this row has completed generation
186
+
187
+ class Engine:
188
+
189
+ def __init__(self, model, tokenizer):
190
+ self.model = model
191
+ self.tokenizer = tokenizer # needed for tool use
192
+
193
+ @torch.inference_mode()
194
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
195
+ """Same as generate, but does single prefill and then clones the KV cache."""
196
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
197
+ device = self.model.get_device()
198
+ rng = torch.Generator(device=device)
199
+ rng.manual_seed(seed)
200
+
201
+ # Get the special tokens we need to coordinate the tool use state machine
202
+ get_special = lambda s: self.tokenizer.encode_special(s)
203
+ python_start = get_special("<|python_start|>")
204
+ python_end = get_special("<|python_end|>")
205
+ output_start = get_special("<|output_start|>")
206
+ output_end = get_special("<|output_end|>")
207
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
208
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
209
+
210
+ # 1) Run a batch 1 prefill of the prompt tokens
211
+ m = self.model.config
212
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
213
+ kv_cache_prefill = KVCache(
214
+ batch_size=1,
215
+ seq_len=len(tokens),
216
+ **kv_model_kwargs,
217
+ )
218
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
219
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
220
+ logits = logits[:, -1, :]
221
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
222
+ sampled_tokens = next_ids[:, 0].tolist()
223
+
224
+ # 2) Replicate the KV cache for each sample/row
225
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
226
+ kv_cache_decode = KVCache(
227
+ batch_size=num_samples,
228
+ seq_len=kv_length_hint,
229
+ **kv_model_kwargs,
230
+ )
231
+ kv_cache_decode.prefill(kv_cache_prefill)
232
+ del kv_cache_prefill # no need to keep this memory around
233
+
234
+ # 3) Initialize states for each sample
235
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
236
+
237
+ # 4) Main generation loop
238
+ num_generated = 0
239
+ first_iteration = True
240
+ while True:
241
+ # Stop condition: we've reached max tokens
242
+ if max_tokens is not None and num_generated >= max_tokens:
243
+ break
244
+ # Stop condition: all rows are completed
245
+ if all(state.completed for state in row_states):
246
+ break
247
+
248
+ # Get sampled tokens - either from prefill or from forward pass
249
+ if first_iteration:
250
+ # Use the tokens we already sampled from prefill
251
+ sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
252
+ # TODO: we should sample a token for each row instead of broadcasting
253
+ first_iteration = False
254
+ else:
255
+ # Forward the model and get the next token for each row
256
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
257
+ logits = logits[:, -1, :] # (B, vocab_size) at last time step
258
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
259
+ sampled_tokens = next_ids[:, 0].tolist()
260
+
261
+ # Process each row: choose the next token, update state, optional tool use
262
+ token_column = [] # contains the next token id along each row
263
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
264
+ for i, state in enumerate(row_states):
265
+ # Select the next token in this row
266
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
267
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
268
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
269
+ token_column.append(next_token)
270
+ # Update the state of this row to include the next token
271
+ state.current_tokens.append(next_token)
272
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
273
+ if next_token == assistant_end or next_token == bos:
274
+ state.completed = True
275
+ # Handle tool logic
276
+ if next_token == python_start:
277
+ state.in_python_block = True
278
+ state.python_expr_tokens = []
279
+ elif next_token == python_end and state.in_python_block:
280
+ state.in_python_block = False
281
+ if state.python_expr_tokens:
282
+ expr = self.tokenizer.decode(state.python_expr_tokens)
283
+ result = use_calculator(expr)
284
+ if result is not None:
285
+ result_tokens = self.tokenizer.encode(str(result))
286
+ state.forced_tokens.append(output_start)
287
+ state.forced_tokens.extend(result_tokens)
288
+ state.forced_tokens.append(output_end)
289
+ state.python_expr_tokens = []
290
+ elif state.in_python_block:
291
+ state.python_expr_tokens.append(next_token)
292
+
293
+ # Yield the token column
294
+ yield token_column, token_masks
295
+ num_generated += 1
296
+ # Prepare ids for next iteration
297
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
298
+
299
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
300
+ """
301
+ Non-streaming batch generation that just returns the final token sequences.
302
+ Returns a list of token sequences (list of lists of ints).
303
+ Terminal tokens (assistant_end, bos) are not included in the results.
304
+ """
305
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
306
+ bos = self.tokenizer.get_bos_token_id()
307
+ results = [tokens.copy() for _ in range(num_samples)]
308
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
309
+ completed = [False] * num_samples
310
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
311
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
312
+ if not completed[i]:
313
+ if token == assistant_end or token == bos:
314
+ completed[i] = True
315
+ else:
316
+ results[i].append(token)
317
+ masks[i].append(mask)
318
+ # Stop if all rows are completed
319
+ if all(completed):
320
+ break
321
+ return results, masks
322
+
323
+
324
+ if __name__ == "__main__":
325
+ """
326
+ Quick inline test to make sure that the naive/slow model.generate function
327
+ is equivalent to the faster Engine.generate function here.
328
+ """
329
+ import time
330
+ # init compute
331
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
332
+ device_type = autodetect_device_type()
333
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
334
+
335
+ # load the model and tokenizer
336
+ model, tokenizer, meta = load_model("base", device, phase="eval")
337
+ bos_token_id = tokenizer.get_bos_token_id()
338
+ # common hyperparameters
339
+ kwargs = dict(max_tokens=64, temperature=0.0)
340
+ # set the starting prompt
341
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
342
+ # generate the reference sequence using the model.generate() function
343
+ generated_tokens = []
344
+ torch.cuda.synchronize()
345
+ t0 = time.time()
346
+ stream = model.generate(prompt_tokens, **kwargs)
347
+ with autocast_ctx:
348
+ for token in stream:
349
+ generated_tokens.append(token)
350
+ chunk = tokenizer.decode([token])
351
+ print(chunk, end="", flush=True)
352
+ print()
353
+ torch.cuda.synchronize()
354
+ t1 = time.time()
355
+ print(f"Reference time: {t1 - t0:.2f}s")
356
+ reference_ids = generated_tokens
357
+ # generate tokens with Engine
358
+ generated_tokens = []
359
+ engine = Engine(model, tokenizer)
360
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
361
+ torch.cuda.synchronize()
362
+ t0 = time.time()
363
+ with autocast_ctx:
364
+ for token_column, token_masks in stream:
365
+ token = token_column[0] # only print out the first row
366
+ generated_tokens.append(token)
367
+ chunk = tokenizer.decode([token])
368
+ print(chunk, end="", flush=True)
369
+ print()
370
+ torch.cuda.synchronize()
371
+ t1 = time.time()
372
+ print(f"Engine time: {t1 - t0:.2f}s")
373
+ # compare the two sequences
374
+ for i in range(len(reference_ids)):
375
+ if reference_ids[i] != generated_tokens[i]:
376
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
377
+ break
378
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ finally:
131
+ os.chdir(cwd)
132
+
133
+
134
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
135
+ """
136
+ This disables various destructive functions and prevents the generated code
137
+ from interfering with the test (e.g. fork bomb, killing other processes,
138
+ removing filesystem files, etc.)
139
+
140
+ WARNING
141
+ This function is NOT a security sandbox. Untrusted code, including, model-
142
+ generated code, should not be blindly executed outside of one. See the
143
+ Codex paper for more information about OpenAI's code sandbox, and proceed
144
+ with caution.
145
+ """
146
+
147
+ if platform.uname().system != "Darwin":
148
+ # These resource limit calls seem to fail on macOS (Darwin), skip?
149
+ import resource
150
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
151
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
152
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
153
+
154
+ faulthandler.disable()
155
+
156
+ import builtins
157
+
158
+ builtins.exit = None
159
+ builtins.quit = None
160
+
161
+ import os
162
+
163
+ os.environ["OMP_NUM_THREADS"] = "1"
164
+
165
+ os.kill = None
166
+ os.system = None
167
+ os.putenv = None
168
+ os.remove = None
169
+ os.removedirs = None
170
+ os.rmdir = None
171
+ os.fchdir = None
172
+ os.setuid = None
173
+ os.fork = None
174
+ os.forkpty = None
175
+ os.killpg = None
176
+ os.rename = None
177
+ os.renames = None
178
+ os.truncate = None
179
+ os.replace = None
180
+ os.unlink = None
181
+ os.fchmod = None
182
+ os.fchown = None
183
+ os.chmod = None
184
+ os.chown = None
185
+ os.chroot = None
186
+ os.fchdir = None
187
+ os.lchflags = None
188
+ os.lchmod = None
189
+ os.lchown = None
190
+ os.getcwd = None
191
+ os.chdir = None
192
+
193
+ import shutil
194
+
195
+ shutil.rmtree = None
196
+ shutil.move = None
197
+ shutil.chown = None
198
+
199
+ import subprocess
200
+
201
+ subprocess.Popen = None # type: ignore
202
+
203
+ __builtins__["help"] = None
204
+
205
+ import sys
206
+
207
+ sys.modules["ipdb"] = None
208
+ sys.modules["joblib"] = None
209
+ sys.modules["resource"] = None
210
+ sys.modules["psutil"] = None
211
+ sys.modules["tkinter"] = None
212
+
213
+
214
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
215
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
216
+ with create_tempdir():
217
+
218
+ # These system calls are needed when cleaning up tempdir.
219
+ import os
220
+ import shutil
221
+
222
+ rmtree = shutil.rmtree
223
+ rmdir = os.rmdir
224
+ chdir = os.chdir
225
+ unlink = os.unlink
226
+
227
+ # Disable functionalities that can make destructive changes to the test.
228
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
229
+
230
+ # Default to failure
231
+ result_dict.update({
232
+ "success": False,
233
+ "stdout": "",
234
+ "stderr": "",
235
+ "timeout": False,
236
+ "memory_exceeded": False,
237
+ "error": None,
238
+ })
239
+
240
+ try:
241
+ exec_globals = {}
242
+ with capture_io() as (stdout_capture, stderr_capture):
243
+ with time_limit(timeout):
244
+ # WARNING
245
+ # This program exists to execute untrusted model-generated code. Although
246
+ # it is highly unlikely that model-generated code will do something overtly
247
+ # malicious in response to this test suite, model-generated code may act
248
+ # destructively due to a lack of model capability or alignment.
249
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
250
+ # does not perform destructive actions on their host or network. For more
251
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
252
+ # Once you have read this disclaimer and taken appropriate precautions,
253
+ # uncomment the following line and proceed at your own risk:
254
+ exec(code, exec_globals)
255
+
256
+ result_dict.update({
257
+ "success": True,
258
+ "stdout": stdout_capture.getvalue(),
259
+ "stderr": stderr_capture.getvalue(),
260
+ })
261
+
262
+ except TimeoutException:
263
+ result_dict.update({
264
+ "timeout": True,
265
+ "error": "Execution timed out",
266
+ })
267
+
268
+ except MemoryError as e:
269
+ result_dict.update({
270
+ "memory_exceeded": True,
271
+ "error": f"Memory limit exceeded: {e}",
272
+ })
273
+
274
+ except BaseException as e:
275
+ result_dict.update({
276
+ "error": f"{type(e).__name__}: {e}",
277
+ })
278
+
279
+ # Needed for cleaning up.
280
+ shutil.rmtree = rmtree
281
+ os.rmdir = rmdir
282
+ os.chdir = chdir
283
+ os.unlink = unlink
284
+
285
+
286
+ def execute_code(
287
+ code: str,
288
+ timeout: float = 5.0, # 5 seconds default
289
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
290
+ ) -> ExecutionResult:
291
+ """
292
+ Execute Python code in a sandboxed environment.
293
+
294
+ Args:
295
+ code: Python code to execute as a string
296
+ timeout: Maximum execution time in seconds (default: 5.0)
297
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
298
+
299
+ Returns:
300
+ ExecutionResult with success status, stdout/stderr, and error information
301
+
302
+ Example:
303
+ >>> result = execute_code("print('hello world')")
304
+ >>> result.success
305
+ True
306
+ >>> result.stdout
307
+ 'hello world\\n'
308
+ """
309
+
310
+ manager = multiprocessing.Manager()
311
+ result_dict = manager.dict()
312
+
313
+ p = multiprocessing.Process(
314
+ target=_unsafe_execute,
315
+ args=(code, timeout, maximum_memory_bytes, result_dict)
316
+ )
317
+ p.start()
318
+ p.join(timeout=timeout + 1)
319
+
320
+ if p.is_alive():
321
+ p.kill()
322
+ return ExecutionResult(
323
+ success=False,
324
+ stdout="",
325
+ stderr="",
326
+ error="Execution timed out (process killed)",
327
+ timeout=True,
328
+ memory_exceeded=False,
329
+ )
330
+
331
+ if not result_dict:
332
+ return ExecutionResult(
333
+ success=False,
334
+ stdout="",
335
+ stderr="",
336
+ error="Execution failed (no result returned)",
337
+ timeout=True,
338
+ memory_exceeded=False,
339
+ )
340
+
341
+ return ExecutionResult(
342
+ success=result_dict["success"],
343
+ stdout=result_dict["stdout"],
344
+ stderr=result_dict["stderr"],
345
+ error=result_dict["error"],
346
+ timeout=result_dict["timeout"],
347
+ memory_exceeded=result_dict["memory_exceeded"],
348
+ )
349
+
nanochat/gpt.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Group-Query Attention (GQA) support for more efficient inference
12
+ """
13
+
14
+ import math
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0
23
+ from nanochat.muon import Muon, DistMuon
24
+ from nanochat.adamw import DistAdamW
25
+
26
+ @dataclass
27
+ class GPTConfig:
28
+ sequence_len: int = 1024
29
+ vocab_size: int = 50304
30
+ n_layer: int = 12
31
+ n_head: int = 6 # number of query heads
32
+ n_kv_head: int = 6 # number of key/value heads (GQA)
33
+ n_embd: int = 768
34
+
35
+
36
+ def norm(x):
37
+ # Purely functional rmsnorm with no learnable params
38
+ return F.rms_norm(x, (x.size(-1),))
39
+
40
+
41
+ def apply_rotary_emb(x, cos, sin):
42
+ assert x.ndim == 4 # multihead attention
43
+ d = x.shape[3] // 2
44
+ x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
45
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
46
+ y2 = x1 * (-sin) + x2 * cos
47
+ out = torch.cat([y1, y2], 3) # re-assemble
48
+ out = out.to(x.dtype) # ensure input/output dtypes match
49
+ return out
50
+
51
+ class CausalSelfAttention(nn.Module):
52
+ def __init__(self, config, layer_idx):
53
+ super().__init__()
54
+ self.layer_idx = layer_idx
55
+ self.n_head = config.n_head
56
+ self.n_kv_head = config.n_kv_head
57
+ self.n_embd = config.n_embd
58
+ self.head_dim = self.n_embd // self.n_head
59
+ assert self.n_embd % self.n_head == 0
60
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
61
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
62
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
63
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
64
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
65
+
66
+ def forward(self, x, cos_sin, kv_cache):
67
+ B, T, C = x.size()
68
+
69
+ # Project the input to get queries, keys, and values
70
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
71
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
72
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
73
+
74
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
75
+ cos, sin = cos_sin
76
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
77
+ q, k = norm(q), norm(k) # QK norm
78
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
79
+
80
+ # Apply KV cache: insert current k,v into cache, get the full view so far
81
+ if kv_cache is not None:
82
+ k, v = kv_cache.insert_kv(self.layer_idx, k, v)
83
+ Tq = q.size(2) # number of queries in this forward pass
84
+ Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
85
+
86
+ # Attention: queries attend to keys/values autoregressively. A few cases to handle:
87
+ enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
88
+ if kv_cache is None or Tq == Tk:
89
+ # During training (no KV cache), attend as usual with causal attention
90
+ # And even if there is KV cache, we can still use this simple version when Tq == Tk
91
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
92
+ elif Tq == 1:
93
+ # During inference but with a single query in this forward pass:
94
+ # The query has to attend to all the keys/values in the cache
95
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
96
+ else:
97
+ # During inference AND we have a chunk of queries in this forward pass:
98
+ # First, each query attends to all the cached keys/values (i.e. full prefix)
99
+ attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
100
+ prefix_len = Tk - Tq
101
+ if prefix_len > 0: # can't be negative but could be zero
102
+ attn_mask[:, :prefix_len] = True
103
+ # Then, causal attention within this chunk
104
+ attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
105
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
106
+
107
+ # Re-assemble the heads side by side and project back to residual stream
108
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
109
+ y = self.c_proj(y)
110
+ return y
111
+
112
+
113
+ class MLP(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
117
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
118
+
119
+ def forward(self, x):
120
+ x = self.c_fc(x)
121
+ x = F.relu(x).square()
122
+ x = self.c_proj(x)
123
+ return x
124
+
125
+
126
+ class Block(nn.Module):
127
+ def __init__(self, config, layer_idx):
128
+ super().__init__()
129
+ self.attn = CausalSelfAttention(config, layer_idx)
130
+ self.mlp = MLP(config)
131
+
132
+ def forward(self, x, cos_sin, kv_cache):
133
+ x = x + self.attn(norm(x), cos_sin, kv_cache)
134
+ x = x + self.mlp(norm(x))
135
+ return x
136
+
137
+
138
+ class GPT(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.config = config
142
+ self.transformer = nn.ModuleDict({
143
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
144
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
145
+ })
146
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
147
+ # To support meta device initialization, we init the rotary embeddings here, but it's fake
148
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
149
+ # so let's just over-compute them, but assert fail if we ever reach that amount.
150
+ # In the future we can dynamically grow the cache, for now it's fine.
151
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
152
+ head_dim = config.n_embd // config.n_head
153
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
154
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
155
+ self.register_buffer("sin", sin, persistent=False)
156
+
157
+ def init_weights(self):
158
+ self.apply(self._init_weights)
159
+ # zero out classifier weights
160
+ torch.nn.init.zeros_(self.lm_head.weight)
161
+ # zero out c_proj weights in all blocks
162
+ for block in self.transformer.h:
163
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
164
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
165
+ # init the rotary embeddings
166
+ head_dim = self.config.n_embd // self.config.n_head
167
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
168
+ self.cos, self.sin = cos, sin
169
+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
170
+ if self.transformer.wte.weight.device.type == "cuda":
171
+ self.transformer.wte.to(dtype=torch.bfloat16)
172
+
173
+ def _init_weights(self, module):
174
+ if isinstance(module, nn.Linear):
175
+ # https://arxiv.org/pdf/2310.17813
176
+ fan_out = module.weight.size(0)
177
+ fan_in = module.weight.size(1)
178
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
180
+ if module.bias is not None:
181
+ torch.nn.init.zeros_(module.bias)
182
+ elif isinstance(module, nn.Embedding):
183
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
184
+
185
+ # TODO: bump base theta more, e.g. 100K is more common more recently
186
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
187
+ # autodetect the device from model embeddings
188
+ if device is None:
189
+ device = self.transformer.wte.weight.device
190
+ # stride the channels
191
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
192
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
193
+ # stride the time steps
194
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
195
+ # calculate the rotation frequencies at each (time, channel) pair
196
+ freqs = torch.outer(t, inv_freq)
197
+ cos, sin = freqs.cos(), freqs.sin()
198
+ # Use bfloat16 on CUDA for memory savings, float32 on MPS/CPU for stability
199
+ if device.type == "cuda":
200
+ cos, sin = cos.bfloat16(), sin.bfloat16()
201
+ else:
202
+ cos, sin = cos.float(), sin.float()
203
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
204
+ return cos, sin
205
+
206
+ def get_device(self):
207
+ return self.transformer.wte.weight.device
208
+
209
+ def estimate_flops(self):
210
+ """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
211
+ nparams = sum(p.numel() for p in self.parameters())
212
+ nparams_embedding = self.transformer.wte.weight.numel()
213
+ l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
214
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
215
+ return num_flops_per_token
216
+
217
+ def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
218
+ model_dim = self.config.n_embd
219
+ ddp, rank, local_rank, world_size = get_dist_info()
220
+ # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
221
+ matrix_params = list(self.transformer.h.parameters())
222
+ embedding_params = list(self.transformer.wte.parameters())
223
+ lm_head_params = list(self.lm_head.parameters())
224
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
225
+ # Create the AdamW optimizer for the embedding and lm_head
226
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
227
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
228
+ if rank == 0:
229
+ print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
230
+ adam_groups = [
231
+ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
232
+ dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
233
+ ]
234
+ adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
235
+ AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
236
+ adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
237
+ # Create the Muon optimizer for the linear layers
238
+ muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
239
+ MuonFactory = DistMuon if ddp else Muon
240
+ muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
241
+ # Combine them the two optimizers into one list
242
+ optimizers = [adamw_optimizer, muon_optimizer]
243
+ for opt in optimizers:
244
+ for group in opt.param_groups:
245
+ group["initial_lr"] = group["lr"]
246
+ return optimizers
247
+
248
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
249
+ B, T = idx.size()
250
+
251
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
252
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
253
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
254
+ # Verify cos/sin dtype is correct for device (bfloat16 on CUDA, float32 on MPS/CPU)
255
+ expected_dtype = torch.bfloat16 if idx.device.type == "cuda" else torch.float32
256
+ assert self.cos.dtype == expected_dtype, f"Rotary embeddings dtype mismatch: expected {expected_dtype}, got {self.cos.dtype}"
257
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
258
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
259
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
260
+
261
+ # Forward the trunk of the Transformer
262
+ x = self.transformer.wte(idx)
263
+ x = norm(x)
264
+ for block in self.transformer.h:
265
+ x = block(x, cos_sin, kv_cache)
266
+ x = norm(x)
267
+
268
+ # Forward the lm_head (compute logits)
269
+ softcap = 15
270
+ if targets is not None:
271
+ # training mode: compute and return the loss
272
+ # TODO: experiment with Liger Kernels / chunked cross-entropy etc.
273
+ logits = self.lm_head(x)
274
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
275
+ logits = logits.float() # use tf32/fp32 for logits
276
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
277
+ return loss
278
+ else:
279
+ # inference mode: compute and return the logits
280
+ logits = self.lm_head(x)
281
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
282
+ return logits
283
+
284
+ @torch.inference_mode()
285
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
286
+ """
287
+ Naive autoregressive streaming inference.
288
+ To make it super simple, let's assume:
289
+ - batch size is 1
290
+ - ids and the yielded tokens are simple Python lists and ints
291
+ """
292
+ assert isinstance(tokens, list)
293
+ device = self.get_device()
294
+ rng = None
295
+ if temperature > 0:
296
+ rng = torch.Generator(device=device)
297
+ rng.manual_seed(seed)
298
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
299
+ for _ in range(max_tokens):
300
+ logits = self.forward(ids) # (B, T, vocab_size)
301
+ logits = logits[:, -1, :] # (B, vocab_size)
302
+ if top_k is not None:
303
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
304
+ logits[logits < v[:, [-1]]] = -float('Inf')
305
+ if temperature > 0:
306
+ logits = logits / temperature
307
+ probs = F.softmax(logits, dim=-1)
308
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
309
+ else:
310
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
311
+ ids = torch.cat((ids, next_ids), dim=1)
312
+ token = next_ids.item()
313
+ yield token
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-independent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and independently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ if total_bytes == 0:
63
+ return float('inf')
64
+ bpb = total_nats / (math.log(2) * total_bytes)
65
+ return bpb
nanochat/muon.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon optimizer from Keller et al.
3
+ Also a lot of borrowing of ideas from modded-nanogpt.
4
+ """
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.distributed as dist
8
+
9
+ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
10
+ """
11
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
12
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
13
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
14
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
15
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
16
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
17
+ performance at all relative to UV^T, where USV^T = G is the SVD.
18
+ """
19
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
20
+ a, b, c = (3.4445, -4.7750, 2.0315)
21
+ # Use bfloat16 for CUDA to save memory/compute, but float32 for MPS/CPU for stability
22
+ if G.device.type == "cuda":
23
+ X = G.bfloat16()
24
+ else:
25
+ X = G.float() # MPS/CPU: use float32 for Newton-Schulz to avoid NaNs
26
+ if G.size(-2) > G.size(-1):
27
+ X = X.mT
28
+
29
+ # Ensure spectral norm is at most 1
30
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
31
+ # Perform the NS iterations
32
+ for _ in range(steps):
33
+ A = X @ X.mT
34
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
35
+ X = a * X + B @ X
36
+
37
+ if G.size(-2) > G.size(-1):
38
+ X = X.mT
39
+ return X
40
+
41
+ class Muon(torch.optim.Optimizer):
42
+ """
43
+ Muon - MomentUm Orthogonalized by Newton-schulz
44
+
45
+ https://kellerjordan.github.io/posts/muon/
46
+
47
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
48
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
49
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
50
+ the advantage that it can be stably run in bfloat16 on the GPU.
51
+
52
+ Some warnings:
53
+ - This optimizer should not be used for the embedding layer, the final fully connected layer,
54
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
55
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
56
+
57
+ Arguments:
58
+ lr: The learning rate used by the internal SGD.
59
+ momentum: The momentum used by the internal SGD.
60
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
61
+ ns_steps: The number of Newton-Schulz iteration steps to use.
62
+ """
63
+ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
64
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
65
+ params: list[Tensor] = [*params]
66
+ param_groups = []
67
+ for size in {p.numel() for p in params}:
68
+ group = dict(params=[p for p in params if p.numel() == size])
69
+ param_groups.append(group)
70
+ super().__init__(param_groups, defaults)
71
+
72
+ @torch.no_grad()
73
+ def step(self):
74
+ for group in self.param_groups:
75
+ params: list[Tensor] = group["params"]
76
+ for p in params:
77
+ g = p.grad
78
+ assert g is not None
79
+ state = self.state[p]
80
+ if "momentum_buffer" not in state:
81
+ state["momentum_buffer"] = torch.zeros_like(g)
82
+ buf: Tensor = state["momentum_buffer"]
83
+ buf.lerp_(g, 1 - group["momentum"])
84
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
85
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
86
+ p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
87
+
88
+
89
+ class DistMuon(torch.optim.Optimizer):
90
+ """
91
+ Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
92
+ finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
93
+ - reduce_scatter(AVG) for gradient averaging
94
+ - all_gather to replicate updated weights
95
+
96
+ Notes:
97
+ * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
98
+ params like embeddings or scalars.
99
+ * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
100
+ by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
101
+ consolidate states beforehand.
102
+
103
+ Args:
104
+ params: iterable of Tensors
105
+ lr: learning rate
106
+ momentum: momentum coefficient in [0,1)
107
+ nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
108
+ ns_steps: number of Newton–Schulz iterations for the orthogonalization
109
+ """
110
+ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
111
+ nesterov: bool = True, ns_steps: int = 5):
112
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
113
+ params = list(params)
114
+ assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
115
+ rank = dist.get_rank()
116
+ # Group all parameters by their shape
117
+ shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
118
+ param_groups = []
119
+ for shape in shapes:
120
+ group_params = [p for p in params if p.shape == shape]
121
+ device, dtype = group_params[0].device, group_params[0].dtype
122
+ assert all(p.device == device for p in group_params)
123
+ assert all(p.dtype == dtype for p in group_params)
124
+ if rank == 0:
125
+ print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
126
+ param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
127
+ super().__init__(param_groups, defaults)
128
+
129
+ @torch.no_grad()
130
+ def step(self):
131
+ rank = dist.get_rank()
132
+ world_size = dist.get_world_size()
133
+
134
+ # Ensure all grads exist
135
+ assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
136
+
137
+ # Kick off all the reduce scatter operations to average up the gradients across all ranks
138
+ all_reduce_futures = []
139
+ for group in self.param_groups:
140
+ params = group["params"]
141
+ zero_buffer = group["zero_buffer"]
142
+ # Go through params in groups of world_size.
143
+ for base_i in range(0, len(params), world_size):
144
+ # The compute owner of each param is rank i % world_size
145
+ owner_idx = base_i + rank
146
+ # each rank stacks up its chunk of world_size params into a list
147
+ rs_input = [p.grad for p in params[base_i:base_i + world_size]]
148
+ # pad rs_input with the zero buffer to complete the group
149
+ rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
150
+ # the output buffer gets strided across the group based on the rank
151
+ rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
152
+ # reduce scatter the gradients within this group of world_size params
153
+ work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
154
+ all_reduce_futures.append(work)
155
+
156
+ # Now each rank computes the update and gathers
157
+ future_idx = 0
158
+ all_gather_futures = []
159
+ for group in self.param_groups:
160
+ params = group["params"]
161
+ zero_buffer = group["zero_buffer"]
162
+ # Go through params in groups of world_size.
163
+ for base_i in range(0, len(params), world_size):
164
+ # The compute owner of each param is rank i % world_size
165
+ owner_idx = base_i + rank # calculate the index of the param that this rank owns
166
+ # Wait for the reduce scatter to complete
167
+ all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
168
+ future_idx += 1
169
+ # Owner computes the Muon update, result is in its param
170
+ if owner_idx < len(params):
171
+ p = params[owner_idx]
172
+ g = p.grad # now averaged across ranks
173
+ state = self.state[p]
174
+ if "momentum_buffer" not in state:
175
+ state["momentum_buffer"] = torch.zeros_like(g)
176
+ buf: Tensor = state["momentum_buffer"]
177
+ buf.lerp_(g, 1.0 - group["momentum"])
178
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
179
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
180
+ scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
181
+ p.add_(g, alpha=-group["lr"] * scale)
182
+ # Replicate updated parameters to all ranks
183
+ ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
184
+ ag_output = params[base_i:base_i + world_size]
185
+ ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
186
+ work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
187
+ all_gather_futures.append(work)
188
+
189
+ # Wait for all work to finish
190
+ torch.futures.collect_all(all_gather_futures).wait()
nanochat/report.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ if result.returncode == 0:
20
+ return result.stdout.strip()
21
+ return None
22
+ except:
23
+ return None
24
+
25
+ def get_git_info():
26
+ """Get current git commit, branch, and dirty status."""
27
+ info = {}
28
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
29
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
30
+
31
+ # Check if repo is dirty (has uncommitted changes)
32
+ status = run_command("git status --porcelain")
33
+ info['dirty'] = bool(status) if status is not None else False
34
+
35
+ # Get commit message
36
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
37
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
38
+
39
+ return info
40
+
41
+ def get_gpu_info():
42
+ """Get GPU information."""
43
+ if not torch.cuda.is_available():
44
+ return {"available": False}
45
+
46
+ num_devices = torch.cuda.device_count()
47
+ info = {
48
+ "available": True,
49
+ "count": num_devices,
50
+ "names": [],
51
+ "memory_gb": []
52
+ }
53
+
54
+ for i in range(num_devices):
55
+ props = torch.cuda.get_device_properties(i)
56
+ info["names"].append(props.name)
57
+ info["memory_gb"].append(props.total_memory / (1024**3))
58
+
59
+ # Get CUDA version
60
+ info["cuda_version"] = torch.version.cuda or "unknown"
61
+
62
+ return info
63
+
64
+ def get_system_info():
65
+ """Get system information."""
66
+ info = {}
67
+
68
+ # Basic system info
69
+ info['hostname'] = socket.gethostname()
70
+ info['platform'] = platform.system()
71
+ info['python_version'] = platform.python_version()
72
+ info['torch_version'] = torch.__version__
73
+
74
+ # CPU and memory
75
+ info['cpu_count'] = psutil.cpu_count(logical=False)
76
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
77
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
78
+
79
+ # User and environment
80
+ info['user'] = os.environ.get('USER', 'unknown')
81
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
82
+ info['working_dir'] = os.getcwd()
83
+
84
+ return info
85
+
86
+ def estimate_cost(gpu_info, runtime_hours=None):
87
+ """Estimate training cost based on GPU type and runtime."""
88
+
89
+ # Rough pricing, from Lambda Cloud
90
+ default_rate = 2.0
91
+ gpu_hourly_rates = {
92
+ "H100": 3.00,
93
+ "A100": 1.79,
94
+ "V100": 0.55,
95
+ }
96
+
97
+ if not gpu_info.get("available"):
98
+ return None
99
+
100
+ # Try to identify GPU type from name
101
+ hourly_rate = None
102
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
103
+ for gpu_type, rate in gpu_hourly_rates.items():
104
+ if gpu_type in gpu_name:
105
+ hourly_rate = rate * gpu_info["count"]
106
+ break
107
+
108
+ if hourly_rate is None:
109
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
110
+
111
+ return {
112
+ "hourly_rate": hourly_rate,
113
+ "gpu_type": gpu_name,
114
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
115
+ }
116
+
117
+ def generate_header():
118
+ """Generate the header for a training report."""
119
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
+
121
+ git_info = get_git_info()
122
+ gpu_info = get_gpu_info()
123
+ sys_info = get_system_info()
124
+ cost_info = estimate_cost(gpu_info)
125
+
126
+ header = f"""# nanochat training report
127
+
128
+ Generated: {timestamp}
129
+
130
+ ## Environment
131
+
132
+ ### Git Information
133
+ - Branch: {git_info['branch']}
134
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
135
+ - Message: {git_info['message']}
136
+
137
+ ### Hardware
138
+ - Platform: {sys_info['platform']}
139
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
140
+ - Memory: {sys_info['memory_gb']:.1f} GB
141
+ """
142
+
143
+ if gpu_info.get("available"):
144
+ gpu_names = ", ".join(set(gpu_info["names"]))
145
+ total_vram = sum(gpu_info["memory_gb"])
146
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
147
+ - GPU Memory: {total_vram:.1f} GB total
148
+ - CUDA Version: {gpu_info['cuda_version']}
149
+ """
150
+ else:
151
+ header += "- GPUs: None available\n"
152
+
153
+ if cost_info and cost_info["hourly_rate"] > 0:
154
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
155
+
156
+ header += f"""
157
+ ### Software
158
+ - Python: {sys_info['python_version']}
159
+ - PyTorch: {sys_info['torch_version']}
160
+
161
+ """
162
+
163
+ # bloat metrics: package all of the source code and assess its weight
164
+ packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
165
+ num_chars = len(packaged)
166
+ num_lines = len(packaged.split('\n'))
167
+ num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
168
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
169
+
170
+ # count dependencies via uv.lock
171
+ uv_lock_lines = 0
172
+ if os.path.exists('uv.lock'):
173
+ with open('uv.lock', 'r', encoding='utf-8') as f:
174
+ uv_lock_lines = len(f.readlines())
175
+
176
+ header += f"""
177
+ ### Bloat
178
+ - Characters: {num_chars:,}
179
+ - Lines: {num_lines:,}
180
+ - Files: {num_files:,}
181
+ - Tokens (approx): {num_tokens:,}
182
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
183
+
184
+ """
185
+ return header
186
+
187
+ # -----------------------------------------------------------------------------
188
+
189
+ def slugify(text):
190
+ """Slugify a text string."""
191
+ return text.lower().replace(" ", "-")
192
+
193
+ # the expected files and their order
194
+ EXPECTED_FILES = [
195
+ "tokenizer-training.md",
196
+ "tokenizer-evaluation.md",
197
+ "base-model-training.md",
198
+ "base-model-loss.md",
199
+ "base-model-evaluation.md",
200
+ "midtraining.md",
201
+ "chat-evaluation-mid.md",
202
+ "chat-sft.md",
203
+ "chat-evaluation-sft.md",
204
+ "chat-rl.md",
205
+ "chat-evaluation-rl.md",
206
+ ]
207
+ # the metrics we're currently interested in
208
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
209
+
210
+ def extract(section, keys):
211
+ """simple def to extract a single key from a section"""
212
+ if not isinstance(keys, list):
213
+ keys = [keys] # convenience
214
+ out = {}
215
+ for line in section.split("\n"):
216
+ for key in keys:
217
+ if key in line:
218
+ out[key] = line.split(":")[1].strip()
219
+ return out
220
+
221
+ def extract_timestamp(content, prefix):
222
+ """Extract timestamp from content with given prefix."""
223
+ for line in content.split('\n'):
224
+ if line.startswith(prefix):
225
+ time_str = line.split(":", 1)[1].strip()
226
+ try:
227
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
228
+ except:
229
+ pass
230
+ return None
231
+
232
+ class Report:
233
+ """Maintains a bunch of logs, generates a final markdown report."""
234
+
235
+ def __init__(self, report_dir):
236
+ os.makedirs(report_dir, exist_ok=True)
237
+ self.report_dir = report_dir
238
+
239
+ def log(self, section, data):
240
+ """Log a section of data to the report."""
241
+ slug = slugify(section)
242
+ file_name = f"{slug}.md"
243
+ file_path = os.path.join(self.report_dir, file_name)
244
+ with open(file_path, "w", encoding="utf-8") as f:
245
+ f.write(f"## {section}\n")
246
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
247
+ for item in data:
248
+ if not item:
249
+ # skip falsy values like None or empty dict etc.
250
+ continue
251
+ if isinstance(item, str):
252
+ # directly write the string
253
+ f.write(item)
254
+ else:
255
+ # render a dict
256
+ for k, v in item.items():
257
+ if isinstance(v, float):
258
+ vstr = f"{v:.4f}"
259
+ elif isinstance(v, int) and v >= 10000:
260
+ vstr = f"{v:,.0f}"
261
+ else:
262
+ vstr = str(v)
263
+ f.write(f"- {k}: {vstr}\n")
264
+ f.write("\n")
265
+ return file_path
266
+
267
+ def generate(self):
268
+ """Generate the final report."""
269
+ report_dir = self.report_dir
270
+ report_file = os.path.join(report_dir, "report.md")
271
+ print(f"Generating report to {report_file}")
272
+ final_metrics = {} # the most important final metrics we'll add as table at the end
273
+ start_time = None
274
+ end_time = None
275
+ with open(report_file, "w", encoding="utf-8") as out_file:
276
+ # write the header first
277
+ header_file = os.path.join(report_dir, "header.md")
278
+ if os.path.exists(header_file):
279
+ with open(header_file, "r", encoding="utf-8") as f:
280
+ header_content = f.read()
281
+ out_file.write(header_content)
282
+ start_time = extract_timestamp(header_content, "Run started:")
283
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
284
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
285
+ bloat_data = bloat_data.group(1) if bloat_data else ""
286
+ else:
287
+ start_time = None # will cause us to not write the total wall clock time
288
+ bloat_data = "[bloat data missing]"
289
+ print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
290
+ # process all the individual sections
291
+ for file_name in EXPECTED_FILES:
292
+ section_file = os.path.join(report_dir, file_name)
293
+ if not os.path.exists(section_file):
294
+ print(f"Warning: {section_file} does not exist, skipping")
295
+ continue
296
+ with open(section_file, "r", encoding="utf-8") as in_file:
297
+ section = in_file.read()
298
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
299
+ if "rl" not in file_name:
300
+ # Skip RL sections for end_time calculation because RL is experimental
301
+ end_time = extract_timestamp(section, "timestamp:")
302
+ # extract the most important metrics from the sections
303
+ if file_name == "base-model-evaluation.md":
304
+ final_metrics["base"] = extract(section, "CORE")
305
+ if file_name == "chat-evaluation-mid.md":
306
+ final_metrics["mid"] = extract(section, chat_metrics)
307
+ if file_name == "chat-evaluation-sft.md":
308
+ final_metrics["sft"] = extract(section, chat_metrics)
309
+ if file_name == "chat-evaluation-rl.md":
310
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
311
+ # append this section of the report
312
+ out_file.write(section)
313
+ out_file.write("\n")
314
+ # add the final metrics table
315
+ out_file.write("## Summary\n\n")
316
+ # Copy over the bloat metrics from the header
317
+ out_file.write(bloat_data)
318
+ out_file.write("\n\n")
319
+ # Collect all unique metric names
320
+ all_metrics = set()
321
+ for stage_metrics in final_metrics.values():
322
+ all_metrics.update(stage_metrics.keys())
323
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
324
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
325
+ # Fixed column widths
326
+ stages = ["base", "mid", "sft", "rl"]
327
+ metric_width = 15
328
+ value_width = 8
329
+ # Write table header
330
+ header = f"| {'Metric'.ljust(metric_width)} |"
331
+ for stage in stages:
332
+ header += f" {stage.upper().ljust(value_width)} |"
333
+ out_file.write(header + "\n")
334
+ # Write separator
335
+ separator = f"|{'-' * (metric_width + 2)}|"
336
+ for stage in stages:
337
+ separator += f"{'-' * (value_width + 2)}|"
338
+ out_file.write(separator + "\n")
339
+ # Write table rows
340
+ for metric in all_metrics:
341
+ row = f"| {metric.ljust(metric_width)} |"
342
+ for stage in stages:
343
+ value = final_metrics.get(stage, {}).get(metric, "-")
344
+ row += f" {str(value).ljust(value_width)} |"
345
+ out_file.write(row + "\n")
346
+ out_file.write("\n")
347
+ # Calculate and write total wall clock time
348
+ if start_time and end_time:
349
+ duration = end_time - start_time
350
+ total_seconds = int(duration.total_seconds())
351
+ hours = total_seconds // 3600
352
+ minutes = (total_seconds % 3600) // 60
353
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
354
+ else:
355
+ out_file.write("Total wall clock time: unknown\n")
356
+ # also cp the report.md file to current directory
357
+ print(f"Copying report.md to current directory for convenience")
358
+ shutil.copy(report_file, "report.md")
359
+ return report_file
360
+
361
+ def reset(self):
362
+ """Reset the report."""
363
+ # Remove section files
364
+ for file_name in EXPECTED_FILES:
365
+ file_path = os.path.join(self.report_dir, file_name)
366
+ if os.path.exists(file_path):
367
+ os.remove(file_path)
368
+ # Remove report.md if it exists
369
+ report_file = os.path.join(self.report_dir, "report.md")
370
+ if os.path.exists(report_file):
371
+ os.remove(report_file)
372
+ # Generate and write the header section with start timestamp
373
+ header_file = os.path.join(self.report_dir, "header.md")
374
+ header = generate_header()
375
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
376
+ with open(header_file, "w", encoding="utf-8") as f:
377
+ f.write(header)
378
+ f.write(f"Run started: {start_time}\n\n---\n\n")
379
+ print(f"Reset report and wrote header to {header_file}")
380
+
381
+ # -----------------------------------------------------------------------------
382
+ # nanochat-specific convenience functions
383
+
384
+ class DummyReport:
385
+ def log(self, *args, **kwargs):
386
+ pass
387
+ def reset(self, *args, **kwargs):
388
+ pass
389
+
390
+ def get_report():
391
+ # just for convenience, only rank 0 logs to report
392
+ from nanochat.common import get_base_dir, get_dist_info
393
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
394
+ if ddp_rank == 0:
395
+ report_dir = os.path.join(get_base_dir(), "report")
396
+ return Report(report_dir)
397
+ else:
398
+ return DummyReport()
399
+
400
+ if __name__ == "__main__":
401
+ import argparse
402
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
403
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
404
+ args = parser.parse_args()
405
+ if args.command == "generate":
406
+ get_report().generate()
407
+ elif args.command == "reset":
408
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I haven't validated that this is actually a good idea, TODO.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ assert isinstance(text, str)
110
+ ids = []
111
+ if prepend is not None:
112
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
113
+ ids.append(prepend_id)
114
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
115
+ if append is not None:
116
+ append_id = append if isinstance(append, int) else self.encode_special(append)
117
+ ids.append(append_id)
118
+ return ids
119
+
120
+ def encode_special(self, text):
121
+ # encode a single special token via exact match
122
+ return self.tokenizer.token_to_id(text)
123
+
124
+ def get_bos_token_id(self):
125
+ bos = self.encode_special("<|bos|>")
126
+ return bos
127
+
128
+ def encode(self, text, *args, **kwargs):
129
+ if isinstance(text, str):
130
+ return self._encode_one(text, *args, **kwargs)
131
+ elif isinstance(text, list):
132
+ return [self._encode_one(t, *args, **kwargs) for t in text]
133
+ else:
134
+ raise ValueError(f"Invalid input type: {type(text)}")
135
+
136
+ def __call__(self, *args, **kwargs):
137
+ return self.encode(*args, **kwargs)
138
+
139
+ def decode(self, ids):
140
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
141
+
142
+ def save(self, tokenizer_dir):
143
+ # save the tokenizer to disk
144
+ os.makedirs(tokenizer_dir, exist_ok=True)
145
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
146
+ self.tokenizer.save(tokenizer_path)
147
+ print(f"Saved tokenizer to {tokenizer_path}")
148
+
149
+ # -----------------------------------------------------------------------------
150
+ # Tokenizer based on rustbpe + tiktoken combo
151
+ import pickle
152
+ # import rustbpe # NOT AVAILABLE IN DEPLOYMENT
153
+ import tiktoken
154
+
155
+ class RustBPETokenizer:
156
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
157
+
158
+ def __init__(self, enc, bos_token):
159
+ self.enc = enc
160
+ self.bos_token_id = self.encode_special(bos_token)
161
+
162
+ @classmethod
163
+ def train_from_iterator(cls, text_iterator, vocab_size):
164
+ import rustbpe # Only needed for training
165
+ # 1) train using rustbpe
166
+ tokenizer = rustbpe.Tokenizer()
167
+ # the special tokens are inserted later in __init__, we don't train them here
168
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
169
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
170
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
171
+ # 2) construct the associated tiktoken encoding for inference
172
+ pattern = tokenizer.get_pattern()
173
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
174
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
175
+ tokens_offset = len(mergeable_ranks)
176
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
177
+ enc = tiktoken.Encoding(
178
+ name="rustbpe",
179
+ pat_str=pattern,
180
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
181
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
182
+ )
183
+ return cls(enc, "<|bos|>")
184
+
185
+ @classmethod
186
+ def from_directory(cls, tokenizer_dir):
187
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
188
+ with open(pickle_path, "rb") as f:
189
+ enc = pickle.load(f)
190
+ return cls(enc, "<|bos|>")
191
+
192
+ @classmethod
193
+ def from_pretrained(cls, tiktoken_name):
194
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
195
+ enc = tiktoken.get_encoding(tiktoken_name)
196
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
197
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
198
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
199
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
200
+ return cls(enc, "<|endoftext|>")
201
+
202
+ def get_vocab_size(self):
203
+ return self.enc.n_vocab
204
+
205
+ def get_special_tokens(self):
206
+ return self.enc.special_tokens_set
207
+
208
+ def id_to_token(self, id):
209
+ return self.enc.decode([id])
210
+
211
+ @lru_cache(maxsize=32)
212
+ def encode_special(self, text):
213
+ return self.enc.encode_single_token(text)
214
+
215
+ def get_bos_token_id(self):
216
+ return self.bos_token_id
217
+
218
+ def encode(self, text, prepend=None, append=None, num_threads=8):
219
+ # text can be either a string or a list of strings
220
+
221
+ if prepend is not None:
222
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
223
+ if append is not None:
224
+ append_id = append if isinstance(append, int) else self.encode_special(append)
225
+
226
+ if isinstance(text, str):
227
+ ids = self.enc.encode_ordinary(text)
228
+ if prepend is not None:
229
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
230
+ if append is not None:
231
+ ids.append(append_id)
232
+ elif isinstance(text, list):
233
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
234
+ if prepend is not None:
235
+ for ids_row in ids:
236
+ ids_row.insert(0, prepend_id) # TODO: same
237
+ if append is not None:
238
+ for ids_row in ids:
239
+ ids_row.append(append_id)
240
+ else:
241
+ raise ValueError(f"Invalid input type: {type(text)}")
242
+
243
+ return ids
244
+
245
+ def __call__(self, *args, **kwargs):
246
+ return self.encode(*args, **kwargs)
247
+
248
+ def decode(self, ids):
249
+ return self.enc.decode(ids)
250
+
251
+ def save(self, tokenizer_dir):
252
+ # save the encoding object to disk
253
+ os.makedirs(tokenizer_dir, exist_ok=True)
254
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
255
+ with open(pickle_path, "wb") as f:
256
+ pickle.dump(self.enc, f)
257
+ print(f"Saved tokenizer encoding to {pickle_path}")
258
+
259
+ def render_conversation(self, conversation, max_tokens=2048):
260
+ """
261
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
262
+ Returns:
263
+ - ids: list[int] is a list of token ids of this rendered conversation
264
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
265
+ """
266
+ # ids, masks that we will return and a helper function to help build them up.
267
+ ids, mask = [], []
268
+ def add_tokens(token_ids, mask_val):
269
+ if isinstance(token_ids, int):
270
+ token_ids = [token_ids]
271
+ ids.extend(token_ids)
272
+ mask.extend([mask_val] * len(token_ids))
273
+
274
+ # sometimes the first message is a system message...
275
+ # => just merge it with the second (user) message
276
+ if conversation["messages"][0]["role"] == "system":
277
+ # some conversation surgery is necessary here for now...
278
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
279
+ messages = conversation["messages"]
280
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
281
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
282
+ messages = messages[1:]
283
+ else:
284
+ messages = conversation["messages"]
285
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
286
+
287
+ # fetch all the special tokens we need
288
+ bos = self.get_bos_token_id()
289
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
290
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
291
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
292
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
293
+
294
+ # now we can tokenize the conversation
295
+ add_tokens(bos, 0)
296
+ for i, message in enumerate(messages):
297
+
298
+ # some sanity checking here around assumptions, to prevent footguns
299
+ must_be_from = "user" if i % 2 == 0 else "assistant"
300
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
301
+
302
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
303
+ content = message["content"]
304
+
305
+ if message["role"] == "user":
306
+ assert isinstance(content, str), "User messages are simply expected to be strings"
307
+ value_ids = self.encode(content)
308
+ add_tokens(user_start, 0)
309
+ add_tokens(value_ids, 0)
310
+ add_tokens(user_end, 0)
311
+ elif message["role"] == "assistant":
312
+ add_tokens(assistant_start, 0)
313
+ if isinstance(content, str):
314
+ # simple string => simply add the tokens
315
+ value_ids = self.encode(content)
316
+ add_tokens(value_ids, 1)
317
+ elif isinstance(content, list):
318
+ for part in content:
319
+ value_ids = self.encode(part["text"])
320
+ if part["type"] == "text":
321
+ # string part => simply add the tokens
322
+ add_tokens(value_ids, 1)
323
+ elif part["type"] == "python":
324
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
325
+ add_tokens(python_start, 1)
326
+ add_tokens(value_ids, 1)
327
+ add_tokens(python_end, 1)
328
+ elif part["type"] == "python_output":
329
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
330
+ # none of these tokens are supervised because the tokens come from Python at test time
331
+ add_tokens(output_start, 0)
332
+ add_tokens(value_ids, 0)
333
+ add_tokens(output_end, 0)
334
+ else:
335
+ raise ValueError(f"Unknown part type: {part['type']}")
336
+ else:
337
+ raise ValueError(f"Unknown content type: {type(content)}")
338
+ add_tokens(assistant_end, 1)
339
+
340
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
341
+ ids = ids[:max_tokens]
342
+ mask = mask[:max_tokens]
343
+ return ids, mask
344
+
345
+ def visualize_tokenization(self, ids, mask, with_token_id=False):
346
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
347
+ RED = '\033[91m'
348
+ GREEN = '\033[92m'
349
+ RESET = '\033[0m'
350
+ GRAY = '\033[90m'
351
+ tokens = []
352
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
353
+ token_str = self.decode([token_id])
354
+ color = GREEN if mask_val == 1 else RED
355
+ tokens.append(f"{color}{token_str}{RESET}")
356
+ if with_token_id:
357
+ tokens.append(f"{GRAY}({token_id}){RESET}")
358
+ return '|'.join(tokens)
359
+
360
+ def render_for_completion(self, conversation):
361
+ """
362
+ Used during Reinforcement Learning. In that setting, we want to
363
+ render the conversation priming the Assistant for a completion.
364
+ Unlike the Chat SFT case, we don't need to return the mask.
365
+ """
366
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
367
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
368
+ messages = conversation["messages"]
369
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
370
+ messages.pop() # remove the last message (of the Assistant) inplace
371
+
372
+ # Now tokenize the conversation
373
+ ids, mask = self.render_conversation(conversation)
374
+
375
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
376
+ assistant_start = self.encode_special("<|assistant_start|>")
377
+ ids.append(assistant_start)
378
+ return ids
379
+
380
+ # -----------------------------------------------------------------------------
381
+ # nanochat-specific convenience functions
382
+
383
+ def get_tokenizer():
384
+ from nanochat.common import get_base_dir
385
+ base_dir = get_base_dir()
386
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
387
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
388
+ return RustBPETokenizer.from_directory(tokenizer_dir)
389
+
390
+ def get_token_bytes(device="cpu"):
391
+ import torch
392
+ from nanochat.common import get_base_dir
393
+ base_dir = get_base_dir()
394
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
395
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
396
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
397
+ with open(token_bytes_path, "rb") as f:
398
+ token_bytes = torch.load(f, map_location=device)
399
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ body {
18
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
19
+ background-color: #ffffff;
20
+ color: #111827;
21
+ min-height: 100dvh;
22
+ margin: 0;
23
+ display: flex;
24
+ flex-direction: column;
25
+ }
26
+
27
+ .header {
28
+ background-color: #ffffff;
29
+ padding: 1.25rem 1.5rem;
30
+ }
31
+
32
+ .header-left {
33
+ display: flex;
34
+ align-items: center;
35
+ gap: 0.75rem;
36
+ }
37
+
38
+ .header-logo {
39
+ height: 32px;
40
+ width: auto;
41
+ }
42
+
43
+ .header h1 {
44
+ font-size: 1.25rem;
45
+ font-weight: 600;
46
+ margin: 0;
47
+ color: #111827;
48
+ }
49
+
50
+ .new-conversation-btn {
51
+ width: 32px;
52
+ height: 32px;
53
+ padding: 0;
54
+ border: 1px solid #e5e7eb;
55
+ border-radius: 0.5rem;
56
+ background-color: #ffffff;
57
+ color: #6b7280;
58
+ cursor: pointer;
59
+ display: flex;
60
+ align-items: center;
61
+ justify-content: center;
62
+ transition: all 0.2s ease;
63
+ }
64
+
65
+ .new-conversation-btn:hover {
66
+ background-color: #f3f4f6;
67
+ border-color: #d1d5db;
68
+ color: #374151;
69
+ }
70
+
71
+ .chat-container {
72
+ flex: 1;
73
+ overflow-y: auto;
74
+ background-color: #ffffff;
75
+ }
76
+
77
+ .chat-wrapper {
78
+ max-width: 48rem;
79
+ margin: 0 auto;
80
+ padding: 2rem 1.5rem 3rem;
81
+ display: flex;
82
+ flex-direction: column;
83
+ gap: 0.75rem;
84
+ }
85
+
86
+ .message {
87
+ display: flex;
88
+ justify-content: flex-start;
89
+ margin-bottom: 0.5rem;
90
+ color: #0d0d0d;
91
+ }
92
+
93
+ .message.assistant {
94
+ justify-content: flex-start;
95
+ }
96
+
97
+ .message.user {
98
+ justify-content: flex-end;
99
+ }
100
+
101
+ .message-content {
102
+ white-space: pre-wrap;
103
+ line-height: 1.6;
104
+ max-width: 100%;
105
+ }
106
+
107
+ .message.assistant .message-content {
108
+ background: transparent;
109
+ border: none;
110
+ padding: 0.25rem 0;
111
+ cursor: pointer;
112
+ border-radius: 0.5rem;
113
+ padding: 0.5rem;
114
+ margin-left: -0.5rem;
115
+ transition: background-color 0.2s ease;
116
+ }
117
+
118
+ .message.assistant .message-content:hover {
119
+ background-color: #f9fafb;
120
+ }
121
+
122
+ .message.user .message-content {
123
+ background-color: #f3f4f6;
124
+ border-radius: 1.25rem;
125
+ padding: 0.8rem 1rem;
126
+ max-width: 65%;
127
+ cursor: pointer;
128
+ transition: background-color 0.2s ease;
129
+ }
130
+
131
+ .message.user .message-content:hover {
132
+ background-color: #e5e7eb;
133
+ }
134
+
135
+ .message.console .message-content {
136
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
137
+ font-size: 0.875rem;
138
+ background-color: #fafafa;
139
+ padding: 0.75rem 1rem;
140
+ color: #374151;
141
+ max-width: 80%;
142
+ }
143
+
144
+ .input-container {
145
+ background-color: #ffffff;
146
+ padding: 1rem;
147
+ padding-bottom: calc(1rem + env(safe-area-inset-bottom))
148
+ }
149
+
150
+ .input-wrapper {
151
+ max-width: 48rem;
152
+ margin: 0 auto;
153
+ display: flex;
154
+ gap: 0.75rem;
155
+ align-items: flex-end;
156
+ }
157
+
158
+ .chat-input {
159
+ flex: 1;
160
+ padding: 0.8rem 1rem;
161
+ border: 1px solid #d1d5db;
162
+ border-radius: 0.75rem;
163
+ background-color: #ffffff;
164
+ color: #111827;
165
+ font-size: 1rem;
166
+ line-height: 1.5;
167
+ resize: none;
168
+ outline: none;
169
+ min-height: 54px;
170
+ max-height: 200px;
171
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
172
+ }
173
+
174
+ .chat-input::placeholder {
175
+ color: #9ca3af;
176
+ }
177
+
178
+ .chat-input:focus {
179
+ border-color: #2563eb;
180
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
181
+ }
182
+
183
+ .send-button {
184
+ flex-shrink: 0;
185
+ padding: 0;
186
+ width: 54px;
187
+ height: 54px;
188
+ border: 1px solid #111827;
189
+ border-radius: 0.75rem;
190
+ background-color: #111827;
191
+ color: #ffffff;
192
+ display: flex;
193
+ align-items: center;
194
+ justify-content: center;
195
+ cursor: pointer;
196
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
197
+ }
198
+
199
+ .send-button:hover:not(:disabled) {
200
+ background-color: #2563eb;
201
+ border-color: #2563eb;
202
+ }
203
+
204
+ .send-button:disabled {
205
+ cursor: not-allowed;
206
+ border-color: #d1d5db;
207
+ background-color: #e5e7eb;
208
+ color: #9ca3af;
209
+ }
210
+
211
+ .typing-indicator {
212
+ display: inline-block;
213
+ color: #6b7280;
214
+ letter-spacing: 0.15em;
215
+ }
216
+
217
+ .typing-indicator::after {
218
+ content: '···';
219
+ animation: typing 1.4s infinite;
220
+ }
221
+
222
+ @keyframes typing {
223
+ 0%, 60%, 100% { opacity: 0.2; }
224
+ 30% { opacity: 1; }
225
+ }
226
+
227
+ .error-message {
228
+ background-color: #fee2e2;
229
+ border: 1px solid #fecaca;
230
+ color: #b91c1c;
231
+ padding: 0.75rem 1rem;
232
+ border-radius: 0.75rem;
233
+ margin-top: 0.5rem;
234
+ }
235
+ </style>
236
+ </head>
237
+ <body>
238
+ <div class="header">
239
+ <div class="header-left">
240
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
241
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
242
+ <path d="M12 5v14"></path>
243
+ <path d="M5 12h14"></path>
244
+ </svg>
245
+ </button>
246
+ <h1>nanochat</h1>
247
+ </div>
248
+ </div>
249
+
250
+ <div class="chat-container" id="chatContainer">
251
+ <div class="chat-wrapper" id="chatWrapper">
252
+ <!-- Messages will be added here -->
253
+ </div>
254
+ </div>
255
+
256
+ <div class="input-container">
257
+ <div class="input-wrapper">
258
+ <textarea
259
+ id="chatInput"
260
+ class="chat-input"
261
+ placeholder="Ask anything"
262
+ rows="1"
263
+ onkeydown="handleKeyDown(event)"
264
+ ></textarea>
265
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
266
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
267
+ <path d="M22 2L11 13"></path>
268
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
269
+ </svg>
270
+ </button>
271
+ </div>
272
+ </div>
273
+
274
+ <script>
275
+ const API_URL = '';
276
+ const chatContainer = document.getElementById('chatContainer');
277
+ const chatWrapper = document.getElementById('chatWrapper');
278
+ const chatInput = document.getElementById('chatInput');
279
+ const sendButton = document.getElementById('sendButton');
280
+
281
+ let messages = [];
282
+ let isGenerating = false;
283
+ let currentTemperature = 0.8;
284
+ let currentTopK = 50;
285
+
286
+ chatInput.addEventListener('input', function() {
287
+ this.style.height = 'auto';
288
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
289
+ sendButton.disabled = !this.value.trim() || isGenerating;
290
+ });
291
+
292
+ function handleKeyDown(event) {
293
+ if (event.key === 'Enter' && !event.shiftKey) {
294
+ event.preventDefault();
295
+ sendMessage();
296
+ }
297
+ }
298
+
299
+ document.addEventListener('keydown', function(event) {
300
+ // Ctrl+Shift+N for new conversation
301
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
302
+ event.preventDefault();
303
+ if (!isGenerating) {
304
+ newConversation();
305
+ }
306
+ }
307
+ });
308
+
309
+ function newConversation() {
310
+ messages = [];
311
+ chatWrapper.innerHTML = '';
312
+ chatInput.value = '';
313
+ chatInput.style.height = 'auto';
314
+ sendButton.disabled = false;
315
+ isGenerating = false;
316
+ chatInput.focus();
317
+ }
318
+
319
+ function addMessage(role, content, messageIndex = null) {
320
+ const messageDiv = document.createElement('div');
321
+ messageDiv.className = `message ${role}`;
322
+
323
+ const contentDiv = document.createElement('div');
324
+ contentDiv.className = 'message-content';
325
+ contentDiv.textContent = content;
326
+
327
+ // Add click handler for user messages to enable editing
328
+ if (role === 'user' && messageIndex !== null) {
329
+ contentDiv.setAttribute('data-message-index', messageIndex);
330
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
331
+ contentDiv.addEventListener('click', function() {
332
+ if (!isGenerating) {
333
+ editMessage(messageIndex);
334
+ }
335
+ });
336
+ }
337
+
338
+ // Add click handler for assistant messages to enable regeneration
339
+ if (role === 'assistant' && messageIndex !== null) {
340
+ contentDiv.setAttribute('data-message-index', messageIndex);
341
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
342
+ contentDiv.addEventListener('click', function() {
343
+ if (!isGenerating) {
344
+ regenerateMessage(messageIndex);
345
+ }
346
+ });
347
+ }
348
+
349
+ messageDiv.appendChild(contentDiv);
350
+ chatWrapper.appendChild(messageDiv);
351
+
352
+ chatContainer.scrollTop = chatContainer.scrollHeight;
353
+ return contentDiv;
354
+ }
355
+
356
+ function editMessage(messageIndex) {
357
+ // Find the message in the messages array
358
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
359
+
360
+ const messageToEdit = messages[messageIndex];
361
+ if (messageToEdit.role !== 'user') return;
362
+
363
+ // Copy message content to input
364
+ chatInput.value = messageToEdit.content;
365
+ chatInput.style.height = 'auto';
366
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
367
+
368
+ // Remove this message and all subsequent messages from the array
369
+ messages = messages.slice(0, messageIndex);
370
+
371
+ // Remove message elements from DOM starting from messageIndex
372
+ const allMessages = chatWrapper.querySelectorAll('.message');
373
+ for (let i = messageIndex; i < allMessages.length; i++) {
374
+ allMessages[i].remove();
375
+ }
376
+
377
+ // Enable send button and focus input
378
+ sendButton.disabled = false;
379
+ chatInput.focus();
380
+ }
381
+
382
+ async function generateAssistantResponse() {
383
+ isGenerating = true;
384
+ sendButton.disabled = true;
385
+
386
+ const assistantContent = addMessage('assistant', '');
387
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
388
+
389
+ try {
390
+ const response = await fetch(`${API_URL}/chat/completions`, {
391
+ method: 'POST',
392
+ headers: {
393
+ 'Content-Type': 'application/json',
394
+ },
395
+ body: JSON.stringify({
396
+ messages: messages,
397
+ temperature: currentTemperature,
398
+ top_k: currentTopK,
399
+ max_tokens: 512
400
+ }),
401
+ });
402
+
403
+ if (!response.ok) {
404
+ throw new Error(`HTTP error! status: ${response.status}`);
405
+ }
406
+
407
+ const reader = response.body.getReader();
408
+ const decoder = new TextDecoder();
409
+ let fullResponse = '';
410
+ assistantContent.textContent = '';
411
+
412
+ while (true) {
413
+ const { done, value } = await reader.read();
414
+ if (done) break;
415
+
416
+ const chunk = decoder.decode(value);
417
+ const lines = chunk.split('\n');
418
+
419
+ for (const line of lines) {
420
+ if (line.startsWith('data: ')) {
421
+ try {
422
+ const data = JSON.parse(line.slice(6));
423
+ if (data.token) {
424
+ fullResponse += data.token;
425
+ assistantContent.textContent = fullResponse;
426
+ chatContainer.scrollTop = chatContainer.scrollHeight;
427
+ }
428
+ } catch (e) {
429
+ }
430
+ }
431
+ }
432
+ }
433
+
434
+ const assistantMessageIndex = messages.length;
435
+ messages.push({ role: 'assistant', content: fullResponse });
436
+
437
+ // Add click handler to regenerate this assistant message
438
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
439
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
440
+ assistantContent.addEventListener('click', function() {
441
+ if (!isGenerating) {
442
+ regenerateMessage(assistantMessageIndex);
443
+ }
444
+ });
445
+
446
+ } catch (error) {
447
+ console.error('Error:', error);
448
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
449
+ } finally {
450
+ isGenerating = false;
451
+ sendButton.disabled = !chatInput.value.trim();
452
+ }
453
+ }
454
+
455
+ async function regenerateMessage(messageIndex) {
456
+ // Find the message in the messages array
457
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
458
+
459
+ const messageToRegenerate = messages[messageIndex];
460
+ if (messageToRegenerate.role !== 'assistant') return;
461
+
462
+ // Remove this message and all subsequent messages from the array
463
+ messages = messages.slice(0, messageIndex);
464
+
465
+ // Remove message elements from DOM starting from messageIndex
466
+ const allMessages = chatWrapper.querySelectorAll('.message');
467
+ for (let i = messageIndex; i < allMessages.length; i++) {
468
+ allMessages[i].remove();
469
+ }
470
+
471
+ // Regenerate the assistant response
472
+ await generateAssistantResponse();
473
+ }
474
+
475
+ function handleSlashCommand(command) {
476
+ const parts = command.trim().split(/\s+/);
477
+ const cmd = parts[0].toLowerCase();
478
+ const arg = parts[1];
479
+
480
+ if (cmd === '/temperature') {
481
+ if (arg === undefined) {
482
+ addMessage('console', `Current temperature: ${currentTemperature}`);
483
+ } else {
484
+ const temp = parseFloat(arg);
485
+ if (isNaN(temp) || temp < 0 || temp > 2) {
486
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
487
+ } else {
488
+ currentTemperature = temp;
489
+ addMessage('console', `Temperature set to ${currentTemperature}`);
490
+ }
491
+ }
492
+ return true;
493
+ } else if (cmd === '/topk') {
494
+ if (arg === undefined) {
495
+ addMessage('console', `Current top-k: ${currentTopK}`);
496
+ } else {
497
+ const topk = parseInt(arg);
498
+ if (isNaN(topk) || topk < 1 || topk > 200) {
499
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
500
+ } else {
501
+ currentTopK = topk;
502
+ addMessage('console', `Top-k set to ${currentTopK}`);
503
+ }
504
+ }
505
+ return true;
506
+ } else if (cmd === '/clear') {
507
+ newConversation();
508
+ return true;
509
+ } else if (cmd === '/help') {
510
+ addMessage('console',
511
+ 'Available commands:\n' +
512
+ '/temperature - Show current temperature\n' +
513
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
514
+ '/topk - Show current top-k\n' +
515
+ '/topk <value> - Set top-k (1-200)\n' +
516
+ '/clear - Clear conversation\n' +
517
+ '/help - Show this help message'
518
+ );
519
+ return true;
520
+ }
521
+ return false;
522
+ }
523
+
524
+ async function sendMessage() {
525
+ const message = chatInput.value.trim();
526
+ if (!message || isGenerating) return;
527
+
528
+ // Handle slash commands
529
+ if (message.startsWith('/')) {
530
+ chatInput.value = '';
531
+ chatInput.style.height = 'auto';
532
+ handleSlashCommand(message);
533
+ return;
534
+ }
535
+
536
+ chatInput.value = '';
537
+ chatInput.style.height = 'auto';
538
+
539
+ const userMessageIndex = messages.length;
540
+ messages.push({ role: 'user', content: message });
541
+ addMessage('user', message, userMessageIndex);
542
+
543
+ await generateAssistantResponse();
544
+ }
545
+
546
+ sendButton.disabled = false;
547
+
548
+ // Autofocus the chat input on page load
549
+ chatInput.focus();
550
+
551
+ fetch(`${API_URL}/health`)
552
+ .then(response => response.json())
553
+ .then(data => {
554
+ console.log('Engine status:', data);
555
+ })
556
+ .catch(error => {
557
+ console.error('Engine not available:', error);
558
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
559
+ });
560
+ </script>
561
+ </body>
562
+ </html>
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ tiktoken
4
+ gradio
5
+ filelock
6
+ tokenizers
7
+ huggingface_hub