TMVishnu commited on
Commit
8bc34c1
·
verified ·
1 Parent(s): 6c192fb

Upload 20 files

Browse files
nanochat/__init__.py ADDED
File without changes
nanochat/adamw.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distributed AdamW optimizer with a fused step function.
3
+ A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
4
+ """
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import Tensor
8
+
9
+ @torch.compile(dynamic=False, fullgraph=True)
10
+ def adamw_step_fused(
11
+ p: Tensor,
12
+ grad: Tensor,
13
+ exp_avg: Tensor,
14
+ exp_avg_sq: Tensor,
15
+ step_t: Tensor,
16
+ lr_t: Tensor,
17
+ beta1_t: Tensor,
18
+ beta2_t: Tensor,
19
+ eps_t: Tensor,
20
+ wd_t: Tensor,
21
+ ) -> None:
22
+ """
23
+ Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
24
+ All in one compiled graph to eliminate Python overhead between ops.
25
+ The 0-D CPU tensors avoid recompilation when hyperparameter values change.
26
+ """
27
+ # Weight decay (decoupled, applied before the update)
28
+ p.mul_(1 - lr_t * wd_t)
29
+ # Update running averages (lerp_ is cleaner and fuses well)
30
+ exp_avg.lerp_(grad, 1 - beta1_t)
31
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
32
+ # Bias corrections
33
+ bias1 = 1 - beta1_t ** step_t
34
+ bias2 = 1 - beta2_t ** step_t
35
+ # Compute update and apply
36
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
37
+ step_size = lr_t / bias1
38
+ p.add_(exp_avg / denom, alpha=-step_size)
39
+
40
+
41
+ class DistAdamW(torch.optim.Optimizer):
42
+ """
43
+ Distributed AdamW optimizer.
44
+ In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
45
+ """
46
+ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
47
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
48
+ rank = dist.get_rank()
49
+ world_size = dist.get_world_size()
50
+ # Validate
51
+ if rank == 0:
52
+ for group in param_groups:
53
+ assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
54
+ assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
55
+ for p in group['params']:
56
+ sliced = p.numel() >= 1024
57
+ print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
58
+ if sliced: # large parameter tensors will be operated on in slices
59
+ assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
60
+ super().__init__(param_groups, defaults)
61
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
62
+ self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
63
+ self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
64
+ self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
65
+ self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
66
+ self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
67
+ self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
68
+
69
+ @torch.no_grad()
70
+ def step(self):
71
+ rank = dist.get_rank()
72
+ world_size = dist.get_world_size()
73
+ reduce_futures: list[torch.Future] = []
74
+ gather_futures: list[torch.Future] = []
75
+ grad_slices = []
76
+ is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
77
+
78
+ for group in self.param_groups:
79
+ params: list[Tensor] = group["params"]
80
+ for p in params:
81
+ grad = p.grad
82
+ # Small params: use all_reduce (no scatter/gather needed)
83
+ if p.numel() < 1024:
84
+ is_small.append(True)
85
+ reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
86
+ grad_slices.append(grad)
87
+ else:
88
+ is_small.append(False)
89
+ rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
90
+ grad_slice = torch.empty_like(grad[:rank_size])
91
+ reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
92
+ grad_slices.append(grad_slice)
93
+
94
+ idx = 0
95
+ for group in self.param_groups:
96
+ beta1, beta2 = group['betas']
97
+ eps = group['eps']
98
+ wd = group['weight_decay']
99
+ params = group['params']
100
+ for p in params:
101
+ reduce_futures[idx].wait()
102
+ g_slice = grad_slices[idx]
103
+ lr = group['lr'] * getattr(p, "lr_mul", 1.0)
104
+ state = self.state[p]
105
+
106
+ # For small params, operate on full param; for large, operate on slice
107
+ if is_small[idx]:
108
+ p_slice = p
109
+ else:
110
+ rank_size = p.shape[0] // world_size
111
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
112
+
113
+ # State init
114
+ if not state:
115
+ state['step'] = 0
116
+ state['exp_avg'] = torch.zeros_like(p_slice)
117
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
118
+ exp_avg = state['exp_avg']
119
+ exp_avg_sq = state['exp_avg_sq']
120
+ state['step'] += 1
121
+
122
+ # Fill 0-D tensors with current values
123
+ eff_wd = wd * getattr(p, "wd_mul", 1.0)
124
+ self._step_t.fill_(state['step'])
125
+ self._lr_t.fill_(lr)
126
+ self._beta1_t.fill_(beta1)
127
+ self._beta2_t.fill_(beta2)
128
+ self._eps_t.fill_(eps)
129
+ self._wd_t.fill_(eff_wd)
130
+
131
+ # Fused update: weight_decay -> momentum -> bias_correction -> param_update
132
+ adamw_step_fused(
133
+ p_slice, g_slice, exp_avg, exp_avg_sq,
134
+ self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
135
+ )
136
+
137
+ # Only large params need all_gather
138
+ if not is_small[idx]:
139
+ gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
140
+ idx += 1
141
+
142
+ if gather_futures:
143
+ torch.futures.collect_all(gather_futures).wait()
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def _patch_missing_config_keys(model_config_kwargs):
24
+ """Add default values for new config keys missing in old checkpoints."""
25
+ # Old models were trained with full context (no sliding window)
26
+ if "window_pattern" not in model_config_kwargs:
27
+ model_config_kwargs["window_pattern"] = "L"
28
+ log0(f"Patching missing window_pattern in model config to 'L'")
29
+
30
+ def _patch_missing_keys(model_data, model_config):
31
+ """Add default values for new parameters that may be missing in old checkpoints."""
32
+ n_layer = model_config.n_layer
33
+ # resid_lambdas defaults to 1.0 (identity scaling)
34
+ if "resid_lambdas" not in model_data:
35
+ model_data["resid_lambdas"] = torch.ones(n_layer)
36
+ log0(f"Patching missing resid_lambdas in model data to 1.0")
37
+ # x0_lambdas defaults to 0.0 (disabled)
38
+ if "x0_lambdas" not in model_data:
39
+ model_data["x0_lambdas"] = torch.zeros(n_layer)
40
+ log0(f"Patching missing x0_lambdas in model data to 0.0")
41
+
42
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
43
+ if rank == 0:
44
+ os.makedirs(checkpoint_dir, exist_ok=True)
45
+ # Save the model state parameters
46
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
47
+ torch.save(model_data, model_path)
48
+ logger.info(f"Saved model parameters to: {model_path}")
49
+ # Save the metadata dict as json
50
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
51
+ with open(meta_path, "w", encoding="utf-8") as f:
52
+ json.dump(meta_data, f, indent=2)
53
+ logger.info(f"Saved metadata to: {meta_path}")
54
+ # Note that optimizer state is sharded across ranks, so each rank must save its own.
55
+ if optimizer_data is not None:
56
+ os.makedirs(checkpoint_dir, exist_ok=True)
57
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
58
+ torch.save(optimizer_data, optimizer_path)
59
+ logger.info(f"Saved optimizer state to: {optimizer_path}")
60
+
61
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
62
+ # Load the model state
63
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
64
+ model_data = torch.load(model_path, map_location=device)
65
+ # Load the optimizer state if requested
66
+ optimizer_data = None
67
+ if load_optimizer:
68
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
69
+ optimizer_data = torch.load(optimizer_path, map_location=device)
70
+ # Load the metadata
71
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
72
+ with open(meta_path, "r", encoding="utf-8") as f:
73
+ meta_data = json.load(f)
74
+ return model_data, optimizer_data, meta_data
75
+
76
+
77
+ def build_model(checkpoint_dir, step, device, phase):
78
+ """
79
+ A bunch of repetitive code to build a model from a given checkpoint.
80
+ Returns:
81
+ - base model - uncompiled, not wrapped in DDP
82
+ - tokenizer
83
+ - meta data saved during base model training
84
+ """
85
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
86
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
87
+ if device.type in {"cpu", "mps"}:
88
+ # Convert bfloat16 tensors to float for CPU inference
89
+ model_data = {
90
+ k: v.float() if v.dtype == torch.bfloat16 else v
91
+ for k, v in model_data.items()
92
+ }
93
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
94
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
95
+ model_config_kwargs = meta_data["model_config"]
96
+ _patch_missing_config_keys(model_config_kwargs)
97
+ log0(f"Building model with config: {model_config_kwargs}")
98
+ model_config = GPTConfig(**model_config_kwargs)
99
+ _patch_missing_keys(model_data, model_config)
100
+ with torch.device("meta"):
101
+ model = GPT(model_config)
102
+ # Load the model state
103
+ model.to_empty(device=device)
104
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
105
+ model.load_state_dict(model_data, strict=False, assign=True)
106
+ # Put the model in the right training phase / mode
107
+ if phase == "eval":
108
+ model.eval()
109
+ else:
110
+ model.train()
111
+ # Load the Tokenizer
112
+ tokenizer = get_tokenizer()
113
+ # Sanity check: compatibility between model and tokenizer
114
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
115
+ return model, tokenizer, meta_data
116
+
117
+
118
+ def find_largest_model(checkpoints_dir):
119
+ # attempt to guess the model tag: take the biggest model available
120
+ model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
121
+ if not model_tags:
122
+ raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
123
+ # 1) normally all model tags are of the form d<number>, try that first:
124
+ candidates = []
125
+ for model_tag in model_tags:
126
+ match = re.match(r"d(\d+)", model_tag)
127
+ if match:
128
+ model_depth = int(match.group(1))
129
+ candidates.append((model_depth, model_tag))
130
+ if candidates:
131
+ candidates.sort(key=lambda x: x[0], reverse=True)
132
+ return candidates[0][1]
133
+ # 2) if that failed, take the most recently updated model:
134
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
135
+ return model_tags[0]
136
+
137
+
138
+ def find_last_step(checkpoint_dir):
139
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
140
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
141
+ if not checkpoint_files:
142
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
143
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
144
+ return last_step
145
+
146
+ # -----------------------------------------------------------------------------
147
+ # convenience functions that take into account nanochat's directory structure
148
+
149
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
150
+ if model_tag is None:
151
+ # guess the model tag by defaulting to the largest model
152
+ model_tag = find_largest_model(checkpoints_dir)
153
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
154
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
155
+ if step is None:
156
+ # guess the step by defaulting to the last step
157
+ step = find_last_step(checkpoint_dir)
158
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
159
+ # build the model
160
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
161
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
162
+ return model, tokenizer, meta_data
163
+
164
+ def load_model(source, *args, **kwargs):
165
+ model_dir = {
166
+ "base": "base_checkpoints",
167
+ "distill": "distill_checkpoints",
168
+ "mid": "mid_checkpoints",
169
+ "sft": "chatsft_checkpoints",
170
+ "rl": "chatrl_checkpoints",
171
+ }[source]
172
+ base_dir = get_base_dir()
173
+ checkpoints_dir = os.path.join(base_dir, model_dir)
174
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
nanochat/common.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import urllib.request
9
+ import torch
10
+ import torch.distributed as dist
11
+ from filelock import FileLock
12
+
13
+ class ColoredFormatter(logging.Formatter):
14
+ """Custom formatter that adds colors to log messages."""
15
+ # ANSI color codes
16
+ COLORS = {
17
+ 'DEBUG': '\033[36m', # Cyan
18
+ 'INFO': '\033[32m', # Green
19
+ 'WARNING': '\033[33m', # Yellow
20
+ 'ERROR': '\033[31m', # Red
21
+ 'CRITICAL': '\033[35m', # Magenta
22
+ }
23
+ RESET = '\033[0m'
24
+ BOLD = '\033[1m'
25
+ def format(self, record):
26
+ # Add color to the level name
27
+ levelname = record.levelname
28
+ if levelname in self.COLORS:
29
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
30
+ # Format the message
31
+ message = super().format(record)
32
+ # Add color to specific parts of the message
33
+ if levelname == 'INFO':
34
+ # Highlight numbers and percentages
35
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
36
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
37
+ return message
38
+
39
+ def setup_default_logging():
40
+ handler = logging.StreamHandler()
41
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ handlers=[handler]
45
+ )
46
+
47
+ setup_default_logging()
48
+ logger = logging.getLogger(__name__)
49
+
50
+ def get_base_dir():
51
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
52
+ if os.environ.get("NANOCHAT_BASE_DIR"):
53
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
54
+ else:
55
+ home_dir = os.path.expanduser("~")
56
+ cache_dir = os.path.join(home_dir, ".cache")
57
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
58
+ os.makedirs(nanochat_dir, exist_ok=True)
59
+ return nanochat_dir
60
+
61
+ def download_file_with_lock(url, filename, postprocess_fn=None):
62
+ """
63
+ Downloads a file from a URL to a local path in the base directory.
64
+ Uses a lock file to prevent concurrent downloads among multiple ranks.
65
+ """
66
+ base_dir = get_base_dir()
67
+ file_path = os.path.join(base_dir, filename)
68
+ lock_path = file_path + ".lock"
69
+
70
+ if os.path.exists(file_path):
71
+ return file_path
72
+
73
+ with FileLock(lock_path):
74
+ # Only a single rank can acquire this lock
75
+ # All other ranks block until it is released
76
+
77
+ # Recheck after acquiring lock
78
+ if os.path.exists(file_path):
79
+ return file_path
80
+
81
+ # Download the content as bytes
82
+ print(f"Downloading {url}...")
83
+ with urllib.request.urlopen(url) as response:
84
+ content = response.read() # bytes
85
+
86
+ # Write to local file
87
+ with open(file_path, 'wb') as f:
88
+ f.write(content)
89
+ print(f"Downloaded to {file_path}")
90
+
91
+ # Run the postprocess function if provided
92
+ if postprocess_fn is not None:
93
+ postprocess_fn(file_path)
94
+
95
+ return file_path
96
+
97
+ def print0(s="",**kwargs):
98
+ ddp_rank = int(os.environ.get('RANK', 0))
99
+ if ddp_rank == 0:
100
+ print(s, **kwargs)
101
+
102
+ def print_banner():
103
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
104
+ banner = """
105
+ █████ █████
106
+ ░░███ ░░███
107
+ ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
108
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
109
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
110
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
111
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
112
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
113
+ """
114
+ print0(banner)
115
+
116
+ def is_ddp_requested() -> bool:
117
+ """
118
+ True if launched by torchrun (env present), even before init.
119
+ Used to decide whether we *should* initialize a PG.
120
+ """
121
+ return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
122
+
123
+ def is_ddp_initialized() -> bool:
124
+ """
125
+ True if torch.distributed is available and the process group is initialized.
126
+ Used at cleanup to avoid destroying a non-existent PG.
127
+ """
128
+ return dist.is_available() and dist.is_initialized()
129
+
130
+ def get_dist_info():
131
+ if is_ddp_requested():
132
+ # We rely on torchrun's env to decide if we SHOULD init.
133
+ # (Initialization itself happens in compute init.)
134
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
135
+ ddp_rank = int(os.environ['RANK'])
136
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
137
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
138
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
139
+ else:
140
+ return False, 0, 0, 1
141
+
142
+ def autodetect_device_type():
143
+ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
144
+ if torch.cuda.is_available():
145
+ device_type = "cuda"
146
+ elif torch.backends.mps.is_available():
147
+ device_type = "mps"
148
+ else:
149
+ device_type = "cpu"
150
+ print0(f"Autodetected device type: {device_type}")
151
+ return device_type
152
+
153
+ def compute_init(device_type="cuda"): # cuda|cpu|mps
154
+ """Basic initialization that we keep doing over and over, so make common."""
155
+
156
+ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
157
+ if device_type == "cuda":
158
+ assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
159
+ if device_type == "mps":
160
+ assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
161
+
162
+ # Reproducibility
163
+ # Note that we set the global seeds here, but most of the code uses explicit rng objects.
164
+ # The only place where global rng might be used is nn.Module initialization of the model weights.
165
+ torch.manual_seed(42)
166
+ if device_type == "cuda":
167
+ torch.cuda.manual_seed(42)
168
+ # skipping full reproducibility for now, possibly investigate slowdown later
169
+ # torch.use_deterministic_algorithms(True)
170
+
171
+ # Precision
172
+ if device_type == "cuda":
173
+ torch.backends.cuda.matmul.allow_tf32 = True # uses tf32 instead of fp32 for matmuls
174
+
175
+ # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
176
+ is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
177
+ if is_ddp_requested and device_type == "cuda":
178
+ device = torch.device("cuda", ddp_local_rank)
179
+ torch.cuda.set_device(device) # make "cuda" default to this device
180
+ dist.init_process_group(backend="nccl", device_id=device)
181
+ dist.barrier()
182
+ else:
183
+ device = torch.device(device_type) # mps|cpu
184
+
185
+ if ddp_rank == 0:
186
+ logger.info(f"Distributed world size: {ddp_world_size}")
187
+
188
+ return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
189
+
190
+ def compute_cleanup():
191
+ """Companion function to compute_init, to clean things up before script exit"""
192
+ if is_ddp_initialized():
193
+ dist.destroy_process_group()
194
+
195
+ class DummyWandb:
196
+ """Useful if we wish to not use wandb but have all the same signatures"""
197
+ def __init__(self):
198
+ pass
199
+ def log(self, *args, **kwargs):
200
+ pass
201
+ def finish(self):
202
+ pass
203
+
204
+ # hardcoded BF16 peak flops for various GPUs
205
+ # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
206
+ # and PR: https://github.com/karpathy/nanochat/pull/147
207
+ def get_peak_flops(device_name: str) -> float:
208
+ name = device_name.lower()
209
+
210
+ # --- NVIDIA Blackwell ---
211
+ if "gb200" in name or "grace blackwell" in name:
212
+ return 2.5e15
213
+ if "b200" in name:
214
+ return 2.25e15
215
+ if "b100" in name:
216
+ return 1.8e15
217
+
218
+ # --- NVIDIA Hopper (H100/H200/H800) ---
219
+ if "h200" in name:
220
+ if "nvl" in name or "pcie" in name:
221
+ return 836e12
222
+ return 989e12 # H200 SXM
223
+ if "h100" in name:
224
+ if "nvl" in name:
225
+ return 835e12
226
+ if "pcie" in name:
227
+ return 756e12
228
+ return 989e12 # H100 SXM
229
+ if "h800" in name:
230
+ if "nvl" in name:
231
+ return 989e12
232
+ return 756e12 # H800 PCIe
233
+
234
+ # --- NVIDIA Ampere data center ---
235
+ if "a100" in name or "a800" in name:
236
+ return 312e12
237
+ if "a40" in name:
238
+ return 149.7e12
239
+ if "a30" in name:
240
+ return 165e12
241
+
242
+ # --- NVIDIA Ada data center ---
243
+ if "l40s" in name or "l40-s" in name or "l40 s" in name:
244
+ return 362e12
245
+ if "l4" in name:
246
+ return 121e12
247
+
248
+ # --- AMD CDNA accelerators ---
249
+ if "mi355" in name:
250
+ return 2.5e15
251
+ if "mi325" in name or "mi300x" in name:
252
+ return 1.3074e15
253
+ if "mi300a" in name:
254
+ return 980.6e12
255
+ if "mi250x" in name:
256
+ return 383e12
257
+ if "mi250" in name:
258
+ return 362.1e12
259
+
260
+ # --- Intel ---
261
+ if "data center gpu max 1550" in name:
262
+ # Ponte Vecchio (PVC) - dynamic based on compute units
263
+ max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
264
+ return 512 * max_comp_units * 1300 * 10**6
265
+
266
+ # --- Consumer RTX (for hobbyists) ---
267
+ if "5090" in name:
268
+ return 209.5e12
269
+ if "4090" in name:
270
+ return 165.2e12
271
+ if "3090" in name:
272
+ return 71e12
273
+
274
+ # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
275
+ logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
276
+ return float('inf')
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, num_fewshot)
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distributed dataloaders for pretraining.
3
+
4
+ Two implementations are provided:
5
+
6
+ 1. Original (tokenizing_distributed_data_loader):
7
+ - Streams tokens into a flat buffer, reshapes to (B, T)
8
+ - Rows may start mid-document (no guaranteed BOS at position 0)
9
+ - 100% token utilization, simple and efficient
10
+
11
+ 2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
12
+ - Every row starts with BOS token
13
+ - Documents packed using best-fit algorithm to minimize cropping
14
+ - When no document fits remaining space, crops a document to fill exactly
15
+ - 100% utilization (no padding), ~35% tokens cropped at T=2048
16
+
17
+ The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
18
+ there are fewer "confusing" tokens in the train/val batches as every token can
19
+ now attend back to the BOS token and sees the full context of the document.
20
+ (2) is the new default if you have enough data.
21
+ Fallback to (1) if you have very limited data AND long documents.
22
+ """
23
+
24
+ import torch
25
+ import pyarrow.parquet as pq
26
+
27
+ from nanochat.common import get_dist_info
28
+ from nanochat.dataset import list_parquet_files
29
+
30
+ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
31
+ """
32
+ Infinite iterator over document batches (list of text strings) from parquet files.
33
+
34
+ Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
35
+ where text_batch is a list of document strings, indices track position for resumption,
36
+ and epoch counts how many times we've cycled through the dataset (starts at 1).
37
+ """
38
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
39
+
40
+ parquet_paths = list_parquet_files()
41
+ assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
42
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
43
+
44
+ resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
45
+ resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
46
+ resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
47
+ first_pass = True
48
+ pq_idx = resume_pq_idx
49
+ epoch = resume_epoch
50
+
51
+ while True: # iterate infinitely (multi-epoch)
52
+ pq_idx = resume_pq_idx if first_pass else 0
53
+ while pq_idx < len(parquet_paths):
54
+ filepath = parquet_paths[pq_idx]
55
+ pf = pq.ParquetFile(filepath)
56
+ # Start from resume point if resuming on same file, otherwise from DDP rank
57
+ if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
58
+ base_idx = resume_rg_idx // ddp_world_size
59
+ base_idx += 1 # advance by 1 so we don't repeat data after resuming
60
+ rg_idx = base_idx * ddp_world_size + ddp_rank
61
+ if rg_idx >= pf.num_row_groups:
62
+ pq_idx += 1
63
+ continue
64
+ resume_rg_idx = None # only do this once
65
+ else:
66
+ rg_idx = ddp_rank
67
+ while rg_idx < pf.num_row_groups:
68
+ rg = pf.read_row_group(rg_idx)
69
+ batch = rg.column('text').to_pylist()
70
+ for i in range(0, len(batch), tokenizer_batch_size):
71
+ yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
72
+ rg_idx += ddp_world_size
73
+ pq_idx += 1
74
+ first_pass = False
75
+ epoch += 1
76
+
77
+
78
+ def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
79
+ """
80
+ Stream pretraining text from parquet files, tokenize, yield training batches.
81
+
82
+ This is the original dataloader that streams tokens into a flat buffer and reshapes.
83
+ Rows may start mid-document (no guaranteed BOS at position 0).
84
+
85
+ Supports approximate resume via state_dict.
86
+ """
87
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
88
+
89
+ batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
90
+ needed_tokens = B * T + 1 # +1 for target at last position
91
+ bos_token = tokenizer.get_bos_token_id()
92
+ token_buffer = []
93
+ pq_idx, rg_idx, epoch = 0, 0, 1
94
+
95
+ while True:
96
+
97
+ # Accumulate enough tokens
98
+ while len(token_buffer) < needed_tokens:
99
+ doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
100
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
101
+ for tokens in token_lists:
102
+ token_buffer.extend(tokens)
103
+ tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
104
+ token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
105
+
106
+ # Package tokens into inputs and targets, yield
107
+ use_cuda = device == "cuda"
108
+ scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
109
+ inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
110
+ targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
111
+ yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
112
+
113
+
114
+ def tokenizing_distributed_data_loader(*args, **kwargs):
115
+ """Helper that omits state_dict from yields."""
116
+ for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
117
+ yield inputs, targets
118
+
119
+
120
+ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
121
+ tokenizer, B, T, split,
122
+ tokenizer_threads=4, tokenizer_batch_size=128,
123
+ device="cuda", resume_state_dict=None,
124
+ buffer_size=1000
125
+ ):
126
+ """
127
+ BOS-aligned dataloader with Best-Fit Cropping.
128
+
129
+ Reduces token waste compared to simple greedy cropping by searching a buffer
130
+ for documents that fit well, while maintaining 100% utilization (no padding).
131
+
132
+ Algorithm for each row:
133
+ 1. From buffered docs, pick the LARGEST doc that fits entirely
134
+ 2. Repeat until no doc fits
135
+ 3. When nothing fits, crop a doc to fill remaining space exactly
136
+
137
+ Key properties:
138
+ - Every row starts with BOS
139
+ - 100% utilization (no padding, every token is trained on)
140
+ - Approximately 35% of all tokens are discarded due to cropping
141
+ """
142
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
143
+
144
+ row_capacity = T + 1
145
+ batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
146
+ bos_token = tokenizer.get_bos_token_id()
147
+ doc_buffer = []
148
+ pq_idx, rg_idx, epoch = 0, 0, 1
149
+
150
+ def refill_buffer():
151
+ nonlocal pq_idx, rg_idx, epoch
152
+ doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
153
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
154
+ for tokens in token_lists:
155
+ doc_buffer.append(tokens)
156
+
157
+ while True:
158
+ rows = []
159
+ for _ in range(B):
160
+ row = []
161
+ while len(row) < row_capacity:
162
+ # Ensure buffer has documents
163
+ while len(doc_buffer) < buffer_size:
164
+ refill_buffer()
165
+
166
+ remaining = row_capacity - len(row)
167
+
168
+ # Find largest doc that fits entirely
169
+ best_idx = -1
170
+ best_len = 0
171
+ for i, doc in enumerate(doc_buffer):
172
+ doc_len = len(doc)
173
+ if doc_len <= remaining and doc_len > best_len:
174
+ best_idx = i
175
+ best_len = doc_len
176
+
177
+ if best_idx >= 0:
178
+ doc = doc_buffer.pop(best_idx)
179
+ row.extend(doc)
180
+ else:
181
+ # No doc fits - crop shortest in buffer to fill remaining and minimize waste
182
+ shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
183
+ doc = doc_buffer.pop(shortest_idx)
184
+ row.extend(doc[:remaining])
185
+
186
+ rows.append(row[:row_capacity])
187
+
188
+ use_cuda = device == "cuda"
189
+ batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
190
+ inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
191
+ targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
192
+
193
+ yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
194
+
195
+
196
+ def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
197
+ """Helper that omits state_dict from yields."""
198
+ for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
199
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
24
+ MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data")
28
+ os.makedirs(DATA_DIR, exist_ok=True)
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # These functions are useful utilities to other modules, can/should be imported
32
+
33
+ def list_parquet_files(data_dir=None):
34
+ """ Looks into a data dir and returns full paths to all parquet files. """
35
+ data_dir = DATA_DIR if data_dir is None else data_dir
36
+ parquet_files = sorted([
37
+ f for f in os.listdir(data_dir)
38
+ if f.endswith('.parquet') and not f.endswith('.tmp')
39
+ ])
40
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
41
+ return parquet_paths
42
+
43
+ def parquets_iter_batched(split, start=0, step=1):
44
+ """
45
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
46
+ - split can be "train" or "val". the last parquet file will be val.
47
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
48
+ """
49
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
50
+ parquet_paths = list_parquet_files()
51
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
52
+ for filepath in parquet_paths:
53
+ pf = pq.ParquetFile(filepath)
54
+ for rg_idx in range(start, pf.num_row_groups, step):
55
+ rg = pf.read_row_group(rg_idx)
56
+ texts = rg.column('text').to_pylist()
57
+ yield texts
58
+
59
+ # -----------------------------------------------------------------------------
60
+ def download_single_file(index):
61
+ """ Downloads a single file index, with some backoff """
62
+
63
+ # Construct the local filepath for this file and skip if it already exists
64
+ filename = index_to_filename(index)
65
+ filepath = os.path.join(DATA_DIR, filename)
66
+ if os.path.exists(filepath):
67
+ print(f"Skipping {filepath} (already exists)")
68
+ return True
69
+
70
+ # Construct the remote URL for this file
71
+ url = f"{BASE_URL}/{filename}"
72
+ print(f"Downloading {filename}...")
73
+
74
+ # Download with retries
75
+ max_attempts = 5
76
+ for attempt in range(1, max_attempts + 1):
77
+ try:
78
+ response = requests.get(url, stream=True, timeout=30)
79
+ response.raise_for_status()
80
+ # Write to temporary file first
81
+ temp_path = filepath + f".tmp"
82
+ with open(temp_path, 'wb') as f:
83
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
84
+ if chunk:
85
+ f.write(chunk)
86
+ # Move temp file to final location
87
+ os.rename(temp_path, filepath)
88
+ print(f"Successfully downloaded {filename}")
89
+ return True
90
+
91
+ except (requests.RequestException, IOError) as e:
92
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
93
+ # Clean up any partial files
94
+ for path in [filepath + f".tmp", filepath]:
95
+ if os.path.exists(path):
96
+ try:
97
+ os.remove(path)
98
+ except:
99
+ pass
100
+ # Try a few times with exponential backoff: 2^attempt seconds
101
+ if attempt < max_attempts:
102
+ wait_time = 2 ** attempt
103
+ print(f"Waiting {wait_time} seconds before retry...")
104
+ time.sleep(wait_time)
105
+ else:
106
+ print(f"Failed to download {filename} after {max_attempts} attempts")
107
+ return False
108
+
109
+ return False
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
114
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
115
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
116
+ args = parser.parse_args()
117
+
118
+ num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
119
+ ids_to_download = list(range(num))
120
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
121
+ print(f"Target directory: {DATA_DIR}")
122
+ print()
123
+ with Pool(processes=args.num_workers) as pool:
124
+ results = pool.map(download_single_file, ids_to_download)
125
+
126
+ # Report results
127
+ successful = sum(1 for success in results if success)
128
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/distill_loss.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def compute_distillation_loss(
6
+ student_logits,
7
+ teacher_logits,
8
+ temperature=1.0,
9
+ reduction='mean'
10
+ ):
11
+ """
12
+ Compute KL divergence loss between student and teacher logits.
13
+
14
+ Args:
15
+ student_logits: (B, T, vocab_size) logits from student model
16
+ teacher_logits: (B, T, vocab_size) logits from teacher model
17
+ temperature: Temperature for softmax (higher = softer distribution)
18
+ reduction: 'mean' or 'sum' or 'none'
19
+
20
+ Returns:
21
+ loss: Scalar or (B, T) tensor depending on reduction
22
+ """
23
+ # Apply temperature scaling
24
+ student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
25
+ teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
26
+
27
+ # KL divergence: We use KL(teacher || student) which is more numerically stable
28
+ # KL(teacher || student) = sum(teacher * log(teacher/student))
29
+ # = sum(teacher * log(teacher)) - sum(teacher * log(student))
30
+ # Using F.kl_div: input=log(student), target=teacher, log_target=False
31
+ # This computes: sum(target * (log(target) - input))
32
+ # = sum(teacher * (log(teacher) - log(student))) = KL(teacher || student)
33
+ kl_loss = F.kl_div(
34
+ student_log_probs,
35
+ teacher_probs,
36
+ reduction='none',
37
+ log_target=False
38
+ ) # (B, T, vocab_size)
39
+
40
+ # Sum over vocab dimension
41
+ kl_loss = kl_loss.sum(dim=-1) # (B, T)
42
+
43
+ # Scale by temperature^2 (standard in distillation literature)
44
+ kl_loss = kl_loss * (temperature ** 2)
45
+
46
+ # Sum over vocab dimension, then apply reduction
47
+ kl_loss = kl_loss.sum(dim=-1) # (B, T)
48
+
49
+ if reduction == 'mean':
50
+ return kl_loss.mean()
51
+ elif reduction == 'sum':
52
+ return kl_loss.sum()
53
+ else:
54
+ return kl_loss
55
+
56
+
57
+ def compute_combined_loss(
58
+ student_logits,
59
+ teacher_logits,
60
+ targets,
61
+ temperature=1.0,
62
+ alpha=0.5,
63
+ ignore_index=-1,
64
+ reduction='mean'
65
+ ):
66
+ """
67
+ Combine distillation loss with standard cross-entropy loss.
68
+
69
+ Args:
70
+ student_logits: (B, T, vocab_size) logits from student model
71
+ teacher_logits: (B, T, vocab_size) logits from teacher model
72
+ targets: (B, T) ground truth token ids
73
+ temperature: Temperature for distillation
74
+ alpha: Weight for distillation loss (1-alpha for CE loss)
75
+ ignore_index: Tokens to ignore in CE loss
76
+ reduction: 'mean' or 'sum' or 'none'
77
+
78
+ Returns:
79
+ total_loss: Combined loss
80
+ distill_loss: Distillation loss component
81
+ ce_loss: Cross-entropy loss component
82
+ """
83
+ # Distillation loss
84
+ distill_loss = compute_distillation_loss(
85
+ student_logits,
86
+ teacher_logits,
87
+ temperature=temperature,
88
+ reduction=reduction
89
+ )
90
+
91
+ # Standard cross-entropy loss
92
+ ce_loss = F.cross_entropy(
93
+ student_logits.view(-1, student_logits.size(-1)),
94
+ targets.view(-1),
95
+ ignore_index=ignore_index,
96
+ reduction=reduction
97
+ )
98
+
99
+ # Combine: alpha * distill + (1-alpha) * ce
100
+ if reduction == 'none':
101
+ # For 'none', we need to handle the shape mismatch
102
+ # distill_loss is (B, T), ce_loss is (B*T,)
103
+ ce_loss = ce_loss.view(student_logits.shape[:2])
104
+ total_loss = alpha * distill_loss + (1 - alpha) * ce_loss
105
+ else:
106
+ total_loss = alpha * distill_loss + (1 - alpha) * ce_loss
107
+
108
+ return total_loss, distill_loss, ce_loss
109
+
110
+
111
+ def compute_multi_token_loss(multi_token_logits, targets, ignore_index=-1, reduction='mean'):
112
+ """Train multi-token heads (t+2, t+3, t+4 predictions)"""
113
+ total_loss = 0.0
114
+ count = 0
115
+
116
+ for head_name, logits in multi_token_logits.items():
117
+ offset = int(head_name.split('_')[1]) # "head_2" -> 2
118
+
119
+ # Shift targets: head_2 predicts t+2, so target is y shifted by 1
120
+ if targets.size(1) >= offset:
121
+ shifted_targets = targets[:, offset-1:]
122
+ shifted_logits = logits[:, :targets.size(1)-offset+1, :]
123
+
124
+ if shifted_targets.numel() > 0:
125
+ loss = F.cross_entropy(
126
+ shifted_logits.reshape(-1, shifted_logits.size(-1)),
127
+ shifted_targets.reshape(-1),
128
+ ignore_index=ignore_index,
129
+ reduction=reduction
130
+ )
131
+ total_loss += loss
132
+ count += 1
133
+
134
+ return total_loss / count if count > 0 else torch.tensor(0.0, device=targets.device)
135
+
136
+
137
+ def compute_draft_loss(student_model, x, teacher_logits, temperature=1.0):
138
+ """Train draft head to predict multiple future tokens"""
139
+ if student_model.draft_head is None:
140
+ return torch.tensor(0.0, device=x.device)
141
+
142
+ # Get hidden states from last transformer layer
143
+ from nanochat.gpt import norm
144
+ hidden = student_model.transformer.wte(x)
145
+ hidden = norm(hidden)
146
+ x0 = hidden
147
+
148
+ for i, block in enumerate(student_model.transformer.h):
149
+ hidden = student_model.resid_lambdas[i] * hidden + student_model.x0_lambdas[i] * x0
150
+ ve = student_model.value_embeds[str(i)](x) if str(i) in student_model.value_embeds else None
151
+ cos_sin = student_model.cos[:, :x.size(1)], student_model.sin[:, :x.size(1)]
152
+ hidden = block(hidden, ve, cos_sin, student_model.window_sizes[i], None)
153
+
154
+ hidden = norm(hidden)
155
+ last_hidden = hidden[:, -1, :] # (B, n_embd)
156
+
157
+ # Draft head predicts next N tokens
158
+ draft_logits = student_model.draft_head(last_hidden) # (B, draft_n, vocab)
159
+
160
+ # Match with teacher's future predictions
161
+ B, T, V = teacher_logits.shape
162
+ draft_n = draft_logits.shape[1]
163
+
164
+ total_loss = 0.0
165
+ for i in range(min(draft_n, T-1)):
166
+ draft_pred = draft_logits[:, i, :]
167
+ teacher_future = teacher_logits[:, i+1, :]
168
+
169
+ student_log_probs = F.log_softmax(draft_pred / temperature, dim=-1)
170
+ teacher_probs = F.softmax(teacher_future / temperature, dim=-1)
171
+
172
+ kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean', log_target=False)
173
+ total_loss += kl
174
+
175
+ return total_loss / min(draft_n, T-1) if T > 1 else torch.tensor(0.0, device=x.device)
176
+
nanochat/engine.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init, autodetect_device_type
21
+ from nanochat.checkpoint_manager import load_model
22
+ from contextlib import nullcontext
23
+
24
+ # -----------------------------------------------------------------------------
25
+ # Calculator tool helpers
26
+ @contextmanager
27
+ def timeout(duration, formula):
28
+ def timeout_handler(signum, frame):
29
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
30
+
31
+ signal.signal(signal.SIGALRM, timeout_handler)
32
+ signal.alarm(duration)
33
+ yield
34
+ signal.alarm(0)
35
+
36
+ def eval_with_timeout(formula, max_time=3):
37
+ try:
38
+ with timeout(max_time, formula):
39
+ with warnings.catch_warnings():
40
+ warnings.simplefilter("ignore", SyntaxWarning)
41
+ return eval(formula, {"__builtins__": {}}, {})
42
+ except Exception as e:
43
+ signal.alarm(0)
44
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
45
+ return None
46
+
47
+ def use_calculator(expr):
48
+ """
49
+ Evaluate a Python expression safely.
50
+ Supports both math expressions and string operations like .count()
51
+ """
52
+ # Remove commas from numbers
53
+ expr = expr.replace(",", "")
54
+
55
+ # Check if it's a pure math expression (old behavior)
56
+ if all([x in "0123456789*+-/.() " for x in expr]):
57
+ if "**" in expr: # disallow power operator
58
+ return None
59
+ return eval_with_timeout(expr)
60
+
61
+ # Check if it's a string operation we support
62
+ # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
63
+ allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
64
+ if not all([x in allowed_chars for x in expr]):
65
+ return None
66
+
67
+ # Disallow dangerous patterns
68
+ dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
69
+ 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
70
+ 'getattr', 'setattr', 'delattr', 'hasattr']
71
+ expr_lower = expr.lower()
72
+ if any(pattern in expr_lower for pattern in dangerous_patterns):
73
+ return None
74
+
75
+ # Only allow .count() method for now (can expand later)
76
+ if '.count(' not in expr:
77
+ return None
78
+
79
+ # Evaluate with timeout
80
+ return eval_with_timeout(expr)
81
+
82
+ # -----------------------------------------------------------------------------
83
+ class KVCache:
84
+ """
85
+ KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
86
+
87
+ Key differences from FA2-style cache:
88
+ - Tensors are (B, T, H, D) not (B, H, T, D)
89
+ - FA3 updates the cache in-place during flash_attn_with_kvcache
90
+ - Position tracked per batch element via cache_seqlens tensor
91
+ """
92
+
93
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
94
+ self.batch_size = batch_size
95
+ self.max_seq_len = seq_len
96
+ self.n_layers = num_layers
97
+ self.n_heads = num_heads
98
+ self.head_dim = head_dim
99
+ # Pre-allocate cache tensors: (n_layers, B, T, H, D)
100
+ self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
101
+ self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
102
+ # Current sequence length per batch element (FA3 needs int32)
103
+ self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
104
+
105
+ def reset(self):
106
+ """Reset cache to empty state."""
107
+ self.cache_seqlens.zero_()
108
+
109
+ def get_pos(self):
110
+ """Get current position (assumes all batch elements at same position)."""
111
+ return self.cache_seqlens[0].item()
112
+
113
+ def get_layer_cache(self, layer_idx):
114
+ """Return (k_cache, v_cache) views for a specific layer."""
115
+ return self.k_cache[layer_idx], self.v_cache[layer_idx]
116
+
117
+ def advance(self, num_tokens):
118
+ """Advance the cache position by num_tokens."""
119
+ self.cache_seqlens += num_tokens
120
+
121
+ def prefill(self, other):
122
+ """
123
+ Copy cached KV from another cache into this one.
124
+ Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
125
+ """
126
+ assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
127
+ assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
128
+ assert self.max_seq_len >= other.max_seq_len
129
+ other_pos = other.get_pos()
130
+ self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
131
+ self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
132
+ self.cache_seqlens.fill_(other_pos)
133
+
134
+ # -----------------------------------------------------------------------------
135
+ @torch.inference_mode()
136
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
137
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
138
+ assert temperature >= 0.0, "temperature must be non-negative"
139
+ if temperature == 0.0:
140
+ return torch.argmax(logits, dim=-1, keepdim=True)
141
+ if top_k is not None and top_k > 0:
142
+ k = min(top_k, logits.size(-1))
143
+ vals, idx = torch.topk(logits, k, dim=-1)
144
+ vals = vals / temperature
145
+ probs = F.softmax(vals, dim=-1)
146
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
147
+ return idx.gather(1, choice)
148
+ else:
149
+ logits = logits / temperature
150
+ probs = F.softmax(logits, dim=-1)
151
+ return torch.multinomial(probs, num_samples=1, generator=rng)
152
+
153
+ # -----------------------------------------------------------------------------
154
+
155
+ class RowState:
156
+ # Per-row state tracking during generation
157
+ def __init__(self, current_tokens=None):
158
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
159
+ self.forced_tokens = deque() # Queue of tokens to force inject
160
+ self.in_python_block = False # Whether we are inside a python block
161
+ self.python_expr_tokens = [] # Tokens of the current python expression
162
+ self.completed = False # Whether this row has completed generation
163
+
164
+ class Engine:
165
+
166
+ def __init__(self, model, tokenizer):
167
+ self.model = model
168
+ self.tokenizer = tokenizer # needed for tool use
169
+
170
+ @torch.inference_mode()
171
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
172
+ """Same as generate, but does single prefill and then clones the KV cache."""
173
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
174
+ device = self.model.get_device()
175
+ # NOTE: setting the dtype here and in this way is an ugly hack.
176
+ # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
177
+ # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
178
+ # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
179
+ # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
180
+ # In particular, the KVCache should allocate its tensors lazily
181
+ dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
182
+ rng = torch.Generator(device=device)
183
+ rng.manual_seed(seed)
184
+
185
+ # Get the special tokens we need to coordinate the tool use state machine
186
+ get_special = lambda s: self.tokenizer.encode_special(s)
187
+ python_start = get_special("<|python_start|>")
188
+ python_end = get_special("<|python_end|>")
189
+ output_start = get_special("<|output_start|>")
190
+ output_end = get_special("<|output_end|>")
191
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
192
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
193
+
194
+ # 1) Run a batch 1 prefill of the prompt tokens
195
+ m = self.model.config
196
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
197
+ kv_cache_prefill = KVCache(
198
+ batch_size=1,
199
+ seq_len=len(tokens),
200
+ device=device,
201
+ dtype=dtype,
202
+ **kv_model_kwargs,
203
+ )
204
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
205
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
206
+ logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
207
+
208
+ # 2) Replicate the KV cache for each sample/row
209
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
210
+ kv_cache_decode = KVCache(
211
+ batch_size=num_samples,
212
+ seq_len=kv_length_hint,
213
+ device=device,
214
+ dtype=dtype,
215
+ **kv_model_kwargs,
216
+ )
217
+ kv_cache_decode.prefill(kv_cache_prefill)
218
+ del kv_cache_prefill # no need to keep this memory around
219
+
220
+ # 3) Initialize states for each sample
221
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
222
+
223
+ # 4) Main generation loop
224
+ num_generated = 0
225
+ while True:
226
+ # Stop condition: we've reached max tokens
227
+ if max_tokens is not None and num_generated >= max_tokens:
228
+ break
229
+ # Stop condition: all rows are completed
230
+ if all(state.completed for state in row_states):
231
+ break
232
+
233
+ # Sample the next token for each row
234
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
235
+ sampled_tokens = next_ids[:, 0].tolist()
236
+
237
+ # Process each row: choose the next token, update state, optional tool use
238
+ token_column = [] # contains the next token id along each row
239
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
240
+ for i, state in enumerate(row_states):
241
+ # Select the next token in this row
242
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
243
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
244
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
245
+ token_column.append(next_token)
246
+ # Update the state of this row to include the next token
247
+ state.current_tokens.append(next_token)
248
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
249
+ if next_token == assistant_end or next_token == bos:
250
+ state.completed = True
251
+ # Handle tool logic
252
+ if next_token == python_start:
253
+ state.in_python_block = True
254
+ state.python_expr_tokens = []
255
+ elif next_token == python_end and state.in_python_block:
256
+ state.in_python_block = False
257
+ if state.python_expr_tokens:
258
+ expr = self.tokenizer.decode(state.python_expr_tokens)
259
+ result = use_calculator(expr)
260
+ if result is not None:
261
+ result_tokens = self.tokenizer.encode(str(result))
262
+ state.forced_tokens.append(output_start)
263
+ state.forced_tokens.extend(result_tokens)
264
+ state.forced_tokens.append(output_end)
265
+ state.python_expr_tokens = []
266
+ elif state.in_python_block:
267
+ state.python_expr_tokens.append(next_token)
268
+
269
+ # Yield the token column
270
+ yield token_column, token_masks
271
+ num_generated += 1
272
+
273
+ # Prepare logits for next iteration
274
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
275
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
276
+
277
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
278
+ """
279
+ Non-streaming batch generation that just returns the final token sequences.
280
+ Returns a list of token sequences (list of lists of ints).
281
+ Terminal tokens (assistant_end, bos) are not included in the results.
282
+ """
283
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
284
+ bos = self.tokenizer.get_bos_token_id()
285
+ results = [tokens.copy() for _ in range(num_samples)]
286
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
287
+ completed = [False] * num_samples
288
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
289
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
290
+ if not completed[i]:
291
+ if token == assistant_end or token == bos:
292
+ completed[i] = True
293
+ else:
294
+ results[i].append(token)
295
+ masks[i].append(mask)
296
+ # Stop if all rows are completed
297
+ if all(completed):
298
+ break
299
+ return results, masks
300
+
301
+
302
+ if __name__ == "__main__":
303
+ """
304
+ Quick inline test to make sure that the naive/slow model.generate function
305
+ is equivalent to the faster Engine.generate function here.
306
+ """
307
+ import time
308
+ # init compute
309
+ device_type = autodetect_device_type()
310
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
311
+ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
312
+
313
+ # load the model and tokenizer
314
+ model, tokenizer, meta = load_model("base", device, phase="eval")
315
+ bos_token_id = tokenizer.get_bos_token_id()
316
+ # common hyperparameters
317
+ kwargs = dict(max_tokens=64, temperature=0.0)
318
+ # set the starting prompt
319
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
320
+ # generate the reference sequence using the model.generate() function
321
+ generated_tokens = []
322
+ torch.cuda.synchronize()
323
+ t0 = time.time()
324
+ stream = model.generate(prompt_tokens, **kwargs)
325
+ with autocast_ctx:
326
+ for token in stream:
327
+ generated_tokens.append(token)
328
+ chunk = tokenizer.decode([token])
329
+ print(chunk, end="", flush=True)
330
+ print()
331
+ torch.cuda.synchronize()
332
+ t1 = time.time()
333
+ print(f"Reference time: {t1 - t0:.2f}s")
334
+ reference_ids = generated_tokens
335
+ # generate tokens with Engine
336
+ generated_tokens = []
337
+ engine = Engine(model, tokenizer)
338
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
339
+ torch.cuda.synchronize()
340
+ t0 = time.time()
341
+ with autocast_ctx:
342
+ for token_column, token_masks in stream:
343
+ token = token_column[0] # only print out the first row
344
+ generated_tokens.append(token)
345
+ chunk = tokenizer.decode([token])
346
+ print(chunk, end="", flush=True)
347
+ print()
348
+ torch.cuda.synchronize()
349
+ t1 = time.time()
350
+ print(f"Engine time: {t1 - t0:.2f}s")
351
+ # compare the two sequences
352
+ for i in range(len(reference_ids)):
353
+ if reference_ids[i] != generated_tokens[i]:
354
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
355
+ break
356
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ finally:
131
+ os.chdir(cwd)
132
+
133
+
134
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
135
+ """
136
+ This disables various destructive functions and prevents the generated code
137
+ from interfering with the test (e.g. fork bomb, killing other processes,
138
+ removing filesystem files, etc.)
139
+
140
+ WARNING
141
+ This function is NOT a security sandbox. Untrusted code, including, model-
142
+ generated code, should not be blindly executed outside of one. See the
143
+ Codex paper for more information about OpenAI's code sandbox, and proceed
144
+ with caution.
145
+ """
146
+
147
+ if platform.uname().system != "Darwin":
148
+ # These resource limit calls seem to fail on macOS (Darwin), skip?
149
+ import resource
150
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
151
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
152
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
153
+
154
+ faulthandler.disable()
155
+
156
+ import builtins
157
+
158
+ builtins.exit = None
159
+ builtins.quit = None
160
+
161
+ import os
162
+
163
+ os.environ["OMP_NUM_THREADS"] = "1"
164
+
165
+ os.kill = None
166
+ os.system = None
167
+ os.putenv = None
168
+ os.remove = None
169
+ os.removedirs = None
170
+ os.rmdir = None
171
+ os.fchdir = None
172
+ os.setuid = None
173
+ os.fork = None
174
+ os.forkpty = None
175
+ os.killpg = None
176
+ os.rename = None
177
+ os.renames = None
178
+ os.truncate = None
179
+ os.replace = None
180
+ os.unlink = None
181
+ os.fchmod = None
182
+ os.fchown = None
183
+ os.chmod = None
184
+ os.chown = None
185
+ os.chroot = None
186
+ os.fchdir = None
187
+ os.lchflags = None
188
+ os.lchmod = None
189
+ os.lchown = None
190
+ os.getcwd = None
191
+ os.chdir = None
192
+
193
+ import shutil
194
+
195
+ shutil.rmtree = None
196
+ shutil.move = None
197
+ shutil.chown = None
198
+
199
+ import subprocess
200
+
201
+ subprocess.Popen = None # type: ignore
202
+
203
+ __builtins__["help"] = None
204
+
205
+ import sys
206
+
207
+ sys.modules["ipdb"] = None
208
+ sys.modules["joblib"] = None
209
+ sys.modules["resource"] = None
210
+ sys.modules["psutil"] = None
211
+ sys.modules["tkinter"] = None
212
+
213
+
214
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
215
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
216
+ with create_tempdir():
217
+
218
+ # These system calls are needed when cleaning up tempdir.
219
+ import os
220
+ import shutil
221
+
222
+ rmtree = shutil.rmtree
223
+ rmdir = os.rmdir
224
+ chdir = os.chdir
225
+ unlink = os.unlink
226
+
227
+ # Disable functionalities that can make destructive changes to the test.
228
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
229
+
230
+ # Default to failure
231
+ result_dict.update({
232
+ "success": False,
233
+ "stdout": "",
234
+ "stderr": "",
235
+ "timeout": False,
236
+ "memory_exceeded": False,
237
+ "error": None,
238
+ })
239
+
240
+ try:
241
+ exec_globals = {}
242
+ with capture_io() as (stdout_capture, stderr_capture):
243
+ with time_limit(timeout):
244
+ # WARNING
245
+ # This program exists to execute untrusted model-generated code. Although
246
+ # it is highly unlikely that model-generated code will do something overtly
247
+ # malicious in response to this test suite, model-generated code may act
248
+ # destructively due to a lack of model capability or alignment.
249
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
250
+ # does not perform destructive actions on their host or network. For more
251
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
252
+ # Once you have read this disclaimer and taken appropriate precautions,
253
+ # uncomment the following line and proceed at your own risk:
254
+ exec(code, exec_globals)
255
+
256
+ result_dict.update({
257
+ "success": True,
258
+ "stdout": stdout_capture.getvalue(),
259
+ "stderr": stderr_capture.getvalue(),
260
+ })
261
+
262
+ except TimeoutException:
263
+ result_dict.update({
264
+ "timeout": True,
265
+ "error": "Execution timed out",
266
+ })
267
+
268
+ except MemoryError as e:
269
+ result_dict.update({
270
+ "memory_exceeded": True,
271
+ "error": f"Memory limit exceeded: {e}",
272
+ })
273
+
274
+ except BaseException as e:
275
+ result_dict.update({
276
+ "error": f"{type(e).__name__}: {e}",
277
+ })
278
+
279
+ # Needed for cleaning up.
280
+ shutil.rmtree = rmtree
281
+ os.rmdir = rmdir
282
+ os.chdir = chdir
283
+ os.unlink = unlink
284
+
285
+
286
+ def execute_code(
287
+ code: str,
288
+ timeout: float = 5.0, # 5 seconds default
289
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
290
+ ) -> ExecutionResult:
291
+ """
292
+ Execute Python code in a sandboxed environment.
293
+
294
+ Args:
295
+ code: Python code to execute as a string
296
+ timeout: Maximum execution time in seconds (default: 5.0)
297
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
298
+
299
+ Returns:
300
+ ExecutionResult with success status, stdout/stderr, and error information
301
+
302
+ Example:
303
+ >>> result = execute_code("print('hello world')")
304
+ >>> result.success
305
+ True
306
+ >>> result.stdout
307
+ 'hello world\\n'
308
+ """
309
+
310
+ manager = multiprocessing.Manager()
311
+ result_dict = manager.dict()
312
+
313
+ p = multiprocessing.Process(
314
+ target=_unsafe_execute,
315
+ args=(code, timeout, maximum_memory_bytes, result_dict)
316
+ )
317
+ p.start()
318
+ p.join(timeout=timeout + 1)
319
+
320
+ if p.is_alive():
321
+ p.kill()
322
+ return ExecutionResult(
323
+ success=False,
324
+ stdout="",
325
+ stderr="",
326
+ error="Execution timed out (process killed)",
327
+ timeout=True,
328
+ memory_exceeded=False,
329
+ )
330
+
331
+ if not result_dict:
332
+ return ExecutionResult(
333
+ success=False,
334
+ stdout="",
335
+ stderr="",
336
+ error="Execution failed (no result returned)",
337
+ timeout=True,
338
+ memory_exceeded=False,
339
+ )
340
+
341
+ return ExecutionResult(
342
+ success=result_dict["success"],
343
+ stdout=result_dict["stdout"],
344
+ stderr=result_dict["stderr"],
345
+ error=result_dict["error"],
346
+ timeout=result_dict["timeout"],
347
+ memory_exceeded=result_dict["memory_exceeded"],
348
+ )
349
+
nanochat/flash_attention.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Flash Attention interface with automatic FA3/SDPA switching.
3
+
4
+ Exports `flash_attn` module that matches the FA3 API exactly, but falls back
5
+ to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
6
+
7
+ Usage (drop-in replacement for FA3):
8
+ from nanochat.flash_attention import flash_attn
9
+
10
+ # Training (no KV cache)
11
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
12
+
13
+ # Inference (with KV cache)
14
+ y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
15
+ """
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # =============================================================================
21
+ # Detection: Try to load FA3 on Hopper+ GPUs
22
+ # =============================================================================
23
+ def _load_flash_attention_3():
24
+ """Try to load Flash Attention 3 (requires Hopper+ GPU)."""
25
+ if not torch.cuda.is_available():
26
+ return None
27
+ try:
28
+ major, _ = torch.cuda.get_device_capability()
29
+ if major < 9: # Hopper is sm90
30
+ return None
31
+ import os
32
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
33
+ from kernels import get_kernel
34
+ return get_kernel('varunneal/flash-attention-3').flash_attn_interface
35
+ except Exception:
36
+ return None
37
+
38
+
39
+ _fa3 = _load_flash_attention_3()
40
+ HAS_FA3 = _fa3 is not None
41
+
42
+ # Override for testing: set to 'fa3', 'sdpa', or None (auto)
43
+ _override_impl = None
44
+
45
+
46
+ def _use_fa3():
47
+ """Determine whether to use FA3 based on availability and override."""
48
+ if _override_impl == 'fa3':
49
+ assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
50
+ return True
51
+ if _override_impl == 'sdpa':
52
+ return False
53
+ return HAS_FA3 # auto
54
+
55
+
56
+ # =============================================================================
57
+ # SDPA helpers
58
+ # =============================================================================
59
+ def _sdpa_attention(q, k, v, window_size, enable_gqa):
60
+ """
61
+ SDPA attention with sliding window support.
62
+ q, k, v are (B, H, T, D) format.
63
+ """
64
+ Tq = q.size(2)
65
+ Tk = k.size(2)
66
+ window = window_size[0]
67
+
68
+ # Full context, same length
69
+ if (window < 0 or window >= Tq) and Tq == Tk:
70
+ return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
71
+
72
+ # Single token generation
73
+ if Tq == 1:
74
+ return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
75
+
76
+ # Need explicit mask
77
+ device = q.device
78
+ if Tq == Tk:
79
+ # Causal + sliding window
80
+ mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
81
+ if window > 0 and window < Tq:
82
+ row_idx = torch.arange(Tq, device=device).unsqueeze(1)
83
+ col_idx = torch.arange(Tk, device=device).unsqueeze(0)
84
+ mask = mask & ((row_idx - col_idx) <= window)
85
+ else:
86
+ # Chunk inference: attend to prefix + causal within chunk
87
+ prefix_len = Tk - Tq
88
+ mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
89
+ mask[:, :prefix_len] = True
90
+ mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
91
+
92
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
93
+
94
+
95
+ # =============================================================================
96
+ # Public API: Same interface as FA3
97
+ # =============================================================================
98
+ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
99
+ """
100
+ Flash Attention for training (no KV cache).
101
+
102
+ Args:
103
+ q, k, v: Tensors of shape (B, T, H, D)
104
+ causal: Whether to use causal masking
105
+ window_size: (left, right) sliding window. -1 means unlimited.
106
+
107
+ Returns:
108
+ Output tensor of shape (B, T, H, D)
109
+ """
110
+ if _use_fa3():
111
+ return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
112
+
113
+ # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
114
+ q = q.transpose(1, 2)
115
+ k = k.transpose(1, 2)
116
+ v = v.transpose(1, 2)
117
+ enable_gqa = q.size(1) != k.size(1)
118
+ y = _sdpa_attention(q, k, v, window_size, enable_gqa)
119
+ return y.transpose(1, 2) # back to (B, T, H, D)
120
+
121
+
122
+ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
123
+ causal=False, window_size=(-1, -1)):
124
+ """
125
+ Flash Attention with KV cache for inference.
126
+
127
+ FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
128
+
129
+ Args:
130
+ q: Queries, shape (B, T_new, H, D)
131
+ k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
132
+ k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
133
+ cache_seqlens: Current position in cache, shape (B,) int32
134
+ causal: Whether to use causal masking
135
+ window_size: (left, right) sliding window. -1 means unlimited.
136
+
137
+ Returns:
138
+ Output tensor of shape (B, T_new, H, D)
139
+ """
140
+ if _use_fa3():
141
+ return _fa3.flash_attn_with_kvcache(
142
+ q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
143
+ causal=causal, window_size=window_size
144
+ )
145
+
146
+ # SDPA fallback: manually manage KV cache
147
+ B, T_new, H, D = q.shape
148
+ pos = cache_seqlens[0].item() # assume uniform position across batch
149
+
150
+ # Insert new k, v into cache (in-place, matching FA3 behavior)
151
+ if k is not None and v is not None:
152
+ k_cache[:, pos:pos+T_new, :, :] = k
153
+ v_cache[:, pos:pos+T_new, :, :] = v
154
+
155
+ # Get full cache up to current position + new tokens
156
+ end_pos = pos + T_new
157
+ k_full = k_cache[:, :end_pos, :, :]
158
+ v_full = v_cache[:, :end_pos, :, :]
159
+
160
+ # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
161
+ q_sdpa = q.transpose(1, 2)
162
+ k_sdpa = k_full.transpose(1, 2)
163
+ v_sdpa = v_full.transpose(1, 2)
164
+
165
+ enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
166
+ y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
167
+
168
+ return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
169
+
170
+
171
+ # =============================================================================
172
+ # Export: flash_attn module interface (drop-in replacement for FA3)
173
+ # =============================================================================
174
+ from types import SimpleNamespace
175
+ flash_attn = SimpleNamespace(
176
+ flash_attn_func=flash_attn_func,
177
+ flash_attn_with_kvcache=flash_attn_with_kvcache,
178
+ )
nanochat/gpt.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Group-Query Attention (GQA) support for more efficient inference
12
+ - Multi-Query Attention (MQA) option for maximum KV cache compression
13
+ - Multi-Token Prediction heads for improved training signal
14
+ - Flash Attention 3 integration
15
+ """
16
+
17
+ from functools import partial
18
+ from dataclasses import dataclass
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from nanochat.common import get_dist_info, print0
25
+ from nanochat.muon import Muon, DistMuon
26
+ from nanochat.adamw import DistAdamW
27
+
28
+ # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
29
+ from nanochat.flash_attention import flash_attn
30
+
31
+ @dataclass
32
+ class GPTConfig:
33
+ sequence_len: int = 2048
34
+ vocab_size: int = 32768
35
+ n_layer: int = 12
36
+ n_head: int = 6 # number of query heads
37
+ n_kv_head: int = 6 # number of key/value heads (GQA)
38
+ n_embd: int = 768
39
+ # Sliding window attention pattern string, tiled across layers. Final layer always L.
40
+ # Characters: L=long (full context), S=short (half context)
41
+ # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
42
+ window_pattern: str = "SSSL"
43
+ # Multi-Query Attention: use single KV head for all query heads that can be shared across multiple query heads,
44
+ # Reduces KV cache by n_head times
45
+ use_mqa: bool = False
46
+ # Multi-Token Prediction: extra heads predicting future tokens (t+2, t+3, t+4)
47
+ # Improves training signal and enables speculative decoding
48
+ multi_token_n: int = 3 # predicts 3 future tokens (t+2, t+3, t+4)
49
+ # Draft Head for self-draft speculative decoding
50
+ # Lightweight MLP that predicts multiple tokens at once for fast drafting
51
+ draft_n: int = 4 # number of tokens to draft in one shot
52
+ draft_hidden_mult: float = 0.5 # draft head hidden dim = n_embd * mult (smaller = faster)
53
+
54
+ def __post_init__(self):
55
+ if self.use_mqa:
56
+ self.n_kv_head = 1
57
+
58
+
59
+ def norm(x):
60
+ # Purely functional rmsnorm with no learnable params
61
+ return F.rms_norm(x, (x.size(-1),))
62
+
63
+
64
+ def has_ve(layer_idx, n_layer):
65
+ """Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
66
+ return layer_idx % 2 == (n_layer - 1) % 2
67
+
68
+ def apply_rotary_emb(x, cos, sin):
69
+ assert x.ndim == 4 # multihead attention
70
+ d = x.shape[3] // 2
71
+ x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
72
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
73
+ y2 = x1 * (-sin) + x2 * cos
74
+ return torch.cat([y1, y2], 3)
75
+
76
+
77
+ class DraftHead(nn.Module):
78
+ """
79
+ Lightweight MLP head for self-draft speculative decoding.
80
+ Predicts multiple tokens at once from the last hidden state.
81
+
82
+ During inference:
83
+ 1. Draft head quickly predicts N draft tokens
84
+ 2. Main model verifies all N tokens in one parallel forward pass
85
+ 3. Accept verified tokens, resample where draft was wrong
86
+
87
+ This amortizes the cost of autoregressive decoding.
88
+ """
89
+ def __init__(self, n_embd, vocab_size, draft_n, hidden_mult=0.5):
90
+ super().__init__()
91
+ self.draft_n = draft_n
92
+ hidden_dim = int(n_embd * hidden_mult)
93
+ # 2-layer MLP: hidden layer + output layer predicting draft_n * vocab_size
94
+ self.fc1 = nn.Linear(n_embd, hidden_dim, bias=False)
95
+ self.fc2 = nn.Linear(hidden_dim, draft_n * vocab_size, bias=False)
96
+ self.vocab_size = vocab_size
97
+
98
+ def forward(self, x):
99
+ """
100
+ Args:
101
+ x: hidden states (B, T, n_embd) or (B, n_embd) for single position
102
+ Returns:
103
+ draft_logits: (B, T, draft_n, vocab_size) or (B, draft_n, vocab_size)
104
+ """
105
+ squeeze = x.dim() == 2
106
+ if squeeze:
107
+ x = x.unsqueeze(1) # (B, 1, n_embd)
108
+
109
+ B, T, _ = x.shape
110
+ h = F.relu(self.fc1(x)) ** 2 # ReLU² like the main MLP
111
+ out = self.fc2(h) # (B, T, draft_n * vocab_size)
112
+ out = out.view(B, T, self.draft_n, self.vocab_size)
113
+
114
+ if squeeze:
115
+ out = out.squeeze(1) # (B, draft_n, vocab_size)
116
+ return out
117
+
118
+
119
+ class CausalSelfAttention(nn.Module):
120
+ def __init__(self, config, layer_idx):
121
+ super().__init__()
122
+ self.layer_idx = layer_idx
123
+ self.n_head = config.n_head
124
+ self.n_kv_head = config.n_kv_head
125
+ self.n_embd = config.n_embd
126
+ self.head_dim = self.n_embd // self.n_head
127
+ assert self.n_embd % self.n_head == 0
128
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
129
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
130
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
131
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
132
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
133
+ self.ve_gate_channels = 32
134
+ self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
135
+
136
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
137
+ B, T, C = x.size()
138
+
139
+ # Project the input to get queries, keys, and values
140
+ # Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
141
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
142
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
143
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
144
+
145
+ # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
146
+ if ve is not None:
147
+ ve = ve.view(B, T, self.n_kv_head, self.head_dim)
148
+ gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
149
+ v = v + gate.unsqueeze(-1) * ve
150
+
151
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
152
+ cos, sin = cos_sin
153
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
154
+ q, k = norm(q), norm(k) # QK norm
155
+
156
+ # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
157
+ # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
158
+ if kv_cache is None:
159
+ # Training: causal attention with optional sliding window
160
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
161
+ else:
162
+ # Inference: use flash_attn_with_kvcache which handles cache management
163
+ k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
164
+ y = flash_attn.flash_attn_with_kvcache(
165
+ q, k_cache, v_cache,
166
+ k=k, v=v,
167
+ cache_seqlens=kv_cache.cache_seqlens,
168
+ causal=True,
169
+ window_size=window_size,
170
+ )
171
+ # Advance position after last layer processes
172
+ if self.layer_idx == kv_cache.n_layers - 1:
173
+ kv_cache.advance(T)
174
+
175
+ # Re-assemble the heads and project back to residual stream
176
+ y = y.contiguous().view(B, T, -1)
177
+ y = self.c_proj(y)
178
+ return y
179
+
180
+
181
+ class MLP(nn.Module):
182
+ def __init__(self, config):
183
+ super().__init__()
184
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
185
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
186
+
187
+ def forward(self, x):
188
+ x = self.c_fc(x)
189
+ x = F.relu(x).square()
190
+ x = self.c_proj(x)
191
+ return x
192
+
193
+
194
+ class Block(nn.Module):
195
+ def __init__(self, config, layer_idx):
196
+ super().__init__()
197
+ self.attn = CausalSelfAttention(config, layer_idx)
198
+ self.mlp = MLP(config)
199
+
200
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
201
+ x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
202
+ x = x + self.mlp(norm(x))
203
+ return x
204
+
205
+
206
+ class GPT(nn.Module):
207
+ def __init__(self, config, pad_vocab_size_to=64):
208
+ """
209
+ NOTE a major footgun: this __init__ function runs in meta device context (!!)
210
+ Therefore, any calculations inside here are shapes and dtypes only, no actual data.
211
+ => We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
212
+ """
213
+ super().__init__()
214
+ self.config = config
215
+ # Compute per-layer window sizes for sliding window attention
216
+ # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
217
+ self.window_sizes = self._compute_window_sizes(config)
218
+ # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
219
+ # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
220
+ padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
221
+ if padded_vocab_size != config.vocab_size:
222
+ print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
223
+ self.transformer = nn.ModuleDict({
224
+ "wte": nn.Embedding(padded_vocab_size, config.n_embd),
225
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
226
+ })
227
+ self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
228
+ # Multi-token prediction heads: predict tokens at t+2, t+3, etc.
229
+ self.multi_token_heads = nn.ModuleDict()
230
+ for i in range(config.multi_token_n):
231
+ self.multi_token_heads[f"head_{i+2}"] = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
232
+ # Draft head for self-draft speculative decoding
233
+ self.draft_head = None
234
+ if config.draft_n > 0:
235
+ self.draft_head = DraftHead(
236
+ n_embd=config.n_embd,
237
+ vocab_size=config.vocab_size, # use actual vocab, not padded
238
+ draft_n=config.draft_n,
239
+ hidden_mult=config.draft_hidden_mult
240
+ )
241
+ # Per-layer learnable scalars (inspired by modded-nanogpt)
242
+ # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
243
+ # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
244
+ # Separate parameters so they can have different optimizer treatment
245
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
246
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
247
+ # Value embeddings (ResFormer-style): alternating layers, last layer always included
248
+ head_dim = config.n_embd // config.n_head
249
+ kv_dim = config.n_kv_head * head_dim
250
+ self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
251
+ # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
252
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
253
+ # so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
254
+ # In the future we can dynamically grow the cache, for now it's fine.
255
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
256
+ head_dim = config.n_embd // config.n_head
257
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
258
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
259
+ self.register_buffer("sin", sin, persistent=False)
260
+
261
+ @torch.no_grad()
262
+ def init_weights(self):
263
+ """
264
+ Initialize the full model in this one function for maximum clarity.
265
+
266
+ wte (embedding): normal, std=1.0
267
+ lm_head: normal, std=0.001
268
+ for each block:
269
+ attn.c_q: uniform, std=1/sqrt(n_embd)
270
+ attn.c_k: uniform, std=1/sqrt(n_embd)
271
+ attn.c_v: uniform, std=1/sqrt(n_embd)
272
+ attn.c_proj: zeros
273
+ mlp.c_fc: uniform, std=1/sqrt(n_embd)
274
+ mlp.c_proj: zeros
275
+ """
276
+
277
+ # Embedding and unembedding
278
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
279
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
280
+ # Multi-token prediction heads (same init as lm_head)
281
+ for head in self.multi_token_heads.values():
282
+ torch.nn.init.normal_(head.weight, mean=0.0, std=0.001)
283
+ # Draft head: small std for fc1 (like other projections), zeros for fc2 (starts neutral)
284
+ if self.draft_head is not None:
285
+ torch.nn.init.normal_(self.draft_head.fc1.weight, mean=0.0, std=self.config.n_embd**-0.5)
286
+ torch.nn.init.zeros_(self.draft_head.fc2.weight) # start with zero output
287
+
288
+ # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
289
+ n_embd = self.config.n_embd
290
+ s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
291
+ for block in self.transformer.h:
292
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
293
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
294
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
295
+ torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
296
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
297
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
298
+
299
+ # Per-layer scalars
300
+ self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
301
+ self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
302
+
303
+ # Value embeddings (init like c_v: uniform with same std)
304
+ for ve in self.value_embeds.values():
305
+ torch.nn.init.uniform_(ve.weight, -s, s)
306
+
307
+ # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
308
+ for block in self.transformer.h:
309
+ if block.attn.ve_gate is not None:
310
+ torch.nn.init.zeros_(block.attn.ve_gate.weight)
311
+
312
+ # Rotary embeddings
313
+ head_dim = self.config.n_embd // self.config.n_head
314
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
315
+ self.cos, self.sin = cos, sin
316
+
317
+ # Cast embeddings to bf16: optimizer can tolerate it and it saves memory
318
+ if self.transformer.wte.weight.device.type == "cuda":
319
+ self.transformer.wte.to(dtype=torch.bfloat16)
320
+ for ve in self.value_embeds.values():
321
+ ve.to(dtype=torch.bfloat16)
322
+
323
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
324
+ # TODO: bump base theta more? e.g. 100K is more common more recently
325
+ # autodetect the device from model embeddings
326
+ if device is None:
327
+ device = self.transformer.wte.weight.device
328
+ # stride the channels
329
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
330
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
331
+ # stride the time steps
332
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
333
+ # calculate the rotation frequencies at each (time, channel) pair
334
+ freqs = torch.outer(t, inv_freq)
335
+ cos, sin = freqs.cos(), freqs.sin()
336
+ cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
337
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
338
+ return cos, sin
339
+
340
+ def _compute_window_sizes(self, config):
341
+ """
342
+ Compute per-layer window sizes for sliding window attention.
343
+
344
+ Returns list of (left, right) tuples for FA3's window_size parameter:
345
+ - left: how many tokens before current position to attend to (-1 = unlimited)
346
+ - right: how many tokens after current position to attend to (0 for causal)
347
+
348
+ Pattern string is tiled across layers. Final layer always gets L (full context).
349
+ Characters: L=long (full context), S=short (half context)
350
+ """
351
+ pattern = config.window_pattern.upper()
352
+ assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
353
+ # Map characters to window sizes
354
+ long_window = config.sequence_len
355
+ short_window = long_window // 2
356
+ char_to_window = {
357
+ "L": (long_window, 0),
358
+ "S": (short_window, 0),
359
+ }
360
+ # Tile pattern across layers
361
+ window_sizes = []
362
+ for layer_idx in range(config.n_layer):
363
+ char = pattern[layer_idx % len(pattern)]
364
+ window_sizes.append(char_to_window[char])
365
+ # Final layer always gets full context
366
+ window_sizes[-1] = (long_window, 0)
367
+ return window_sizes
368
+
369
+ def get_device(self):
370
+ return self.transformer.wte.weight.device
371
+
372
+ def estimate_flops(self):
373
+ """
374
+ Return the estimated FLOPs per token for the model (forward + backward).
375
+ Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
376
+ Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
377
+ On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
378
+ With sliding windows, effective_seq_len varies per layer (capped by window size).
379
+ Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
380
+ This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
381
+ - Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
382
+ - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
383
+ """
384
+ nparams = sum(p.numel() for p in self.parameters())
385
+ # Exclude non-matmul params: embeddings and per-layer scalars
386
+ value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
387
+ nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
388
+ self.resid_lambdas.numel() + self.x0_lambdas.numel())
389
+ h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
390
+ # Sum attention FLOPs per layer, accounting for sliding window
391
+ attn_flops = 0
392
+ for window_size in self.window_sizes:
393
+ window = window_size[0] # (left, right) tuple, we use left
394
+ effective_seq = t if window < 0 else min(window, t)
395
+ attn_flops += 12 * h * q * effective_seq
396
+ num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
397
+ return num_flops_per_token
398
+
399
+ def num_scaling_params(self):
400
+ """
401
+ Return all of the parameters, same as Chinchilla paper.
402
+ Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
403
+ But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
404
+ My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
405
+ Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
406
+ Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
407
+ """
408
+ nparams = sum(p.numel() for p in self.parameters())
409
+ return nparams
410
+
411
+ def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
412
+ model_dim = self.config.n_embd
413
+ ddp, rank, local_rank, world_size = get_dist_info()
414
+ # Separate out all parameters into groups
415
+ matrix_params = list(self.transformer.h.parameters())
416
+ value_embeds_params = list(self.value_embeds.parameters())
417
+ embedding_params = list(self.transformer.wte.parameters())
418
+ lm_head_params = list(self.lm_head.parameters())
419
+ multi_token_params = list(self.multi_token_heads.parameters())
420
+ draft_head_params = list(self.draft_head.parameters()) if self.draft_head is not None else []
421
+ resid_params = [self.resid_lambdas]
422
+ x0_params = [self.x0_lambdas]
423
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(multi_token_params) + len(draft_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
424
+ # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
425
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
426
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
427
+ print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
428
+ adam_groups = [
429
+ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
430
+ dict(params=multi_token_params, lr=unembedding_lr * dmodel_lr_scale), # same LR as lm_head
431
+ dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
432
+ dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
433
+ dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
434
+ dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
435
+ ]
436
+ # Add draft head params if present
437
+ if draft_head_params:
438
+ adam_groups.insert(2, dict(params=draft_head_params, lr=unembedding_lr * dmodel_lr_scale))
439
+ adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
440
+ AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
441
+ adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
442
+ # Create the Muon optimizer for the linear layers
443
+ muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
444
+ MuonFactory = DistMuon if ddp else Muon
445
+ muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
446
+ # Combine them the two optimizers into one list
447
+ optimizers = [adamw_optimizer, muon_optimizer]
448
+ for opt in optimizers:
449
+ for group in opt.param_groups:
450
+ group["initial_lr"] = group["lr"]
451
+ return optimizers
452
+
453
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean', return_multi_token=False):
454
+ B, T = idx.size()
455
+
456
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
457
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
458
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
459
+ assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
460
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
461
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
462
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
463
+
464
+ # Forward the trunk of the Transformer
465
+ x = self.transformer.wte(idx)
466
+ x = norm(x)
467
+ x0 = x # save initial normalized embedding for x0 residual
468
+ for i, block in enumerate(self.transformer.h):
469
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
470
+ ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
471
+ x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
472
+ x = norm(x)
473
+
474
+ # Forward the lm_head (compute logits)
475
+ softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
476
+ logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
477
+ logits = logits[..., :self.config.vocab_size] # slice to remove padding
478
+ logits = logits.float() # switch to fp32 for logit softcap and loss computation
479
+ logits = softcap * torch.tanh(logits / softcap) # squash the logits
480
+
481
+ # Multi-token prediction heads (for training with future token prediction)
482
+ multi_token_logits = {}
483
+ if return_multi_token and self.multi_token_heads:
484
+ for name, head in self.multi_token_heads.items():
485
+ mt_logits = head(x)
486
+ mt_logits = mt_logits[..., :self.config.vocab_size]
487
+ mt_logits = mt_logits.float()
488
+ mt_logits = softcap * torch.tanh(mt_logits / softcap)
489
+ multi_token_logits[name] = mt_logits
490
+
491
+ if targets is not None:
492
+ # training: given the targets, compute and return the loss
493
+ # TODO experiment with chunked cross-entropy?
494
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
495
+ if return_multi_token:
496
+ return loss, logits, multi_token_logits
497
+ return loss
498
+ else:
499
+ # inference: just return the logits directly
500
+ if return_multi_token:
501
+ return logits, multi_token_logits
502
+ return logits
503
+
504
+ @torch.inference_mode()
505
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
506
+ """
507
+ Naive autoregressive streaming inference.
508
+ To make it super simple, let's assume:
509
+ - batch size is 1
510
+ - ids and the yielded tokens are simple Python lists and ints
511
+ """
512
+ assert isinstance(tokens, list)
513
+ device = self.get_device()
514
+ rng = None
515
+ if temperature > 0:
516
+ rng = torch.Generator(device=device)
517
+ rng.manual_seed(seed)
518
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
519
+ for _ in range(max_tokens):
520
+ logits = self.forward(ids) # (B, T, vocab_size)
521
+ logits = logits[:, -1, :] # (B, vocab_size)
522
+ if top_k is not None:
523
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
524
+ logits[logits < v[:, [-1]]] = -float('Inf')
525
+ if temperature > 0:
526
+ logits = logits / temperature
527
+ probs = F.softmax(logits, dim=-1)
528
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
529
+ else:
530
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
531
+ ids = torch.cat((ids, next_ids), dim=1)
532
+ token = next_ids.item()
533
+ yield token
534
+
535
+ @torch.inference_mode()
536
+ def generate_speculative(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
537
+ """
538
+ Speculative decoding using self-draft.
539
+
540
+ Algorithm:
541
+ 1. Get hidden state for last token
542
+ 2. Draft head predicts N tokens quickly
543
+ 3. Verify all N+1 positions (original + drafts) in one forward pass
544
+ 4. Accept longest prefix where draft matches verification
545
+ 5. Yield accepted tokens, repeat
546
+
547
+ This reduces the effective number of forward passes from max_tokens to ~max_tokens / acceptance_rate.
548
+ """
549
+ assert isinstance(tokens, list)
550
+ assert self.draft_head is not None, "Draft head not available (draft_n=0 in config)"
551
+ device = self.get_device()
552
+ draft_n = self.config.draft_n
553
+ rng = None
554
+ if temperature > 0:
555
+ rng = torch.Generator(device=device)
556
+ rng.manual_seed(seed)
557
+
558
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
559
+ tokens_generated = 0
560
+
561
+ while tokens_generated < max_tokens:
562
+ # Forward pass to get hidden states (we need the raw hidden state for draft head)
563
+ # Run trunk to get hidden states
564
+ B, T = ids.size()
565
+ T0 = 0 # no kv cache for simplicity in this version
566
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
567
+
568
+ x = self.transformer.wte(ids)
569
+ x = norm(x)
570
+ x0 = x
571
+ for i, block in enumerate(self.transformer.h):
572
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
573
+ ve = self.value_embeds[str(i)](ids) if str(i) in self.value_embeds else None
574
+ x = block(x, ve, cos_sin, self.window_sizes[i], None)
575
+ x = norm(x)
576
+
577
+ # Get hidden state for last position
578
+ last_hidden = x[:, -1, :] # (B, n_embd)
579
+
580
+ # Draft N tokens using draft head
581
+ draft_logits = self.draft_head(last_hidden) # (B, draft_n, vocab_size)
582
+ if temperature > 0:
583
+ draft_logits = draft_logits / temperature
584
+ draft_probs = F.softmax(draft_logits, dim=-1)
585
+ draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.size(-1)), num_samples=1, generator=rng)
586
+ draft_tokens = draft_tokens.view(B, draft_n) # (B, draft_n)
587
+ else:
588
+ draft_tokens = torch.argmax(draft_logits, dim=-1) # (B, draft_n)
589
+
590
+ # Prepare verification sequence: original + draft tokens
591
+ verify_ids = torch.cat([ids, draft_tokens], dim=1) # (B, T + draft_n)
592
+
593
+ # Verify all draft tokens with full model in one forward pass
594
+ verify_logits = self.forward(verify_ids) # (B, T + draft_n, vocab_size)
595
+
596
+ # Sample from verification logits for positions T-1 to T+draft_n-1
597
+ # Position T-1 verifies the first draft token, etc.
598
+ accepted = []
599
+ for i in range(draft_n):
600
+ pos = T - 1 + i # verification position
601
+ if pos >= verify_logits.size(1):
602
+ break
603
+
604
+ v_logits = verify_logits[:, pos, :]
605
+ if top_k is not None:
606
+ v, _ = torch.topk(v_logits, min(top_k, v_logits.size(-1)))
607
+ v_logits[v_logits < v[:, [-1]]] = -float('Inf')
608
+
609
+ if temperature > 0:
610
+ v_logits = v_logits / temperature
611
+ v_probs = F.softmax(v_logits, dim=-1)
612
+ verified_token = torch.multinomial(v_probs, num_samples=1, generator=rng)
613
+ else:
614
+ verified_token = torch.argmax(v_logits, dim=-1, keepdim=True)
615
+
616
+ # Check if draft matches verification
617
+ if i < draft_n and draft_tokens[0, i] == verified_token[0, 0]:
618
+ accepted.append(verified_token[0, 0].item())
619
+ else:
620
+ # Draft wrong, accept verified token and stop
621
+ accepted.append(verified_token[0, 0].item())
622
+ break
623
+
624
+ # Yield accepted tokens
625
+ for tok in accepted:
626
+ if tokens_generated >= max_tokens:
627
+ return
628
+ yield tok
629
+ tokens_generated += 1
630
+
631
+ # Update ids with accepted tokens
632
+ accepted_tensor = torch.tensor([accepted], dtype=torch.long, device=device)
633
+ ids = torch.cat([ids, accepted_tensor], dim=1)
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-independent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and independently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ if total_bytes == 0:
63
+ return float('inf')
64
+ bpb = total_nats / (math.log(2) * total_bytes)
65
+ return bpb
nanochat/muon.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon optimizer adapted and simplified from modded-nanogpt.
3
+ https://github.com/KellerJordan/modded-nanogpt
4
+
5
+ Background:
6
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
7
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
8
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
9
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
10
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
11
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
12
+ performance at all relative to UV^T, where USV^T = G is the SVD.
13
+
14
+ Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
15
+ Polar Express Sign Method for orthogonalization.
16
+ https://arxiv.org/pdf/2505.16932
17
+ by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
18
+
19
+ Some of the changes in nanochat implementation:
20
+ - Uses a simpler, more general approach to parameter grouping and stacking
21
+ - Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
22
+ - Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
23
+ """
24
+
25
+ import torch
26
+ from torch import Tensor
27
+ import torch.distributed as dist
28
+
29
+ # Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
30
+ # From https://arxiv.org/pdf/2505.16932
31
+ polar_express_coeffs = [
32
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
33
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
34
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
35
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
36
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
37
+ ]
38
+
39
+ @torch.compile(dynamic=False, fullgraph=True)
40
+ def muon_step_fused(
41
+ stacked_grads: Tensor,
42
+ stacked_params: Tensor,
43
+ momentum_buffer: Tensor,
44
+ second_momentum_buffer: Tensor,
45
+ momentum_t: Tensor,
46
+ lr_t: Tensor,
47
+ wd_t: Tensor,
48
+ beta2_t: Tensor,
49
+ ns_steps: int,
50
+ red_dim: int,
51
+ ) -> None:
52
+ """
53
+ Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
54
+ All in one compiled graph to eliminate Python overhead between ops.
55
+ Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
56
+ """
57
+
58
+ # Nesterov momentum
59
+ momentum = momentum_t.to(stacked_grads.dtype)
60
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
61
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
62
+
63
+ # Polar express
64
+ X = g.bfloat16()
65
+ if g.size(-2) > g.size(-1):
66
+ X = X.mT
67
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
68
+ for a, b, c in polar_express_coeffs[:ns_steps]:
69
+ A = X @ X.mT
70
+ B = b * A + c * (A @ A)
71
+ X = a * X + B @ X
72
+ if g.size(-2) > g.size(-1):
73
+ X = X.mT
74
+ g = X
75
+
76
+ # Variance reduction
77
+ beta2 = beta2_t.to(g.dtype)
78
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
79
+ red_dim_size = g.size(red_dim)
80
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
81
+ v_norm = v_norm_sq.sqrt()
82
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
83
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
84
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
85
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
86
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
87
+ g = g * final_scale.to(g.dtype)
88
+
89
+ # Cautious weight decay + parameter update
90
+ lr = lr_t.to(g.dtype)
91
+ wd = wd_t.to(g.dtype)
92
+ mask = (g * stacked_params) >= 0
93
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
94
+
95
+ class Muon(torch.optim.Optimizer):
96
+ """
97
+ Muon - MomentUm Orthogonalized by Newton-schulz
98
+
99
+ https://kellerjordan.github.io/posts/muon/
100
+
101
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
102
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
103
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
104
+ the advantage that it can be stably run in bfloat16 on the GPU.
105
+
106
+ Some warnings:
107
+ - This optimizer should not be used for the embedding layer, the final fully connected layer,
108
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
109
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
110
+
111
+ Arguments:
112
+ lr: The learning rate used by the internal SGD.
113
+ momentum: The momentum used by the internal SGD.
114
+ ns_steps: The number of Newton-Schulz iteration steps to use.
115
+ beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
116
+ weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
117
+ """
118
+ def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
119
+ defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
120
+ assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
121
+ params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
122
+ # Group by shape so we can stack tensors
123
+ shapes = sorted({p.shape for p in params})
124
+ param_groups = []
125
+ for shape in shapes:
126
+ group_params = [p for p in params if p.shape == shape]
127
+ param_groups.append(dict(params=group_params))
128
+ super().__init__(param_groups, defaults)
129
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
130
+ self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
131
+ self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
132
+ self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
133
+ self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
134
+
135
+ @torch.no_grad()
136
+ def step(self):
137
+ for group in self.param_groups:
138
+ params: list[Tensor] = group["params"]
139
+ if not params:
140
+ continue
141
+
142
+ # Get or create group-level buffers (stored in first param's state for convenience)
143
+ state = self.state[params[0]]
144
+ num_params = len(params) # e.g.: 12 (for a d12 model)
145
+ # e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
146
+ shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
147
+
148
+ # Momentum for every individual parameter
149
+ if "momentum_buffer" not in state:
150
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
151
+ momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
152
+
153
+ # Second momentum buffer is factored, either per-row or per-column
154
+ if "second_momentum_buffer" not in state:
155
+ if shape[-2] >= shape[-1]:
156
+ state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
157
+ else:
158
+ state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
159
+ second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
160
+ red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
161
+
162
+ # Stack grads and params
163
+ stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
164
+ stacked_params = torch.stack(params) # (12, 768, 3072)
165
+
166
+ # Fill all the 0-D tensors with current values
167
+ self._momentum_t.fill_(group["momentum"])
168
+ self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
169
+ self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
170
+ self._wd_t.fill_(group["weight_decay"])
171
+
172
+ # Single fused kernel: momentum -> polar_express -> variance_reduction -> update
173
+ muon_step_fused(
174
+ stacked_grads,
175
+ stacked_params,
176
+ momentum_buffer,
177
+ second_momentum_buffer,
178
+ self._momentum_t,
179
+ self._lr_t,
180
+ self._wd_t,
181
+ self._beta2_t,
182
+ group["ns_steps"],
183
+ red_dim,
184
+ )
185
+
186
+ # Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
187
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
188
+
189
+
190
+ class DistMuon(torch.optim.Optimizer):
191
+ """
192
+ Distributed version of the Muon optimizer.
193
+ """
194
+ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
195
+ ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
196
+ defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
197
+ assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
198
+ params = list(params)
199
+ world_size = dist.get_world_size()
200
+ rank = dist.get_rank()
201
+ # Group all parameters by their shape
202
+ shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
203
+ param_groups = []
204
+ for shape in shapes:
205
+ group_params = [p for p in params if p.shape == shape]
206
+ device, dtype = group_params[0].device, group_params[0].dtype
207
+ assert all(p.device == device for p in group_params)
208
+ assert all(p.dtype == dtype for p in group_params)
209
+ # Compute chunk size for this group (how many params each rank owns)
210
+ chunk_size = (len(group_params) + world_size - 1) // world_size
211
+ if rank == 0:
212
+ print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
213
+ param_groups.append(dict(params=group_params, chunk_size=chunk_size))
214
+ super().__init__(param_groups, defaults)
215
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
216
+ self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
217
+ self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
218
+ self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
219
+ self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
220
+
221
+ @torch.no_grad()
222
+ def step(self):
223
+ rank = dist.get_rank()
224
+ world_size = dist.get_world_size()
225
+
226
+ # Ensure all grads exist
227
+ assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
228
+
229
+ # First pass: stack grads and kick off reduce_scatter for each group
230
+ group_infos = []
231
+ for group in self.param_groups:
232
+ params: list[Tensor] = group["params"]
233
+ chunk_size = group["chunk_size"]
234
+ padded_num_params = chunk_size * world_size
235
+ shape = params[0].shape
236
+ device, dtype = params[0].device, params[0].dtype
237
+
238
+ # Stack all gradients into a single tensor (single kernel via torch.stack)
239
+ grad_stack = torch.stack([p.grad for p in params])
240
+ stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
241
+ stacked_grads[:len(params)].copy_(grad_stack)
242
+ # Zero-pad if we have fewer params than padded size
243
+ if len(params) < padded_num_params:
244
+ stacked_grads[len(params):].zero_()
245
+
246
+ # Output buffer for this rank's chunk
247
+ grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
248
+
249
+ # Async reduce_scatter on the stacked tensor
250
+ reduce_future = dist.reduce_scatter_tensor(
251
+ grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
252
+ ).get_future()
253
+
254
+ group_infos.append(dict(
255
+ grad_chunk=grad_chunk,
256
+ reduce_future=reduce_future,
257
+ stacked_grads=stacked_grads, # reuse for all_gather output
258
+ ))
259
+
260
+ # Second pass: wait for reduce, compute batched updates, kick off all_gather
261
+ all_gather_futures = []
262
+ for group, info in zip(self.param_groups, group_infos):
263
+ info["reduce_future"].wait()
264
+
265
+ params = group["params"]
266
+ chunk_size = group["chunk_size"]
267
+ shape = params[0].shape
268
+ device, dtype = params[0].device, params[0].dtype
269
+ grad_chunk = info["grad_chunk"]
270
+
271
+ # How many params does this rank actually own?
272
+ start_idx = rank * chunk_size
273
+ num_owned = min(chunk_size, max(0, len(params) - start_idx))
274
+
275
+ # Get or create group-level state (stored keyed by first param)
276
+ state = self.state[params[0]]
277
+
278
+ # Momentum buffer
279
+ if "momentum_buffer" not in state:
280
+ state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
281
+ momentum_buffer = state["momentum_buffer"]
282
+
283
+ # Second momentum buffer is factored, either per-row or per-column
284
+ if "second_momentum_buffer" not in state:
285
+ if shape[-2] >= shape[-1]:
286
+ state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
287
+ else:
288
+ state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
289
+ second_momentum_buffer = state["second_momentum_buffer"]
290
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
291
+
292
+ # Build updated_params tensor for all_gather
293
+ updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
294
+
295
+ if num_owned > 0:
296
+ # Stack owned params (single kernel via torch.stack)
297
+ owned_params = [params[start_idx + i] for i in range(num_owned)]
298
+ stacked_owned_params = torch.stack(owned_params)
299
+
300
+ # Get owned slices of buffers and grads
301
+ owned_grads = grad_chunk[:num_owned]
302
+ owned_momentum = momentum_buffer[:num_owned]
303
+ owned_second_momentum = second_momentum_buffer[:num_owned]
304
+
305
+ # Fill 0-D tensors with current values
306
+ self._momentum_t.fill_(group["momentum"])
307
+ self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
308
+ self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
309
+ self._wd_t.fill_(group["weight_decay"])
310
+
311
+ # Single fused kernel: momentum -> polar_express -> variance_reduction -> update
312
+ muon_step_fused(
313
+ owned_grads,
314
+ stacked_owned_params,
315
+ owned_momentum,
316
+ owned_second_momentum,
317
+ self._momentum_t,
318
+ self._lr_t,
319
+ self._wd_t,
320
+ self._beta2_t,
321
+ group["ns_steps"],
322
+ red_dim,
323
+ )
324
+
325
+ # Copy updated params to output buffer
326
+ updated_params[:num_owned].copy_(stacked_owned_params)
327
+
328
+ # Zero-pad the rest (for ranks that own fewer params)
329
+ if num_owned < chunk_size:
330
+ updated_params[num_owned:].zero_()
331
+
332
+ # Reuse stacked_grads buffer for all_gather output
333
+ stacked_params = info["stacked_grads"]
334
+
335
+ # Async all_gather to replicate updated params to all ranks
336
+ gather_future = dist.all_gather_into_tensor(
337
+ stacked_params, updated_params, async_op=True
338
+ ).get_future()
339
+
340
+ all_gather_futures.append(dict(
341
+ gather_future=gather_future,
342
+ stacked_params=stacked_params,
343
+ params=params,
344
+ ))
345
+
346
+ # Final pass: wait for all_gather and copy back to params
347
+ for info in all_gather_futures:
348
+ info["gather_future"].wait()
349
+ stacked_params = info["stacked_params"]
350
+ params = info["params"]
351
+ # Batched copy back (single kernel instead of N individual copies)
352
+ torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
nanochat/prune.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nanochat.gpt import GPT, GPTConfig
3
+
4
+
5
+ def head_imp(model):
6
+ head_importance = {}
7
+ head_dim = model.config.n_embd // model.config.n_head
8
+
9
+ for layer_idx, block in enumerate(model.transformer.h):
10
+ attn = block.attn
11
+ n_head = attn.n_head
12
+ n_kv_head = attn.n_kv_head
13
+
14
+ head_scores = torch.zeros(n_head, device=next(model.parameters()).device)
15
+
16
+ q_weight = attn.c_q.weight.view(n_head, head_dim, model.config.n_embd)
17
+ head_scores += q_weight.abs().sum(dim=(1, 2))
18
+
19
+ proj_weight = attn.c_proj.weight.view(model.config.n_embd, n_head, head_dim)
20
+ head_scores += proj_weight.abs().sum(dim=(0, 2))
21
+
22
+ if n_kv_head == n_head:
23
+ k_weight = attn.c_k.weight.view(n_head, head_dim, model.config.n_embd)
24
+ v_weight = attn.c_v.weight.view(n_head, head_dim, model.config.n_embd)
25
+ head_scores += k_weight.abs().sum(dim=(1, 2))
26
+ head_scores += v_weight.abs().sum(dim=(1, 2))
27
+
28
+ head_importance[layer_idx] = head_scores
29
+
30
+ return head_importance
31
+
32
+
33
+ def neuron_imp(model):
34
+ neuron_importance = {}
35
+ hidden_dim = 4 * model.config.n_embd
36
+
37
+ for layer_idx, block in enumerate(model.transformer.h):
38
+ mlp = block.mlp
39
+
40
+ fc_importance = mlp.c_fc.weight.abs().sum(dim=1)
41
+ proj_importance = mlp.c_proj.weight.abs().sum(dim=0)
42
+
43
+ neuron_scores = fc_importance + proj_importance
44
+ neuron_importance[layer_idx] = neuron_scores
45
+
46
+ return neuron_importance
47
+
48
+
49
+ def select_heads(head_importance, prune_ratio):
50
+ heads_to_keep = {}
51
+
52
+ for layer_idx, scores in head_importance.items():
53
+ n_head = len(scores)
54
+ n_to_keep = int(n_head * (1 - prune_ratio))
55
+ n_to_keep = max(1, n_to_keep)
56
+
57
+ _, top_indices = torch.topk(scores, n_to_keep, largest=True)
58
+ heads_to_keep[layer_idx] = top_indices.sort().values.tolist()
59
+
60
+ return heads_to_keep
61
+
62
+
63
+ def select_neurons(neuron_importance, prune_ratio):
64
+ neurons_to_keep = {}
65
+
66
+ for layer_idx, scores in neuron_importance.items():
67
+ n_neurons = len(scores)
68
+ n_to_keep = int(n_neurons * (1 - prune_ratio))
69
+ n_to_keep = max(1, n_to_keep)
70
+
71
+ _, top_indices = torch.topk(scores, n_to_keep, largest=True)
72
+ neurons_to_keep[layer_idx] = top_indices.sort().values.tolist()
73
+
74
+ return neurons_to_keep
75
+
76
+
77
+ def make_pruned_config(original_config, heads_to_keep, neurons_to_keep):
78
+ min_heads = min(len(heads) for heads in heads_to_keep.values())
79
+ min_neurons = min(len(neurons) for neurons in neurons_to_keep.values())
80
+
81
+ head_dim = original_config.n_embd // original_config.n_head
82
+ new_n_embd = min_heads * head_dim
83
+
84
+ config = GPTConfig(
85
+ sequence_len=original_config.sequence_len,
86
+ vocab_size=original_config.vocab_size,
87
+ n_layer=original_config.n_layer,
88
+ n_head=min_heads,
89
+ n_kv_head=1 if original_config.use_mqa else min_heads,
90
+ n_embd=new_n_embd,
91
+ window_pattern=original_config.window_pattern,
92
+ use_mqa=original_config.use_mqa,
93
+ multi_token_n=original_config.multi_token_n,
94
+ draft_n=original_config.draft_n,
95
+ draft_hidden_mult=original_config.draft_hidden_mult,
96
+ )
97
+
98
+ return config, min_heads, min_neurons
99
+
100
+
101
+ def prune_weights(model, heads_to_keep, neurons_to_keep, pruned_config, min_heads, min_neurons):
102
+ head_dim = model.config.n_embd // model.config.n_head
103
+ original_n_embd = model.config.n_embd
104
+ pruned_n_embd = pruned_config.n_embd
105
+
106
+ with torch.device("meta"):
107
+ pruned_model = GPT(pruned_config)
108
+
109
+ device = next(model.parameters()).device
110
+ pruned_model.to_empty(device=device)
111
+
112
+ pruned_model.transformer.wte.weight.data.copy_(model.transformer.wte.weight.data[:, :pruned_n_embd])
113
+
114
+ pruned_model.resid_lambdas.data.copy_(model.resid_lambdas.data)
115
+ pruned_model.x0_lambdas.data.copy_(model.x0_lambdas.data)
116
+
117
+ for key in model.value_embeds.keys():
118
+ if key in pruned_model.value_embeds:
119
+ orig_ve = model.value_embeds[key].weight.data
120
+ pruned_ve = pruned_model.value_embeds[key].weight.data
121
+ if orig_ve.size(1) > pruned_ve.size(1):
122
+ pruned_model.value_embeds[key].weight.data.copy_(orig_ve[:, :pruned_ve.size(1)])
123
+ else:
124
+ pruned_model.value_embeds[key].weight.data.copy_(orig_ve)
125
+
126
+ for key in model.multi_token_heads.keys():
127
+ if key in pruned_model.multi_token_heads:
128
+ orig_weight = model.multi_token_heads[key].weight.data
129
+ pruned_model.multi_token_heads[key].weight.data.copy_(orig_weight[:, :pruned_n_embd])
130
+
131
+ if model.draft_head is not None and pruned_model.draft_head is not None:
132
+ pruned_model.draft_head.fc1.weight.data.copy_(model.draft_head.fc1.weight.data[:, :pruned_n_embd])
133
+ pruned_model.draft_head.fc2.weight.data.copy_(model.draft_head.fc2.weight.data)
134
+
135
+ pruned_model.lm_head.weight.data.copy_(model.lm_head.weight.data[:, :pruned_n_embd])
136
+
137
+ for layer_idx in range(model.config.n_layer):
138
+ orig_block = model.transformer.h[layer_idx]
139
+ pruned_block = pruned_model.transformer.h[layer_idx]
140
+
141
+ layer_heads = heads_to_keep[layer_idx]
142
+ layer_neurons = neurons_to_keep[layer_idx]
143
+
144
+ attn_orig = orig_block.attn
145
+ attn_pruned = pruned_block.attn
146
+
147
+ q_orig = attn_orig.c_q.weight.view(model.config.n_head, head_dim, original_n_embd)
148
+ q_pruned = q_orig[layer_heads[:min_heads]].contiguous().view(min_heads * head_dim, pruned_n_embd)
149
+ attn_pruned.c_q.weight.data.copy_(q_pruned)
150
+
151
+ if attn_orig.n_kv_head == model.config.n_head:
152
+ k_orig = attn_orig.c_k.weight.view(model.config.n_head, head_dim, original_n_embd)
153
+ k_pruned = k_orig[layer_heads[:min_heads]].contiguous().view(min_heads * head_dim, pruned_n_embd)
154
+ attn_pruned.c_k.weight.data.copy_(k_pruned)
155
+ else:
156
+ k_orig = attn_orig.c_k.weight
157
+ k_pruned = k_orig[:, :pruned_n_embd]
158
+ attn_pruned.c_k.weight.data.copy_(k_pruned)
159
+
160
+ if attn_orig.n_kv_head == model.config.n_head:
161
+ v_orig = attn_orig.c_v.weight.view(model.config.n_head, head_dim, original_n_embd)
162
+ v_pruned = v_orig[layer_heads[:min_heads]].contiguous().view(min_heads * head_dim, pruned_n_embd)
163
+ attn_pruned.c_v.weight.data.copy_(v_pruned)
164
+ else:
165
+ v_orig = attn_orig.c_v.weight
166
+ v_pruned = v_orig[:, :pruned_n_embd]
167
+ attn_pruned.c_v.weight.data.copy_(v_pruned)
168
+
169
+ proj_orig = attn_orig.c_proj.weight.view(original_n_embd, model.config.n_head, head_dim)
170
+ proj_pruned = proj_orig[:, layer_heads[:min_heads], :].contiguous().view(original_n_embd, min_heads * head_dim)
171
+ proj_pruned = proj_pruned[:pruned_n_embd, :]
172
+ attn_pruned.c_proj.weight.data.copy_(proj_pruned)
173
+
174
+ if attn_orig.ve_gate is not None and attn_pruned.ve_gate is not None:
175
+ if attn_orig.n_kv_head == model.config.n_head:
176
+ gate_orig = attn_orig.ve_gate.weight.view(model.config.n_head, -1)
177
+ gate_pruned = gate_orig[layer_heads[:min_heads]]
178
+ attn_pruned.ve_gate.weight.data.copy_(gate_pruned.view(min_heads, -1))
179
+ else:
180
+ attn_pruned.ve_gate.weight.data.copy_(attn_orig.ve_gate.weight.data)
181
+
182
+ mlp_orig = orig_block.mlp
183
+ mlp_pruned = pruned_block.mlp
184
+
185
+ fc_orig = mlp_orig.c_fc.weight
186
+ fc_pruned = fc_orig[layer_neurons[:min_neurons]]
187
+ fc_pruned = fc_pruned[:, :pruned_n_embd]
188
+ mlp_pruned.c_fc.weight.data.copy_(fc_pruned)
189
+
190
+ proj_orig = mlp_orig.c_proj.weight
191
+ proj_pruned = proj_orig[:, layer_neurons[:min_neurons]]
192
+ proj_pruned = proj_pruned[:pruned_n_embd, :]
193
+ mlp_pruned.c_proj.weight.data.copy_(proj_pruned)
194
+
195
+ pruned_model.cos.copy_(model.cos)
196
+ pruned_model.sin.copy_(model.sin)
197
+
198
+ return pruned_model
199
+
200
+
201
+ def prune_model(model, head_prune_ratio=0.2, neuron_prune_ratio=0.2):
202
+ head_importance = head_imp(model)
203
+ neuron_importance = neuron_imp(model)
204
+
205
+ heads_to_keep = select_heads(head_importance, head_prune_ratio)
206
+ neurons_to_keep = select_neurons(neuron_importance, neuron_prune_ratio)
207
+
208
+ config, min_heads, min_neurons = make_pruned_config(
209
+ model.config, heads_to_keep, neurons_to_keep
210
+ )
211
+
212
+ pruned_model = prune_weights(
213
+ model, heads_to_keep, neurons_to_keep, config, min_heads, min_neurons
214
+ )
215
+
216
+ return pruned_model, config
nanochat/quantize.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from nanochat.gpt import GPT, GPTConfig
4
+
5
+
6
+ def quantize_tensor(weight, bits=8):
7
+ qmin = -(2 ** (bits - 1))
8
+ qmax = 2 ** (bits - 1) - 1
9
+
10
+ scale = weight.abs().max() / qmax
11
+ scale = scale.clamp(min=1e-8)
12
+
13
+ quantized = (weight / scale).round().clamp(qmin, qmax)
14
+
15
+ return quantized.to(torch.int8), scale.item()
16
+
17
+
18
+ def dequantize_tensor(quantized, scale):
19
+ return quantized.float() * scale
20
+
21
+
22
+ def quantize_linear(linear_layer, bits=8):
23
+ weight = linear_layer.weight.data
24
+ quantized, scale = quantize_tensor(weight, bits)
25
+ return quantized, scale
26
+
27
+
28
+ def quantize_model(model, bits=8):
29
+ quantized_state = {}
30
+ scales = {}
31
+
32
+ for name, param in model.named_parameters():
33
+ if param.requires_grad and len(param.shape) >= 2:
34
+ quantized, scale = quantize_tensor(param.data, bits)
35
+ quantized_state[name] = quantized
36
+ scales[name] = scale
37
+ else:
38
+ quantized_state[name] = param.data
39
+
40
+ for name, buffer in model.named_buffers():
41
+ quantized_state[name] = buffer.data
42
+
43
+ return quantized_state, scales
44
+
45
+
46
+ def apply_quantization(model, scales, bits=8):
47
+ for name, param in model.named_parameters():
48
+ if name in scales and len(param.shape) >= 2:
49
+ scale = scales[name]
50
+ quantized = param.data
51
+ param.data = dequantize_tensor(quantized, scale)
52
+
53
+
54
+ def export_int8(model, output_path, bits=8):
55
+ quantized_state, scales = quantize_model(model, bits)
56
+
57
+ export_data = {
58
+ 'quantized_weights': quantized_state,
59
+ 'scales': scales,
60
+ 'config': {
61
+ 'n_layer': model.config.n_layer,
62
+ 'n_head': model.config.n_head,
63
+ 'n_kv_head': model.config.n_kv_head,
64
+ 'n_embd': model.config.n_embd,
65
+ 'vocab_size': model.config.vocab_size,
66
+ 'sequence_len': model.config.sequence_len,
67
+ 'window_pattern': model.config.window_pattern,
68
+ 'use_mqa': model.config.use_mqa,
69
+ 'multi_token_n': model.config.multi_token_n,
70
+ 'draft_n': model.config.draft_n,
71
+ 'draft_hidden_mult': model.config.draft_hidden_mult,
72
+ },
73
+ 'bits': bits,
74
+ }
75
+
76
+ torch.save(export_data, output_path)
77
+ return export_data
78
+
79
+
80
+ def load_int8(model_path, device):
81
+ data = torch.load(model_path, map_location=device)
82
+
83
+ config_kwargs = data['config']
84
+ config = GPTConfig(**config_kwargs)
85
+
86
+ with torch.device("meta"):
87
+ model = GPT(config)
88
+
89
+ model.to_empty(device=device)
90
+ model.init_weights()
91
+
92
+ quantized_state = data['quantized_weights']
93
+ scales = data['scales']
94
+
95
+ state_dict = {}
96
+ for name, param in model.named_parameters():
97
+ if name in quantized_state:
98
+ if name in scales:
99
+ quantized = quantized_state[name]
100
+ scale = scales[name]
101
+ state_dict[name] = dequantize_tensor(quantized, scale)
102
+ else:
103
+ state_dict[name] = quantized_state[name]
104
+
105
+ for name, buffer in model.named_buffers():
106
+ if name in quantized_state:
107
+ state_dict[name] = quantized_state[name]
108
+
109
+ model.load_state_dict(state_dict, strict=False)
110
+ model.eval()
111
+
112
+ return model, data['config'], scales
113
+
nanochat/report.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ # Return stdout if we got output (even if some files in xargs failed)
20
+ if result.stdout.strip():
21
+ return result.stdout.strip()
22
+ if result.returncode == 0:
23
+ return ""
24
+ return None
25
+ except:
26
+ return None
27
+
28
+ def get_git_info():
29
+ """Get current git commit, branch, and dirty status."""
30
+ info = {}
31
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
32
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
33
+
34
+ # Check if repo is dirty (has uncommitted changes)
35
+ status = run_command("git status --porcelain")
36
+ info['dirty'] = bool(status) if status is not None else False
37
+
38
+ # Get commit message
39
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
40
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
41
+
42
+ return info
43
+
44
+ def get_gpu_info():
45
+ """Get GPU information."""
46
+ if not torch.cuda.is_available():
47
+ return {"available": False}
48
+
49
+ num_devices = torch.cuda.device_count()
50
+ info = {
51
+ "available": True,
52
+ "count": num_devices,
53
+ "names": [],
54
+ "memory_gb": []
55
+ }
56
+
57
+ for i in range(num_devices):
58
+ props = torch.cuda.get_device_properties(i)
59
+ info["names"].append(props.name)
60
+ info["memory_gb"].append(props.total_memory / (1024**3))
61
+
62
+ # Get CUDA version
63
+ info["cuda_version"] = torch.version.cuda or "unknown"
64
+
65
+ return info
66
+
67
+ def get_system_info():
68
+ """Get system information."""
69
+ info = {}
70
+
71
+ # Basic system info
72
+ info['hostname'] = socket.gethostname()
73
+ info['platform'] = platform.system()
74
+ info['python_version'] = platform.python_version()
75
+ info['torch_version'] = torch.__version__
76
+
77
+ # CPU and memory
78
+ info['cpu_count'] = psutil.cpu_count(logical=False)
79
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
80
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
81
+
82
+ # User and environment
83
+ info['user'] = os.environ.get('USER', 'unknown')
84
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
85
+ info['working_dir'] = os.getcwd()
86
+
87
+ return info
88
+
89
+ def estimate_cost(gpu_info, runtime_hours=None):
90
+ """Estimate training cost based on GPU type and runtime."""
91
+
92
+ # Rough pricing, from Lambda Cloud
93
+ default_rate = 2.0
94
+ gpu_hourly_rates = {
95
+ "H100": 3.00,
96
+ "A100": 1.79,
97
+ "V100": 0.55,
98
+ }
99
+
100
+ if not gpu_info.get("available"):
101
+ return None
102
+
103
+ # Try to identify GPU type from name
104
+ hourly_rate = None
105
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
106
+ for gpu_type, rate in gpu_hourly_rates.items():
107
+ if gpu_type in gpu_name:
108
+ hourly_rate = rate * gpu_info["count"]
109
+ break
110
+
111
+ if hourly_rate is None:
112
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
113
+
114
+ return {
115
+ "hourly_rate": hourly_rate,
116
+ "gpu_type": gpu_name,
117
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
118
+ }
119
+
120
+ def generate_header():
121
+ """Generate the header for a training report."""
122
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
123
+
124
+ git_info = get_git_info()
125
+ gpu_info = get_gpu_info()
126
+ sys_info = get_system_info()
127
+ cost_info = estimate_cost(gpu_info)
128
+
129
+ header = f"""# nanochat training report
130
+
131
+ Generated: {timestamp}
132
+
133
+ ## Environment
134
+
135
+ ### Git Information
136
+ - Branch: {git_info['branch']}
137
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
138
+ - Message: {git_info['message']}
139
+
140
+ ### Hardware
141
+ - Platform: {sys_info['platform']}
142
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
143
+ - Memory: {sys_info['memory_gb']:.1f} GB
144
+ """
145
+
146
+ if gpu_info.get("available"):
147
+ gpu_names = ", ".join(set(gpu_info["names"]))
148
+ total_vram = sum(gpu_info["memory_gb"])
149
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
150
+ - GPU Memory: {total_vram:.1f} GB total
151
+ - CUDA Version: {gpu_info['cuda_version']}
152
+ """
153
+ else:
154
+ header += "- GPUs: None available\n"
155
+
156
+ if cost_info and cost_info["hourly_rate"] > 0:
157
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
158
+
159
+ header += f"""
160
+ ### Software
161
+ - Python: {sys_info['python_version']}
162
+ - PyTorch: {sys_info['torch_version']}
163
+
164
+ """
165
+
166
+ # bloat metrics: count lines/chars in git-tracked source files only
167
+ extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
168
+ git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
169
+ files_output = run_command(f"git ls-files -- {git_patterns}")
170
+ file_list = [f for f in (files_output or '').split('\n') if f]
171
+ num_files = len(file_list)
172
+ num_lines = 0
173
+ num_chars = 0
174
+ if num_files > 0:
175
+ wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
176
+ if wc_output:
177
+ total_line = wc_output.strip().split('\n')[-1]
178
+ parts = total_line.split()
179
+ if len(parts) >= 2:
180
+ num_lines = int(parts[0])
181
+ num_chars = int(parts[1])
182
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
183
+
184
+ # count dependencies via uv.lock
185
+ uv_lock_lines = 0
186
+ if os.path.exists('uv.lock'):
187
+ with open('uv.lock', 'r', encoding='utf-8') as f:
188
+ uv_lock_lines = len(f.readlines())
189
+
190
+ header += f"""
191
+ ### Bloat
192
+ - Characters: {num_chars:,}
193
+ - Lines: {num_lines:,}
194
+ - Files: {num_files:,}
195
+ - Tokens (approx): {num_tokens:,}
196
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
197
+
198
+ """
199
+ return header
200
+
201
+ # -----------------------------------------------------------------------------
202
+
203
+ def slugify(text):
204
+ """Slugify a text string."""
205
+ return text.lower().replace(" ", "-")
206
+
207
+ # the expected files and their order
208
+ EXPECTED_FILES = [
209
+ "tokenizer-training.md",
210
+ "tokenizer-evaluation.md",
211
+ "base-model-training.md",
212
+ "base-model-loss.md",
213
+ "base-model-evaluation.md",
214
+ "midtraining.md",
215
+ "chat-evaluation-mid.md",
216
+ "chat-sft.md",
217
+ "chat-evaluation-sft.md",
218
+ "chat-rl.md",
219
+ "chat-evaluation-rl.md",
220
+ ]
221
+ # the metrics we're currently interested in
222
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
223
+
224
+ def extract(section, keys):
225
+ """simple def to extract a single key from a section"""
226
+ if not isinstance(keys, list):
227
+ keys = [keys] # convenience
228
+ out = {}
229
+ for line in section.split("\n"):
230
+ for key in keys:
231
+ if key in line:
232
+ out[key] = line.split(":")[1].strip()
233
+ return out
234
+
235
+ def extract_timestamp(content, prefix):
236
+ """Extract timestamp from content with given prefix."""
237
+ for line in content.split('\n'):
238
+ if line.startswith(prefix):
239
+ time_str = line.split(":", 1)[1].strip()
240
+ try:
241
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
242
+ except:
243
+ pass
244
+ return None
245
+
246
+ class Report:
247
+ """Maintains a bunch of logs, generates a final markdown report."""
248
+
249
+ def __init__(self, report_dir):
250
+ os.makedirs(report_dir, exist_ok=True)
251
+ self.report_dir = report_dir
252
+
253
+ def log(self, section, data):
254
+ """Log a section of data to the report."""
255
+ slug = slugify(section)
256
+ file_name = f"{slug}.md"
257
+ file_path = os.path.join(self.report_dir, file_name)
258
+ with open(file_path, "w", encoding="utf-8") as f:
259
+ f.write(f"## {section}\n")
260
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
261
+ for item in data:
262
+ if not item:
263
+ # skip falsy values like None or empty dict etc.
264
+ continue
265
+ if isinstance(item, str):
266
+ # directly write the string
267
+ f.write(item)
268
+ else:
269
+ # render a dict
270
+ for k, v in item.items():
271
+ if isinstance(v, float):
272
+ vstr = f"{v:.4f}"
273
+ elif isinstance(v, int) and v >= 10000:
274
+ vstr = f"{v:,.0f}"
275
+ else:
276
+ vstr = str(v)
277
+ f.write(f"- {k}: {vstr}\n")
278
+ f.write("\n")
279
+ return file_path
280
+
281
+ def generate(self):
282
+ """Generate the final report."""
283
+ report_dir = self.report_dir
284
+ report_file = os.path.join(report_dir, "report.md")
285
+ print(f"Generating report to {report_file}")
286
+ final_metrics = {} # the most important final metrics we'll add as table at the end
287
+ start_time = None
288
+ end_time = None
289
+ with open(report_file, "w", encoding="utf-8") as out_file:
290
+ # write the header first
291
+ header_file = os.path.join(report_dir, "header.md")
292
+ if os.path.exists(header_file):
293
+ with open(header_file, "r", encoding="utf-8") as f:
294
+ header_content = f.read()
295
+ out_file.write(header_content)
296
+ start_time = extract_timestamp(header_content, "Run started:")
297
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
298
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
299
+ bloat_data = bloat_data.group(1) if bloat_data else ""
300
+ else:
301
+ start_time = None # will cause us to not write the total wall clock time
302
+ bloat_data = "[bloat data missing]"
303
+ print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
304
+ # process all the individual sections
305
+ for file_name in EXPECTED_FILES:
306
+ section_file = os.path.join(report_dir, file_name)
307
+ if not os.path.exists(section_file):
308
+ print(f"Warning: {section_file} does not exist, skipping")
309
+ continue
310
+ with open(section_file, "r", encoding="utf-8") as in_file:
311
+ section = in_file.read()
312
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
313
+ if "rl" not in file_name:
314
+ # Skip RL sections for end_time calculation because RL is experimental
315
+ end_time = extract_timestamp(section, "timestamp:")
316
+ # extract the most important metrics from the sections
317
+ if file_name == "base-model-evaluation.md":
318
+ final_metrics["base"] = extract(section, "CORE")
319
+ if file_name == "chat-evaluation-mid.md":
320
+ final_metrics["mid"] = extract(section, chat_metrics)
321
+ if file_name == "chat-evaluation-sft.md":
322
+ final_metrics["sft"] = extract(section, chat_metrics)
323
+ if file_name == "chat-evaluation-rl.md":
324
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
325
+ # append this section of the report
326
+ out_file.write(section)
327
+ out_file.write("\n")
328
+ # add the final metrics table
329
+ out_file.write("## Summary\n\n")
330
+ # Copy over the bloat metrics from the header
331
+ out_file.write(bloat_data)
332
+ out_file.write("\n\n")
333
+ # Collect all unique metric names
334
+ all_metrics = set()
335
+ for stage_metrics in final_metrics.values():
336
+ all_metrics.update(stage_metrics.keys())
337
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
338
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
339
+ # Fixed column widths
340
+ stages = ["base", "mid", "sft", "rl"]
341
+ metric_width = 15
342
+ value_width = 8
343
+ # Write table header
344
+ header = f"| {'Metric'.ljust(metric_width)} |"
345
+ for stage in stages:
346
+ header += f" {stage.upper().ljust(value_width)} |"
347
+ out_file.write(header + "\n")
348
+ # Write separator
349
+ separator = f"|{'-' * (metric_width + 2)}|"
350
+ for stage in stages:
351
+ separator += f"{'-' * (value_width + 2)}|"
352
+ out_file.write(separator + "\n")
353
+ # Write table rows
354
+ for metric in all_metrics:
355
+ row = f"| {metric.ljust(metric_width)} |"
356
+ for stage in stages:
357
+ value = final_metrics.get(stage, {}).get(metric, "-")
358
+ row += f" {str(value).ljust(value_width)} |"
359
+ out_file.write(row + "\n")
360
+ out_file.write("\n")
361
+ # Calculate and write total wall clock time
362
+ if start_time and end_time:
363
+ duration = end_time - start_time
364
+ total_seconds = int(duration.total_seconds())
365
+ hours = total_seconds // 3600
366
+ minutes = (total_seconds % 3600) // 60
367
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
368
+ else:
369
+ out_file.write("Total wall clock time: unknown\n")
370
+ # also cp the report.md file to current directory
371
+ print(f"Copying report.md to current directory for convenience")
372
+ shutil.copy(report_file, "report.md")
373
+ return report_file
374
+
375
+ def reset(self):
376
+ """Reset the report."""
377
+ # Remove section files
378
+ for file_name in EXPECTED_FILES:
379
+ file_path = os.path.join(self.report_dir, file_name)
380
+ if os.path.exists(file_path):
381
+ os.remove(file_path)
382
+ # Remove report.md if it exists
383
+ report_file = os.path.join(self.report_dir, "report.md")
384
+ if os.path.exists(report_file):
385
+ os.remove(report_file)
386
+ # Generate and write the header section with start timestamp
387
+ header_file = os.path.join(self.report_dir, "header.md")
388
+ header = generate_header()
389
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
390
+ with open(header_file, "w", encoding="utf-8") as f:
391
+ f.write(header)
392
+ f.write(f"Run started: {start_time}\n\n---\n\n")
393
+ print(f"Reset report and wrote header to {header_file}")
394
+
395
+ # -----------------------------------------------------------------------------
396
+ # nanochat-specific convenience functions
397
+
398
+ class DummyReport:
399
+ def log(self, *args, **kwargs):
400
+ pass
401
+ def reset(self, *args, **kwargs):
402
+ pass
403
+
404
+ def get_report():
405
+ # just for convenience, only rank 0 logs to report
406
+ from nanochat.common import get_base_dir, get_dist_info
407
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
408
+ if ddp_rank == 0:
409
+ report_dir = os.path.join(get_base_dir(), "report")
410
+ return Report(report_dir)
411
+ else:
412
+ return DummyReport()
413
+
414
+ if __name__ == "__main__":
415
+ import argparse
416
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
417
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
418
+ args = parser.parse_args()
419
+ if args.command == "generate":
420
+ get_report().generate()
421
+ elif args.command == "reset":
422
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None, num_threads=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ # num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
110
+ assert isinstance(text, str)
111
+ ids = []
112
+ if prepend is not None:
113
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
114
+ ids.append(prepend_id)
115
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
116
+ if append is not None:
117
+ append_id = append if isinstance(append, int) else self.encode_special(append)
118
+ ids.append(append_id)
119
+ return ids
120
+
121
+ def encode_special(self, text):
122
+ # encode a single special token via exact match
123
+ return self.tokenizer.token_to_id(text)
124
+
125
+ def get_bos_token_id(self):
126
+ # Different HuggingFace models use different BOS tokens and there is little consistency
127
+ # 1) attempt to find a <|bos|> token
128
+ bos = self.encode_special("<|bos|>")
129
+ # 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
130
+ if bos is None:
131
+ bos = self.encode_special("<|endoftext|>")
132
+ # 3) if these fail, it's better to crash than to silently return None
133
+ assert bos is not None, "Failed to find BOS token in tokenizer"
134
+ return bos
135
+
136
+ def encode(self, text, *args, **kwargs):
137
+ if isinstance(text, str):
138
+ return self._encode_one(text, *args, **kwargs)
139
+ elif isinstance(text, list):
140
+ return [self._encode_one(t, *args, **kwargs) for t in text]
141
+ else:
142
+ raise ValueError(f"Invalid input type: {type(text)}")
143
+
144
+ def __call__(self, *args, **kwargs):
145
+ return self.encode(*args, **kwargs)
146
+
147
+ def decode(self, ids):
148
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
149
+
150
+ def save(self, tokenizer_dir):
151
+ # save the tokenizer to disk
152
+ os.makedirs(tokenizer_dir, exist_ok=True)
153
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
154
+ self.tokenizer.save(tokenizer_path)
155
+ print(f"Saved tokenizer to {tokenizer_path}")
156
+
157
+ # -----------------------------------------------------------------------------
158
+ # Tokenizer based on rustbpe + tiktoken combo
159
+ import pickle
160
+ import rustbpe
161
+ import tiktoken
162
+
163
+ class RustBPETokenizer:
164
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
165
+
166
+ def __init__(self, enc, bos_token):
167
+ self.enc = enc
168
+ self.bos_token_id = self.encode_special(bos_token)
169
+
170
+ @classmethod
171
+ def train_from_iterator(cls, text_iterator, vocab_size):
172
+ # 1) train using rustbpe
173
+ tokenizer = rustbpe.Tokenizer()
174
+ # the special tokens are inserted later in __init__, we don't train them here
175
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
176
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
177
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
178
+ # 2) construct the associated tiktoken encoding for inference
179
+ pattern = tokenizer.get_pattern()
180
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
181
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
182
+ tokens_offset = len(mergeable_ranks)
183
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
184
+ enc = tiktoken.Encoding(
185
+ name="rustbpe",
186
+ pat_str=pattern,
187
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
188
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
189
+ )
190
+ return cls(enc, "<|bos|>")
191
+
192
+ @classmethod
193
+ def from_directory(cls, tokenizer_dir):
194
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
195
+ with open(pickle_path, "rb") as f:
196
+ enc = pickle.load(f)
197
+ return cls(enc, "<|bos|>")
198
+
199
+ @classmethod
200
+ def from_pretrained(cls, tiktoken_name):
201
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
202
+ enc = tiktoken.get_encoding(tiktoken_name)
203
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
204
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
205
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
206
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
207
+ return cls(enc, "<|endoftext|>")
208
+
209
+ def get_vocab_size(self):
210
+ return self.enc.n_vocab
211
+
212
+ def get_special_tokens(self):
213
+ return self.enc.special_tokens_set
214
+
215
+ def id_to_token(self, id):
216
+ return self.enc.decode([id])
217
+
218
+ @lru_cache(maxsize=32)
219
+ def encode_special(self, text):
220
+ return self.enc.encode_single_token(text)
221
+
222
+ def get_bos_token_id(self):
223
+ return self.bos_token_id
224
+
225
+ def encode(self, text, prepend=None, append=None, num_threads=8):
226
+ # text can be either a string or a list of strings
227
+
228
+ if prepend is not None:
229
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
230
+ if append is not None:
231
+ append_id = append if isinstance(append, int) else self.encode_special(append)
232
+
233
+ if isinstance(text, str):
234
+ ids = self.enc.encode_ordinary(text)
235
+ if prepend is not None:
236
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
237
+ if append is not None:
238
+ ids.append(append_id)
239
+ elif isinstance(text, list):
240
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
241
+ if prepend is not None:
242
+ for ids_row in ids:
243
+ ids_row.insert(0, prepend_id) # TODO: same
244
+ if append is not None:
245
+ for ids_row in ids:
246
+ ids_row.append(append_id)
247
+ else:
248
+ raise ValueError(f"Invalid input type: {type(text)}")
249
+
250
+ return ids
251
+
252
+ def __call__(self, *args, **kwargs):
253
+ return self.encode(*args, **kwargs)
254
+
255
+ def decode(self, ids):
256
+ return self.enc.decode(ids)
257
+
258
+ def save(self, tokenizer_dir):
259
+ # save the encoding object to disk
260
+ os.makedirs(tokenizer_dir, exist_ok=True)
261
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
262
+ with open(pickle_path, "wb") as f:
263
+ pickle.dump(self.enc, f)
264
+ print(f"Saved tokenizer encoding to {pickle_path}")
265
+
266
+ def render_conversation(self, conversation, max_tokens=2048):
267
+ """
268
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
269
+ Returns:
270
+ - ids: list[int] is a list of token ids of this rendered conversation
271
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
272
+ """
273
+ # ids, masks that we will return and a helper function to help build them up.
274
+ ids, mask = [], []
275
+ def add_tokens(token_ids, mask_val):
276
+ if isinstance(token_ids, int):
277
+ token_ids = [token_ids]
278
+ ids.extend(token_ids)
279
+ mask.extend([mask_val] * len(token_ids))
280
+
281
+ # sometimes the first message is a system message...
282
+ # => just merge it with the second (user) message
283
+ if conversation["messages"][0]["role"] == "system":
284
+ # some conversation surgery is necessary here for now...
285
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
286
+ messages = conversation["messages"]
287
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
288
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
289
+ messages = messages[1:]
290
+ else:
291
+ messages = conversation["messages"]
292
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
293
+
294
+ # fetch all the special tokens we need
295
+ bos = self.get_bos_token_id()
296
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
297
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
298
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
299
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
300
+
301
+ # now we can tokenize the conversation
302
+ add_tokens(bos, 0)
303
+ for i, message in enumerate(messages):
304
+
305
+ # some sanity checking here around assumptions, to prevent footguns
306
+ must_be_from = "user" if i % 2 == 0 else "assistant"
307
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
308
+
309
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
310
+ content = message["content"]
311
+
312
+ if message["role"] == "user":
313
+ assert isinstance(content, str), "User messages are simply expected to be strings"
314
+ value_ids = self.encode(content)
315
+ add_tokens(user_start, 0)
316
+ add_tokens(value_ids, 0)
317
+ add_tokens(user_end, 0)
318
+ elif message["role"] == "assistant":
319
+ add_tokens(assistant_start, 0)
320
+ if isinstance(content, str):
321
+ # simple string => simply add the tokens
322
+ value_ids = self.encode(content)
323
+ add_tokens(value_ids, 1)
324
+ elif isinstance(content, list):
325
+ for part in content:
326
+ value_ids = self.encode(part["text"])
327
+ if part["type"] == "text":
328
+ # string part => simply add the tokens
329
+ add_tokens(value_ids, 1)
330
+ elif part["type"] == "python":
331
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
332
+ add_tokens(python_start, 1)
333
+ add_tokens(value_ids, 1)
334
+ add_tokens(python_end, 1)
335
+ elif part["type"] == "python_output":
336
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
337
+ # none of these tokens are supervised because the tokens come from Python at test time
338
+ add_tokens(output_start, 0)
339
+ add_tokens(value_ids, 0)
340
+ add_tokens(output_end, 0)
341
+ else:
342
+ raise ValueError(f"Unknown part type: {part['type']}")
343
+ else:
344
+ raise ValueError(f"Unknown content type: {type(content)}")
345
+ add_tokens(assistant_end, 1)
346
+
347
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
348
+ ids = ids[:max_tokens]
349
+ mask = mask[:max_tokens]
350
+ return ids, mask
351
+
352
+ def visualize_tokenization(self, ids, mask, with_token_id=False):
353
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
354
+ RED = '\033[91m'
355
+ GREEN = '\033[92m'
356
+ RESET = '\033[0m'
357
+ GRAY = '\033[90m'
358
+ tokens = []
359
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
360
+ token_str = self.decode([token_id])
361
+ color = GREEN if mask_val == 1 else RED
362
+ tokens.append(f"{color}{token_str}{RESET}")
363
+ if with_token_id:
364
+ tokens.append(f"{GRAY}({token_id}){RESET}")
365
+ return '|'.join(tokens)
366
+
367
+ def render_for_completion(self, conversation):
368
+ """
369
+ Used during Reinforcement Learning. In that setting, we want to
370
+ render the conversation priming the Assistant for a completion.
371
+ Unlike the Chat SFT case, we don't need to return the mask.
372
+ """
373
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
374
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
375
+ messages = conversation["messages"]
376
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
377
+ messages.pop() # remove the last message (of the Assistant) inplace
378
+
379
+ # Now tokenize the conversation
380
+ ids, mask = self.render_conversation(conversation)
381
+
382
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
383
+ assistant_start = self.encode_special("<|assistant_start|>")
384
+ ids.append(assistant_start)
385
+ return ids
386
+
387
+ # -----------------------------------------------------------------------------
388
+ # nanochat-specific convenience functions
389
+
390
+ def get_tokenizer():
391
+ from nanochat.common import get_base_dir
392
+ base_dir = get_base_dir()
393
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
394
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
395
+ return RustBPETokenizer.from_directory(tokenizer_dir)
396
+
397
+ def get_token_bytes(device="cpu"):
398
+ import torch
399
+ from nanochat.common import get_base_dir
400
+ base_dir = get_base_dir()
401
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
402
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
403
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
404
+ with open(token_bytes_path, "rb") as f:
405
+ token_bytes = torch.load(f, map_location=device)
406
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ html, body{
18
+ height: 100%;
19
+ margin: 0;
20
+ }
21
+
22
+ body {
23
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
24
+ background-color: #ffffff;
25
+ color: #111827;
26
+ min-height: 100dvh;
27
+ margin: 0;
28
+ display: flex;
29
+ flex-direction: column;
30
+ }
31
+
32
+ .header {
33
+ background-color: #ffffff;
34
+ padding: 1.25rem 1.5rem;
35
+ }
36
+
37
+ .header-left {
38
+ display: flex;
39
+ align-items: center;
40
+ gap: 0.75rem;
41
+ }
42
+
43
+ .header-logo {
44
+ height: 32px;
45
+ width: auto;
46
+ }
47
+
48
+ .header h1 {
49
+ font-size: 1.25rem;
50
+ font-weight: 600;
51
+ margin: 0;
52
+ color: #111827;
53
+ }
54
+
55
+ .new-conversation-btn {
56
+ width: 32px;
57
+ height: 32px;
58
+ padding: 0;
59
+ border: 1px solid #e5e7eb;
60
+ border-radius: 0.5rem;
61
+ background-color: #ffffff;
62
+ color: #6b7280;
63
+ cursor: pointer;
64
+ display: flex;
65
+ align-items: center;
66
+ justify-content: center;
67
+ transition: all 0.2s ease;
68
+ }
69
+
70
+ .new-conversation-btn:hover {
71
+ background-color: #f3f4f6;
72
+ border-color: #d1d5db;
73
+ color: #374151;
74
+ }
75
+
76
+ .chat-container {
77
+ flex: 1;
78
+ overflow-y: auto;
79
+ background-color: #ffffff;
80
+ }
81
+
82
+ .chat-wrapper {
83
+ max-width: 48rem;
84
+ margin: 0 auto;
85
+ padding: 2rem 1.5rem 3rem;
86
+ display: flex;
87
+ flex-direction: column;
88
+ gap: 0.75rem;
89
+ }
90
+
91
+ .message {
92
+ display: flex;
93
+ justify-content: flex-start;
94
+ margin-bottom: 0.5rem;
95
+ color: #0d0d0d;
96
+ }
97
+
98
+ .message.assistant {
99
+ justify-content: flex-start;
100
+ }
101
+
102
+ .message.user {
103
+ justify-content: flex-end;
104
+ }
105
+
106
+ .message-content {
107
+ white-space: pre-wrap;
108
+ line-height: 1.6;
109
+ max-width: 100%;
110
+ }
111
+
112
+ .message.assistant .message-content {
113
+ background: transparent;
114
+ border: none;
115
+ cursor: pointer;
116
+ border-radius: 0.5rem;
117
+ padding: 0.5rem;
118
+ margin-left: -0.5rem;
119
+ transition: background-color 0.2s ease;
120
+ }
121
+
122
+ .message.assistant .message-content:hover {
123
+ background-color: #f9fafb;
124
+ }
125
+
126
+ .message.user .message-content {
127
+ background-color: #f3f4f6;
128
+ border-radius: 1.25rem;
129
+ padding: 0.8rem 1rem;
130
+ max-width: 65%;
131
+ cursor: pointer;
132
+ transition: background-color 0.2s ease;
133
+ }
134
+
135
+ .message.user .message-content:hover {
136
+ background-color: #e5e7eb;
137
+ }
138
+
139
+ .message.console .message-content {
140
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
141
+ font-size: 0.875rem;
142
+ background-color: #fafafa;
143
+ padding: 0.75rem 1rem;
144
+ color: #374151;
145
+ max-width: 80%;
146
+ }
147
+
148
+ .input-container {
149
+ background-color: #ffffff;
150
+ padding: 1rem;
151
+ padding-bottom: calc(1rem + env(safe-area-inset-bottom))
152
+ }
153
+
154
+ .input-wrapper {
155
+ max-width: 48rem;
156
+ margin: 0 auto;
157
+ display: flex;
158
+ gap: 0.75rem;
159
+ align-items: flex-end;
160
+ }
161
+
162
+ .chat-input {
163
+ flex: 1;
164
+ padding: 0.8rem 1rem;
165
+ border: 1px solid #d1d5db;
166
+ border-radius: 0.75rem;
167
+ background-color: #ffffff;
168
+ color: #111827;
169
+ font-size: 1rem;
170
+ line-height: 1.5;
171
+ resize: none;
172
+ outline: none;
173
+ min-height: 54px;
174
+ max-height: 200px;
175
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
176
+ }
177
+
178
+ .chat-input::placeholder {
179
+ color: #9ca3af;
180
+ }
181
+
182
+ .chat-input:focus {
183
+ border-color: #2563eb;
184
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
185
+ }
186
+
187
+ .send-button {
188
+ flex-shrink: 0;
189
+ padding: 0;
190
+ width: 54px;
191
+ height: 54px;
192
+ border: 1px solid #111827;
193
+ border-radius: 0.75rem;
194
+ background-color: #111827;
195
+ color: #ffffff;
196
+ display: flex;
197
+ align-items: center;
198
+ justify-content: center;
199
+ cursor: pointer;
200
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
201
+ }
202
+
203
+ .send-button:hover:not(:disabled) {
204
+ background-color: #2563eb;
205
+ border-color: #2563eb;
206
+ }
207
+
208
+ .send-button:disabled {
209
+ cursor: not-allowed;
210
+ border-color: #d1d5db;
211
+ background-color: #e5e7eb;
212
+ color: #9ca3af;
213
+ }
214
+
215
+ .typing-indicator {
216
+ display: inline-block;
217
+ color: #6b7280;
218
+ letter-spacing: 0.15em;
219
+ }
220
+
221
+ .typing-indicator::after {
222
+ content: '···';
223
+ animation: typing 1.4s infinite;
224
+ }
225
+
226
+ @keyframes typing {
227
+ 0%, 60%, 100% { opacity: 0.2; }
228
+ 30% { opacity: 1; }
229
+ }
230
+
231
+ .error-message {
232
+ background-color: #fee2e2;
233
+ border: 1px solid #fecaca;
234
+ color: #b91c1c;
235
+ padding: 0.75rem 1rem;
236
+ border-radius: 0.75rem;
237
+ margin-top: 0.5rem;
238
+ }
239
+ </style>
240
+ </head>
241
+ <body>
242
+ <div class="header">
243
+ <div class="header-left">
244
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
245
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
246
+ <path d="M12 5v14"></path>
247
+ <path d="M5 12h14"></path>
248
+ </svg>
249
+ </button>
250
+ <h1>nanochat</h1>
251
+ </div>
252
+ </div>
253
+
254
+ <div class="chat-container" id="chatContainer">
255
+ <div class="chat-wrapper" id="chatWrapper">
256
+ <!-- Messages will be added here -->
257
+ </div>
258
+ </div>
259
+
260
+ <div class="input-container">
261
+ <div class="input-wrapper">
262
+ <textarea
263
+ id="chatInput"
264
+ class="chat-input"
265
+ placeholder="Ask anything"
266
+ rows="1"
267
+ onkeydown="handleKeyDown(event)"
268
+ ></textarea>
269
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
270
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
271
+ <path d="M22 2L11 13"></path>
272
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
273
+ </svg>
274
+ </button>
275
+ </div>
276
+ </div>
277
+
278
+ <script>
279
+ const API_URL = '';
280
+ const chatContainer = document.getElementById('chatContainer');
281
+ const chatWrapper = document.getElementById('chatWrapper');
282
+ const chatInput = document.getElementById('chatInput');
283
+ const sendButton = document.getElementById('sendButton');
284
+
285
+ let messages = [];
286
+ let isGenerating = false;
287
+ let currentTemperature = 0.8;
288
+ let currentTopK = 50;
289
+
290
+ chatInput.addEventListener('input', function() {
291
+ this.style.height = 'auto';
292
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
293
+ sendButton.disabled = !this.value.trim() || isGenerating;
294
+ });
295
+
296
+ function handleKeyDown(event) {
297
+ if (event.key === 'Enter' && !event.shiftKey) {
298
+ event.preventDefault();
299
+ sendMessage();
300
+ }
301
+ }
302
+
303
+ document.addEventListener('keydown', function(event) {
304
+ // Ctrl+Shift+N for new conversation
305
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
306
+ event.preventDefault();
307
+ if (!isGenerating) {
308
+ newConversation();
309
+ }
310
+ }
311
+ });
312
+
313
+ function newConversation() {
314
+ messages = [];
315
+ chatWrapper.innerHTML = '';
316
+ chatInput.value = '';
317
+ chatInput.style.height = 'auto';
318
+ sendButton.disabled = false;
319
+ isGenerating = false;
320
+ chatInput.focus();
321
+ }
322
+
323
+ function addMessage(role, content, messageIndex = null) {
324
+ const messageDiv = document.createElement('div');
325
+ messageDiv.className = `message ${role}`;
326
+
327
+ const contentDiv = document.createElement('div');
328
+ contentDiv.className = 'message-content';
329
+ contentDiv.textContent = content;
330
+
331
+ // Add click handler for user messages to enable editing
332
+ if (role === 'user' && messageIndex !== null) {
333
+ contentDiv.setAttribute('data-message-index', messageIndex);
334
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
335
+ contentDiv.addEventListener('click', function() {
336
+ if (!isGenerating) {
337
+ editMessage(messageIndex);
338
+ }
339
+ });
340
+ }
341
+
342
+ // Add click handler for assistant messages to enable regeneration
343
+ if (role === 'assistant' && messageIndex !== null) {
344
+ contentDiv.setAttribute('data-message-index', messageIndex);
345
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
346
+ contentDiv.addEventListener('click', function() {
347
+ if (!isGenerating) {
348
+ regenerateMessage(messageIndex);
349
+ }
350
+ });
351
+ }
352
+
353
+ messageDiv.appendChild(contentDiv);
354
+ chatWrapper.appendChild(messageDiv);
355
+
356
+ chatContainer.scrollTop = chatContainer.scrollHeight;
357
+ return contentDiv;
358
+ }
359
+
360
+ function editMessage(messageIndex) {
361
+ // Find the message in the messages array
362
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
363
+
364
+ const messageToEdit = messages[messageIndex];
365
+ if (messageToEdit.role !== 'user') return;
366
+
367
+ // Copy message content to input
368
+ chatInput.value = messageToEdit.content;
369
+ chatInput.style.height = 'auto';
370
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
371
+
372
+ // Remove this message and all subsequent messages from the array
373
+ messages = messages.slice(0, messageIndex);
374
+
375
+ // Remove message elements from DOM starting from messageIndex
376
+ const allMessages = chatWrapper.querySelectorAll('.message');
377
+ for (let i = messageIndex; i < allMessages.length; i++) {
378
+ allMessages[i].remove();
379
+ }
380
+
381
+ // Enable send button and focus input
382
+ sendButton.disabled = false;
383
+ chatInput.focus();
384
+ }
385
+
386
+ async function generateAssistantResponse() {
387
+ isGenerating = true;
388
+ sendButton.disabled = true;
389
+
390
+ const assistantContent = addMessage('assistant', '');
391
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
392
+
393
+ try {
394
+ const response = await fetch(`${API_URL}/chat/completions`, {
395
+ method: 'POST',
396
+ headers: {
397
+ 'Content-Type': 'application/json',
398
+ },
399
+ body: JSON.stringify({
400
+ messages: messages,
401
+ temperature: currentTemperature,
402
+ top_k: currentTopK,
403
+ max_tokens: 512
404
+ }),
405
+ });
406
+
407
+ if (!response.ok) {
408
+ throw new Error(`HTTP error! status: ${response.status}`);
409
+ }
410
+
411
+ const reader = response.body.getReader();
412
+ const decoder = new TextDecoder();
413
+ let fullResponse = '';
414
+ assistantContent.textContent = '';
415
+
416
+ while (true) {
417
+ const { done, value } = await reader.read();
418
+ if (done) break;
419
+
420
+ const chunk = decoder.decode(value);
421
+ const lines = chunk.split('\n');
422
+
423
+ for (const line of lines) {
424
+ if (line.startsWith('data: ')) {
425
+ try {
426
+ const data = JSON.parse(line.slice(6));
427
+ if (data.token) {
428
+ fullResponse += data.token;
429
+ assistantContent.textContent = fullResponse;
430
+ chatContainer.scrollTop = chatContainer.scrollHeight;
431
+ }
432
+ } catch (e) {
433
+ }
434
+ }
435
+ }
436
+ }
437
+
438
+ const assistantMessageIndex = messages.length;
439
+ messages.push({ role: 'assistant', content: fullResponse });
440
+
441
+ // Add click handler to regenerate this assistant message
442
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
443
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
444
+ assistantContent.addEventListener('click', function() {
445
+ if (!isGenerating) {
446
+ regenerateMessage(assistantMessageIndex);
447
+ }
448
+ });
449
+
450
+ } catch (error) {
451
+ console.error('Error:', error);
452
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
453
+ } finally {
454
+ isGenerating = false;
455
+ sendButton.disabled = !chatInput.value.trim();
456
+ }
457
+ }
458
+
459
+ async function regenerateMessage(messageIndex) {
460
+ // Find the message in the messages array
461
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
462
+
463
+ const messageToRegenerate = messages[messageIndex];
464
+ if (messageToRegenerate.role !== 'assistant') return;
465
+
466
+ // Remove this message and all subsequent messages from the array
467
+ messages = messages.slice(0, messageIndex);
468
+
469
+ // Remove message elements from DOM starting from messageIndex
470
+ const allMessages = chatWrapper.querySelectorAll('.message');
471
+ for (let i = messageIndex; i < allMessages.length; i++) {
472
+ allMessages[i].remove();
473
+ }
474
+
475
+ // Regenerate the assistant response
476
+ await generateAssistantResponse();
477
+ }
478
+
479
+ function handleSlashCommand(command) {
480
+ const parts = command.trim().split(/\s+/);
481
+ const cmd = parts[0].toLowerCase();
482
+ const arg = parts[1];
483
+
484
+ if (cmd === '/temperature') {
485
+ if (arg === undefined) {
486
+ addMessage('console', `Current temperature: ${currentTemperature}`);
487
+ } else {
488
+ const temp = parseFloat(arg);
489
+ if (isNaN(temp) || temp < 0 || temp > 2) {
490
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
491
+ } else {
492
+ currentTemperature = temp;
493
+ addMessage('console', `Temperature set to ${currentTemperature}`);
494
+ }
495
+ }
496
+ return true;
497
+ } else if (cmd === '/topk') {
498
+ if (arg === undefined) {
499
+ addMessage('console', `Current top-k: ${currentTopK}`);
500
+ } else {
501
+ const topk = parseInt(arg);
502
+ if (isNaN(topk) || topk < 1 || topk > 200) {
503
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
504
+ } else {
505
+ currentTopK = topk;
506
+ addMessage('console', `Top-k set to ${currentTopK}`);
507
+ }
508
+ }
509
+ return true;
510
+ } else if (cmd === '/clear') {
511
+ newConversation();
512
+ return true;
513
+ } else if (cmd === '/help') {
514
+ addMessage('console',
515
+ 'Available commands:\n' +
516
+ '/temperature - Show current temperature\n' +
517
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
518
+ '/topk - Show current top-k\n' +
519
+ '/topk <value> - Set top-k (1-200)\n' +
520
+ '/clear - Clear conversation\n' +
521
+ '/help - Show this help message'
522
+ );
523
+ return true;
524
+ }
525
+ return false;
526
+ }
527
+
528
+ async function sendMessage() {
529
+ const message = chatInput.value.trim();
530
+ if (!message || isGenerating) return;
531
+
532
+ // Handle slash commands
533
+ if (message.startsWith('/')) {
534
+ chatInput.value = '';
535
+ chatInput.style.height = 'auto';
536
+ handleSlashCommand(message);
537
+ return;
538
+ }
539
+
540
+ chatInput.value = '';
541
+ chatInput.style.height = 'auto';
542
+
543
+ const userMessageIndex = messages.length;
544
+ messages.push({ role: 'user', content: message });
545
+ addMessage('user', message, userMessageIndex);
546
+
547
+ await generateAssistantResponse();
548
+ }
549
+
550
+ sendButton.disabled = false;
551
+
552
+ // Autofocus the chat input on page load
553
+ chatInput.focus();
554
+
555
+ fetch(`${API_URL}/health`)
556
+ .then(response => response.json())
557
+ .then(data => {
558
+ console.log('Engine status:', data);
559
+ })
560
+ .catch(error => {
561
+ console.error('Engine not available:', error);
562
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
563
+ });
564
+ </script>
565
+ </body>
566
+ </html>