Spaces:
Sleeping
Sleeping
Upload 20 files
Browse files- nanochat/__init__.py +0 -0
- nanochat/adamw.py +143 -0
- nanochat/checkpoint_manager.py +174 -0
- nanochat/common.py +276 -0
- nanochat/core_eval.py +262 -0
- nanochat/dataloader.py +199 -0
- nanochat/dataset.py +128 -0
- nanochat/distill_loss.py +176 -0
- nanochat/engine.py +356 -0
- nanochat/execution.py +349 -0
- nanochat/flash_attention.py +178 -0
- nanochat/gpt.py +633 -0
- nanochat/logo.svg +8 -0
- nanochat/loss_eval.py +65 -0
- nanochat/muon.py +352 -0
- nanochat/prune.py +216 -0
- nanochat/quantize.py +113 -0
- nanochat/report.py +422 -0
- nanochat/tokenizer.py +406 -0
- nanochat/ui.html +566 -0
nanochat/__init__.py
ADDED
|
File without changes
|
nanochat/adamw.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distributed AdamW optimizer with a fused step function.
|
| 3 |
+
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 10 |
+
def adamw_step_fused(
|
| 11 |
+
p: Tensor,
|
| 12 |
+
grad: Tensor,
|
| 13 |
+
exp_avg: Tensor,
|
| 14 |
+
exp_avg_sq: Tensor,
|
| 15 |
+
step_t: Tensor,
|
| 16 |
+
lr_t: Tensor,
|
| 17 |
+
beta1_t: Tensor,
|
| 18 |
+
beta2_t: Tensor,
|
| 19 |
+
eps_t: Tensor,
|
| 20 |
+
wd_t: Tensor,
|
| 21 |
+
) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
| 24 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 25 |
+
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
| 26 |
+
"""
|
| 27 |
+
# Weight decay (decoupled, applied before the update)
|
| 28 |
+
p.mul_(1 - lr_t * wd_t)
|
| 29 |
+
# Update running averages (lerp_ is cleaner and fuses well)
|
| 30 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 31 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 32 |
+
# Bias corrections
|
| 33 |
+
bias1 = 1 - beta1_t ** step_t
|
| 34 |
+
bias2 = 1 - beta2_t ** step_t
|
| 35 |
+
# Compute update and apply
|
| 36 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 37 |
+
step_size = lr_t / bias1
|
| 38 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DistAdamW(torch.optim.Optimizer):
|
| 42 |
+
"""
|
| 43 |
+
Distributed AdamW optimizer.
|
| 44 |
+
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
| 47 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
| 48 |
+
rank = dist.get_rank()
|
| 49 |
+
world_size = dist.get_world_size()
|
| 50 |
+
# Validate
|
| 51 |
+
if rank == 0:
|
| 52 |
+
for group in param_groups:
|
| 53 |
+
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
| 54 |
+
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
| 55 |
+
for p in group['params']:
|
| 56 |
+
sliced = p.numel() >= 1024
|
| 57 |
+
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
| 58 |
+
if sliced: # large parameter tensors will be operated on in slices
|
| 59 |
+
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
| 60 |
+
super().__init__(param_groups, defaults)
|
| 61 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 62 |
+
self._step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 63 |
+
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 64 |
+
self._beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 65 |
+
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 66 |
+
self._eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 67 |
+
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def step(self):
|
| 71 |
+
rank = dist.get_rank()
|
| 72 |
+
world_size = dist.get_world_size()
|
| 73 |
+
reduce_futures: list[torch.Future] = []
|
| 74 |
+
gather_futures: list[torch.Future] = []
|
| 75 |
+
grad_slices = []
|
| 76 |
+
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
|
| 77 |
+
|
| 78 |
+
for group in self.param_groups:
|
| 79 |
+
params: list[Tensor] = group["params"]
|
| 80 |
+
for p in params:
|
| 81 |
+
grad = p.grad
|
| 82 |
+
# Small params: use all_reduce (no scatter/gather needed)
|
| 83 |
+
if p.numel() < 1024:
|
| 84 |
+
is_small.append(True)
|
| 85 |
+
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
| 86 |
+
grad_slices.append(grad)
|
| 87 |
+
else:
|
| 88 |
+
is_small.append(False)
|
| 89 |
+
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
| 90 |
+
grad_slice = torch.empty_like(grad[:rank_size])
|
| 91 |
+
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
| 92 |
+
grad_slices.append(grad_slice)
|
| 93 |
+
|
| 94 |
+
idx = 0
|
| 95 |
+
for group in self.param_groups:
|
| 96 |
+
beta1, beta2 = group['betas']
|
| 97 |
+
eps = group['eps']
|
| 98 |
+
wd = group['weight_decay']
|
| 99 |
+
params = group['params']
|
| 100 |
+
for p in params:
|
| 101 |
+
reduce_futures[idx].wait()
|
| 102 |
+
g_slice = grad_slices[idx]
|
| 103 |
+
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
| 104 |
+
state = self.state[p]
|
| 105 |
+
|
| 106 |
+
# For small params, operate on full param; for large, operate on slice
|
| 107 |
+
if is_small[idx]:
|
| 108 |
+
p_slice = p
|
| 109 |
+
else:
|
| 110 |
+
rank_size = p.shape[0] // world_size
|
| 111 |
+
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
| 112 |
+
|
| 113 |
+
# State init
|
| 114 |
+
if not state:
|
| 115 |
+
state['step'] = 0
|
| 116 |
+
state['exp_avg'] = torch.zeros_like(p_slice)
|
| 117 |
+
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
| 118 |
+
exp_avg = state['exp_avg']
|
| 119 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 120 |
+
state['step'] += 1
|
| 121 |
+
|
| 122 |
+
# Fill 0-D tensors with current values
|
| 123 |
+
eff_wd = wd * getattr(p, "wd_mul", 1.0)
|
| 124 |
+
self._step_t.fill_(state['step'])
|
| 125 |
+
self._lr_t.fill_(lr)
|
| 126 |
+
self._beta1_t.fill_(beta1)
|
| 127 |
+
self._beta2_t.fill_(beta2)
|
| 128 |
+
self._eps_t.fill_(eps)
|
| 129 |
+
self._wd_t.fill_(eff_wd)
|
| 130 |
+
|
| 131 |
+
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
| 132 |
+
adamw_step_fused(
|
| 133 |
+
p_slice, g_slice, exp_avg, exp_avg_sq,
|
| 134 |
+
self._step_t, self._lr_t, self._beta1_t, self._beta2_t, self._eps_t, self._wd_t,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Only large params need all_gather
|
| 138 |
+
if not is_small[idx]:
|
| 139 |
+
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
| 140 |
+
idx += 1
|
| 141 |
+
|
| 142 |
+
if gather_futures:
|
| 143 |
+
torch.futures.collect_all(gather_futures).wait()
|
nanochat/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for saving and loading model/optim/state checkpoints.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from nanochat.common import get_base_dir
|
| 12 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 13 |
+
from nanochat.tokenizer import get_tokenizer
|
| 14 |
+
from nanochat.common import setup_default_logging
|
| 15 |
+
|
| 16 |
+
# Set up logging
|
| 17 |
+
setup_default_logging()
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
def log0(message):
|
| 20 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
| 21 |
+
logger.info(message)
|
| 22 |
+
|
| 23 |
+
def _patch_missing_config_keys(model_config_kwargs):
|
| 24 |
+
"""Add default values for new config keys missing in old checkpoints."""
|
| 25 |
+
# Old models were trained with full context (no sliding window)
|
| 26 |
+
if "window_pattern" not in model_config_kwargs:
|
| 27 |
+
model_config_kwargs["window_pattern"] = "L"
|
| 28 |
+
log0(f"Patching missing window_pattern in model config to 'L'")
|
| 29 |
+
|
| 30 |
+
def _patch_missing_keys(model_data, model_config):
|
| 31 |
+
"""Add default values for new parameters that may be missing in old checkpoints."""
|
| 32 |
+
n_layer = model_config.n_layer
|
| 33 |
+
# resid_lambdas defaults to 1.0 (identity scaling)
|
| 34 |
+
if "resid_lambdas" not in model_data:
|
| 35 |
+
model_data["resid_lambdas"] = torch.ones(n_layer)
|
| 36 |
+
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
| 37 |
+
# x0_lambdas defaults to 0.0 (disabled)
|
| 38 |
+
if "x0_lambdas" not in model_data:
|
| 39 |
+
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
| 40 |
+
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
| 41 |
+
|
| 42 |
+
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
| 43 |
+
if rank == 0:
|
| 44 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 45 |
+
# Save the model state parameters
|
| 46 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 47 |
+
torch.save(model_data, model_path)
|
| 48 |
+
logger.info(f"Saved model parameters to: {model_path}")
|
| 49 |
+
# Save the metadata dict as json
|
| 50 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 51 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 52 |
+
json.dump(meta_data, f, indent=2)
|
| 53 |
+
logger.info(f"Saved metadata to: {meta_path}")
|
| 54 |
+
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
| 55 |
+
if optimizer_data is not None:
|
| 56 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 57 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 58 |
+
torch.save(optimizer_data, optimizer_path)
|
| 59 |
+
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
| 60 |
+
|
| 61 |
+
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
| 62 |
+
# Load the model state
|
| 63 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 64 |
+
model_data = torch.load(model_path, map_location=device)
|
| 65 |
+
# Load the optimizer state if requested
|
| 66 |
+
optimizer_data = None
|
| 67 |
+
if load_optimizer:
|
| 68 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 69 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 70 |
+
# Load the metadata
|
| 71 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 72 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 73 |
+
meta_data = json.load(f)
|
| 74 |
+
return model_data, optimizer_data, meta_data
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_model(checkpoint_dir, step, device, phase):
|
| 78 |
+
"""
|
| 79 |
+
A bunch of repetitive code to build a model from a given checkpoint.
|
| 80 |
+
Returns:
|
| 81 |
+
- base model - uncompiled, not wrapped in DDP
|
| 82 |
+
- tokenizer
|
| 83 |
+
- meta data saved during base model training
|
| 84 |
+
"""
|
| 85 |
+
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
| 86 |
+
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
| 87 |
+
if device.type in {"cpu", "mps"}:
|
| 88 |
+
# Convert bfloat16 tensors to float for CPU inference
|
| 89 |
+
model_data = {
|
| 90 |
+
k: v.float() if v.dtype == torch.bfloat16 else v
|
| 91 |
+
for k, v in model_data.items()
|
| 92 |
+
}
|
| 93 |
+
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
| 94 |
+
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
| 95 |
+
model_config_kwargs = meta_data["model_config"]
|
| 96 |
+
_patch_missing_config_keys(model_config_kwargs)
|
| 97 |
+
log0(f"Building model with config: {model_config_kwargs}")
|
| 98 |
+
model_config = GPTConfig(**model_config_kwargs)
|
| 99 |
+
_patch_missing_keys(model_data, model_config)
|
| 100 |
+
with torch.device("meta"):
|
| 101 |
+
model = GPT(model_config)
|
| 102 |
+
# Load the model state
|
| 103 |
+
model.to_empty(device=device)
|
| 104 |
+
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
| 105 |
+
model.load_state_dict(model_data, strict=False, assign=True)
|
| 106 |
+
# Put the model in the right training phase / mode
|
| 107 |
+
if phase == "eval":
|
| 108 |
+
model.eval()
|
| 109 |
+
else:
|
| 110 |
+
model.train()
|
| 111 |
+
# Load the Tokenizer
|
| 112 |
+
tokenizer = get_tokenizer()
|
| 113 |
+
# Sanity check: compatibility between model and tokenizer
|
| 114 |
+
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
| 115 |
+
return model, tokenizer, meta_data
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def find_largest_model(checkpoints_dir):
|
| 119 |
+
# attempt to guess the model tag: take the biggest model available
|
| 120 |
+
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
| 121 |
+
if not model_tags:
|
| 122 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
| 123 |
+
# 1) normally all model tags are of the form d<number>, try that first:
|
| 124 |
+
candidates = []
|
| 125 |
+
for model_tag in model_tags:
|
| 126 |
+
match = re.match(r"d(\d+)", model_tag)
|
| 127 |
+
if match:
|
| 128 |
+
model_depth = int(match.group(1))
|
| 129 |
+
candidates.append((model_depth, model_tag))
|
| 130 |
+
if candidates:
|
| 131 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 132 |
+
return candidates[0][1]
|
| 133 |
+
# 2) if that failed, take the most recently updated model:
|
| 134 |
+
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
| 135 |
+
return model_tags[0]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def find_last_step(checkpoint_dir):
|
| 139 |
+
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
| 140 |
+
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
| 141 |
+
if not checkpoint_files:
|
| 142 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 143 |
+
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
| 144 |
+
return last_step
|
| 145 |
+
|
| 146 |
+
# -----------------------------------------------------------------------------
|
| 147 |
+
# convenience functions that take into account nanochat's directory structure
|
| 148 |
+
|
| 149 |
+
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
| 150 |
+
if model_tag is None:
|
| 151 |
+
# guess the model tag by defaulting to the largest model
|
| 152 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 153 |
+
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
| 154 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 155 |
+
if step is None:
|
| 156 |
+
# guess the step by defaulting to the last step
|
| 157 |
+
step = find_last_step(checkpoint_dir)
|
| 158 |
+
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
| 159 |
+
# build the model
|
| 160 |
+
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
| 161 |
+
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
| 162 |
+
return model, tokenizer, meta_data
|
| 163 |
+
|
| 164 |
+
def load_model(source, *args, **kwargs):
|
| 165 |
+
model_dir = {
|
| 166 |
+
"base": "base_checkpoints",
|
| 167 |
+
"distill": "distill_checkpoints",
|
| 168 |
+
"mid": "mid_checkpoints",
|
| 169 |
+
"sft": "chatsft_checkpoints",
|
| 170 |
+
"rl": "chatrl_checkpoints",
|
| 171 |
+
}[source]
|
| 172 |
+
base_dir = get_base_dir()
|
| 173 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 174 |
+
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
nanochat/common.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities for nanochat.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import urllib.request
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from filelock import FileLock
|
| 12 |
+
|
| 13 |
+
class ColoredFormatter(logging.Formatter):
|
| 14 |
+
"""Custom formatter that adds colors to log messages."""
|
| 15 |
+
# ANSI color codes
|
| 16 |
+
COLORS = {
|
| 17 |
+
'DEBUG': '\033[36m', # Cyan
|
| 18 |
+
'INFO': '\033[32m', # Green
|
| 19 |
+
'WARNING': '\033[33m', # Yellow
|
| 20 |
+
'ERROR': '\033[31m', # Red
|
| 21 |
+
'CRITICAL': '\033[35m', # Magenta
|
| 22 |
+
}
|
| 23 |
+
RESET = '\033[0m'
|
| 24 |
+
BOLD = '\033[1m'
|
| 25 |
+
def format(self, record):
|
| 26 |
+
# Add color to the level name
|
| 27 |
+
levelname = record.levelname
|
| 28 |
+
if levelname in self.COLORS:
|
| 29 |
+
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
| 30 |
+
# Format the message
|
| 31 |
+
message = super().format(record)
|
| 32 |
+
# Add color to specific parts of the message
|
| 33 |
+
if levelname == 'INFO':
|
| 34 |
+
# Highlight numbers and percentages
|
| 35 |
+
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
| 36 |
+
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
| 37 |
+
return message
|
| 38 |
+
|
| 39 |
+
def setup_default_logging():
|
| 40 |
+
handler = logging.StreamHandler()
|
| 41 |
+
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 42 |
+
logging.basicConfig(
|
| 43 |
+
level=logging.INFO,
|
| 44 |
+
handlers=[handler]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
setup_default_logging()
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
def get_base_dir():
|
| 51 |
+
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
| 52 |
+
if os.environ.get("NANOCHAT_BASE_DIR"):
|
| 53 |
+
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
|
| 54 |
+
else:
|
| 55 |
+
home_dir = os.path.expanduser("~")
|
| 56 |
+
cache_dir = os.path.join(home_dir, ".cache")
|
| 57 |
+
nanochat_dir = os.path.join(cache_dir, "nanochat")
|
| 58 |
+
os.makedirs(nanochat_dir, exist_ok=True)
|
| 59 |
+
return nanochat_dir
|
| 60 |
+
|
| 61 |
+
def download_file_with_lock(url, filename, postprocess_fn=None):
|
| 62 |
+
"""
|
| 63 |
+
Downloads a file from a URL to a local path in the base directory.
|
| 64 |
+
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
| 65 |
+
"""
|
| 66 |
+
base_dir = get_base_dir()
|
| 67 |
+
file_path = os.path.join(base_dir, filename)
|
| 68 |
+
lock_path = file_path + ".lock"
|
| 69 |
+
|
| 70 |
+
if os.path.exists(file_path):
|
| 71 |
+
return file_path
|
| 72 |
+
|
| 73 |
+
with FileLock(lock_path):
|
| 74 |
+
# Only a single rank can acquire this lock
|
| 75 |
+
# All other ranks block until it is released
|
| 76 |
+
|
| 77 |
+
# Recheck after acquiring lock
|
| 78 |
+
if os.path.exists(file_path):
|
| 79 |
+
return file_path
|
| 80 |
+
|
| 81 |
+
# Download the content as bytes
|
| 82 |
+
print(f"Downloading {url}...")
|
| 83 |
+
with urllib.request.urlopen(url) as response:
|
| 84 |
+
content = response.read() # bytes
|
| 85 |
+
|
| 86 |
+
# Write to local file
|
| 87 |
+
with open(file_path, 'wb') as f:
|
| 88 |
+
f.write(content)
|
| 89 |
+
print(f"Downloaded to {file_path}")
|
| 90 |
+
|
| 91 |
+
# Run the postprocess function if provided
|
| 92 |
+
if postprocess_fn is not None:
|
| 93 |
+
postprocess_fn(file_path)
|
| 94 |
+
|
| 95 |
+
return file_path
|
| 96 |
+
|
| 97 |
+
def print0(s="",**kwargs):
|
| 98 |
+
ddp_rank = int(os.environ.get('RANK', 0))
|
| 99 |
+
if ddp_rank == 0:
|
| 100 |
+
print(s, **kwargs)
|
| 101 |
+
|
| 102 |
+
def print_banner():
|
| 103 |
+
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
| 104 |
+
banner = """
|
| 105 |
+
█████ █████
|
| 106 |
+
░░███ ░░███
|
| 107 |
+
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
| 108 |
+
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
| 109 |
+
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
| 110 |
+
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
| 111 |
+
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
| 112 |
+
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
| 113 |
+
"""
|
| 114 |
+
print0(banner)
|
| 115 |
+
|
| 116 |
+
def is_ddp_requested() -> bool:
|
| 117 |
+
"""
|
| 118 |
+
True if launched by torchrun (env present), even before init.
|
| 119 |
+
Used to decide whether we *should* initialize a PG.
|
| 120 |
+
"""
|
| 121 |
+
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
| 122 |
+
|
| 123 |
+
def is_ddp_initialized() -> bool:
|
| 124 |
+
"""
|
| 125 |
+
True if torch.distributed is available and the process group is initialized.
|
| 126 |
+
Used at cleanup to avoid destroying a non-existent PG.
|
| 127 |
+
"""
|
| 128 |
+
return dist.is_available() and dist.is_initialized()
|
| 129 |
+
|
| 130 |
+
def get_dist_info():
|
| 131 |
+
if is_ddp_requested():
|
| 132 |
+
# We rely on torchrun's env to decide if we SHOULD init.
|
| 133 |
+
# (Initialization itself happens in compute init.)
|
| 134 |
+
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
| 135 |
+
ddp_rank = int(os.environ['RANK'])
|
| 136 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
| 137 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
| 138 |
+
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
| 139 |
+
else:
|
| 140 |
+
return False, 0, 0, 1
|
| 141 |
+
|
| 142 |
+
def autodetect_device_type():
|
| 143 |
+
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
| 144 |
+
if torch.cuda.is_available():
|
| 145 |
+
device_type = "cuda"
|
| 146 |
+
elif torch.backends.mps.is_available():
|
| 147 |
+
device_type = "mps"
|
| 148 |
+
else:
|
| 149 |
+
device_type = "cpu"
|
| 150 |
+
print0(f"Autodetected device type: {device_type}")
|
| 151 |
+
return device_type
|
| 152 |
+
|
| 153 |
+
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
| 154 |
+
"""Basic initialization that we keep doing over and over, so make common."""
|
| 155 |
+
|
| 156 |
+
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
| 157 |
+
if device_type == "cuda":
|
| 158 |
+
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
| 159 |
+
if device_type == "mps":
|
| 160 |
+
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
| 161 |
+
|
| 162 |
+
# Reproducibility
|
| 163 |
+
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
| 164 |
+
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
| 165 |
+
torch.manual_seed(42)
|
| 166 |
+
if device_type == "cuda":
|
| 167 |
+
torch.cuda.manual_seed(42)
|
| 168 |
+
# skipping full reproducibility for now, possibly investigate slowdown later
|
| 169 |
+
# torch.use_deterministic_algorithms(True)
|
| 170 |
+
|
| 171 |
+
# Precision
|
| 172 |
+
if device_type == "cuda":
|
| 173 |
+
torch.backends.cuda.matmul.allow_tf32 = True # uses tf32 instead of fp32 for matmuls
|
| 174 |
+
|
| 175 |
+
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
| 176 |
+
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 177 |
+
if is_ddp_requested and device_type == "cuda":
|
| 178 |
+
device = torch.device("cuda", ddp_local_rank)
|
| 179 |
+
torch.cuda.set_device(device) # make "cuda" default to this device
|
| 180 |
+
dist.init_process_group(backend="nccl", device_id=device)
|
| 181 |
+
dist.barrier()
|
| 182 |
+
else:
|
| 183 |
+
device = torch.device(device_type) # mps|cpu
|
| 184 |
+
|
| 185 |
+
if ddp_rank == 0:
|
| 186 |
+
logger.info(f"Distributed world size: {ddp_world_size}")
|
| 187 |
+
|
| 188 |
+
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
| 189 |
+
|
| 190 |
+
def compute_cleanup():
|
| 191 |
+
"""Companion function to compute_init, to clean things up before script exit"""
|
| 192 |
+
if is_ddp_initialized():
|
| 193 |
+
dist.destroy_process_group()
|
| 194 |
+
|
| 195 |
+
class DummyWandb:
|
| 196 |
+
"""Useful if we wish to not use wandb but have all the same signatures"""
|
| 197 |
+
def __init__(self):
|
| 198 |
+
pass
|
| 199 |
+
def log(self, *args, **kwargs):
|
| 200 |
+
pass
|
| 201 |
+
def finish(self):
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
# hardcoded BF16 peak flops for various GPUs
|
| 205 |
+
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
| 206 |
+
# and PR: https://github.com/karpathy/nanochat/pull/147
|
| 207 |
+
def get_peak_flops(device_name: str) -> float:
|
| 208 |
+
name = device_name.lower()
|
| 209 |
+
|
| 210 |
+
# --- NVIDIA Blackwell ---
|
| 211 |
+
if "gb200" in name or "grace blackwell" in name:
|
| 212 |
+
return 2.5e15
|
| 213 |
+
if "b200" in name:
|
| 214 |
+
return 2.25e15
|
| 215 |
+
if "b100" in name:
|
| 216 |
+
return 1.8e15
|
| 217 |
+
|
| 218 |
+
# --- NVIDIA Hopper (H100/H200/H800) ---
|
| 219 |
+
if "h200" in name:
|
| 220 |
+
if "nvl" in name or "pcie" in name:
|
| 221 |
+
return 836e12
|
| 222 |
+
return 989e12 # H200 SXM
|
| 223 |
+
if "h100" in name:
|
| 224 |
+
if "nvl" in name:
|
| 225 |
+
return 835e12
|
| 226 |
+
if "pcie" in name:
|
| 227 |
+
return 756e12
|
| 228 |
+
return 989e12 # H100 SXM
|
| 229 |
+
if "h800" in name:
|
| 230 |
+
if "nvl" in name:
|
| 231 |
+
return 989e12
|
| 232 |
+
return 756e12 # H800 PCIe
|
| 233 |
+
|
| 234 |
+
# --- NVIDIA Ampere data center ---
|
| 235 |
+
if "a100" in name or "a800" in name:
|
| 236 |
+
return 312e12
|
| 237 |
+
if "a40" in name:
|
| 238 |
+
return 149.7e12
|
| 239 |
+
if "a30" in name:
|
| 240 |
+
return 165e12
|
| 241 |
+
|
| 242 |
+
# --- NVIDIA Ada data center ---
|
| 243 |
+
if "l40s" in name or "l40-s" in name or "l40 s" in name:
|
| 244 |
+
return 362e12
|
| 245 |
+
if "l4" in name:
|
| 246 |
+
return 121e12
|
| 247 |
+
|
| 248 |
+
# --- AMD CDNA accelerators ---
|
| 249 |
+
if "mi355" in name:
|
| 250 |
+
return 2.5e15
|
| 251 |
+
if "mi325" in name or "mi300x" in name:
|
| 252 |
+
return 1.3074e15
|
| 253 |
+
if "mi300a" in name:
|
| 254 |
+
return 980.6e12
|
| 255 |
+
if "mi250x" in name:
|
| 256 |
+
return 383e12
|
| 257 |
+
if "mi250" in name:
|
| 258 |
+
return 362.1e12
|
| 259 |
+
|
| 260 |
+
# --- Intel ---
|
| 261 |
+
if "data center gpu max 1550" in name:
|
| 262 |
+
# Ponte Vecchio (PVC) - dynamic based on compute units
|
| 263 |
+
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
| 264 |
+
return 512 * max_comp_units * 1300 * 10**6
|
| 265 |
+
|
| 266 |
+
# --- Consumer RTX (for hobbyists) ---
|
| 267 |
+
if "5090" in name:
|
| 268 |
+
return 209.5e12
|
| 269 |
+
if "4090" in name:
|
| 270 |
+
return 165.2e12
|
| 271 |
+
if "3090" in name:
|
| 272 |
+
return 71e12
|
| 273 |
+
|
| 274 |
+
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
| 275 |
+
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
| 276 |
+
return float('inf')
|
nanochat/core_eval.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions for evaluating the CORE metric, as described in the DCLM paper.
|
| 3 |
+
https://arxiv.org/abs/2406.11794
|
| 4 |
+
|
| 5 |
+
TODOs:
|
| 6 |
+
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
|
| 7 |
+
"""
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from jinja2 import Template
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# Prompt rendering utilities
|
| 16 |
+
|
| 17 |
+
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
| 18 |
+
"""Render complete prompts for a multiple choice question"""
|
| 19 |
+
template_str = """
|
| 20 |
+
{%- for example in fewshot_examples -%}
|
| 21 |
+
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
|
| 22 |
+
|
| 23 |
+
{% endfor -%}
|
| 24 |
+
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
|
| 25 |
+
template = Template(template_str)
|
| 26 |
+
fewshot_examples = fewshot_examples or []
|
| 27 |
+
context = {
|
| 28 |
+
'fewshot_examples': fewshot_examples,
|
| 29 |
+
'continuation_delimiter': continuation_delimiter,
|
| 30 |
+
'item': item
|
| 31 |
+
}
|
| 32 |
+
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
|
| 33 |
+
return prompts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
|
| 37 |
+
"""Render complete prompts for a schema question"""
|
| 38 |
+
template_str = """
|
| 39 |
+
{%- for example in fewshot_examples -%}
|
| 40 |
+
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 41 |
+
|
| 42 |
+
{% endfor -%}
|
| 43 |
+
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
|
| 44 |
+
template = Template(template_str)
|
| 45 |
+
fewshot_examples = fewshot_examples or []
|
| 46 |
+
context = {
|
| 47 |
+
'fewshot_examples': fewshot_examples,
|
| 48 |
+
'continuation_delimiter': continuation_delimiter,
|
| 49 |
+
'item': item
|
| 50 |
+
}
|
| 51 |
+
prompts = [template.render(context=context_option, **context)
|
| 52 |
+
for context_option in item['context_options']]
|
| 53 |
+
return prompts
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
| 57 |
+
"""
|
| 58 |
+
Render complete prompt for a language modeling task.
|
| 59 |
+
Notice that we manually trim the context in the template,
|
| 60 |
+
which in some datasets seems to have trailing whitespace (which we don't want).
|
| 61 |
+
"""
|
| 62 |
+
template_str = """
|
| 63 |
+
{%- for example in fewshot_examples -%}
|
| 64 |
+
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 65 |
+
|
| 66 |
+
{% endfor -%}
|
| 67 |
+
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
|
| 68 |
+
template = Template(template_str)
|
| 69 |
+
fewshot_examples = fewshot_examples or []
|
| 70 |
+
context = {
|
| 71 |
+
'fewshot_examples': fewshot_examples,
|
| 72 |
+
'continuation_delimiter': continuation_delimiter,
|
| 73 |
+
'item': item
|
| 74 |
+
}
|
| 75 |
+
# Return two prompts: without and with the continuation
|
| 76 |
+
prompt_without = template.render(include_continuation=False, **context)
|
| 77 |
+
prompt_with = template.render(include_continuation=True, **context)
|
| 78 |
+
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
|
| 79 |
+
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
|
| 80 |
+
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
|
| 81 |
+
# to detect the final continuation. Tokenizers...
|
| 82 |
+
prompt_without = prompt_without.strip()
|
| 83 |
+
return [prompt_without, prompt_with]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def find_common_length(token_sequences, direction='left'):
|
| 87 |
+
"""
|
| 88 |
+
Find the length of the common prefix or suffix across token sequences
|
| 89 |
+
- direction: 'left' for prefix, 'right' for suffix
|
| 90 |
+
"""
|
| 91 |
+
min_len = min(len(seq) for seq in token_sequences)
|
| 92 |
+
indices = {
|
| 93 |
+
'left': range(min_len),
|
| 94 |
+
'right': range(-1, -min_len-1, -1)
|
| 95 |
+
}[direction]
|
| 96 |
+
# Find the first position where the token sequences differ
|
| 97 |
+
for i, idx in enumerate(indices):
|
| 98 |
+
token = token_sequences[0][idx]
|
| 99 |
+
if not all(seq[idx] == token for seq in token_sequences):
|
| 100 |
+
return i
|
| 101 |
+
return min_len
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def stack_sequences(tokens, pad_token_id):
|
| 105 |
+
"""Stack up a list of token sequences, pad to longest on the right"""
|
| 106 |
+
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
|
| 107 |
+
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
|
| 108 |
+
for i, x in enumerate(tokens):
|
| 109 |
+
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
|
| 110 |
+
return input_ids
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def batch_sequences_mc(tokenizer, prompts):
|
| 114 |
+
# In multiple choice, contexts are the same but the continuation is different (common prefix)
|
| 115 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 116 |
+
# figure out the start and end of each continuation
|
| 117 |
+
answer_start_idx = find_common_length(tokens, direction='left')
|
| 118 |
+
start_indices = [answer_start_idx] * len(prompts)
|
| 119 |
+
end_indices = [len(x) for x in tokens]
|
| 120 |
+
return tokens, start_indices, end_indices
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def batch_sequences_schema(tokenizer, prompts):
|
| 124 |
+
# In schema tasks, contexts vary but continuation is the same (common suffix)
|
| 125 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 126 |
+
# figure out the start and end of each context
|
| 127 |
+
suffix_length = find_common_length(tokens, direction='right')
|
| 128 |
+
end_indices = [len(x) for x in tokens]
|
| 129 |
+
start_indices = [ei - suffix_length for ei in end_indices]
|
| 130 |
+
return tokens, start_indices, end_indices
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def batch_sequences_lm(tokenizer, prompts):
|
| 134 |
+
# In LM tasks, we have two prompts: without and with continuation
|
| 135 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 136 |
+
tokens_without, tokens_with = tokens
|
| 137 |
+
start_idx, end_idx = len(tokens_without), len(tokens_with)
|
| 138 |
+
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
|
| 139 |
+
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
|
| 140 |
+
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
|
| 141 |
+
return [tokens_with], [start_idx], [end_idx]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@torch.no_grad()
|
| 145 |
+
def forward_model(model, input_ids):
|
| 146 |
+
"""
|
| 147 |
+
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
| 148 |
+
The last column of losses is set to nan because we don't have autoregressive targets there.
|
| 149 |
+
"""
|
| 150 |
+
batch_size, seq_len = input_ids.size()
|
| 151 |
+
outputs = model(input_ids)
|
| 152 |
+
# Roll the tensor to the left by one position to get the (autoregressive) target ids
|
| 153 |
+
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
| 154 |
+
# Calculate cross entropy at all positions
|
| 155 |
+
losses = torch.nn.functional.cross_entropy(
|
| 156 |
+
outputs.view(batch_size * seq_len, -1),
|
| 157 |
+
target_ids.view(batch_size * seq_len),
|
| 158 |
+
reduction='none'
|
| 159 |
+
).view(batch_size, seq_len)
|
| 160 |
+
# Set the last column to be nan because there is no autoregressive loss there
|
| 161 |
+
losses[:, -1] = float('nan')
|
| 162 |
+
# Get the argmax predictions at each position
|
| 163 |
+
predictions = outputs.argmax(dim=-1)
|
| 164 |
+
return losses, predictions
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
| 169 |
+
"""Evaluate a single example, return True if correct, False otherwise"""
|
| 170 |
+
item = data[idx]
|
| 171 |
+
task_type = task_meta['task_type']
|
| 172 |
+
num_fewshot = task_meta['num_fewshot']
|
| 173 |
+
continuation_delimiter = task_meta['continuation_delimiter']
|
| 174 |
+
|
| 175 |
+
# Sample few-shot examples (excluding current item)
|
| 176 |
+
fewshot_examples = []
|
| 177 |
+
if num_fewshot > 0:
|
| 178 |
+
rng = random.Random(1234 + idx)
|
| 179 |
+
available_indices = [i for i in range(len(data)) if i != idx]
|
| 180 |
+
fewshot_indices = rng.sample(available_indices, num_fewshot)
|
| 181 |
+
fewshot_examples = [data[i] for i in fewshot_indices]
|
| 182 |
+
|
| 183 |
+
# Render prompts and batch sequences based on task type
|
| 184 |
+
if task_type == 'multiple_choice':
|
| 185 |
+
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
|
| 186 |
+
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
|
| 187 |
+
elif task_type == 'schema':
|
| 188 |
+
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
|
| 189 |
+
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
|
| 190 |
+
elif task_type == 'language_modeling':
|
| 191 |
+
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
|
| 192 |
+
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 195 |
+
|
| 196 |
+
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
|
| 197 |
+
# In these cases, we have to truncate sequences to max length and adjust the indices
|
| 198 |
+
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
| 199 |
+
max_tokens = model.max_seq_len
|
| 200 |
+
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
| 201 |
+
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
| 202 |
+
if len(t) > max_tokens:
|
| 203 |
+
num_to_crop = len(t) - max_tokens
|
| 204 |
+
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
| 205 |
+
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
| 206 |
+
new_end_idxs.append(e - num_to_crop)
|
| 207 |
+
assert s - num_to_crop >= 0, "this should never happen right?"
|
| 208 |
+
assert e - num_to_crop >= 0, "this should never happen right?"
|
| 209 |
+
else:
|
| 210 |
+
new_tokens.append(t) # keep unchanged
|
| 211 |
+
new_start_idxs.append(s)
|
| 212 |
+
new_end_idxs.append(e)
|
| 213 |
+
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
| 214 |
+
|
| 215 |
+
# Stack up all the sequences into a batch
|
| 216 |
+
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
| 217 |
+
input_ids = stack_sequences(tokens, pad_token_id)
|
| 218 |
+
input_ids = input_ids.to(device)
|
| 219 |
+
|
| 220 |
+
# Forward the model, get the autoregressive loss and argmax prediction at each token
|
| 221 |
+
losses, predictions = forward_model(model, input_ids)
|
| 222 |
+
|
| 223 |
+
# See if the losses/predictions come out correctly
|
| 224 |
+
if task_type == 'language_modeling':
|
| 225 |
+
# language modeling task is currently always batch size 1
|
| 226 |
+
si = start_idxs[0]
|
| 227 |
+
ei = end_idxs[0]
|
| 228 |
+
# predictions[i] predict input_ids[i+1] autoregressively
|
| 229 |
+
predicted_tokens = predictions[0, si-1:ei-1]
|
| 230 |
+
actual_tokens = input_ids[0, si:ei]
|
| 231 |
+
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
| 232 |
+
elif task_type in ['multiple_choice', 'schema']:
|
| 233 |
+
# For MC/schema: find the option with lowest average loss
|
| 234 |
+
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
| 235 |
+
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
| 236 |
+
pred_idx = mean_losses.index(min(mean_losses))
|
| 237 |
+
is_correct = pred_idx == item['gold']
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 240 |
+
|
| 241 |
+
return is_correct
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def evaluate_task(model, tokenizer, data, device, task_meta):
|
| 245 |
+
"""
|
| 246 |
+
This function is responsible for evaluating one task across many examples.
|
| 247 |
+
It also handles dispatch to all processes if the script is run with torchrun.
|
| 248 |
+
"""
|
| 249 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 250 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 251 |
+
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
| 252 |
+
# stride the examples to each rank
|
| 253 |
+
for idx in range(rank, len(data), world_size):
|
| 254 |
+
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
|
| 255 |
+
correct[idx] = float(is_correct)
|
| 256 |
+
# sync results across all the processes if running distributed
|
| 257 |
+
if world_size > 1:
|
| 258 |
+
dist.barrier()
|
| 259 |
+
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
| 260 |
+
# compute the mean
|
| 261 |
+
mean_correct = correct.mean().item()
|
| 262 |
+
return mean_correct
|
nanochat/dataloader.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distributed dataloaders for pretraining.
|
| 3 |
+
|
| 4 |
+
Two implementations are provided:
|
| 5 |
+
|
| 6 |
+
1. Original (tokenizing_distributed_data_loader):
|
| 7 |
+
- Streams tokens into a flat buffer, reshapes to (B, T)
|
| 8 |
+
- Rows may start mid-document (no guaranteed BOS at position 0)
|
| 9 |
+
- 100% token utilization, simple and efficient
|
| 10 |
+
|
| 11 |
+
2. BOS-aligned bestfit (tokenizing_distributed_data_loader_bos_bestfit):
|
| 12 |
+
- Every row starts with BOS token
|
| 13 |
+
- Documents packed using best-fit algorithm to minimize cropping
|
| 14 |
+
- When no document fits remaining space, crops a document to fill exactly
|
| 15 |
+
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
| 16 |
+
|
| 17 |
+
The tradeoff: BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
| 18 |
+
there are fewer "confusing" tokens in the train/val batches as every token can
|
| 19 |
+
now attend back to the BOS token and sees the full context of the document.
|
| 20 |
+
(2) is the new default if you have enough data.
|
| 21 |
+
Fallback to (1) if you have very limited data AND long documents.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import pyarrow.parquet as pq
|
| 26 |
+
|
| 27 |
+
from nanochat.common import get_dist_info
|
| 28 |
+
from nanochat.dataset import list_parquet_files
|
| 29 |
+
|
| 30 |
+
def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
| 31 |
+
"""
|
| 32 |
+
Infinite iterator over document batches (list of text strings) from parquet files.
|
| 33 |
+
|
| 34 |
+
Handles DDP sharding and approximate resume. Each yield is (text_batch, (pq_idx, rg_idx, epoch))
|
| 35 |
+
where text_batch is a list of document strings, indices track position for resumption,
|
| 36 |
+
and epoch counts how many times we've cycled through the dataset (starts at 1).
|
| 37 |
+
"""
|
| 38 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 39 |
+
|
| 40 |
+
parquet_paths = list_parquet_files()
|
| 41 |
+
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
|
| 42 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 43 |
+
|
| 44 |
+
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
| 45 |
+
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
| 46 |
+
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
|
| 47 |
+
first_pass = True
|
| 48 |
+
pq_idx = resume_pq_idx
|
| 49 |
+
epoch = resume_epoch
|
| 50 |
+
|
| 51 |
+
while True: # iterate infinitely (multi-epoch)
|
| 52 |
+
pq_idx = resume_pq_idx if first_pass else 0
|
| 53 |
+
while pq_idx < len(parquet_paths):
|
| 54 |
+
filepath = parquet_paths[pq_idx]
|
| 55 |
+
pf = pq.ParquetFile(filepath)
|
| 56 |
+
# Start from resume point if resuming on same file, otherwise from DDP rank
|
| 57 |
+
if first_pass and (resume_rg_idx is not None) and (pq_idx == resume_pq_idx):
|
| 58 |
+
base_idx = resume_rg_idx // ddp_world_size
|
| 59 |
+
base_idx += 1 # advance by 1 so we don't repeat data after resuming
|
| 60 |
+
rg_idx = base_idx * ddp_world_size + ddp_rank
|
| 61 |
+
if rg_idx >= pf.num_row_groups:
|
| 62 |
+
pq_idx += 1
|
| 63 |
+
continue
|
| 64 |
+
resume_rg_idx = None # only do this once
|
| 65 |
+
else:
|
| 66 |
+
rg_idx = ddp_rank
|
| 67 |
+
while rg_idx < pf.num_row_groups:
|
| 68 |
+
rg = pf.read_row_group(rg_idx)
|
| 69 |
+
batch = rg.column('text').to_pylist()
|
| 70 |
+
for i in range(0, len(batch), tokenizer_batch_size):
|
| 71 |
+
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx, epoch)
|
| 72 |
+
rg_idx += ddp_world_size
|
| 73 |
+
pq_idx += 1
|
| 74 |
+
first_pass = False
|
| 75 |
+
epoch += 1
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def tokenizing_distributed_data_loader_with_state(tokenizer, B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
| 79 |
+
"""
|
| 80 |
+
Stream pretraining text from parquet files, tokenize, yield training batches.
|
| 81 |
+
|
| 82 |
+
This is the original dataloader that streams tokens into a flat buffer and reshapes.
|
| 83 |
+
Rows may start mid-document (no guaranteed BOS at position 0).
|
| 84 |
+
|
| 85 |
+
Supports approximate resume via state_dict.
|
| 86 |
+
"""
|
| 87 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 88 |
+
|
| 89 |
+
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
| 90 |
+
needed_tokens = B * T + 1 # +1 for target at last position
|
| 91 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 92 |
+
token_buffer = []
|
| 93 |
+
pq_idx, rg_idx, epoch = 0, 0, 1
|
| 94 |
+
|
| 95 |
+
while True:
|
| 96 |
+
|
| 97 |
+
# Accumulate enough tokens
|
| 98 |
+
while len(token_buffer) < needed_tokens:
|
| 99 |
+
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
| 100 |
+
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
| 101 |
+
for tokens in token_lists:
|
| 102 |
+
token_buffer.extend(tokens)
|
| 103 |
+
tokens = token_buffer[:needed_tokens] # Read B*T+1 tokens (+1 is only for the target for the last token)
|
| 104 |
+
token_buffer = token_buffer[B*T:] # Advance by B*T tokens, so we move exactly one window of B*T tokens over
|
| 105 |
+
|
| 106 |
+
# Package tokens into inputs and targets, yield
|
| 107 |
+
use_cuda = device == "cuda"
|
| 108 |
+
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda)
|
| 109 |
+
inputs = scratch[:-1].view(B, T).to(device=device, non_blocking=use_cuda)
|
| 110 |
+
targets = scratch[1:].view(B, T).to(device=device, non_blocking=use_cuda)
|
| 111 |
+
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def tokenizing_distributed_data_loader(*args, **kwargs):
|
| 115 |
+
"""Helper that omits state_dict from yields."""
|
| 116 |
+
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
| 117 |
+
yield inputs, targets
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
| 121 |
+
tokenizer, B, T, split,
|
| 122 |
+
tokenizer_threads=4, tokenizer_batch_size=128,
|
| 123 |
+
device="cuda", resume_state_dict=None,
|
| 124 |
+
buffer_size=1000
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
BOS-aligned dataloader with Best-Fit Cropping.
|
| 128 |
+
|
| 129 |
+
Reduces token waste compared to simple greedy cropping by searching a buffer
|
| 130 |
+
for documents that fit well, while maintaining 100% utilization (no padding).
|
| 131 |
+
|
| 132 |
+
Algorithm for each row:
|
| 133 |
+
1. From buffered docs, pick the LARGEST doc that fits entirely
|
| 134 |
+
2. Repeat until no doc fits
|
| 135 |
+
3. When nothing fits, crop a doc to fill remaining space exactly
|
| 136 |
+
|
| 137 |
+
Key properties:
|
| 138 |
+
- Every row starts with BOS
|
| 139 |
+
- 100% utilization (no padding, every token is trained on)
|
| 140 |
+
- Approximately 35% of all tokens are discarded due to cropping
|
| 141 |
+
"""
|
| 142 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 143 |
+
|
| 144 |
+
row_capacity = T + 1
|
| 145 |
+
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
| 146 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 147 |
+
doc_buffer = []
|
| 148 |
+
pq_idx, rg_idx, epoch = 0, 0, 1
|
| 149 |
+
|
| 150 |
+
def refill_buffer():
|
| 151 |
+
nonlocal pq_idx, rg_idx, epoch
|
| 152 |
+
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
| 153 |
+
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
| 154 |
+
for tokens in token_lists:
|
| 155 |
+
doc_buffer.append(tokens)
|
| 156 |
+
|
| 157 |
+
while True:
|
| 158 |
+
rows = []
|
| 159 |
+
for _ in range(B):
|
| 160 |
+
row = []
|
| 161 |
+
while len(row) < row_capacity:
|
| 162 |
+
# Ensure buffer has documents
|
| 163 |
+
while len(doc_buffer) < buffer_size:
|
| 164 |
+
refill_buffer()
|
| 165 |
+
|
| 166 |
+
remaining = row_capacity - len(row)
|
| 167 |
+
|
| 168 |
+
# Find largest doc that fits entirely
|
| 169 |
+
best_idx = -1
|
| 170 |
+
best_len = 0
|
| 171 |
+
for i, doc in enumerate(doc_buffer):
|
| 172 |
+
doc_len = len(doc)
|
| 173 |
+
if doc_len <= remaining and doc_len > best_len:
|
| 174 |
+
best_idx = i
|
| 175 |
+
best_len = doc_len
|
| 176 |
+
|
| 177 |
+
if best_idx >= 0:
|
| 178 |
+
doc = doc_buffer.pop(best_idx)
|
| 179 |
+
row.extend(doc)
|
| 180 |
+
else:
|
| 181 |
+
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
| 182 |
+
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
| 183 |
+
doc = doc_buffer.pop(shortest_idx)
|
| 184 |
+
row.extend(doc[:remaining])
|
| 185 |
+
|
| 186 |
+
rows.append(row[:row_capacity])
|
| 187 |
+
|
| 188 |
+
use_cuda = device == "cuda"
|
| 189 |
+
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
| 190 |
+
inputs = batch_tensor[:, :-1].to(device=device, non_blocking=use_cuda)
|
| 191 |
+
targets = batch_tensor[:, 1:].to(device=device, non_blocking=use_cuda)
|
| 192 |
+
|
| 193 |
+
yield inputs, targets, {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
| 197 |
+
"""Helper that omits state_dict from yields."""
|
| 198 |
+
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
| 199 |
+
yield inputs, targets
|
nanochat/dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The base/pretraining dataset is a set of parquet files.
|
| 3 |
+
This file contains utilities for:
|
| 4 |
+
- iterating over the parquet files and yielding documents from it
|
| 5 |
+
- download the files on demand if they are not on disk
|
| 6 |
+
|
| 7 |
+
For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
import requests
|
| 14 |
+
import pyarrow.parquet as pq
|
| 15 |
+
from multiprocessing import Pool
|
| 16 |
+
|
| 17 |
+
from nanochat.common import get_base_dir
|
| 18 |
+
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
# The specifics of the current pretraining dataset
|
| 21 |
+
|
| 22 |
+
# The URL on the internet where the data is hosted and downloaded from on demand
|
| 23 |
+
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
| 24 |
+
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
| 25 |
+
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
| 26 |
+
base_dir = get_base_dir()
|
| 27 |
+
DATA_DIR = os.path.join(base_dir, "base_data")
|
| 28 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
# These functions are useful utilities to other modules, can/should be imported
|
| 32 |
+
|
| 33 |
+
def list_parquet_files(data_dir=None):
|
| 34 |
+
""" Looks into a data dir and returns full paths to all parquet files. """
|
| 35 |
+
data_dir = DATA_DIR if data_dir is None else data_dir
|
| 36 |
+
parquet_files = sorted([
|
| 37 |
+
f for f in os.listdir(data_dir)
|
| 38 |
+
if f.endswith('.parquet') and not f.endswith('.tmp')
|
| 39 |
+
])
|
| 40 |
+
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
| 41 |
+
return parquet_paths
|
| 42 |
+
|
| 43 |
+
def parquets_iter_batched(split, start=0, step=1):
|
| 44 |
+
"""
|
| 45 |
+
Iterate through the dataset, in batches of underlying row_groups for efficiency.
|
| 46 |
+
- split can be "train" or "val". the last parquet file will be val.
|
| 47 |
+
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
|
| 48 |
+
"""
|
| 49 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 50 |
+
parquet_paths = list_parquet_files()
|
| 51 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 52 |
+
for filepath in parquet_paths:
|
| 53 |
+
pf = pq.ParquetFile(filepath)
|
| 54 |
+
for rg_idx in range(start, pf.num_row_groups, step):
|
| 55 |
+
rg = pf.read_row_group(rg_idx)
|
| 56 |
+
texts = rg.column('text').to_pylist()
|
| 57 |
+
yield texts
|
| 58 |
+
|
| 59 |
+
# -----------------------------------------------------------------------------
|
| 60 |
+
def download_single_file(index):
|
| 61 |
+
""" Downloads a single file index, with some backoff """
|
| 62 |
+
|
| 63 |
+
# Construct the local filepath for this file and skip if it already exists
|
| 64 |
+
filename = index_to_filename(index)
|
| 65 |
+
filepath = os.path.join(DATA_DIR, filename)
|
| 66 |
+
if os.path.exists(filepath):
|
| 67 |
+
print(f"Skipping {filepath} (already exists)")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
# Construct the remote URL for this file
|
| 71 |
+
url = f"{BASE_URL}/{filename}"
|
| 72 |
+
print(f"Downloading {filename}...")
|
| 73 |
+
|
| 74 |
+
# Download with retries
|
| 75 |
+
max_attempts = 5
|
| 76 |
+
for attempt in range(1, max_attempts + 1):
|
| 77 |
+
try:
|
| 78 |
+
response = requests.get(url, stream=True, timeout=30)
|
| 79 |
+
response.raise_for_status()
|
| 80 |
+
# Write to temporary file first
|
| 81 |
+
temp_path = filepath + f".tmp"
|
| 82 |
+
with open(temp_path, 'wb') as f:
|
| 83 |
+
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
| 84 |
+
if chunk:
|
| 85 |
+
f.write(chunk)
|
| 86 |
+
# Move temp file to final location
|
| 87 |
+
os.rename(temp_path, filepath)
|
| 88 |
+
print(f"Successfully downloaded {filename}")
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
except (requests.RequestException, IOError) as e:
|
| 92 |
+
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
| 93 |
+
# Clean up any partial files
|
| 94 |
+
for path in [filepath + f".tmp", filepath]:
|
| 95 |
+
if os.path.exists(path):
|
| 96 |
+
try:
|
| 97 |
+
os.remove(path)
|
| 98 |
+
except:
|
| 99 |
+
pass
|
| 100 |
+
# Try a few times with exponential backoff: 2^attempt seconds
|
| 101 |
+
if attempt < max_attempts:
|
| 102 |
+
wait_time = 2 ** attempt
|
| 103 |
+
print(f"Waiting {wait_time} seconds before retry...")
|
| 104 |
+
time.sleep(wait_time)
|
| 105 |
+
else:
|
| 106 |
+
print(f"Failed to download {filename} after {max_attempts} attempts")
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
| 114 |
+
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
| 115 |
+
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
| 119 |
+
ids_to_download = list(range(num))
|
| 120 |
+
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
| 121 |
+
print(f"Target directory: {DATA_DIR}")
|
| 122 |
+
print()
|
| 123 |
+
with Pool(processes=args.num_workers) as pool:
|
| 124 |
+
results = pool.map(download_single_file, ids_to_download)
|
| 125 |
+
|
| 126 |
+
# Report results
|
| 127 |
+
successful = sum(1 for success in results if success)
|
| 128 |
+
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
|
nanochat/distill_loss.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_distillation_loss(
|
| 6 |
+
student_logits,
|
| 7 |
+
teacher_logits,
|
| 8 |
+
temperature=1.0,
|
| 9 |
+
reduction='mean'
|
| 10 |
+
):
|
| 11 |
+
"""
|
| 12 |
+
Compute KL divergence loss between student and teacher logits.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
student_logits: (B, T, vocab_size) logits from student model
|
| 16 |
+
teacher_logits: (B, T, vocab_size) logits from teacher model
|
| 17 |
+
temperature: Temperature for softmax (higher = softer distribution)
|
| 18 |
+
reduction: 'mean' or 'sum' or 'none'
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
loss: Scalar or (B, T) tensor depending on reduction
|
| 22 |
+
"""
|
| 23 |
+
# Apply temperature scaling
|
| 24 |
+
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
| 25 |
+
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
|
| 26 |
+
|
| 27 |
+
# KL divergence: We use KL(teacher || student) which is more numerically stable
|
| 28 |
+
# KL(teacher || student) = sum(teacher * log(teacher/student))
|
| 29 |
+
# = sum(teacher * log(teacher)) - sum(teacher * log(student))
|
| 30 |
+
# Using F.kl_div: input=log(student), target=teacher, log_target=False
|
| 31 |
+
# This computes: sum(target * (log(target) - input))
|
| 32 |
+
# = sum(teacher * (log(teacher) - log(student))) = KL(teacher || student)
|
| 33 |
+
kl_loss = F.kl_div(
|
| 34 |
+
student_log_probs,
|
| 35 |
+
teacher_probs,
|
| 36 |
+
reduction='none',
|
| 37 |
+
log_target=False
|
| 38 |
+
) # (B, T, vocab_size)
|
| 39 |
+
|
| 40 |
+
# Sum over vocab dimension
|
| 41 |
+
kl_loss = kl_loss.sum(dim=-1) # (B, T)
|
| 42 |
+
|
| 43 |
+
# Scale by temperature^2 (standard in distillation literature)
|
| 44 |
+
kl_loss = kl_loss * (temperature ** 2)
|
| 45 |
+
|
| 46 |
+
# Sum over vocab dimension, then apply reduction
|
| 47 |
+
kl_loss = kl_loss.sum(dim=-1) # (B, T)
|
| 48 |
+
|
| 49 |
+
if reduction == 'mean':
|
| 50 |
+
return kl_loss.mean()
|
| 51 |
+
elif reduction == 'sum':
|
| 52 |
+
return kl_loss.sum()
|
| 53 |
+
else:
|
| 54 |
+
return kl_loss
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def compute_combined_loss(
|
| 58 |
+
student_logits,
|
| 59 |
+
teacher_logits,
|
| 60 |
+
targets,
|
| 61 |
+
temperature=1.0,
|
| 62 |
+
alpha=0.5,
|
| 63 |
+
ignore_index=-1,
|
| 64 |
+
reduction='mean'
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Combine distillation loss with standard cross-entropy loss.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
student_logits: (B, T, vocab_size) logits from student model
|
| 71 |
+
teacher_logits: (B, T, vocab_size) logits from teacher model
|
| 72 |
+
targets: (B, T) ground truth token ids
|
| 73 |
+
temperature: Temperature for distillation
|
| 74 |
+
alpha: Weight for distillation loss (1-alpha for CE loss)
|
| 75 |
+
ignore_index: Tokens to ignore in CE loss
|
| 76 |
+
reduction: 'mean' or 'sum' or 'none'
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
total_loss: Combined loss
|
| 80 |
+
distill_loss: Distillation loss component
|
| 81 |
+
ce_loss: Cross-entropy loss component
|
| 82 |
+
"""
|
| 83 |
+
# Distillation loss
|
| 84 |
+
distill_loss = compute_distillation_loss(
|
| 85 |
+
student_logits,
|
| 86 |
+
teacher_logits,
|
| 87 |
+
temperature=temperature,
|
| 88 |
+
reduction=reduction
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Standard cross-entropy loss
|
| 92 |
+
ce_loss = F.cross_entropy(
|
| 93 |
+
student_logits.view(-1, student_logits.size(-1)),
|
| 94 |
+
targets.view(-1),
|
| 95 |
+
ignore_index=ignore_index,
|
| 96 |
+
reduction=reduction
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Combine: alpha * distill + (1-alpha) * ce
|
| 100 |
+
if reduction == 'none':
|
| 101 |
+
# For 'none', we need to handle the shape mismatch
|
| 102 |
+
# distill_loss is (B, T), ce_loss is (B*T,)
|
| 103 |
+
ce_loss = ce_loss.view(student_logits.shape[:2])
|
| 104 |
+
total_loss = alpha * distill_loss + (1 - alpha) * ce_loss
|
| 105 |
+
else:
|
| 106 |
+
total_loss = alpha * distill_loss + (1 - alpha) * ce_loss
|
| 107 |
+
|
| 108 |
+
return total_loss, distill_loss, ce_loss
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def compute_multi_token_loss(multi_token_logits, targets, ignore_index=-1, reduction='mean'):
|
| 112 |
+
"""Train multi-token heads (t+2, t+3, t+4 predictions)"""
|
| 113 |
+
total_loss = 0.0
|
| 114 |
+
count = 0
|
| 115 |
+
|
| 116 |
+
for head_name, logits in multi_token_logits.items():
|
| 117 |
+
offset = int(head_name.split('_')[1]) # "head_2" -> 2
|
| 118 |
+
|
| 119 |
+
# Shift targets: head_2 predicts t+2, so target is y shifted by 1
|
| 120 |
+
if targets.size(1) >= offset:
|
| 121 |
+
shifted_targets = targets[:, offset-1:]
|
| 122 |
+
shifted_logits = logits[:, :targets.size(1)-offset+1, :]
|
| 123 |
+
|
| 124 |
+
if shifted_targets.numel() > 0:
|
| 125 |
+
loss = F.cross_entropy(
|
| 126 |
+
shifted_logits.reshape(-1, shifted_logits.size(-1)),
|
| 127 |
+
shifted_targets.reshape(-1),
|
| 128 |
+
ignore_index=ignore_index,
|
| 129 |
+
reduction=reduction
|
| 130 |
+
)
|
| 131 |
+
total_loss += loss
|
| 132 |
+
count += 1
|
| 133 |
+
|
| 134 |
+
return total_loss / count if count > 0 else torch.tensor(0.0, device=targets.device)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def compute_draft_loss(student_model, x, teacher_logits, temperature=1.0):
|
| 138 |
+
"""Train draft head to predict multiple future tokens"""
|
| 139 |
+
if student_model.draft_head is None:
|
| 140 |
+
return torch.tensor(0.0, device=x.device)
|
| 141 |
+
|
| 142 |
+
# Get hidden states from last transformer layer
|
| 143 |
+
from nanochat.gpt import norm
|
| 144 |
+
hidden = student_model.transformer.wte(x)
|
| 145 |
+
hidden = norm(hidden)
|
| 146 |
+
x0 = hidden
|
| 147 |
+
|
| 148 |
+
for i, block in enumerate(student_model.transformer.h):
|
| 149 |
+
hidden = student_model.resid_lambdas[i] * hidden + student_model.x0_lambdas[i] * x0
|
| 150 |
+
ve = student_model.value_embeds[str(i)](x) if str(i) in student_model.value_embeds else None
|
| 151 |
+
cos_sin = student_model.cos[:, :x.size(1)], student_model.sin[:, :x.size(1)]
|
| 152 |
+
hidden = block(hidden, ve, cos_sin, student_model.window_sizes[i], None)
|
| 153 |
+
|
| 154 |
+
hidden = norm(hidden)
|
| 155 |
+
last_hidden = hidden[:, -1, :] # (B, n_embd)
|
| 156 |
+
|
| 157 |
+
# Draft head predicts next N tokens
|
| 158 |
+
draft_logits = student_model.draft_head(last_hidden) # (B, draft_n, vocab)
|
| 159 |
+
|
| 160 |
+
# Match with teacher's future predictions
|
| 161 |
+
B, T, V = teacher_logits.shape
|
| 162 |
+
draft_n = draft_logits.shape[1]
|
| 163 |
+
|
| 164 |
+
total_loss = 0.0
|
| 165 |
+
for i in range(min(draft_n, T-1)):
|
| 166 |
+
draft_pred = draft_logits[:, i, :]
|
| 167 |
+
teacher_future = teacher_logits[:, i+1, :]
|
| 168 |
+
|
| 169 |
+
student_log_probs = F.log_softmax(draft_pred / temperature, dim=-1)
|
| 170 |
+
teacher_probs = F.softmax(teacher_future / temperature, dim=-1)
|
| 171 |
+
|
| 172 |
+
kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean', log_target=False)
|
| 173 |
+
total_loss += kl
|
| 174 |
+
|
| 175 |
+
return total_loss / min(draft_n, T-1) if T > 1 else torch.tensor(0.0, device=x.device)
|
| 176 |
+
|
nanochat/engine.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engine for efficient inference of our models.
|
| 3 |
+
|
| 4 |
+
Everything works around token sequences:
|
| 5 |
+
- The user can send token sequences to the engine
|
| 6 |
+
- The engine returns the next token
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- The engine knows nothing about tokenization, it's purely token id sequences.
|
| 10 |
+
|
| 11 |
+
The whole thing is made as efficient as possible.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import signal
|
| 17 |
+
import warnings
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from collections import deque
|
| 20 |
+
from nanochat.common import compute_init, autodetect_device_type
|
| 21 |
+
from nanochat.checkpoint_manager import load_model
|
| 22 |
+
from contextlib import nullcontext
|
| 23 |
+
|
| 24 |
+
# -----------------------------------------------------------------------------
|
| 25 |
+
# Calculator tool helpers
|
| 26 |
+
@contextmanager
|
| 27 |
+
def timeout(duration, formula):
|
| 28 |
+
def timeout_handler(signum, frame):
|
| 29 |
+
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
| 30 |
+
|
| 31 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 32 |
+
signal.alarm(duration)
|
| 33 |
+
yield
|
| 34 |
+
signal.alarm(0)
|
| 35 |
+
|
| 36 |
+
def eval_with_timeout(formula, max_time=3):
|
| 37 |
+
try:
|
| 38 |
+
with timeout(max_time, formula):
|
| 39 |
+
with warnings.catch_warnings():
|
| 40 |
+
warnings.simplefilter("ignore", SyntaxWarning)
|
| 41 |
+
return eval(formula, {"__builtins__": {}}, {})
|
| 42 |
+
except Exception as e:
|
| 43 |
+
signal.alarm(0)
|
| 44 |
+
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
def use_calculator(expr):
|
| 48 |
+
"""
|
| 49 |
+
Evaluate a Python expression safely.
|
| 50 |
+
Supports both math expressions and string operations like .count()
|
| 51 |
+
"""
|
| 52 |
+
# Remove commas from numbers
|
| 53 |
+
expr = expr.replace(",", "")
|
| 54 |
+
|
| 55 |
+
# Check if it's a pure math expression (old behavior)
|
| 56 |
+
if all([x in "0123456789*+-/.() " for x in expr]):
|
| 57 |
+
if "**" in expr: # disallow power operator
|
| 58 |
+
return None
|
| 59 |
+
return eval_with_timeout(expr)
|
| 60 |
+
|
| 61 |
+
# Check if it's a string operation we support
|
| 62 |
+
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
| 63 |
+
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
| 64 |
+
if not all([x in allowed_chars for x in expr]):
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
# Disallow dangerous patterns
|
| 68 |
+
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
| 69 |
+
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
| 70 |
+
'getattr', 'setattr', 'delattr', 'hasattr']
|
| 71 |
+
expr_lower = expr.lower()
|
| 72 |
+
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
# Only allow .count() method for now (can expand later)
|
| 76 |
+
if '.count(' not in expr:
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
# Evaluate with timeout
|
| 80 |
+
return eval_with_timeout(expr)
|
| 81 |
+
|
| 82 |
+
# -----------------------------------------------------------------------------
|
| 83 |
+
class KVCache:
|
| 84 |
+
"""
|
| 85 |
+
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
| 86 |
+
|
| 87 |
+
Key differences from FA2-style cache:
|
| 88 |
+
- Tensors are (B, T, H, D) not (B, H, T, D)
|
| 89 |
+
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
| 90 |
+
- Position tracked per batch element via cache_seqlens tensor
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
| 94 |
+
self.batch_size = batch_size
|
| 95 |
+
self.max_seq_len = seq_len
|
| 96 |
+
self.n_layers = num_layers
|
| 97 |
+
self.n_heads = num_heads
|
| 98 |
+
self.head_dim = head_dim
|
| 99 |
+
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
| 100 |
+
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 101 |
+
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 102 |
+
# Current sequence length per batch element (FA3 needs int32)
|
| 103 |
+
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
| 104 |
+
|
| 105 |
+
def reset(self):
|
| 106 |
+
"""Reset cache to empty state."""
|
| 107 |
+
self.cache_seqlens.zero_()
|
| 108 |
+
|
| 109 |
+
def get_pos(self):
|
| 110 |
+
"""Get current position (assumes all batch elements at same position)."""
|
| 111 |
+
return self.cache_seqlens[0].item()
|
| 112 |
+
|
| 113 |
+
def get_layer_cache(self, layer_idx):
|
| 114 |
+
"""Return (k_cache, v_cache) views for a specific layer."""
|
| 115 |
+
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
| 116 |
+
|
| 117 |
+
def advance(self, num_tokens):
|
| 118 |
+
"""Advance the cache position by num_tokens."""
|
| 119 |
+
self.cache_seqlens += num_tokens
|
| 120 |
+
|
| 121 |
+
def prefill(self, other):
|
| 122 |
+
"""
|
| 123 |
+
Copy cached KV from another cache into this one.
|
| 124 |
+
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
| 125 |
+
"""
|
| 126 |
+
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
| 127 |
+
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
| 128 |
+
assert self.max_seq_len >= other.max_seq_len
|
| 129 |
+
other_pos = other.get_pos()
|
| 130 |
+
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
| 131 |
+
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
| 132 |
+
self.cache_seqlens.fill_(other_pos)
|
| 133 |
+
|
| 134 |
+
# -----------------------------------------------------------------------------
|
| 135 |
+
@torch.inference_mode()
|
| 136 |
+
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
| 137 |
+
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
| 138 |
+
assert temperature >= 0.0, "temperature must be non-negative"
|
| 139 |
+
if temperature == 0.0:
|
| 140 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 141 |
+
if top_k is not None and top_k > 0:
|
| 142 |
+
k = min(top_k, logits.size(-1))
|
| 143 |
+
vals, idx = torch.topk(logits, k, dim=-1)
|
| 144 |
+
vals = vals / temperature
|
| 145 |
+
probs = F.softmax(vals, dim=-1)
|
| 146 |
+
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 147 |
+
return idx.gather(1, choice)
|
| 148 |
+
else:
|
| 149 |
+
logits = logits / temperature
|
| 150 |
+
probs = F.softmax(logits, dim=-1)
|
| 151 |
+
return torch.multinomial(probs, num_samples=1, generator=rng)
|
| 152 |
+
|
| 153 |
+
# -----------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
class RowState:
|
| 156 |
+
# Per-row state tracking during generation
|
| 157 |
+
def __init__(self, current_tokens=None):
|
| 158 |
+
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
| 159 |
+
self.forced_tokens = deque() # Queue of tokens to force inject
|
| 160 |
+
self.in_python_block = False # Whether we are inside a python block
|
| 161 |
+
self.python_expr_tokens = [] # Tokens of the current python expression
|
| 162 |
+
self.completed = False # Whether this row has completed generation
|
| 163 |
+
|
| 164 |
+
class Engine:
|
| 165 |
+
|
| 166 |
+
def __init__(self, model, tokenizer):
|
| 167 |
+
self.model = model
|
| 168 |
+
self.tokenizer = tokenizer # needed for tool use
|
| 169 |
+
|
| 170 |
+
@torch.inference_mode()
|
| 171 |
+
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
| 172 |
+
"""Same as generate, but does single prefill and then clones the KV cache."""
|
| 173 |
+
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
| 174 |
+
device = self.model.get_device()
|
| 175 |
+
# NOTE: setting the dtype here and in this way is an ugly hack.
|
| 176 |
+
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
| 177 |
+
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
| 178 |
+
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
| 179 |
+
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
| 180 |
+
# In particular, the KVCache should allocate its tensors lazily
|
| 181 |
+
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 182 |
+
rng = torch.Generator(device=device)
|
| 183 |
+
rng.manual_seed(seed)
|
| 184 |
+
|
| 185 |
+
# Get the special tokens we need to coordinate the tool use state machine
|
| 186 |
+
get_special = lambda s: self.tokenizer.encode_special(s)
|
| 187 |
+
python_start = get_special("<|python_start|>")
|
| 188 |
+
python_end = get_special("<|python_end|>")
|
| 189 |
+
output_start = get_special("<|output_start|>")
|
| 190 |
+
output_end = get_special("<|output_end|>")
|
| 191 |
+
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
| 192 |
+
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
| 193 |
+
|
| 194 |
+
# 1) Run a batch 1 prefill of the prompt tokens
|
| 195 |
+
m = self.model.config
|
| 196 |
+
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
| 197 |
+
kv_cache_prefill = KVCache(
|
| 198 |
+
batch_size=1,
|
| 199 |
+
seq_len=len(tokens),
|
| 200 |
+
device=device,
|
| 201 |
+
dtype=dtype,
|
| 202 |
+
**kv_model_kwargs,
|
| 203 |
+
)
|
| 204 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 205 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
| 206 |
+
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
| 207 |
+
|
| 208 |
+
# 2) Replicate the KV cache for each sample/row
|
| 209 |
+
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
| 210 |
+
kv_cache_decode = KVCache(
|
| 211 |
+
batch_size=num_samples,
|
| 212 |
+
seq_len=kv_length_hint,
|
| 213 |
+
device=device,
|
| 214 |
+
dtype=dtype,
|
| 215 |
+
**kv_model_kwargs,
|
| 216 |
+
)
|
| 217 |
+
kv_cache_decode.prefill(kv_cache_prefill)
|
| 218 |
+
del kv_cache_prefill # no need to keep this memory around
|
| 219 |
+
|
| 220 |
+
# 3) Initialize states for each sample
|
| 221 |
+
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
| 222 |
+
|
| 223 |
+
# 4) Main generation loop
|
| 224 |
+
num_generated = 0
|
| 225 |
+
while True:
|
| 226 |
+
# Stop condition: we've reached max tokens
|
| 227 |
+
if max_tokens is not None and num_generated >= max_tokens:
|
| 228 |
+
break
|
| 229 |
+
# Stop condition: all rows are completed
|
| 230 |
+
if all(state.completed for state in row_states):
|
| 231 |
+
break
|
| 232 |
+
|
| 233 |
+
# Sample the next token for each row
|
| 234 |
+
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
| 235 |
+
sampled_tokens = next_ids[:, 0].tolist()
|
| 236 |
+
|
| 237 |
+
# Process each row: choose the next token, update state, optional tool use
|
| 238 |
+
token_column = [] # contains the next token id along each row
|
| 239 |
+
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
| 240 |
+
for i, state in enumerate(row_states):
|
| 241 |
+
# Select the next token in this row
|
| 242 |
+
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
| 243 |
+
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
| 244 |
+
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
| 245 |
+
token_column.append(next_token)
|
| 246 |
+
# Update the state of this row to include the next token
|
| 247 |
+
state.current_tokens.append(next_token)
|
| 248 |
+
# On <|assistant_end|> or <|bos|>, mark the row as completed
|
| 249 |
+
if next_token == assistant_end or next_token == bos:
|
| 250 |
+
state.completed = True
|
| 251 |
+
# Handle tool logic
|
| 252 |
+
if next_token == python_start:
|
| 253 |
+
state.in_python_block = True
|
| 254 |
+
state.python_expr_tokens = []
|
| 255 |
+
elif next_token == python_end and state.in_python_block:
|
| 256 |
+
state.in_python_block = False
|
| 257 |
+
if state.python_expr_tokens:
|
| 258 |
+
expr = self.tokenizer.decode(state.python_expr_tokens)
|
| 259 |
+
result = use_calculator(expr)
|
| 260 |
+
if result is not None:
|
| 261 |
+
result_tokens = self.tokenizer.encode(str(result))
|
| 262 |
+
state.forced_tokens.append(output_start)
|
| 263 |
+
state.forced_tokens.extend(result_tokens)
|
| 264 |
+
state.forced_tokens.append(output_end)
|
| 265 |
+
state.python_expr_tokens = []
|
| 266 |
+
elif state.in_python_block:
|
| 267 |
+
state.python_expr_tokens.append(next_token)
|
| 268 |
+
|
| 269 |
+
# Yield the token column
|
| 270 |
+
yield token_column, token_masks
|
| 271 |
+
num_generated += 1
|
| 272 |
+
|
| 273 |
+
# Prepare logits for next iteration
|
| 274 |
+
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
| 275 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
| 276 |
+
|
| 277 |
+
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
| 278 |
+
"""
|
| 279 |
+
Non-streaming batch generation that just returns the final token sequences.
|
| 280 |
+
Returns a list of token sequences (list of lists of ints).
|
| 281 |
+
Terminal tokens (assistant_end, bos) are not included in the results.
|
| 282 |
+
"""
|
| 283 |
+
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
| 284 |
+
bos = self.tokenizer.get_bos_token_id()
|
| 285 |
+
results = [tokens.copy() for _ in range(num_samples)]
|
| 286 |
+
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
| 287 |
+
completed = [False] * num_samples
|
| 288 |
+
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
| 289 |
+
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
| 290 |
+
if not completed[i]:
|
| 291 |
+
if token == assistant_end or token == bos:
|
| 292 |
+
completed[i] = True
|
| 293 |
+
else:
|
| 294 |
+
results[i].append(token)
|
| 295 |
+
masks[i].append(mask)
|
| 296 |
+
# Stop if all rows are completed
|
| 297 |
+
if all(completed):
|
| 298 |
+
break
|
| 299 |
+
return results, masks
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
"""
|
| 304 |
+
Quick inline test to make sure that the naive/slow model.generate function
|
| 305 |
+
is equivalent to the faster Engine.generate function here.
|
| 306 |
+
"""
|
| 307 |
+
import time
|
| 308 |
+
# init compute
|
| 309 |
+
device_type = autodetect_device_type()
|
| 310 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 311 |
+
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
| 312 |
+
|
| 313 |
+
# load the model and tokenizer
|
| 314 |
+
model, tokenizer, meta = load_model("base", device, phase="eval")
|
| 315 |
+
bos_token_id = tokenizer.get_bos_token_id()
|
| 316 |
+
# common hyperparameters
|
| 317 |
+
kwargs = dict(max_tokens=64, temperature=0.0)
|
| 318 |
+
# set the starting prompt
|
| 319 |
+
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
| 320 |
+
# generate the reference sequence using the model.generate() function
|
| 321 |
+
generated_tokens = []
|
| 322 |
+
torch.cuda.synchronize()
|
| 323 |
+
t0 = time.time()
|
| 324 |
+
stream = model.generate(prompt_tokens, **kwargs)
|
| 325 |
+
with autocast_ctx:
|
| 326 |
+
for token in stream:
|
| 327 |
+
generated_tokens.append(token)
|
| 328 |
+
chunk = tokenizer.decode([token])
|
| 329 |
+
print(chunk, end="", flush=True)
|
| 330 |
+
print()
|
| 331 |
+
torch.cuda.synchronize()
|
| 332 |
+
t1 = time.time()
|
| 333 |
+
print(f"Reference time: {t1 - t0:.2f}s")
|
| 334 |
+
reference_ids = generated_tokens
|
| 335 |
+
# generate tokens with Engine
|
| 336 |
+
generated_tokens = []
|
| 337 |
+
engine = Engine(model, tokenizer)
|
| 338 |
+
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
| 339 |
+
torch.cuda.synchronize()
|
| 340 |
+
t0 = time.time()
|
| 341 |
+
with autocast_ctx:
|
| 342 |
+
for token_column, token_masks in stream:
|
| 343 |
+
token = token_column[0] # only print out the first row
|
| 344 |
+
generated_tokens.append(token)
|
| 345 |
+
chunk = tokenizer.decode([token])
|
| 346 |
+
print(chunk, end="", flush=True)
|
| 347 |
+
print()
|
| 348 |
+
torch.cuda.synchronize()
|
| 349 |
+
t1 = time.time()
|
| 350 |
+
print(f"Engine time: {t1 - t0:.2f}s")
|
| 351 |
+
# compare the two sequences
|
| 352 |
+
for i in range(len(reference_ids)):
|
| 353 |
+
if reference_ids[i] != generated_tokens[i]:
|
| 354 |
+
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
| 355 |
+
break
|
| 356 |
+
print(f"Match: {reference_ids == generated_tokens}")
|
nanochat/execution.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sandboxed execution utilities for running Python code that comes out of an LLM.
|
| 3 |
+
Adapted from OpenAI HumanEval code:
|
| 4 |
+
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
|
| 5 |
+
|
| 6 |
+
What is covered:
|
| 7 |
+
- Each execution runs in its own process (can be killed if it hangs or crashes)
|
| 8 |
+
- Execution is limited by a timeout to stop infinite loops
|
| 9 |
+
- Memory limits are enforced by default (256MB)
|
| 10 |
+
- stdout and stderr are captured and returned
|
| 11 |
+
- Code runs in a temporary directory that is deleted afterwards
|
| 12 |
+
- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
|
| 13 |
+
|
| 14 |
+
What is not covered:
|
| 15 |
+
- Not a true security sandbox
|
| 16 |
+
- Network access is not blocked (e.g. sockets could be opened)
|
| 17 |
+
- Python's dynamic features (e.g. ctypes) could bypass restrictions
|
| 18 |
+
- No kernel-level isolation (no seccomp, no containers, no virtualization)
|
| 19 |
+
|
| 20 |
+
Overall this sandbox is good for evaluation of generated code and protects against
|
| 21 |
+
accidental destructive behavior, but it is not safe against malicious adversarial code.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import contextlib
|
| 25 |
+
import faulthandler
|
| 26 |
+
import io
|
| 27 |
+
import multiprocessing
|
| 28 |
+
import os
|
| 29 |
+
import platform
|
| 30 |
+
import signal
|
| 31 |
+
import tempfile
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from typing import Optional
|
| 34 |
+
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ExecutionResult:
|
| 39 |
+
"""Result of executing Python code in a sandbox."""
|
| 40 |
+
success: bool
|
| 41 |
+
stdout: str
|
| 42 |
+
stderr: str
|
| 43 |
+
error: Optional[str] = None
|
| 44 |
+
timeout: bool = False
|
| 45 |
+
memory_exceeded: bool = False
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
parts = []
|
| 49 |
+
parts.append(f"ExecutionResult(success={self.success}")
|
| 50 |
+
if self.timeout:
|
| 51 |
+
parts.append(", timeout=True")
|
| 52 |
+
if self.memory_exceeded:
|
| 53 |
+
parts.append(", memory_exceeded=True")
|
| 54 |
+
if self.error:
|
| 55 |
+
parts.append(f", error={self.error!r}")
|
| 56 |
+
if self.stdout:
|
| 57 |
+
parts.append(f", stdout={self.stdout!r}")
|
| 58 |
+
if self.stderr:
|
| 59 |
+
parts.append(f", stderr={self.stderr!r}")
|
| 60 |
+
parts.append(")")
|
| 61 |
+
return "".join(parts)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@contextlib.contextmanager
|
| 65 |
+
def time_limit(seconds: float):
|
| 66 |
+
def signal_handler(signum, frame):
|
| 67 |
+
raise TimeoutException("Timed out!")
|
| 68 |
+
|
| 69 |
+
signal.setitimer(signal.ITIMER_REAL, seconds)
|
| 70 |
+
signal.signal(signal.SIGALRM, signal_handler)
|
| 71 |
+
try:
|
| 72 |
+
yield
|
| 73 |
+
finally:
|
| 74 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@contextlib.contextmanager
|
| 78 |
+
def capture_io():
|
| 79 |
+
"""Capture stdout and stderr, and disable stdin."""
|
| 80 |
+
stdout_capture = io.StringIO()
|
| 81 |
+
stderr_capture = io.StringIO()
|
| 82 |
+
stdin_block = WriteOnlyStringIO()
|
| 83 |
+
with contextlib.redirect_stdout(stdout_capture):
|
| 84 |
+
with contextlib.redirect_stderr(stderr_capture):
|
| 85 |
+
with redirect_stdin(stdin_block):
|
| 86 |
+
yield stdout_capture, stderr_capture
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@contextlib.contextmanager
|
| 90 |
+
def create_tempdir():
|
| 91 |
+
with tempfile.TemporaryDirectory() as dirname:
|
| 92 |
+
with chdir(dirname):
|
| 93 |
+
yield dirname
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TimeoutException(Exception):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class WriteOnlyStringIO(io.StringIO):
|
| 101 |
+
"""StringIO that throws an exception when it's read from"""
|
| 102 |
+
|
| 103 |
+
def read(self, *args, **kwargs):
|
| 104 |
+
raise IOError
|
| 105 |
+
|
| 106 |
+
def readline(self, *args, **kwargs):
|
| 107 |
+
raise IOError
|
| 108 |
+
|
| 109 |
+
def readlines(self, *args, **kwargs):
|
| 110 |
+
raise IOError
|
| 111 |
+
|
| 112 |
+
def readable(self, *args, **kwargs):
|
| 113 |
+
"""Returns True if the IO object can be read."""
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class redirect_stdin(contextlib._RedirectStream): # type: ignore
|
| 118 |
+
_stream = "stdin"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@contextlib.contextmanager
|
| 122 |
+
def chdir(root):
|
| 123 |
+
if root == ".":
|
| 124 |
+
yield
|
| 125 |
+
return
|
| 126 |
+
cwd = os.getcwd()
|
| 127 |
+
os.chdir(root)
|
| 128 |
+
try:
|
| 129 |
+
yield
|
| 130 |
+
finally:
|
| 131 |
+
os.chdir(cwd)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
| 135 |
+
"""
|
| 136 |
+
This disables various destructive functions and prevents the generated code
|
| 137 |
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
| 138 |
+
removing filesystem files, etc.)
|
| 139 |
+
|
| 140 |
+
WARNING
|
| 141 |
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
| 142 |
+
generated code, should not be blindly executed outside of one. See the
|
| 143 |
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
| 144 |
+
with caution.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
if platform.uname().system != "Darwin":
|
| 148 |
+
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
| 149 |
+
import resource
|
| 150 |
+
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
| 151 |
+
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
| 152 |
+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
| 153 |
+
|
| 154 |
+
faulthandler.disable()
|
| 155 |
+
|
| 156 |
+
import builtins
|
| 157 |
+
|
| 158 |
+
builtins.exit = None
|
| 159 |
+
builtins.quit = None
|
| 160 |
+
|
| 161 |
+
import os
|
| 162 |
+
|
| 163 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 164 |
+
|
| 165 |
+
os.kill = None
|
| 166 |
+
os.system = None
|
| 167 |
+
os.putenv = None
|
| 168 |
+
os.remove = None
|
| 169 |
+
os.removedirs = None
|
| 170 |
+
os.rmdir = None
|
| 171 |
+
os.fchdir = None
|
| 172 |
+
os.setuid = None
|
| 173 |
+
os.fork = None
|
| 174 |
+
os.forkpty = None
|
| 175 |
+
os.killpg = None
|
| 176 |
+
os.rename = None
|
| 177 |
+
os.renames = None
|
| 178 |
+
os.truncate = None
|
| 179 |
+
os.replace = None
|
| 180 |
+
os.unlink = None
|
| 181 |
+
os.fchmod = None
|
| 182 |
+
os.fchown = None
|
| 183 |
+
os.chmod = None
|
| 184 |
+
os.chown = None
|
| 185 |
+
os.chroot = None
|
| 186 |
+
os.fchdir = None
|
| 187 |
+
os.lchflags = None
|
| 188 |
+
os.lchmod = None
|
| 189 |
+
os.lchown = None
|
| 190 |
+
os.getcwd = None
|
| 191 |
+
os.chdir = None
|
| 192 |
+
|
| 193 |
+
import shutil
|
| 194 |
+
|
| 195 |
+
shutil.rmtree = None
|
| 196 |
+
shutil.move = None
|
| 197 |
+
shutil.chown = None
|
| 198 |
+
|
| 199 |
+
import subprocess
|
| 200 |
+
|
| 201 |
+
subprocess.Popen = None # type: ignore
|
| 202 |
+
|
| 203 |
+
__builtins__["help"] = None
|
| 204 |
+
|
| 205 |
+
import sys
|
| 206 |
+
|
| 207 |
+
sys.modules["ipdb"] = None
|
| 208 |
+
sys.modules["joblib"] = None
|
| 209 |
+
sys.modules["resource"] = None
|
| 210 |
+
sys.modules["psutil"] = None
|
| 211 |
+
sys.modules["tkinter"] = None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
|
| 215 |
+
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
|
| 216 |
+
with create_tempdir():
|
| 217 |
+
|
| 218 |
+
# These system calls are needed when cleaning up tempdir.
|
| 219 |
+
import os
|
| 220 |
+
import shutil
|
| 221 |
+
|
| 222 |
+
rmtree = shutil.rmtree
|
| 223 |
+
rmdir = os.rmdir
|
| 224 |
+
chdir = os.chdir
|
| 225 |
+
unlink = os.unlink
|
| 226 |
+
|
| 227 |
+
# Disable functionalities that can make destructive changes to the test.
|
| 228 |
+
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
| 229 |
+
|
| 230 |
+
# Default to failure
|
| 231 |
+
result_dict.update({
|
| 232 |
+
"success": False,
|
| 233 |
+
"stdout": "",
|
| 234 |
+
"stderr": "",
|
| 235 |
+
"timeout": False,
|
| 236 |
+
"memory_exceeded": False,
|
| 237 |
+
"error": None,
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
exec_globals = {}
|
| 242 |
+
with capture_io() as (stdout_capture, stderr_capture):
|
| 243 |
+
with time_limit(timeout):
|
| 244 |
+
# WARNING
|
| 245 |
+
# This program exists to execute untrusted model-generated code. Although
|
| 246 |
+
# it is highly unlikely that model-generated code will do something overtly
|
| 247 |
+
# malicious in response to this test suite, model-generated code may act
|
| 248 |
+
# destructively due to a lack of model capability or alignment.
|
| 249 |
+
# Users are strongly encouraged to sandbox this evaluation suite so that it
|
| 250 |
+
# does not perform destructive actions on their host or network. For more
|
| 251 |
+
# information on how OpenAI sandboxes its code, see the accompanying paper.
|
| 252 |
+
# Once you have read this disclaimer and taken appropriate precautions,
|
| 253 |
+
# uncomment the following line and proceed at your own risk:
|
| 254 |
+
exec(code, exec_globals)
|
| 255 |
+
|
| 256 |
+
result_dict.update({
|
| 257 |
+
"success": True,
|
| 258 |
+
"stdout": stdout_capture.getvalue(),
|
| 259 |
+
"stderr": stderr_capture.getvalue(),
|
| 260 |
+
})
|
| 261 |
+
|
| 262 |
+
except TimeoutException:
|
| 263 |
+
result_dict.update({
|
| 264 |
+
"timeout": True,
|
| 265 |
+
"error": "Execution timed out",
|
| 266 |
+
})
|
| 267 |
+
|
| 268 |
+
except MemoryError as e:
|
| 269 |
+
result_dict.update({
|
| 270 |
+
"memory_exceeded": True,
|
| 271 |
+
"error": f"Memory limit exceeded: {e}",
|
| 272 |
+
})
|
| 273 |
+
|
| 274 |
+
except BaseException as e:
|
| 275 |
+
result_dict.update({
|
| 276 |
+
"error": f"{type(e).__name__}: {e}",
|
| 277 |
+
})
|
| 278 |
+
|
| 279 |
+
# Needed for cleaning up.
|
| 280 |
+
shutil.rmtree = rmtree
|
| 281 |
+
os.rmdir = rmdir
|
| 282 |
+
os.chdir = chdir
|
| 283 |
+
os.unlink = unlink
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def execute_code(
|
| 287 |
+
code: str,
|
| 288 |
+
timeout: float = 5.0, # 5 seconds default
|
| 289 |
+
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
|
| 290 |
+
) -> ExecutionResult:
|
| 291 |
+
"""
|
| 292 |
+
Execute Python code in a sandboxed environment.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
code: Python code to execute as a string
|
| 296 |
+
timeout: Maximum execution time in seconds (default: 5.0)
|
| 297 |
+
maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
ExecutionResult with success status, stdout/stderr, and error information
|
| 301 |
+
|
| 302 |
+
Example:
|
| 303 |
+
>>> result = execute_code("print('hello world')")
|
| 304 |
+
>>> result.success
|
| 305 |
+
True
|
| 306 |
+
>>> result.stdout
|
| 307 |
+
'hello world\\n'
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
manager = multiprocessing.Manager()
|
| 311 |
+
result_dict = manager.dict()
|
| 312 |
+
|
| 313 |
+
p = multiprocessing.Process(
|
| 314 |
+
target=_unsafe_execute,
|
| 315 |
+
args=(code, timeout, maximum_memory_bytes, result_dict)
|
| 316 |
+
)
|
| 317 |
+
p.start()
|
| 318 |
+
p.join(timeout=timeout + 1)
|
| 319 |
+
|
| 320 |
+
if p.is_alive():
|
| 321 |
+
p.kill()
|
| 322 |
+
return ExecutionResult(
|
| 323 |
+
success=False,
|
| 324 |
+
stdout="",
|
| 325 |
+
stderr="",
|
| 326 |
+
error="Execution timed out (process killed)",
|
| 327 |
+
timeout=True,
|
| 328 |
+
memory_exceeded=False,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if not result_dict:
|
| 332 |
+
return ExecutionResult(
|
| 333 |
+
success=False,
|
| 334 |
+
stdout="",
|
| 335 |
+
stderr="",
|
| 336 |
+
error="Execution failed (no result returned)",
|
| 337 |
+
timeout=True,
|
| 338 |
+
memory_exceeded=False,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return ExecutionResult(
|
| 342 |
+
success=result_dict["success"],
|
| 343 |
+
stdout=result_dict["stdout"],
|
| 344 |
+
stderr=result_dict["stderr"],
|
| 345 |
+
error=result_dict["error"],
|
| 346 |
+
timeout=result_dict["timeout"],
|
| 347 |
+
memory_exceeded=result_dict["memory_exceeded"],
|
| 348 |
+
)
|
| 349 |
+
|
nanochat/flash_attention.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
| 3 |
+
|
| 4 |
+
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
| 5 |
+
to PyTorch SDPA on non-Hopper GPUs, MPS, and CPU.
|
| 6 |
+
|
| 7 |
+
Usage (drop-in replacement for FA3):
|
| 8 |
+
from nanochat.flash_attention import flash_attn
|
| 9 |
+
|
| 10 |
+
# Training (no KV cache)
|
| 11 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 12 |
+
|
| 13 |
+
# Inference (with KV cache)
|
| 14 |
+
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
| 15 |
+
"""
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# Detection: Try to load FA3 on Hopper+ GPUs
|
| 22 |
+
# =============================================================================
|
| 23 |
+
def _load_flash_attention_3():
|
| 24 |
+
"""Try to load Flash Attention 3 (requires Hopper+ GPU)."""
|
| 25 |
+
if not torch.cuda.is_available():
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
major, _ = torch.cuda.get_device_capability()
|
| 29 |
+
if major < 9: # Hopper is sm90
|
| 30 |
+
return None
|
| 31 |
+
import os
|
| 32 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 33 |
+
from kernels import get_kernel
|
| 34 |
+
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
| 35 |
+
except Exception:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_fa3 = _load_flash_attention_3()
|
| 40 |
+
HAS_FA3 = _fa3 is not None
|
| 41 |
+
|
| 42 |
+
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
| 43 |
+
_override_impl = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _use_fa3():
|
| 47 |
+
"""Determine whether to use FA3 based on availability and override."""
|
| 48 |
+
if _override_impl == 'fa3':
|
| 49 |
+
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
| 50 |
+
return True
|
| 51 |
+
if _override_impl == 'sdpa':
|
| 52 |
+
return False
|
| 53 |
+
return HAS_FA3 # auto
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# =============================================================================
|
| 57 |
+
# SDPA helpers
|
| 58 |
+
# =============================================================================
|
| 59 |
+
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
| 60 |
+
"""
|
| 61 |
+
SDPA attention with sliding window support.
|
| 62 |
+
q, k, v are (B, H, T, D) format.
|
| 63 |
+
"""
|
| 64 |
+
Tq = q.size(2)
|
| 65 |
+
Tk = k.size(2)
|
| 66 |
+
window = window_size[0]
|
| 67 |
+
|
| 68 |
+
# Full context, same length
|
| 69 |
+
if (window < 0 or window >= Tq) and Tq == Tk:
|
| 70 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
| 71 |
+
|
| 72 |
+
# Single token generation
|
| 73 |
+
if Tq == 1:
|
| 74 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
| 75 |
+
|
| 76 |
+
# Need explicit mask
|
| 77 |
+
device = q.device
|
| 78 |
+
if Tq == Tk:
|
| 79 |
+
# Causal + sliding window
|
| 80 |
+
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
|
| 81 |
+
if window > 0 and window < Tq:
|
| 82 |
+
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
|
| 83 |
+
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
| 84 |
+
mask = mask & ((row_idx - col_idx) <= window)
|
| 85 |
+
else:
|
| 86 |
+
# Chunk inference: attend to prefix + causal within chunk
|
| 87 |
+
prefix_len = Tk - Tq
|
| 88 |
+
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
|
| 89 |
+
mask[:, :prefix_len] = True
|
| 90 |
+
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
|
| 91 |
+
|
| 92 |
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# =============================================================================
|
| 96 |
+
# Public API: Same interface as FA3
|
| 97 |
+
# =============================================================================
|
| 98 |
+
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
| 99 |
+
"""
|
| 100 |
+
Flash Attention for training (no KV cache).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
q, k, v: Tensors of shape (B, T, H, D)
|
| 104 |
+
causal: Whether to use causal masking
|
| 105 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Output tensor of shape (B, T, H, D)
|
| 109 |
+
"""
|
| 110 |
+
if _use_fa3():
|
| 111 |
+
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
| 112 |
+
|
| 113 |
+
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
| 114 |
+
q = q.transpose(1, 2)
|
| 115 |
+
k = k.transpose(1, 2)
|
| 116 |
+
v = v.transpose(1, 2)
|
| 117 |
+
enable_gqa = q.size(1) != k.size(1)
|
| 118 |
+
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
| 119 |
+
return y.transpose(1, 2) # back to (B, T, H, D)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
| 123 |
+
causal=False, window_size=(-1, -1)):
|
| 124 |
+
"""
|
| 125 |
+
Flash Attention with KV cache for inference.
|
| 126 |
+
|
| 127 |
+
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
q: Queries, shape (B, T_new, H, D)
|
| 131 |
+
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
| 132 |
+
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
| 133 |
+
cache_seqlens: Current position in cache, shape (B,) int32
|
| 134 |
+
causal: Whether to use causal masking
|
| 135 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Output tensor of shape (B, T_new, H, D)
|
| 139 |
+
"""
|
| 140 |
+
if _use_fa3():
|
| 141 |
+
return _fa3.flash_attn_with_kvcache(
|
| 142 |
+
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
| 143 |
+
causal=causal, window_size=window_size
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# SDPA fallback: manually manage KV cache
|
| 147 |
+
B, T_new, H, D = q.shape
|
| 148 |
+
pos = cache_seqlens[0].item() # assume uniform position across batch
|
| 149 |
+
|
| 150 |
+
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
| 151 |
+
if k is not None and v is not None:
|
| 152 |
+
k_cache[:, pos:pos+T_new, :, :] = k
|
| 153 |
+
v_cache[:, pos:pos+T_new, :, :] = v
|
| 154 |
+
|
| 155 |
+
# Get full cache up to current position + new tokens
|
| 156 |
+
end_pos = pos + T_new
|
| 157 |
+
k_full = k_cache[:, :end_pos, :, :]
|
| 158 |
+
v_full = v_cache[:, :end_pos, :, :]
|
| 159 |
+
|
| 160 |
+
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
| 161 |
+
q_sdpa = q.transpose(1, 2)
|
| 162 |
+
k_sdpa = k_full.transpose(1, 2)
|
| 163 |
+
v_sdpa = v_full.transpose(1, 2)
|
| 164 |
+
|
| 165 |
+
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
| 166 |
+
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
| 167 |
+
|
| 168 |
+
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# =============================================================================
|
| 172 |
+
# Export: flash_attn module interface (drop-in replacement for FA3)
|
| 173 |
+
# =============================================================================
|
| 174 |
+
from types import SimpleNamespace
|
| 175 |
+
flash_attn = SimpleNamespace(
|
| 176 |
+
flash_attn_func=flash_attn_func,
|
| 177 |
+
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
| 178 |
+
)
|
nanochat/gpt.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT model (rewrite, a lot simpler)
|
| 3 |
+
Notable features:
|
| 4 |
+
- rotary embeddings (and no positional embeddings)
|
| 5 |
+
- QK norm
|
| 6 |
+
- untied weights for token embedding and lm_head
|
| 7 |
+
- relu^2 activation in MLP
|
| 8 |
+
- norm after token embedding
|
| 9 |
+
- no learnable params in rmsnorm
|
| 10 |
+
- no bias in linear layers
|
| 11 |
+
- Group-Query Attention (GQA) support for more efficient inference
|
| 12 |
+
- Multi-Query Attention (MQA) option for maximum KV cache compression
|
| 13 |
+
- Multi-Token Prediction heads for improved training signal
|
| 14 |
+
- Flash Attention 3 integration
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from functools import partial
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from nanochat.common import get_dist_info, print0
|
| 25 |
+
from nanochat.muon import Muon, DistMuon
|
| 26 |
+
from nanochat.adamw import DistAdamW
|
| 27 |
+
|
| 28 |
+
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
| 29 |
+
from nanochat.flash_attention import flash_attn
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class GPTConfig:
|
| 33 |
+
sequence_len: int = 2048
|
| 34 |
+
vocab_size: int = 32768
|
| 35 |
+
n_layer: int = 12
|
| 36 |
+
n_head: int = 6 # number of query heads
|
| 37 |
+
n_kv_head: int = 6 # number of key/value heads (GQA)
|
| 38 |
+
n_embd: int = 768
|
| 39 |
+
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
| 40 |
+
# Characters: L=long (full context), S=short (half context)
|
| 41 |
+
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
| 42 |
+
window_pattern: str = "SSSL"
|
| 43 |
+
# Multi-Query Attention: use single KV head for all query heads that can be shared across multiple query heads,
|
| 44 |
+
# Reduces KV cache by n_head times
|
| 45 |
+
use_mqa: bool = False
|
| 46 |
+
# Multi-Token Prediction: extra heads predicting future tokens (t+2, t+3, t+4)
|
| 47 |
+
# Improves training signal and enables speculative decoding
|
| 48 |
+
multi_token_n: int = 3 # predicts 3 future tokens (t+2, t+3, t+4)
|
| 49 |
+
# Draft Head for self-draft speculative decoding
|
| 50 |
+
# Lightweight MLP that predicts multiple tokens at once for fast drafting
|
| 51 |
+
draft_n: int = 4 # number of tokens to draft in one shot
|
| 52 |
+
draft_hidden_mult: float = 0.5 # draft head hidden dim = n_embd * mult (smaller = faster)
|
| 53 |
+
|
| 54 |
+
def __post_init__(self):
|
| 55 |
+
if self.use_mqa:
|
| 56 |
+
self.n_kv_head = 1
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def norm(x):
|
| 60 |
+
# Purely functional rmsnorm with no learnable params
|
| 61 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def has_ve(layer_idx, n_layer):
|
| 65 |
+
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
| 66 |
+
return layer_idx % 2 == (n_layer - 1) % 2
|
| 67 |
+
|
| 68 |
+
def apply_rotary_emb(x, cos, sin):
|
| 69 |
+
assert x.ndim == 4 # multihead attention
|
| 70 |
+
d = x.shape[3] // 2
|
| 71 |
+
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
| 72 |
+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
| 73 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 74 |
+
return torch.cat([y1, y2], 3)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DraftHead(nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
Lightweight MLP head for self-draft speculative decoding.
|
| 80 |
+
Predicts multiple tokens at once from the last hidden state.
|
| 81 |
+
|
| 82 |
+
During inference:
|
| 83 |
+
1. Draft head quickly predicts N draft tokens
|
| 84 |
+
2. Main model verifies all N tokens in one parallel forward pass
|
| 85 |
+
3. Accept verified tokens, resample where draft was wrong
|
| 86 |
+
|
| 87 |
+
This amortizes the cost of autoregressive decoding.
|
| 88 |
+
"""
|
| 89 |
+
def __init__(self, n_embd, vocab_size, draft_n, hidden_mult=0.5):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.draft_n = draft_n
|
| 92 |
+
hidden_dim = int(n_embd * hidden_mult)
|
| 93 |
+
# 2-layer MLP: hidden layer + output layer predicting draft_n * vocab_size
|
| 94 |
+
self.fc1 = nn.Linear(n_embd, hidden_dim, bias=False)
|
| 95 |
+
self.fc2 = nn.Linear(hidden_dim, draft_n * vocab_size, bias=False)
|
| 96 |
+
self.vocab_size = vocab_size
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
x: hidden states (B, T, n_embd) or (B, n_embd) for single position
|
| 102 |
+
Returns:
|
| 103 |
+
draft_logits: (B, T, draft_n, vocab_size) or (B, draft_n, vocab_size)
|
| 104 |
+
"""
|
| 105 |
+
squeeze = x.dim() == 2
|
| 106 |
+
if squeeze:
|
| 107 |
+
x = x.unsqueeze(1) # (B, 1, n_embd)
|
| 108 |
+
|
| 109 |
+
B, T, _ = x.shape
|
| 110 |
+
h = F.relu(self.fc1(x)) ** 2 # ReLU² like the main MLP
|
| 111 |
+
out = self.fc2(h) # (B, T, draft_n * vocab_size)
|
| 112 |
+
out = out.view(B, T, self.draft_n, self.vocab_size)
|
| 113 |
+
|
| 114 |
+
if squeeze:
|
| 115 |
+
out = out.squeeze(1) # (B, draft_n, vocab_size)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class CausalSelfAttention(nn.Module):
|
| 120 |
+
def __init__(self, config, layer_idx):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.layer_idx = layer_idx
|
| 123 |
+
self.n_head = config.n_head
|
| 124 |
+
self.n_kv_head = config.n_kv_head
|
| 125 |
+
self.n_embd = config.n_embd
|
| 126 |
+
self.head_dim = self.n_embd // self.n_head
|
| 127 |
+
assert self.n_embd % self.n_head == 0
|
| 128 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
| 129 |
+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
| 130 |
+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 131 |
+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 132 |
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
| 133 |
+
self.ve_gate_channels = 32
|
| 134 |
+
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
| 135 |
+
|
| 136 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 137 |
+
B, T, C = x.size()
|
| 138 |
+
|
| 139 |
+
# Project the input to get queries, keys, and values
|
| 140 |
+
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
| 141 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 142 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 143 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 144 |
+
|
| 145 |
+
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
| 146 |
+
if ve is not None:
|
| 147 |
+
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
| 148 |
+
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
| 149 |
+
v = v + gate.unsqueeze(-1) * ve
|
| 150 |
+
|
| 151 |
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 152 |
+
cos, sin = cos_sin
|
| 153 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
| 154 |
+
q, k = norm(q), norm(k) # QK norm
|
| 155 |
+
|
| 156 |
+
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
| 157 |
+
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
| 158 |
+
if kv_cache is None:
|
| 159 |
+
# Training: causal attention with optional sliding window
|
| 160 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 161 |
+
else:
|
| 162 |
+
# Inference: use flash_attn_with_kvcache which handles cache management
|
| 163 |
+
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
| 164 |
+
y = flash_attn.flash_attn_with_kvcache(
|
| 165 |
+
q, k_cache, v_cache,
|
| 166 |
+
k=k, v=v,
|
| 167 |
+
cache_seqlens=kv_cache.cache_seqlens,
|
| 168 |
+
causal=True,
|
| 169 |
+
window_size=window_size,
|
| 170 |
+
)
|
| 171 |
+
# Advance position after last layer processes
|
| 172 |
+
if self.layer_idx == kv_cache.n_layers - 1:
|
| 173 |
+
kv_cache.advance(T)
|
| 174 |
+
|
| 175 |
+
# Re-assemble the heads and project back to residual stream
|
| 176 |
+
y = y.contiguous().view(B, T, -1)
|
| 177 |
+
y = self.c_proj(y)
|
| 178 |
+
return y
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class MLP(nn.Module):
|
| 182 |
+
def __init__(self, config):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 185 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
x = self.c_fc(x)
|
| 189 |
+
x = F.relu(x).square()
|
| 190 |
+
x = self.c_proj(x)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Block(nn.Module):
|
| 195 |
+
def __init__(self, config, layer_idx):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
| 198 |
+
self.mlp = MLP(config)
|
| 199 |
+
|
| 200 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 201 |
+
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
| 202 |
+
x = x + self.mlp(norm(x))
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class GPT(nn.Module):
|
| 207 |
+
def __init__(self, config, pad_vocab_size_to=64):
|
| 208 |
+
"""
|
| 209 |
+
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
| 210 |
+
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
| 211 |
+
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
| 212 |
+
"""
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.config = config
|
| 215 |
+
# Compute per-layer window sizes for sliding window attention
|
| 216 |
+
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
| 217 |
+
self.window_sizes = self._compute_window_sizes(config)
|
| 218 |
+
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
| 219 |
+
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
| 220 |
+
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
| 221 |
+
if padded_vocab_size != config.vocab_size:
|
| 222 |
+
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
| 223 |
+
self.transformer = nn.ModuleDict({
|
| 224 |
+
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
| 225 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 226 |
+
})
|
| 227 |
+
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
| 228 |
+
# Multi-token prediction heads: predict tokens at t+2, t+3, etc.
|
| 229 |
+
self.multi_token_heads = nn.ModuleDict()
|
| 230 |
+
for i in range(config.multi_token_n):
|
| 231 |
+
self.multi_token_heads[f"head_{i+2}"] = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
| 232 |
+
# Draft head for self-draft speculative decoding
|
| 233 |
+
self.draft_head = None
|
| 234 |
+
if config.draft_n > 0:
|
| 235 |
+
self.draft_head = DraftHead(
|
| 236 |
+
n_embd=config.n_embd,
|
| 237 |
+
vocab_size=config.vocab_size, # use actual vocab, not padded
|
| 238 |
+
draft_n=config.draft_n,
|
| 239 |
+
hidden_mult=config.draft_hidden_mult
|
| 240 |
+
)
|
| 241 |
+
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
| 242 |
+
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
| 243 |
+
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
| 244 |
+
# Separate parameters so they can have different optimizer treatment
|
| 245 |
+
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
| 246 |
+
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
| 247 |
+
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
| 248 |
+
head_dim = config.n_embd // config.n_head
|
| 249 |
+
kv_dim = config.n_kv_head * head_dim
|
| 250 |
+
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
| 251 |
+
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
| 252 |
+
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
| 253 |
+
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
| 254 |
+
# In the future we can dynamically grow the cache, for now it's fine.
|
| 255 |
+
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
| 256 |
+
head_dim = config.n_embd // config.n_head
|
| 257 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 258 |
+
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
| 259 |
+
self.register_buffer("sin", sin, persistent=False)
|
| 260 |
+
|
| 261 |
+
@torch.no_grad()
|
| 262 |
+
def init_weights(self):
|
| 263 |
+
"""
|
| 264 |
+
Initialize the full model in this one function for maximum clarity.
|
| 265 |
+
|
| 266 |
+
wte (embedding): normal, std=1.0
|
| 267 |
+
lm_head: normal, std=0.001
|
| 268 |
+
for each block:
|
| 269 |
+
attn.c_q: uniform, std=1/sqrt(n_embd)
|
| 270 |
+
attn.c_k: uniform, std=1/sqrt(n_embd)
|
| 271 |
+
attn.c_v: uniform, std=1/sqrt(n_embd)
|
| 272 |
+
attn.c_proj: zeros
|
| 273 |
+
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
| 274 |
+
mlp.c_proj: zeros
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
# Embedding and unembedding
|
| 278 |
+
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
| 279 |
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
| 280 |
+
# Multi-token prediction heads (same init as lm_head)
|
| 281 |
+
for head in self.multi_token_heads.values():
|
| 282 |
+
torch.nn.init.normal_(head.weight, mean=0.0, std=0.001)
|
| 283 |
+
# Draft head: small std for fc1 (like other projections), zeros for fc2 (starts neutral)
|
| 284 |
+
if self.draft_head is not None:
|
| 285 |
+
torch.nn.init.normal_(self.draft_head.fc1.weight, mean=0.0, std=self.config.n_embd**-0.5)
|
| 286 |
+
torch.nn.init.zeros_(self.draft_head.fc2.weight) # start with zero output
|
| 287 |
+
|
| 288 |
+
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
| 289 |
+
n_embd = self.config.n_embd
|
| 290 |
+
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
| 291 |
+
for block in self.transformer.h:
|
| 292 |
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
| 293 |
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
| 294 |
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
| 295 |
+
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
| 296 |
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
| 297 |
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
| 298 |
+
|
| 299 |
+
# Per-layer scalars
|
| 300 |
+
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
| 301 |
+
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
| 302 |
+
|
| 303 |
+
# Value embeddings (init like c_v: uniform with same std)
|
| 304 |
+
for ve in self.value_embeds.values():
|
| 305 |
+
torch.nn.init.uniform_(ve.weight, -s, s)
|
| 306 |
+
|
| 307 |
+
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
| 308 |
+
for block in self.transformer.h:
|
| 309 |
+
if block.attn.ve_gate is not None:
|
| 310 |
+
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
| 311 |
+
|
| 312 |
+
# Rotary embeddings
|
| 313 |
+
head_dim = self.config.n_embd // self.config.n_head
|
| 314 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 315 |
+
self.cos, self.sin = cos, sin
|
| 316 |
+
|
| 317 |
+
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
| 318 |
+
if self.transformer.wte.weight.device.type == "cuda":
|
| 319 |
+
self.transformer.wte.to(dtype=torch.bfloat16)
|
| 320 |
+
for ve in self.value_embeds.values():
|
| 321 |
+
ve.to(dtype=torch.bfloat16)
|
| 322 |
+
|
| 323 |
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
| 324 |
+
# TODO: bump base theta more? e.g. 100K is more common more recently
|
| 325 |
+
# autodetect the device from model embeddings
|
| 326 |
+
if device is None:
|
| 327 |
+
device = self.transformer.wte.weight.device
|
| 328 |
+
# stride the channels
|
| 329 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 330 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
| 331 |
+
# stride the time steps
|
| 332 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
| 333 |
+
# calculate the rotation frequencies at each (time, channel) pair
|
| 334 |
+
freqs = torch.outer(t, inv_freq)
|
| 335 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 336 |
+
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
| 337 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
| 338 |
+
return cos, sin
|
| 339 |
+
|
| 340 |
+
def _compute_window_sizes(self, config):
|
| 341 |
+
"""
|
| 342 |
+
Compute per-layer window sizes for sliding window attention.
|
| 343 |
+
|
| 344 |
+
Returns list of (left, right) tuples for FA3's window_size parameter:
|
| 345 |
+
- left: how many tokens before current position to attend to (-1 = unlimited)
|
| 346 |
+
- right: how many tokens after current position to attend to (0 for causal)
|
| 347 |
+
|
| 348 |
+
Pattern string is tiled across layers. Final layer always gets L (full context).
|
| 349 |
+
Characters: L=long (full context), S=short (half context)
|
| 350 |
+
"""
|
| 351 |
+
pattern = config.window_pattern.upper()
|
| 352 |
+
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
| 353 |
+
# Map characters to window sizes
|
| 354 |
+
long_window = config.sequence_len
|
| 355 |
+
short_window = long_window // 2
|
| 356 |
+
char_to_window = {
|
| 357 |
+
"L": (long_window, 0),
|
| 358 |
+
"S": (short_window, 0),
|
| 359 |
+
}
|
| 360 |
+
# Tile pattern across layers
|
| 361 |
+
window_sizes = []
|
| 362 |
+
for layer_idx in range(config.n_layer):
|
| 363 |
+
char = pattern[layer_idx % len(pattern)]
|
| 364 |
+
window_sizes.append(char_to_window[char])
|
| 365 |
+
# Final layer always gets full context
|
| 366 |
+
window_sizes[-1] = (long_window, 0)
|
| 367 |
+
return window_sizes
|
| 368 |
+
|
| 369 |
+
def get_device(self):
|
| 370 |
+
return self.transformer.wte.weight.device
|
| 371 |
+
|
| 372 |
+
def estimate_flops(self):
|
| 373 |
+
"""
|
| 374 |
+
Return the estimated FLOPs per token for the model (forward + backward).
|
| 375 |
+
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
| 376 |
+
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
| 377 |
+
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
| 378 |
+
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
| 379 |
+
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
| 380 |
+
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
| 381 |
+
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
| 382 |
+
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
| 383 |
+
"""
|
| 384 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 385 |
+
# Exclude non-matmul params: embeddings and per-layer scalars
|
| 386 |
+
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
| 387 |
+
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
| 388 |
+
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
| 389 |
+
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
| 390 |
+
# Sum attention FLOPs per layer, accounting for sliding window
|
| 391 |
+
attn_flops = 0
|
| 392 |
+
for window_size in self.window_sizes:
|
| 393 |
+
window = window_size[0] # (left, right) tuple, we use left
|
| 394 |
+
effective_seq = t if window < 0 else min(window, t)
|
| 395 |
+
attn_flops += 12 * h * q * effective_seq
|
| 396 |
+
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
| 397 |
+
return num_flops_per_token
|
| 398 |
+
|
| 399 |
+
def num_scaling_params(self):
|
| 400 |
+
"""
|
| 401 |
+
Return all of the parameters, same as Chinchilla paper.
|
| 402 |
+
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
|
| 403 |
+
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
|
| 404 |
+
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
|
| 405 |
+
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
|
| 406 |
+
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
|
| 407 |
+
"""
|
| 408 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 409 |
+
return nparams
|
| 410 |
+
|
| 411 |
+
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
| 412 |
+
model_dim = self.config.n_embd
|
| 413 |
+
ddp, rank, local_rank, world_size = get_dist_info()
|
| 414 |
+
# Separate out all parameters into groups
|
| 415 |
+
matrix_params = list(self.transformer.h.parameters())
|
| 416 |
+
value_embeds_params = list(self.value_embeds.parameters())
|
| 417 |
+
embedding_params = list(self.transformer.wte.parameters())
|
| 418 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 419 |
+
multi_token_params = list(self.multi_token_heads.parameters())
|
| 420 |
+
draft_head_params = list(self.draft_head.parameters()) if self.draft_head is not None else []
|
| 421 |
+
resid_params = [self.resid_lambdas]
|
| 422 |
+
x0_params = [self.x0_lambdas]
|
| 423 |
+
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(multi_token_params) + len(draft_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
| 424 |
+
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
| 425 |
+
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
| 426 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 427 |
+
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 428 |
+
adam_groups = [
|
| 429 |
+
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
| 430 |
+
dict(params=multi_token_params, lr=unembedding_lr * dmodel_lr_scale), # same LR as lm_head
|
| 431 |
+
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
| 432 |
+
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
| 433 |
+
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
| 434 |
+
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
|
| 435 |
+
]
|
| 436 |
+
# Add draft head params if present
|
| 437 |
+
if draft_head_params:
|
| 438 |
+
adam_groups.insert(2, dict(params=draft_head_params, lr=unembedding_lr * dmodel_lr_scale))
|
| 439 |
+
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
| 440 |
+
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
| 441 |
+
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
| 442 |
+
# Create the Muon optimizer for the linear layers
|
| 443 |
+
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
| 444 |
+
MuonFactory = DistMuon if ddp else Muon
|
| 445 |
+
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
| 446 |
+
# Combine them the two optimizers into one list
|
| 447 |
+
optimizers = [adamw_optimizer, muon_optimizer]
|
| 448 |
+
for opt in optimizers:
|
| 449 |
+
for group in opt.param_groups:
|
| 450 |
+
group["initial_lr"] = group["lr"]
|
| 451 |
+
return optimizers
|
| 452 |
+
|
| 453 |
+
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean', return_multi_token=False):
|
| 454 |
+
B, T = idx.size()
|
| 455 |
+
|
| 456 |
+
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
| 457 |
+
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
| 458 |
+
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
| 459 |
+
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
| 460 |
+
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
| 461 |
+
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
| 462 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
| 463 |
+
|
| 464 |
+
# Forward the trunk of the Transformer
|
| 465 |
+
x = self.transformer.wte(idx)
|
| 466 |
+
x = norm(x)
|
| 467 |
+
x0 = x # save initial normalized embedding for x0 residual
|
| 468 |
+
for i, block in enumerate(self.transformer.h):
|
| 469 |
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
| 470 |
+
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
| 471 |
+
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
| 472 |
+
x = norm(x)
|
| 473 |
+
|
| 474 |
+
# Forward the lm_head (compute logits)
|
| 475 |
+
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
| 476 |
+
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
| 477 |
+
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
| 478 |
+
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
| 479 |
+
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
| 480 |
+
|
| 481 |
+
# Multi-token prediction heads (for training with future token prediction)
|
| 482 |
+
multi_token_logits = {}
|
| 483 |
+
if return_multi_token and self.multi_token_heads:
|
| 484 |
+
for name, head in self.multi_token_heads.items():
|
| 485 |
+
mt_logits = head(x)
|
| 486 |
+
mt_logits = mt_logits[..., :self.config.vocab_size]
|
| 487 |
+
mt_logits = mt_logits.float()
|
| 488 |
+
mt_logits = softcap * torch.tanh(mt_logits / softcap)
|
| 489 |
+
multi_token_logits[name] = mt_logits
|
| 490 |
+
|
| 491 |
+
if targets is not None:
|
| 492 |
+
# training: given the targets, compute and return the loss
|
| 493 |
+
# TODO experiment with chunked cross-entropy?
|
| 494 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
| 495 |
+
if return_multi_token:
|
| 496 |
+
return loss, logits, multi_token_logits
|
| 497 |
+
return loss
|
| 498 |
+
else:
|
| 499 |
+
# inference: just return the logits directly
|
| 500 |
+
if return_multi_token:
|
| 501 |
+
return logits, multi_token_logits
|
| 502 |
+
return logits
|
| 503 |
+
|
| 504 |
+
@torch.inference_mode()
|
| 505 |
+
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
| 506 |
+
"""
|
| 507 |
+
Naive autoregressive streaming inference.
|
| 508 |
+
To make it super simple, let's assume:
|
| 509 |
+
- batch size is 1
|
| 510 |
+
- ids and the yielded tokens are simple Python lists and ints
|
| 511 |
+
"""
|
| 512 |
+
assert isinstance(tokens, list)
|
| 513 |
+
device = self.get_device()
|
| 514 |
+
rng = None
|
| 515 |
+
if temperature > 0:
|
| 516 |
+
rng = torch.Generator(device=device)
|
| 517 |
+
rng.manual_seed(seed)
|
| 518 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
| 519 |
+
for _ in range(max_tokens):
|
| 520 |
+
logits = self.forward(ids) # (B, T, vocab_size)
|
| 521 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
| 522 |
+
if top_k is not None:
|
| 523 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 524 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 525 |
+
if temperature > 0:
|
| 526 |
+
logits = logits / temperature
|
| 527 |
+
probs = F.softmax(logits, dim=-1)
|
| 528 |
+
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 529 |
+
else:
|
| 530 |
+
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
| 531 |
+
ids = torch.cat((ids, next_ids), dim=1)
|
| 532 |
+
token = next_ids.item()
|
| 533 |
+
yield token
|
| 534 |
+
|
| 535 |
+
@torch.inference_mode()
|
| 536 |
+
def generate_speculative(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
| 537 |
+
"""
|
| 538 |
+
Speculative decoding using self-draft.
|
| 539 |
+
|
| 540 |
+
Algorithm:
|
| 541 |
+
1. Get hidden state for last token
|
| 542 |
+
2. Draft head predicts N tokens quickly
|
| 543 |
+
3. Verify all N+1 positions (original + drafts) in one forward pass
|
| 544 |
+
4. Accept longest prefix where draft matches verification
|
| 545 |
+
5. Yield accepted tokens, repeat
|
| 546 |
+
|
| 547 |
+
This reduces the effective number of forward passes from max_tokens to ~max_tokens / acceptance_rate.
|
| 548 |
+
"""
|
| 549 |
+
assert isinstance(tokens, list)
|
| 550 |
+
assert self.draft_head is not None, "Draft head not available (draft_n=0 in config)"
|
| 551 |
+
device = self.get_device()
|
| 552 |
+
draft_n = self.config.draft_n
|
| 553 |
+
rng = None
|
| 554 |
+
if temperature > 0:
|
| 555 |
+
rng = torch.Generator(device=device)
|
| 556 |
+
rng.manual_seed(seed)
|
| 557 |
+
|
| 558 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 559 |
+
tokens_generated = 0
|
| 560 |
+
|
| 561 |
+
while tokens_generated < max_tokens:
|
| 562 |
+
# Forward pass to get hidden states (we need the raw hidden state for draft head)
|
| 563 |
+
# Run trunk to get hidden states
|
| 564 |
+
B, T = ids.size()
|
| 565 |
+
T0 = 0 # no kv cache for simplicity in this version
|
| 566 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
|
| 567 |
+
|
| 568 |
+
x = self.transformer.wte(ids)
|
| 569 |
+
x = norm(x)
|
| 570 |
+
x0 = x
|
| 571 |
+
for i, block in enumerate(self.transformer.h):
|
| 572 |
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
| 573 |
+
ve = self.value_embeds[str(i)](ids) if str(i) in self.value_embeds else None
|
| 574 |
+
x = block(x, ve, cos_sin, self.window_sizes[i], None)
|
| 575 |
+
x = norm(x)
|
| 576 |
+
|
| 577 |
+
# Get hidden state for last position
|
| 578 |
+
last_hidden = x[:, -1, :] # (B, n_embd)
|
| 579 |
+
|
| 580 |
+
# Draft N tokens using draft head
|
| 581 |
+
draft_logits = self.draft_head(last_hidden) # (B, draft_n, vocab_size)
|
| 582 |
+
if temperature > 0:
|
| 583 |
+
draft_logits = draft_logits / temperature
|
| 584 |
+
draft_probs = F.softmax(draft_logits, dim=-1)
|
| 585 |
+
draft_tokens = torch.multinomial(draft_probs.view(-1, draft_probs.size(-1)), num_samples=1, generator=rng)
|
| 586 |
+
draft_tokens = draft_tokens.view(B, draft_n) # (B, draft_n)
|
| 587 |
+
else:
|
| 588 |
+
draft_tokens = torch.argmax(draft_logits, dim=-1) # (B, draft_n)
|
| 589 |
+
|
| 590 |
+
# Prepare verification sequence: original + draft tokens
|
| 591 |
+
verify_ids = torch.cat([ids, draft_tokens], dim=1) # (B, T + draft_n)
|
| 592 |
+
|
| 593 |
+
# Verify all draft tokens with full model in one forward pass
|
| 594 |
+
verify_logits = self.forward(verify_ids) # (B, T + draft_n, vocab_size)
|
| 595 |
+
|
| 596 |
+
# Sample from verification logits for positions T-1 to T+draft_n-1
|
| 597 |
+
# Position T-1 verifies the first draft token, etc.
|
| 598 |
+
accepted = []
|
| 599 |
+
for i in range(draft_n):
|
| 600 |
+
pos = T - 1 + i # verification position
|
| 601 |
+
if pos >= verify_logits.size(1):
|
| 602 |
+
break
|
| 603 |
+
|
| 604 |
+
v_logits = verify_logits[:, pos, :]
|
| 605 |
+
if top_k is not None:
|
| 606 |
+
v, _ = torch.topk(v_logits, min(top_k, v_logits.size(-1)))
|
| 607 |
+
v_logits[v_logits < v[:, [-1]]] = -float('Inf')
|
| 608 |
+
|
| 609 |
+
if temperature > 0:
|
| 610 |
+
v_logits = v_logits / temperature
|
| 611 |
+
v_probs = F.softmax(v_logits, dim=-1)
|
| 612 |
+
verified_token = torch.multinomial(v_probs, num_samples=1, generator=rng)
|
| 613 |
+
else:
|
| 614 |
+
verified_token = torch.argmax(v_logits, dim=-1, keepdim=True)
|
| 615 |
+
|
| 616 |
+
# Check if draft matches verification
|
| 617 |
+
if i < draft_n and draft_tokens[0, i] == verified_token[0, 0]:
|
| 618 |
+
accepted.append(verified_token[0, 0].item())
|
| 619 |
+
else:
|
| 620 |
+
# Draft wrong, accept verified token and stop
|
| 621 |
+
accepted.append(verified_token[0, 0].item())
|
| 622 |
+
break
|
| 623 |
+
|
| 624 |
+
# Yield accepted tokens
|
| 625 |
+
for tok in accepted:
|
| 626 |
+
if tokens_generated >= max_tokens:
|
| 627 |
+
return
|
| 628 |
+
yield tok
|
| 629 |
+
tokens_generated += 1
|
| 630 |
+
|
| 631 |
+
# Update ids with accepted tokens
|
| 632 |
+
accepted_tensor = torch.tensor([accepted], dtype=torch.long, device=device)
|
| 633 |
+
ids = torch.cat([ids, accepted_tensor], dim=1)
|
nanochat/logo.svg
ADDED
|
|
nanochat/loss_eval.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A number of functions that help with evaluating a base model.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def evaluate_bpb(model, batches, steps, token_bytes):
|
| 10 |
+
"""
|
| 11 |
+
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
| 12 |
+
which is a tokenization vocab size-independent metric, meaning you are still comparing
|
| 13 |
+
apples:apples if you change the vocab size. The way this works is that instead of just
|
| 14 |
+
calculating the average loss as usual, you calculate the sum loss, and independently
|
| 15 |
+
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
| 16 |
+
the number of bytes that the target tokens represent.
|
| 17 |
+
|
| 18 |
+
The added complexity is so that:
|
| 19 |
+
1) All "normal" tokens are normalized by the length of the token in bytes
|
| 20 |
+
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
|
| 21 |
+
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
|
| 22 |
+
|
| 23 |
+
In addition to evaluate_loss, we need the token_bytes tensor:
|
| 24 |
+
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
|
| 25 |
+
each token id, or 0 if the token is to not be counted (e.g. special tokens).
|
| 26 |
+
"""
|
| 27 |
+
# record the losses
|
| 28 |
+
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
| 29 |
+
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
| 30 |
+
batch_iter = iter(batches)
|
| 31 |
+
for _ in range(steps):
|
| 32 |
+
x, y = next(batch_iter)
|
| 33 |
+
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
| 34 |
+
loss2d = loss2d.view(-1) # flatten
|
| 35 |
+
y = y.view(-1) # flatten
|
| 36 |
+
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
| 37 |
+
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
| 38 |
+
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
| 39 |
+
valid = y >= 0
|
| 40 |
+
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
| 41 |
+
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
| 42 |
+
num_bytes2d = torch.where(
|
| 43 |
+
valid,
|
| 44 |
+
token_bytes[y_safe],
|
| 45 |
+
torch.zeros_like(y, dtype=token_bytes.dtype)
|
| 46 |
+
)
|
| 47 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 48 |
+
total_bytes += num_bytes2d.sum()
|
| 49 |
+
else:
|
| 50 |
+
# fast path: no ignored targets, safe to index directly
|
| 51 |
+
num_bytes2d = token_bytes[y]
|
| 52 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 53 |
+
total_bytes += num_bytes2d.sum()
|
| 54 |
+
# sum reduce across all ranks
|
| 55 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 56 |
+
if world_size > 1:
|
| 57 |
+
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
| 58 |
+
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
| 59 |
+
# move both to cpu, calculate bpb and return
|
| 60 |
+
total_nats = total_nats.item()
|
| 61 |
+
total_bytes = total_bytes.item()
|
| 62 |
+
if total_bytes == 0:
|
| 63 |
+
return float('inf')
|
| 64 |
+
bpb = total_nats / (math.log(2) * total_bytes)
|
| 65 |
+
return bpb
|
nanochat/muon.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Muon optimizer adapted and simplified from modded-nanogpt.
|
| 3 |
+
https://github.com/KellerJordan/modded-nanogpt
|
| 4 |
+
|
| 5 |
+
Background:
|
| 6 |
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
| 7 |
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
| 8 |
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
| 9 |
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
| 10 |
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
| 11 |
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
| 12 |
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 13 |
+
|
| 14 |
+
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
| 15 |
+
Polar Express Sign Method for orthogonalization.
|
| 16 |
+
https://arxiv.org/pdf/2505.16932
|
| 17 |
+
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
| 18 |
+
|
| 19 |
+
Some of the changes in nanochat implementation:
|
| 20 |
+
- Uses a simpler, more general approach to parameter grouping and stacking
|
| 21 |
+
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
| 22 |
+
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import Tensor
|
| 27 |
+
import torch.distributed as dist
|
| 28 |
+
|
| 29 |
+
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
| 30 |
+
# From https://arxiv.org/pdf/2505.16932
|
| 31 |
+
polar_express_coeffs = [
|
| 32 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 33 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 34 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 35 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 36 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 40 |
+
def muon_step_fused(
|
| 41 |
+
stacked_grads: Tensor,
|
| 42 |
+
stacked_params: Tensor,
|
| 43 |
+
momentum_buffer: Tensor,
|
| 44 |
+
second_momentum_buffer: Tensor,
|
| 45 |
+
momentum_t: Tensor,
|
| 46 |
+
lr_t: Tensor,
|
| 47 |
+
wd_t: Tensor,
|
| 48 |
+
beta2_t: Tensor,
|
| 49 |
+
ns_steps: int,
|
| 50 |
+
red_dim: int,
|
| 51 |
+
) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
| 54 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 55 |
+
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Nesterov momentum
|
| 59 |
+
momentum = momentum_t.to(stacked_grads.dtype)
|
| 60 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 61 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 62 |
+
|
| 63 |
+
# Polar express
|
| 64 |
+
X = g.bfloat16()
|
| 65 |
+
if g.size(-2) > g.size(-1):
|
| 66 |
+
X = X.mT
|
| 67 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
| 68 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 69 |
+
A = X @ X.mT
|
| 70 |
+
B = b * A + c * (A @ A)
|
| 71 |
+
X = a * X + B @ X
|
| 72 |
+
if g.size(-2) > g.size(-1):
|
| 73 |
+
X = X.mT
|
| 74 |
+
g = X
|
| 75 |
+
|
| 76 |
+
# Variance reduction
|
| 77 |
+
beta2 = beta2_t.to(g.dtype)
|
| 78 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 79 |
+
red_dim_size = g.size(red_dim)
|
| 80 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 81 |
+
v_norm = v_norm_sq.sqrt()
|
| 82 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 83 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 84 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 85 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 86 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 87 |
+
g = g * final_scale.to(g.dtype)
|
| 88 |
+
|
| 89 |
+
# Cautious weight decay + parameter update
|
| 90 |
+
lr = lr_t.to(g.dtype)
|
| 91 |
+
wd = wd_t.to(g.dtype)
|
| 92 |
+
mask = (g * stacked_params) >= 0
|
| 93 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 94 |
+
|
| 95 |
+
class Muon(torch.optim.Optimizer):
|
| 96 |
+
"""
|
| 97 |
+
Muon - MomentUm Orthogonalized by Newton-schulz
|
| 98 |
+
|
| 99 |
+
https://kellerjordan.github.io/posts/muon/
|
| 100 |
+
|
| 101 |
+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
| 102 |
+
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
| 103 |
+
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
| 104 |
+
the advantage that it can be stably run in bfloat16 on the GPU.
|
| 105 |
+
|
| 106 |
+
Some warnings:
|
| 107 |
+
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
| 108 |
+
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
| 109 |
+
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
| 110 |
+
|
| 111 |
+
Arguments:
|
| 112 |
+
lr: The learning rate used by the internal SGD.
|
| 113 |
+
momentum: The momentum used by the internal SGD.
|
| 114 |
+
ns_steps: The number of Newton-Schulz iteration steps to use.
|
| 115 |
+
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
| 116 |
+
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
| 117 |
+
"""
|
| 118 |
+
def __init__(self, params, lr=0.02, momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
| 119 |
+
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
| 120 |
+
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
| 121 |
+
params = list(params) # ensure we have a list, not an e.g. (exhaustible) iterator
|
| 122 |
+
# Group by shape so we can stack tensors
|
| 123 |
+
shapes = sorted({p.shape for p in params})
|
| 124 |
+
param_groups = []
|
| 125 |
+
for shape in shapes:
|
| 126 |
+
group_params = [p for p in params if p.shape == shape]
|
| 127 |
+
param_groups.append(dict(params=group_params))
|
| 128 |
+
super().__init__(param_groups, defaults)
|
| 129 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 130 |
+
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 131 |
+
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 132 |
+
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 133 |
+
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 134 |
+
|
| 135 |
+
@torch.no_grad()
|
| 136 |
+
def step(self):
|
| 137 |
+
for group in self.param_groups:
|
| 138 |
+
params: list[Tensor] = group["params"]
|
| 139 |
+
if not params:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
# Get or create group-level buffers (stored in first param's state for convenience)
|
| 143 |
+
state = self.state[params[0]]
|
| 144 |
+
num_params = len(params) # e.g.: 12 (for a d12 model)
|
| 145 |
+
# e.g.: shape = (768, 3072), device = cuda:0, dtype = torch.float32, for one of the MLP projections
|
| 146 |
+
shape, device, dtype = params[0].shape, params[0].device, params[0].dtype
|
| 147 |
+
|
| 148 |
+
# Momentum for every individual parameter
|
| 149 |
+
if "momentum_buffer" not in state:
|
| 150 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 151 |
+
momentum_buffer = state["momentum_buffer"] # e.g.: (12, 768, 3072)
|
| 152 |
+
|
| 153 |
+
# Second momentum buffer is factored, either per-row or per-column
|
| 154 |
+
if "second_momentum_buffer" not in state:
|
| 155 |
+
if shape[-2] >= shape[-1]:
|
| 156 |
+
state["second_momentum_buffer"] = torch.zeros(num_params, shape[-2], 1, dtype=dtype, device=device)
|
| 157 |
+
else:
|
| 158 |
+
state["second_momentum_buffer"] = torch.zeros(num_params, 1, shape[-1], dtype=dtype, device=device)
|
| 159 |
+
second_momentum_buffer = state["second_momentum_buffer"] # (12, 1, 3072)
|
| 160 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2 # e.g.: -2
|
| 161 |
+
|
| 162 |
+
# Stack grads and params
|
| 163 |
+
stacked_grads = torch.stack([p.grad for p in params]) # (12, 768, 3072)
|
| 164 |
+
stacked_params = torch.stack(params) # (12, 768, 3072)
|
| 165 |
+
|
| 166 |
+
# Fill all the 0-D tensors with current values
|
| 167 |
+
self._momentum_t.fill_(group["momentum"])
|
| 168 |
+
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 169 |
+
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 170 |
+
self._wd_t.fill_(group["weight_decay"])
|
| 171 |
+
|
| 172 |
+
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
| 173 |
+
muon_step_fused(
|
| 174 |
+
stacked_grads,
|
| 175 |
+
stacked_params,
|
| 176 |
+
momentum_buffer,
|
| 177 |
+
second_momentum_buffer,
|
| 178 |
+
self._momentum_t,
|
| 179 |
+
self._lr_t,
|
| 180 |
+
self._wd_t,
|
| 181 |
+
self._beta2_t,
|
| 182 |
+
group["ns_steps"],
|
| 183 |
+
red_dim,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Copy back to original params: [(768, 3072), (768, 3072), ...] <- (12, 768, 3072)
|
| 187 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class DistMuon(torch.optim.Optimizer):
|
| 191 |
+
"""
|
| 192 |
+
Distributed version of the Muon optimizer.
|
| 193 |
+
"""
|
| 194 |
+
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
| 195 |
+
ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
| 196 |
+
defaults = dict(lr=lr, momentum=momentum, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
| 197 |
+
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
| 198 |
+
params = list(params)
|
| 199 |
+
world_size = dist.get_world_size()
|
| 200 |
+
rank = dist.get_rank()
|
| 201 |
+
# Group all parameters by their shape
|
| 202 |
+
shapes = sorted({p.shape for p in params}) # sort for deterministic ordering across ranks
|
| 203 |
+
param_groups = []
|
| 204 |
+
for shape in shapes:
|
| 205 |
+
group_params = [p for p in params if p.shape == shape]
|
| 206 |
+
device, dtype = group_params[0].device, group_params[0].dtype
|
| 207 |
+
assert all(p.device == device for p in group_params)
|
| 208 |
+
assert all(p.dtype == dtype for p in group_params)
|
| 209 |
+
# Compute chunk size for this group (how many params each rank owns)
|
| 210 |
+
chunk_size = (len(group_params) + world_size - 1) // world_size
|
| 211 |
+
if rank == 0:
|
| 212 |
+
print(f"Muon: {len(group_params)} params of shape {shape}, chunk_size={chunk_size}")
|
| 213 |
+
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
|
| 214 |
+
super().__init__(param_groups, defaults)
|
| 215 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 216 |
+
self._momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 217 |
+
self._lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 218 |
+
self._wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 219 |
+
self._beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 220 |
+
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def step(self):
|
| 223 |
+
rank = dist.get_rank()
|
| 224 |
+
world_size = dist.get_world_size()
|
| 225 |
+
|
| 226 |
+
# Ensure all grads exist
|
| 227 |
+
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
| 228 |
+
|
| 229 |
+
# First pass: stack grads and kick off reduce_scatter for each group
|
| 230 |
+
group_infos = []
|
| 231 |
+
for group in self.param_groups:
|
| 232 |
+
params: list[Tensor] = group["params"]
|
| 233 |
+
chunk_size = group["chunk_size"]
|
| 234 |
+
padded_num_params = chunk_size * world_size
|
| 235 |
+
shape = params[0].shape
|
| 236 |
+
device, dtype = params[0].device, params[0].dtype
|
| 237 |
+
|
| 238 |
+
# Stack all gradients into a single tensor (single kernel via torch.stack)
|
| 239 |
+
grad_stack = torch.stack([p.grad for p in params])
|
| 240 |
+
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
| 241 |
+
stacked_grads[:len(params)].copy_(grad_stack)
|
| 242 |
+
# Zero-pad if we have fewer params than padded size
|
| 243 |
+
if len(params) < padded_num_params:
|
| 244 |
+
stacked_grads[len(params):].zero_()
|
| 245 |
+
|
| 246 |
+
# Output buffer for this rank's chunk
|
| 247 |
+
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 248 |
+
|
| 249 |
+
# Async reduce_scatter on the stacked tensor
|
| 250 |
+
reduce_future = dist.reduce_scatter_tensor(
|
| 251 |
+
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
|
| 252 |
+
).get_future()
|
| 253 |
+
|
| 254 |
+
group_infos.append(dict(
|
| 255 |
+
grad_chunk=grad_chunk,
|
| 256 |
+
reduce_future=reduce_future,
|
| 257 |
+
stacked_grads=stacked_grads, # reuse for all_gather output
|
| 258 |
+
))
|
| 259 |
+
|
| 260 |
+
# Second pass: wait for reduce, compute batched updates, kick off all_gather
|
| 261 |
+
all_gather_futures = []
|
| 262 |
+
for group, info in zip(self.param_groups, group_infos):
|
| 263 |
+
info["reduce_future"].wait()
|
| 264 |
+
|
| 265 |
+
params = group["params"]
|
| 266 |
+
chunk_size = group["chunk_size"]
|
| 267 |
+
shape = params[0].shape
|
| 268 |
+
device, dtype = params[0].device, params[0].dtype
|
| 269 |
+
grad_chunk = info["grad_chunk"]
|
| 270 |
+
|
| 271 |
+
# How many params does this rank actually own?
|
| 272 |
+
start_idx = rank * chunk_size
|
| 273 |
+
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
| 274 |
+
|
| 275 |
+
# Get or create group-level state (stored keyed by first param)
|
| 276 |
+
state = self.state[params[0]]
|
| 277 |
+
|
| 278 |
+
# Momentum buffer
|
| 279 |
+
if "momentum_buffer" not in state:
|
| 280 |
+
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
| 281 |
+
momentum_buffer = state["momentum_buffer"]
|
| 282 |
+
|
| 283 |
+
# Second momentum buffer is factored, either per-row or per-column
|
| 284 |
+
if "second_momentum_buffer" not in state:
|
| 285 |
+
if shape[-2] >= shape[-1]:
|
| 286 |
+
state["second_momentum_buffer"] = torch.zeros(chunk_size, shape[-2], 1, dtype=dtype, device=device)
|
| 287 |
+
else:
|
| 288 |
+
state["second_momentum_buffer"] = torch.zeros(chunk_size, 1, shape[-1], dtype=dtype, device=device)
|
| 289 |
+
second_momentum_buffer = state["second_momentum_buffer"]
|
| 290 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 291 |
+
|
| 292 |
+
# Build updated_params tensor for all_gather
|
| 293 |
+
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 294 |
+
|
| 295 |
+
if num_owned > 0:
|
| 296 |
+
# Stack owned params (single kernel via torch.stack)
|
| 297 |
+
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
| 298 |
+
stacked_owned_params = torch.stack(owned_params)
|
| 299 |
+
|
| 300 |
+
# Get owned slices of buffers and grads
|
| 301 |
+
owned_grads = grad_chunk[:num_owned]
|
| 302 |
+
owned_momentum = momentum_buffer[:num_owned]
|
| 303 |
+
owned_second_momentum = second_momentum_buffer[:num_owned]
|
| 304 |
+
|
| 305 |
+
# Fill 0-D tensors with current values
|
| 306 |
+
self._momentum_t.fill_(group["momentum"])
|
| 307 |
+
self._beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 308 |
+
self._lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 309 |
+
self._wd_t.fill_(group["weight_decay"])
|
| 310 |
+
|
| 311 |
+
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
| 312 |
+
muon_step_fused(
|
| 313 |
+
owned_grads,
|
| 314 |
+
stacked_owned_params,
|
| 315 |
+
owned_momentum,
|
| 316 |
+
owned_second_momentum,
|
| 317 |
+
self._momentum_t,
|
| 318 |
+
self._lr_t,
|
| 319 |
+
self._wd_t,
|
| 320 |
+
self._beta2_t,
|
| 321 |
+
group["ns_steps"],
|
| 322 |
+
red_dim,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Copy updated params to output buffer
|
| 326 |
+
updated_params[:num_owned].copy_(stacked_owned_params)
|
| 327 |
+
|
| 328 |
+
# Zero-pad the rest (for ranks that own fewer params)
|
| 329 |
+
if num_owned < chunk_size:
|
| 330 |
+
updated_params[num_owned:].zero_()
|
| 331 |
+
|
| 332 |
+
# Reuse stacked_grads buffer for all_gather output
|
| 333 |
+
stacked_params = info["stacked_grads"]
|
| 334 |
+
|
| 335 |
+
# Async all_gather to replicate updated params to all ranks
|
| 336 |
+
gather_future = dist.all_gather_into_tensor(
|
| 337 |
+
stacked_params, updated_params, async_op=True
|
| 338 |
+
).get_future()
|
| 339 |
+
|
| 340 |
+
all_gather_futures.append(dict(
|
| 341 |
+
gather_future=gather_future,
|
| 342 |
+
stacked_params=stacked_params,
|
| 343 |
+
params=params,
|
| 344 |
+
))
|
| 345 |
+
|
| 346 |
+
# Final pass: wait for all_gather and copy back to params
|
| 347 |
+
for info in all_gather_futures:
|
| 348 |
+
info["gather_future"].wait()
|
| 349 |
+
stacked_params = info["stacked_params"]
|
| 350 |
+
params = info["params"]
|
| 351 |
+
# Batched copy back (single kernel instead of N individual copies)
|
| 352 |
+
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
nanochat/prune.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def head_imp(model):
|
| 6 |
+
head_importance = {}
|
| 7 |
+
head_dim = model.config.n_embd // model.config.n_head
|
| 8 |
+
|
| 9 |
+
for layer_idx, block in enumerate(model.transformer.h):
|
| 10 |
+
attn = block.attn
|
| 11 |
+
n_head = attn.n_head
|
| 12 |
+
n_kv_head = attn.n_kv_head
|
| 13 |
+
|
| 14 |
+
head_scores = torch.zeros(n_head, device=next(model.parameters()).device)
|
| 15 |
+
|
| 16 |
+
q_weight = attn.c_q.weight.view(n_head, head_dim, model.config.n_embd)
|
| 17 |
+
head_scores += q_weight.abs().sum(dim=(1, 2))
|
| 18 |
+
|
| 19 |
+
proj_weight = attn.c_proj.weight.view(model.config.n_embd, n_head, head_dim)
|
| 20 |
+
head_scores += proj_weight.abs().sum(dim=(0, 2))
|
| 21 |
+
|
| 22 |
+
if n_kv_head == n_head:
|
| 23 |
+
k_weight = attn.c_k.weight.view(n_head, head_dim, model.config.n_embd)
|
| 24 |
+
v_weight = attn.c_v.weight.view(n_head, head_dim, model.config.n_embd)
|
| 25 |
+
head_scores += k_weight.abs().sum(dim=(1, 2))
|
| 26 |
+
head_scores += v_weight.abs().sum(dim=(1, 2))
|
| 27 |
+
|
| 28 |
+
head_importance[layer_idx] = head_scores
|
| 29 |
+
|
| 30 |
+
return head_importance
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def neuron_imp(model):
|
| 34 |
+
neuron_importance = {}
|
| 35 |
+
hidden_dim = 4 * model.config.n_embd
|
| 36 |
+
|
| 37 |
+
for layer_idx, block in enumerate(model.transformer.h):
|
| 38 |
+
mlp = block.mlp
|
| 39 |
+
|
| 40 |
+
fc_importance = mlp.c_fc.weight.abs().sum(dim=1)
|
| 41 |
+
proj_importance = mlp.c_proj.weight.abs().sum(dim=0)
|
| 42 |
+
|
| 43 |
+
neuron_scores = fc_importance + proj_importance
|
| 44 |
+
neuron_importance[layer_idx] = neuron_scores
|
| 45 |
+
|
| 46 |
+
return neuron_importance
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def select_heads(head_importance, prune_ratio):
|
| 50 |
+
heads_to_keep = {}
|
| 51 |
+
|
| 52 |
+
for layer_idx, scores in head_importance.items():
|
| 53 |
+
n_head = len(scores)
|
| 54 |
+
n_to_keep = int(n_head * (1 - prune_ratio))
|
| 55 |
+
n_to_keep = max(1, n_to_keep)
|
| 56 |
+
|
| 57 |
+
_, top_indices = torch.topk(scores, n_to_keep, largest=True)
|
| 58 |
+
heads_to_keep[layer_idx] = top_indices.sort().values.tolist()
|
| 59 |
+
|
| 60 |
+
return heads_to_keep
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def select_neurons(neuron_importance, prune_ratio):
|
| 64 |
+
neurons_to_keep = {}
|
| 65 |
+
|
| 66 |
+
for layer_idx, scores in neuron_importance.items():
|
| 67 |
+
n_neurons = len(scores)
|
| 68 |
+
n_to_keep = int(n_neurons * (1 - prune_ratio))
|
| 69 |
+
n_to_keep = max(1, n_to_keep)
|
| 70 |
+
|
| 71 |
+
_, top_indices = torch.topk(scores, n_to_keep, largest=True)
|
| 72 |
+
neurons_to_keep[layer_idx] = top_indices.sort().values.tolist()
|
| 73 |
+
|
| 74 |
+
return neurons_to_keep
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def make_pruned_config(original_config, heads_to_keep, neurons_to_keep):
|
| 78 |
+
min_heads = min(len(heads) for heads in heads_to_keep.values())
|
| 79 |
+
min_neurons = min(len(neurons) for neurons in neurons_to_keep.values())
|
| 80 |
+
|
| 81 |
+
head_dim = original_config.n_embd // original_config.n_head
|
| 82 |
+
new_n_embd = min_heads * head_dim
|
| 83 |
+
|
| 84 |
+
config = GPTConfig(
|
| 85 |
+
sequence_len=original_config.sequence_len,
|
| 86 |
+
vocab_size=original_config.vocab_size,
|
| 87 |
+
n_layer=original_config.n_layer,
|
| 88 |
+
n_head=min_heads,
|
| 89 |
+
n_kv_head=1 if original_config.use_mqa else min_heads,
|
| 90 |
+
n_embd=new_n_embd,
|
| 91 |
+
window_pattern=original_config.window_pattern,
|
| 92 |
+
use_mqa=original_config.use_mqa,
|
| 93 |
+
multi_token_n=original_config.multi_token_n,
|
| 94 |
+
draft_n=original_config.draft_n,
|
| 95 |
+
draft_hidden_mult=original_config.draft_hidden_mult,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return config, min_heads, min_neurons
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def prune_weights(model, heads_to_keep, neurons_to_keep, pruned_config, min_heads, min_neurons):
|
| 102 |
+
head_dim = model.config.n_embd // model.config.n_head
|
| 103 |
+
original_n_embd = model.config.n_embd
|
| 104 |
+
pruned_n_embd = pruned_config.n_embd
|
| 105 |
+
|
| 106 |
+
with torch.device("meta"):
|
| 107 |
+
pruned_model = GPT(pruned_config)
|
| 108 |
+
|
| 109 |
+
device = next(model.parameters()).device
|
| 110 |
+
pruned_model.to_empty(device=device)
|
| 111 |
+
|
| 112 |
+
pruned_model.transformer.wte.weight.data.copy_(model.transformer.wte.weight.data[:, :pruned_n_embd])
|
| 113 |
+
|
| 114 |
+
pruned_model.resid_lambdas.data.copy_(model.resid_lambdas.data)
|
| 115 |
+
pruned_model.x0_lambdas.data.copy_(model.x0_lambdas.data)
|
| 116 |
+
|
| 117 |
+
for key in model.value_embeds.keys():
|
| 118 |
+
if key in pruned_model.value_embeds:
|
| 119 |
+
orig_ve = model.value_embeds[key].weight.data
|
| 120 |
+
pruned_ve = pruned_model.value_embeds[key].weight.data
|
| 121 |
+
if orig_ve.size(1) > pruned_ve.size(1):
|
| 122 |
+
pruned_model.value_embeds[key].weight.data.copy_(orig_ve[:, :pruned_ve.size(1)])
|
| 123 |
+
else:
|
| 124 |
+
pruned_model.value_embeds[key].weight.data.copy_(orig_ve)
|
| 125 |
+
|
| 126 |
+
for key in model.multi_token_heads.keys():
|
| 127 |
+
if key in pruned_model.multi_token_heads:
|
| 128 |
+
orig_weight = model.multi_token_heads[key].weight.data
|
| 129 |
+
pruned_model.multi_token_heads[key].weight.data.copy_(orig_weight[:, :pruned_n_embd])
|
| 130 |
+
|
| 131 |
+
if model.draft_head is not None and pruned_model.draft_head is not None:
|
| 132 |
+
pruned_model.draft_head.fc1.weight.data.copy_(model.draft_head.fc1.weight.data[:, :pruned_n_embd])
|
| 133 |
+
pruned_model.draft_head.fc2.weight.data.copy_(model.draft_head.fc2.weight.data)
|
| 134 |
+
|
| 135 |
+
pruned_model.lm_head.weight.data.copy_(model.lm_head.weight.data[:, :pruned_n_embd])
|
| 136 |
+
|
| 137 |
+
for layer_idx in range(model.config.n_layer):
|
| 138 |
+
orig_block = model.transformer.h[layer_idx]
|
| 139 |
+
pruned_block = pruned_model.transformer.h[layer_idx]
|
| 140 |
+
|
| 141 |
+
layer_heads = heads_to_keep[layer_idx]
|
| 142 |
+
layer_neurons = neurons_to_keep[layer_idx]
|
| 143 |
+
|
| 144 |
+
attn_orig = orig_block.attn
|
| 145 |
+
attn_pruned = pruned_block.attn
|
| 146 |
+
|
| 147 |
+
q_orig = attn_orig.c_q.weight.view(model.config.n_head, head_dim, original_n_embd)
|
| 148 |
+
q_pruned = q_orig[layer_heads[:min_heads]].contiguous().view(min_heads * head_dim, pruned_n_embd)
|
| 149 |
+
attn_pruned.c_q.weight.data.copy_(q_pruned)
|
| 150 |
+
|
| 151 |
+
if attn_orig.n_kv_head == model.config.n_head:
|
| 152 |
+
k_orig = attn_orig.c_k.weight.view(model.config.n_head, head_dim, original_n_embd)
|
| 153 |
+
k_pruned = k_orig[layer_heads[:min_heads]].contiguous().view(min_heads * head_dim, pruned_n_embd)
|
| 154 |
+
attn_pruned.c_k.weight.data.copy_(k_pruned)
|
| 155 |
+
else:
|
| 156 |
+
k_orig = attn_orig.c_k.weight
|
| 157 |
+
k_pruned = k_orig[:, :pruned_n_embd]
|
| 158 |
+
attn_pruned.c_k.weight.data.copy_(k_pruned)
|
| 159 |
+
|
| 160 |
+
if attn_orig.n_kv_head == model.config.n_head:
|
| 161 |
+
v_orig = attn_orig.c_v.weight.view(model.config.n_head, head_dim, original_n_embd)
|
| 162 |
+
v_pruned = v_orig[layer_heads[:min_heads]].contiguous().view(min_heads * head_dim, pruned_n_embd)
|
| 163 |
+
attn_pruned.c_v.weight.data.copy_(v_pruned)
|
| 164 |
+
else:
|
| 165 |
+
v_orig = attn_orig.c_v.weight
|
| 166 |
+
v_pruned = v_orig[:, :pruned_n_embd]
|
| 167 |
+
attn_pruned.c_v.weight.data.copy_(v_pruned)
|
| 168 |
+
|
| 169 |
+
proj_orig = attn_orig.c_proj.weight.view(original_n_embd, model.config.n_head, head_dim)
|
| 170 |
+
proj_pruned = proj_orig[:, layer_heads[:min_heads], :].contiguous().view(original_n_embd, min_heads * head_dim)
|
| 171 |
+
proj_pruned = proj_pruned[:pruned_n_embd, :]
|
| 172 |
+
attn_pruned.c_proj.weight.data.copy_(proj_pruned)
|
| 173 |
+
|
| 174 |
+
if attn_orig.ve_gate is not None and attn_pruned.ve_gate is not None:
|
| 175 |
+
if attn_orig.n_kv_head == model.config.n_head:
|
| 176 |
+
gate_orig = attn_orig.ve_gate.weight.view(model.config.n_head, -1)
|
| 177 |
+
gate_pruned = gate_orig[layer_heads[:min_heads]]
|
| 178 |
+
attn_pruned.ve_gate.weight.data.copy_(gate_pruned.view(min_heads, -1))
|
| 179 |
+
else:
|
| 180 |
+
attn_pruned.ve_gate.weight.data.copy_(attn_orig.ve_gate.weight.data)
|
| 181 |
+
|
| 182 |
+
mlp_orig = orig_block.mlp
|
| 183 |
+
mlp_pruned = pruned_block.mlp
|
| 184 |
+
|
| 185 |
+
fc_orig = mlp_orig.c_fc.weight
|
| 186 |
+
fc_pruned = fc_orig[layer_neurons[:min_neurons]]
|
| 187 |
+
fc_pruned = fc_pruned[:, :pruned_n_embd]
|
| 188 |
+
mlp_pruned.c_fc.weight.data.copy_(fc_pruned)
|
| 189 |
+
|
| 190 |
+
proj_orig = mlp_orig.c_proj.weight
|
| 191 |
+
proj_pruned = proj_orig[:, layer_neurons[:min_neurons]]
|
| 192 |
+
proj_pruned = proj_pruned[:pruned_n_embd, :]
|
| 193 |
+
mlp_pruned.c_proj.weight.data.copy_(proj_pruned)
|
| 194 |
+
|
| 195 |
+
pruned_model.cos.copy_(model.cos)
|
| 196 |
+
pruned_model.sin.copy_(model.sin)
|
| 197 |
+
|
| 198 |
+
return pruned_model
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def prune_model(model, head_prune_ratio=0.2, neuron_prune_ratio=0.2):
|
| 202 |
+
head_importance = head_imp(model)
|
| 203 |
+
neuron_importance = neuron_imp(model)
|
| 204 |
+
|
| 205 |
+
heads_to_keep = select_heads(head_importance, head_prune_ratio)
|
| 206 |
+
neurons_to_keep = select_neurons(neuron_importance, neuron_prune_ratio)
|
| 207 |
+
|
| 208 |
+
config, min_heads, min_neurons = make_pruned_config(
|
| 209 |
+
model.config, heads_to_keep, neurons_to_keep
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
pruned_model = prune_weights(
|
| 213 |
+
model, heads_to_keep, neurons_to_keep, config, min_heads, min_neurons
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return pruned_model, config
|
nanochat/quantize.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def quantize_tensor(weight, bits=8):
|
| 7 |
+
qmin = -(2 ** (bits - 1))
|
| 8 |
+
qmax = 2 ** (bits - 1) - 1
|
| 9 |
+
|
| 10 |
+
scale = weight.abs().max() / qmax
|
| 11 |
+
scale = scale.clamp(min=1e-8)
|
| 12 |
+
|
| 13 |
+
quantized = (weight / scale).round().clamp(qmin, qmax)
|
| 14 |
+
|
| 15 |
+
return quantized.to(torch.int8), scale.item()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def dequantize_tensor(quantized, scale):
|
| 19 |
+
return quantized.float() * scale
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def quantize_linear(linear_layer, bits=8):
|
| 23 |
+
weight = linear_layer.weight.data
|
| 24 |
+
quantized, scale = quantize_tensor(weight, bits)
|
| 25 |
+
return quantized, scale
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def quantize_model(model, bits=8):
|
| 29 |
+
quantized_state = {}
|
| 30 |
+
scales = {}
|
| 31 |
+
|
| 32 |
+
for name, param in model.named_parameters():
|
| 33 |
+
if param.requires_grad and len(param.shape) >= 2:
|
| 34 |
+
quantized, scale = quantize_tensor(param.data, bits)
|
| 35 |
+
quantized_state[name] = quantized
|
| 36 |
+
scales[name] = scale
|
| 37 |
+
else:
|
| 38 |
+
quantized_state[name] = param.data
|
| 39 |
+
|
| 40 |
+
for name, buffer in model.named_buffers():
|
| 41 |
+
quantized_state[name] = buffer.data
|
| 42 |
+
|
| 43 |
+
return quantized_state, scales
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def apply_quantization(model, scales, bits=8):
|
| 47 |
+
for name, param in model.named_parameters():
|
| 48 |
+
if name in scales and len(param.shape) >= 2:
|
| 49 |
+
scale = scales[name]
|
| 50 |
+
quantized = param.data
|
| 51 |
+
param.data = dequantize_tensor(quantized, scale)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def export_int8(model, output_path, bits=8):
|
| 55 |
+
quantized_state, scales = quantize_model(model, bits)
|
| 56 |
+
|
| 57 |
+
export_data = {
|
| 58 |
+
'quantized_weights': quantized_state,
|
| 59 |
+
'scales': scales,
|
| 60 |
+
'config': {
|
| 61 |
+
'n_layer': model.config.n_layer,
|
| 62 |
+
'n_head': model.config.n_head,
|
| 63 |
+
'n_kv_head': model.config.n_kv_head,
|
| 64 |
+
'n_embd': model.config.n_embd,
|
| 65 |
+
'vocab_size': model.config.vocab_size,
|
| 66 |
+
'sequence_len': model.config.sequence_len,
|
| 67 |
+
'window_pattern': model.config.window_pattern,
|
| 68 |
+
'use_mqa': model.config.use_mqa,
|
| 69 |
+
'multi_token_n': model.config.multi_token_n,
|
| 70 |
+
'draft_n': model.config.draft_n,
|
| 71 |
+
'draft_hidden_mult': model.config.draft_hidden_mult,
|
| 72 |
+
},
|
| 73 |
+
'bits': bits,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
torch.save(export_data, output_path)
|
| 77 |
+
return export_data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_int8(model_path, device):
|
| 81 |
+
data = torch.load(model_path, map_location=device)
|
| 82 |
+
|
| 83 |
+
config_kwargs = data['config']
|
| 84 |
+
config = GPTConfig(**config_kwargs)
|
| 85 |
+
|
| 86 |
+
with torch.device("meta"):
|
| 87 |
+
model = GPT(config)
|
| 88 |
+
|
| 89 |
+
model.to_empty(device=device)
|
| 90 |
+
model.init_weights()
|
| 91 |
+
|
| 92 |
+
quantized_state = data['quantized_weights']
|
| 93 |
+
scales = data['scales']
|
| 94 |
+
|
| 95 |
+
state_dict = {}
|
| 96 |
+
for name, param in model.named_parameters():
|
| 97 |
+
if name in quantized_state:
|
| 98 |
+
if name in scales:
|
| 99 |
+
quantized = quantized_state[name]
|
| 100 |
+
scale = scales[name]
|
| 101 |
+
state_dict[name] = dequantize_tensor(quantized, scale)
|
| 102 |
+
else:
|
| 103 |
+
state_dict[name] = quantized_state[name]
|
| 104 |
+
|
| 105 |
+
for name, buffer in model.named_buffers():
|
| 106 |
+
if name in quantized_state:
|
| 107 |
+
state_dict[name] = quantized_state[name]
|
| 108 |
+
|
| 109 |
+
model.load_state_dict(state_dict, strict=False)
|
| 110 |
+
model.eval()
|
| 111 |
+
|
| 112 |
+
return model, data['config'], scales
|
| 113 |
+
|
nanochat/report.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for generating training report cards. More messy code than usual, will fix.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import subprocess
|
| 9 |
+
import socket
|
| 10 |
+
import datetime
|
| 11 |
+
import platform
|
| 12 |
+
import psutil
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
def run_command(cmd):
|
| 16 |
+
"""Run a shell command and return output, or None if it fails."""
|
| 17 |
+
try:
|
| 18 |
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
| 19 |
+
# Return stdout if we got output (even if some files in xargs failed)
|
| 20 |
+
if result.stdout.strip():
|
| 21 |
+
return result.stdout.strip()
|
| 22 |
+
if result.returncode == 0:
|
| 23 |
+
return ""
|
| 24 |
+
return None
|
| 25 |
+
except:
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def get_git_info():
|
| 29 |
+
"""Get current git commit, branch, and dirty status."""
|
| 30 |
+
info = {}
|
| 31 |
+
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
|
| 32 |
+
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
|
| 33 |
+
|
| 34 |
+
# Check if repo is dirty (has uncommitted changes)
|
| 35 |
+
status = run_command("git status --porcelain")
|
| 36 |
+
info['dirty'] = bool(status) if status is not None else False
|
| 37 |
+
|
| 38 |
+
# Get commit message
|
| 39 |
+
info['message'] = run_command("git log -1 --pretty=%B") or ""
|
| 40 |
+
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
|
| 41 |
+
|
| 42 |
+
return info
|
| 43 |
+
|
| 44 |
+
def get_gpu_info():
|
| 45 |
+
"""Get GPU information."""
|
| 46 |
+
if not torch.cuda.is_available():
|
| 47 |
+
return {"available": False}
|
| 48 |
+
|
| 49 |
+
num_devices = torch.cuda.device_count()
|
| 50 |
+
info = {
|
| 51 |
+
"available": True,
|
| 52 |
+
"count": num_devices,
|
| 53 |
+
"names": [],
|
| 54 |
+
"memory_gb": []
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
for i in range(num_devices):
|
| 58 |
+
props = torch.cuda.get_device_properties(i)
|
| 59 |
+
info["names"].append(props.name)
|
| 60 |
+
info["memory_gb"].append(props.total_memory / (1024**3))
|
| 61 |
+
|
| 62 |
+
# Get CUDA version
|
| 63 |
+
info["cuda_version"] = torch.version.cuda or "unknown"
|
| 64 |
+
|
| 65 |
+
return info
|
| 66 |
+
|
| 67 |
+
def get_system_info():
|
| 68 |
+
"""Get system information."""
|
| 69 |
+
info = {}
|
| 70 |
+
|
| 71 |
+
# Basic system info
|
| 72 |
+
info['hostname'] = socket.gethostname()
|
| 73 |
+
info['platform'] = platform.system()
|
| 74 |
+
info['python_version'] = platform.python_version()
|
| 75 |
+
info['torch_version'] = torch.__version__
|
| 76 |
+
|
| 77 |
+
# CPU and memory
|
| 78 |
+
info['cpu_count'] = psutil.cpu_count(logical=False)
|
| 79 |
+
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
|
| 80 |
+
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
|
| 81 |
+
|
| 82 |
+
# User and environment
|
| 83 |
+
info['user'] = os.environ.get('USER', 'unknown')
|
| 84 |
+
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
|
| 85 |
+
info['working_dir'] = os.getcwd()
|
| 86 |
+
|
| 87 |
+
return info
|
| 88 |
+
|
| 89 |
+
def estimate_cost(gpu_info, runtime_hours=None):
|
| 90 |
+
"""Estimate training cost based on GPU type and runtime."""
|
| 91 |
+
|
| 92 |
+
# Rough pricing, from Lambda Cloud
|
| 93 |
+
default_rate = 2.0
|
| 94 |
+
gpu_hourly_rates = {
|
| 95 |
+
"H100": 3.00,
|
| 96 |
+
"A100": 1.79,
|
| 97 |
+
"V100": 0.55,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if not gpu_info.get("available"):
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Try to identify GPU type from name
|
| 104 |
+
hourly_rate = None
|
| 105 |
+
gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
|
| 106 |
+
for gpu_type, rate in gpu_hourly_rates.items():
|
| 107 |
+
if gpu_type in gpu_name:
|
| 108 |
+
hourly_rate = rate * gpu_info["count"]
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if hourly_rate is None:
|
| 112 |
+
hourly_rate = default_rate * gpu_info["count"] # Default estimate
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"hourly_rate": hourly_rate,
|
| 116 |
+
"gpu_type": gpu_name,
|
| 117 |
+
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def generate_header():
|
| 121 |
+
"""Generate the header for a training report."""
|
| 122 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 123 |
+
|
| 124 |
+
git_info = get_git_info()
|
| 125 |
+
gpu_info = get_gpu_info()
|
| 126 |
+
sys_info = get_system_info()
|
| 127 |
+
cost_info = estimate_cost(gpu_info)
|
| 128 |
+
|
| 129 |
+
header = f"""# nanochat training report
|
| 130 |
+
|
| 131 |
+
Generated: {timestamp}
|
| 132 |
+
|
| 133 |
+
## Environment
|
| 134 |
+
|
| 135 |
+
### Git Information
|
| 136 |
+
- Branch: {git_info['branch']}
|
| 137 |
+
- Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
|
| 138 |
+
- Message: {git_info['message']}
|
| 139 |
+
|
| 140 |
+
### Hardware
|
| 141 |
+
- Platform: {sys_info['platform']}
|
| 142 |
+
- CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
|
| 143 |
+
- Memory: {sys_info['memory_gb']:.1f} GB
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
if gpu_info.get("available"):
|
| 147 |
+
gpu_names = ", ".join(set(gpu_info["names"]))
|
| 148 |
+
total_vram = sum(gpu_info["memory_gb"])
|
| 149 |
+
header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
|
| 150 |
+
- GPU Memory: {total_vram:.1f} GB total
|
| 151 |
+
- CUDA Version: {gpu_info['cuda_version']}
|
| 152 |
+
"""
|
| 153 |
+
else:
|
| 154 |
+
header += "- GPUs: None available\n"
|
| 155 |
+
|
| 156 |
+
if cost_info and cost_info["hourly_rate"] > 0:
|
| 157 |
+
header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
|
| 158 |
+
|
| 159 |
+
header += f"""
|
| 160 |
+
### Software
|
| 161 |
+
- Python: {sys_info['python_version']}
|
| 162 |
+
- PyTorch: {sys_info['torch_version']}
|
| 163 |
+
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# bloat metrics: count lines/chars in git-tracked source files only
|
| 167 |
+
extensions = ['py', 'md', 'rs', 'html', 'toml', 'sh']
|
| 168 |
+
git_patterns = ' '.join(f"'*.{ext}'" for ext in extensions)
|
| 169 |
+
files_output = run_command(f"git ls-files -- {git_patterns}")
|
| 170 |
+
file_list = [f for f in (files_output or '').split('\n') if f]
|
| 171 |
+
num_files = len(file_list)
|
| 172 |
+
num_lines = 0
|
| 173 |
+
num_chars = 0
|
| 174 |
+
if num_files > 0:
|
| 175 |
+
wc_output = run_command(f"git ls-files -- {git_patterns} | xargs wc -lc 2>/dev/null")
|
| 176 |
+
if wc_output:
|
| 177 |
+
total_line = wc_output.strip().split('\n')[-1]
|
| 178 |
+
parts = total_line.split()
|
| 179 |
+
if len(parts) >= 2:
|
| 180 |
+
num_lines = int(parts[0])
|
| 181 |
+
num_chars = int(parts[1])
|
| 182 |
+
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
| 183 |
+
|
| 184 |
+
# count dependencies via uv.lock
|
| 185 |
+
uv_lock_lines = 0
|
| 186 |
+
if os.path.exists('uv.lock'):
|
| 187 |
+
with open('uv.lock', 'r', encoding='utf-8') as f:
|
| 188 |
+
uv_lock_lines = len(f.readlines())
|
| 189 |
+
|
| 190 |
+
header += f"""
|
| 191 |
+
### Bloat
|
| 192 |
+
- Characters: {num_chars:,}
|
| 193 |
+
- Lines: {num_lines:,}
|
| 194 |
+
- Files: {num_files:,}
|
| 195 |
+
- Tokens (approx): {num_tokens:,}
|
| 196 |
+
- Dependencies (uv.lock lines): {uv_lock_lines:,}
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
return header
|
| 200 |
+
|
| 201 |
+
# -----------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def slugify(text):
|
| 204 |
+
"""Slugify a text string."""
|
| 205 |
+
return text.lower().replace(" ", "-")
|
| 206 |
+
|
| 207 |
+
# the expected files and their order
|
| 208 |
+
EXPECTED_FILES = [
|
| 209 |
+
"tokenizer-training.md",
|
| 210 |
+
"tokenizer-evaluation.md",
|
| 211 |
+
"base-model-training.md",
|
| 212 |
+
"base-model-loss.md",
|
| 213 |
+
"base-model-evaluation.md",
|
| 214 |
+
"midtraining.md",
|
| 215 |
+
"chat-evaluation-mid.md",
|
| 216 |
+
"chat-sft.md",
|
| 217 |
+
"chat-evaluation-sft.md",
|
| 218 |
+
"chat-rl.md",
|
| 219 |
+
"chat-evaluation-rl.md",
|
| 220 |
+
]
|
| 221 |
+
# the metrics we're currently interested in
|
| 222 |
+
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
|
| 223 |
+
|
| 224 |
+
def extract(section, keys):
|
| 225 |
+
"""simple def to extract a single key from a section"""
|
| 226 |
+
if not isinstance(keys, list):
|
| 227 |
+
keys = [keys] # convenience
|
| 228 |
+
out = {}
|
| 229 |
+
for line in section.split("\n"):
|
| 230 |
+
for key in keys:
|
| 231 |
+
if key in line:
|
| 232 |
+
out[key] = line.split(":")[1].strip()
|
| 233 |
+
return out
|
| 234 |
+
|
| 235 |
+
def extract_timestamp(content, prefix):
|
| 236 |
+
"""Extract timestamp from content with given prefix."""
|
| 237 |
+
for line in content.split('\n'):
|
| 238 |
+
if line.startswith(prefix):
|
| 239 |
+
time_str = line.split(":", 1)[1].strip()
|
| 240 |
+
try:
|
| 241 |
+
return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
| 242 |
+
except:
|
| 243 |
+
pass
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
class Report:
|
| 247 |
+
"""Maintains a bunch of logs, generates a final markdown report."""
|
| 248 |
+
|
| 249 |
+
def __init__(self, report_dir):
|
| 250 |
+
os.makedirs(report_dir, exist_ok=True)
|
| 251 |
+
self.report_dir = report_dir
|
| 252 |
+
|
| 253 |
+
def log(self, section, data):
|
| 254 |
+
"""Log a section of data to the report."""
|
| 255 |
+
slug = slugify(section)
|
| 256 |
+
file_name = f"{slug}.md"
|
| 257 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 258 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 259 |
+
f.write(f"## {section}\n")
|
| 260 |
+
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 261 |
+
for item in data:
|
| 262 |
+
if not item:
|
| 263 |
+
# skip falsy values like None or empty dict etc.
|
| 264 |
+
continue
|
| 265 |
+
if isinstance(item, str):
|
| 266 |
+
# directly write the string
|
| 267 |
+
f.write(item)
|
| 268 |
+
else:
|
| 269 |
+
# render a dict
|
| 270 |
+
for k, v in item.items():
|
| 271 |
+
if isinstance(v, float):
|
| 272 |
+
vstr = f"{v:.4f}"
|
| 273 |
+
elif isinstance(v, int) and v >= 10000:
|
| 274 |
+
vstr = f"{v:,.0f}"
|
| 275 |
+
else:
|
| 276 |
+
vstr = str(v)
|
| 277 |
+
f.write(f"- {k}: {vstr}\n")
|
| 278 |
+
f.write("\n")
|
| 279 |
+
return file_path
|
| 280 |
+
|
| 281 |
+
def generate(self):
|
| 282 |
+
"""Generate the final report."""
|
| 283 |
+
report_dir = self.report_dir
|
| 284 |
+
report_file = os.path.join(report_dir, "report.md")
|
| 285 |
+
print(f"Generating report to {report_file}")
|
| 286 |
+
final_metrics = {} # the most important final metrics we'll add as table at the end
|
| 287 |
+
start_time = None
|
| 288 |
+
end_time = None
|
| 289 |
+
with open(report_file, "w", encoding="utf-8") as out_file:
|
| 290 |
+
# write the header first
|
| 291 |
+
header_file = os.path.join(report_dir, "header.md")
|
| 292 |
+
if os.path.exists(header_file):
|
| 293 |
+
with open(header_file, "r", encoding="utf-8") as f:
|
| 294 |
+
header_content = f.read()
|
| 295 |
+
out_file.write(header_content)
|
| 296 |
+
start_time = extract_timestamp(header_content, "Run started:")
|
| 297 |
+
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
| 298 |
+
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
| 299 |
+
bloat_data = bloat_data.group(1) if bloat_data else ""
|
| 300 |
+
else:
|
| 301 |
+
start_time = None # will cause us to not write the total wall clock time
|
| 302 |
+
bloat_data = "[bloat data missing]"
|
| 303 |
+
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
| 304 |
+
# process all the individual sections
|
| 305 |
+
for file_name in EXPECTED_FILES:
|
| 306 |
+
section_file = os.path.join(report_dir, file_name)
|
| 307 |
+
if not os.path.exists(section_file):
|
| 308 |
+
print(f"Warning: {section_file} does not exist, skipping")
|
| 309 |
+
continue
|
| 310 |
+
with open(section_file, "r", encoding="utf-8") as in_file:
|
| 311 |
+
section = in_file.read()
|
| 312 |
+
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
| 313 |
+
if "rl" not in file_name:
|
| 314 |
+
# Skip RL sections for end_time calculation because RL is experimental
|
| 315 |
+
end_time = extract_timestamp(section, "timestamp:")
|
| 316 |
+
# extract the most important metrics from the sections
|
| 317 |
+
if file_name == "base-model-evaluation.md":
|
| 318 |
+
final_metrics["base"] = extract(section, "CORE")
|
| 319 |
+
if file_name == "chat-evaluation-mid.md":
|
| 320 |
+
final_metrics["mid"] = extract(section, chat_metrics)
|
| 321 |
+
if file_name == "chat-evaluation-sft.md":
|
| 322 |
+
final_metrics["sft"] = extract(section, chat_metrics)
|
| 323 |
+
if file_name == "chat-evaluation-rl.md":
|
| 324 |
+
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
| 325 |
+
# append this section of the report
|
| 326 |
+
out_file.write(section)
|
| 327 |
+
out_file.write("\n")
|
| 328 |
+
# add the final metrics table
|
| 329 |
+
out_file.write("## Summary\n\n")
|
| 330 |
+
# Copy over the bloat metrics from the header
|
| 331 |
+
out_file.write(bloat_data)
|
| 332 |
+
out_file.write("\n\n")
|
| 333 |
+
# Collect all unique metric names
|
| 334 |
+
all_metrics = set()
|
| 335 |
+
for stage_metrics in final_metrics.values():
|
| 336 |
+
all_metrics.update(stage_metrics.keys())
|
| 337 |
+
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
| 338 |
+
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
| 339 |
+
# Fixed column widths
|
| 340 |
+
stages = ["base", "mid", "sft", "rl"]
|
| 341 |
+
metric_width = 15
|
| 342 |
+
value_width = 8
|
| 343 |
+
# Write table header
|
| 344 |
+
header = f"| {'Metric'.ljust(metric_width)} |"
|
| 345 |
+
for stage in stages:
|
| 346 |
+
header += f" {stage.upper().ljust(value_width)} |"
|
| 347 |
+
out_file.write(header + "\n")
|
| 348 |
+
# Write separator
|
| 349 |
+
separator = f"|{'-' * (metric_width + 2)}|"
|
| 350 |
+
for stage in stages:
|
| 351 |
+
separator += f"{'-' * (value_width + 2)}|"
|
| 352 |
+
out_file.write(separator + "\n")
|
| 353 |
+
# Write table rows
|
| 354 |
+
for metric in all_metrics:
|
| 355 |
+
row = f"| {metric.ljust(metric_width)} |"
|
| 356 |
+
for stage in stages:
|
| 357 |
+
value = final_metrics.get(stage, {}).get(metric, "-")
|
| 358 |
+
row += f" {str(value).ljust(value_width)} |"
|
| 359 |
+
out_file.write(row + "\n")
|
| 360 |
+
out_file.write("\n")
|
| 361 |
+
# Calculate and write total wall clock time
|
| 362 |
+
if start_time and end_time:
|
| 363 |
+
duration = end_time - start_time
|
| 364 |
+
total_seconds = int(duration.total_seconds())
|
| 365 |
+
hours = total_seconds // 3600
|
| 366 |
+
minutes = (total_seconds % 3600) // 60
|
| 367 |
+
out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
|
| 368 |
+
else:
|
| 369 |
+
out_file.write("Total wall clock time: unknown\n")
|
| 370 |
+
# also cp the report.md file to current directory
|
| 371 |
+
print(f"Copying report.md to current directory for convenience")
|
| 372 |
+
shutil.copy(report_file, "report.md")
|
| 373 |
+
return report_file
|
| 374 |
+
|
| 375 |
+
def reset(self):
|
| 376 |
+
"""Reset the report."""
|
| 377 |
+
# Remove section files
|
| 378 |
+
for file_name in EXPECTED_FILES:
|
| 379 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 380 |
+
if os.path.exists(file_path):
|
| 381 |
+
os.remove(file_path)
|
| 382 |
+
# Remove report.md if it exists
|
| 383 |
+
report_file = os.path.join(self.report_dir, "report.md")
|
| 384 |
+
if os.path.exists(report_file):
|
| 385 |
+
os.remove(report_file)
|
| 386 |
+
# Generate and write the header section with start timestamp
|
| 387 |
+
header_file = os.path.join(self.report_dir, "header.md")
|
| 388 |
+
header = generate_header()
|
| 389 |
+
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 390 |
+
with open(header_file, "w", encoding="utf-8") as f:
|
| 391 |
+
f.write(header)
|
| 392 |
+
f.write(f"Run started: {start_time}\n\n---\n\n")
|
| 393 |
+
print(f"Reset report and wrote header to {header_file}")
|
| 394 |
+
|
| 395 |
+
# -----------------------------------------------------------------------------
|
| 396 |
+
# nanochat-specific convenience functions
|
| 397 |
+
|
| 398 |
+
class DummyReport:
|
| 399 |
+
def log(self, *args, **kwargs):
|
| 400 |
+
pass
|
| 401 |
+
def reset(self, *args, **kwargs):
|
| 402 |
+
pass
|
| 403 |
+
|
| 404 |
+
def get_report():
|
| 405 |
+
# just for convenience, only rank 0 logs to report
|
| 406 |
+
from nanochat.common import get_base_dir, get_dist_info
|
| 407 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 408 |
+
if ddp_rank == 0:
|
| 409 |
+
report_dir = os.path.join(get_base_dir(), "report")
|
| 410 |
+
return Report(report_dir)
|
| 411 |
+
else:
|
| 412 |
+
return DummyReport()
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
import argparse
|
| 416 |
+
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
|
| 417 |
+
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
|
| 418 |
+
args = parser.parse_args()
|
| 419 |
+
if args.command == "generate":
|
| 420 |
+
get_report().generate()
|
| 421 |
+
elif args.command == "reset":
|
| 422 |
+
get_report().reset()
|
nanochat/tokenizer.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BPE Tokenizer in the style of GPT-4.
|
| 3 |
+
|
| 4 |
+
Two implementations are available:
|
| 5 |
+
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
|
| 6 |
+
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import copy
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
|
| 13 |
+
SPECIAL_TOKENS = [
|
| 14 |
+
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
| 15 |
+
"<|bos|>",
|
| 16 |
+
# tokens below are only used during finetuning to render Conversations into token ids
|
| 17 |
+
"<|user_start|>", # user messages
|
| 18 |
+
"<|user_end|>",
|
| 19 |
+
"<|assistant_start|>", # assistant messages
|
| 20 |
+
"<|assistant_end|>",
|
| 21 |
+
"<|python_start|>", # assistant invokes python REPL tool
|
| 22 |
+
"<|python_end|>",
|
| 23 |
+
"<|output_start|>", # python REPL outputs back to assistant
|
| 24 |
+
"<|output_end|>",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
| 28 |
+
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
| 29 |
+
# I verified that 2 is the sweet spot for vocab size of 32K. 1 is a bit worse, 3 was worse still.
|
| 30 |
+
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
|
| 34 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 35 |
+
from tokenizers import pre_tokenizers, decoders, Regex
|
| 36 |
+
from tokenizers.models import BPE
|
| 37 |
+
from tokenizers.trainers import BpeTrainer
|
| 38 |
+
|
| 39 |
+
class HuggingFaceTokenizer:
|
| 40 |
+
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, tokenizer):
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(cls, hf_path):
|
| 47 |
+
# init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
|
| 48 |
+
tokenizer = HFTokenizer.from_pretrained(hf_path)
|
| 49 |
+
return cls(tokenizer)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_directory(cls, tokenizer_dir):
|
| 53 |
+
# init from a local directory on disk (e.g. "out/tokenizer")
|
| 54 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 55 |
+
tokenizer = HFTokenizer.from_file(tokenizer_path)
|
| 56 |
+
return cls(tokenizer)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 60 |
+
# train from an iterator of text
|
| 61 |
+
# Configure the HuggingFace Tokenizer
|
| 62 |
+
tokenizer = HFTokenizer(BPE(
|
| 63 |
+
byte_fallback=True, # needed!
|
| 64 |
+
unk_token=None,
|
| 65 |
+
fuse_unk=False,
|
| 66 |
+
))
|
| 67 |
+
# Normalizer: None
|
| 68 |
+
tokenizer.normalizer = None
|
| 69 |
+
# Pre-tokenizer: GPT-4 style
|
| 70 |
+
# the regex pattern used by GPT-4 to split text into groups before BPE
|
| 71 |
+
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
|
| 72 |
+
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
|
| 73 |
+
# (but I haven't validated this! TODO)
|
| 74 |
+
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
| 75 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
| 76 |
+
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
| 77 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
| 78 |
+
])
|
| 79 |
+
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
| 80 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 81 |
+
# Post-processor: None
|
| 82 |
+
tokenizer.post_processor = None
|
| 83 |
+
# Trainer: BPE
|
| 84 |
+
trainer = BpeTrainer(
|
| 85 |
+
vocab_size=vocab_size,
|
| 86 |
+
show_progress=True,
|
| 87 |
+
min_frequency=0, # no minimum frequency
|
| 88 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 89 |
+
special_tokens=SPECIAL_TOKENS,
|
| 90 |
+
)
|
| 91 |
+
# Kick off the training
|
| 92 |
+
tokenizer.train_from_iterator(text_iterator, trainer)
|
| 93 |
+
return cls(tokenizer)
|
| 94 |
+
|
| 95 |
+
def get_vocab_size(self):
|
| 96 |
+
return self.tokenizer.get_vocab_size()
|
| 97 |
+
|
| 98 |
+
def get_special_tokens(self):
|
| 99 |
+
special_tokens_map = self.tokenizer.get_added_tokens_decoder()
|
| 100 |
+
special_tokens = [w.content for w in special_tokens_map.values()]
|
| 101 |
+
return special_tokens
|
| 102 |
+
|
| 103 |
+
def id_to_token(self, id):
|
| 104 |
+
return self.tokenizer.id_to_token(id)
|
| 105 |
+
|
| 106 |
+
def _encode_one(self, text, prepend=None, append=None, num_threads=None):
|
| 107 |
+
# encode a single string
|
| 108 |
+
# prepend/append can be either a string of a special token or a token id directly.
|
| 109 |
+
# num_threads is ignored (only used by the nanochat Tokenizer for parallel encoding)
|
| 110 |
+
assert isinstance(text, str)
|
| 111 |
+
ids = []
|
| 112 |
+
if prepend is not None:
|
| 113 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 114 |
+
ids.append(prepend_id)
|
| 115 |
+
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
|
| 116 |
+
if append is not None:
|
| 117 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 118 |
+
ids.append(append_id)
|
| 119 |
+
return ids
|
| 120 |
+
|
| 121 |
+
def encode_special(self, text):
|
| 122 |
+
# encode a single special token via exact match
|
| 123 |
+
return self.tokenizer.token_to_id(text)
|
| 124 |
+
|
| 125 |
+
def get_bos_token_id(self):
|
| 126 |
+
# Different HuggingFace models use different BOS tokens and there is little consistency
|
| 127 |
+
# 1) attempt to find a <|bos|> token
|
| 128 |
+
bos = self.encode_special("<|bos|>")
|
| 129 |
+
# 2) if that fails, attempt to find a <|endoftext|> token (e.g. GPT-2 models)
|
| 130 |
+
if bos is None:
|
| 131 |
+
bos = self.encode_special("<|endoftext|>")
|
| 132 |
+
# 3) if these fail, it's better to crash than to silently return None
|
| 133 |
+
assert bos is not None, "Failed to find BOS token in tokenizer"
|
| 134 |
+
return bos
|
| 135 |
+
|
| 136 |
+
def encode(self, text, *args, **kwargs):
|
| 137 |
+
if isinstance(text, str):
|
| 138 |
+
return self._encode_one(text, *args, **kwargs)
|
| 139 |
+
elif isinstance(text, list):
|
| 140 |
+
return [self._encode_one(t, *args, **kwargs) for t in text]
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 143 |
+
|
| 144 |
+
def __call__(self, *args, **kwargs):
|
| 145 |
+
return self.encode(*args, **kwargs)
|
| 146 |
+
|
| 147 |
+
def decode(self, ids):
|
| 148 |
+
return self.tokenizer.decode(ids, skip_special_tokens=False)
|
| 149 |
+
|
| 150 |
+
def save(self, tokenizer_dir):
|
| 151 |
+
# save the tokenizer to disk
|
| 152 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 153 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 154 |
+
self.tokenizer.save(tokenizer_path)
|
| 155 |
+
print(f"Saved tokenizer to {tokenizer_path}")
|
| 156 |
+
|
| 157 |
+
# -----------------------------------------------------------------------------
|
| 158 |
+
# Tokenizer based on rustbpe + tiktoken combo
|
| 159 |
+
import pickle
|
| 160 |
+
import rustbpe
|
| 161 |
+
import tiktoken
|
| 162 |
+
|
| 163 |
+
class RustBPETokenizer:
|
| 164 |
+
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, enc, bos_token):
|
| 167 |
+
self.enc = enc
|
| 168 |
+
self.bos_token_id = self.encode_special(bos_token)
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 172 |
+
# 1) train using rustbpe
|
| 173 |
+
tokenizer = rustbpe.Tokenizer()
|
| 174 |
+
# the special tokens are inserted later in __init__, we don't train them here
|
| 175 |
+
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
|
| 176 |
+
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
|
| 177 |
+
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 178 |
+
# 2) construct the associated tiktoken encoding for inference
|
| 179 |
+
pattern = tokenizer.get_pattern()
|
| 180 |
+
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
|
| 181 |
+
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
| 182 |
+
tokens_offset = len(mergeable_ranks)
|
| 183 |
+
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
| 184 |
+
enc = tiktoken.Encoding(
|
| 185 |
+
name="rustbpe",
|
| 186 |
+
pat_str=pattern,
|
| 187 |
+
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
|
| 188 |
+
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
|
| 189 |
+
)
|
| 190 |
+
return cls(enc, "<|bos|>")
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def from_directory(cls, tokenizer_dir):
|
| 194 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 195 |
+
with open(pickle_path, "rb") as f:
|
| 196 |
+
enc = pickle.load(f)
|
| 197 |
+
return cls(enc, "<|bos|>")
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def from_pretrained(cls, tiktoken_name):
|
| 201 |
+
# https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
|
| 202 |
+
enc = tiktoken.get_encoding(tiktoken_name)
|
| 203 |
+
# tiktoken calls the special document delimiter token "<|endoftext|>"
|
| 204 |
+
# yes this is confusing because this token is almost always PREPENDED to the beginning of the document
|
| 205 |
+
# it most often is used to signal the start of a new sequence to the LLM during inference etc.
|
| 206 |
+
# so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
|
| 207 |
+
return cls(enc, "<|endoftext|>")
|
| 208 |
+
|
| 209 |
+
def get_vocab_size(self):
|
| 210 |
+
return self.enc.n_vocab
|
| 211 |
+
|
| 212 |
+
def get_special_tokens(self):
|
| 213 |
+
return self.enc.special_tokens_set
|
| 214 |
+
|
| 215 |
+
def id_to_token(self, id):
|
| 216 |
+
return self.enc.decode([id])
|
| 217 |
+
|
| 218 |
+
@lru_cache(maxsize=32)
|
| 219 |
+
def encode_special(self, text):
|
| 220 |
+
return self.enc.encode_single_token(text)
|
| 221 |
+
|
| 222 |
+
def get_bos_token_id(self):
|
| 223 |
+
return self.bos_token_id
|
| 224 |
+
|
| 225 |
+
def encode(self, text, prepend=None, append=None, num_threads=8):
|
| 226 |
+
# text can be either a string or a list of strings
|
| 227 |
+
|
| 228 |
+
if prepend is not None:
|
| 229 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 230 |
+
if append is not None:
|
| 231 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 232 |
+
|
| 233 |
+
if isinstance(text, str):
|
| 234 |
+
ids = self.enc.encode_ordinary(text)
|
| 235 |
+
if prepend is not None:
|
| 236 |
+
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
| 237 |
+
if append is not None:
|
| 238 |
+
ids.append(append_id)
|
| 239 |
+
elif isinstance(text, list):
|
| 240 |
+
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
| 241 |
+
if prepend is not None:
|
| 242 |
+
for ids_row in ids:
|
| 243 |
+
ids_row.insert(0, prepend_id) # TODO: same
|
| 244 |
+
if append is not None:
|
| 245 |
+
for ids_row in ids:
|
| 246 |
+
ids_row.append(append_id)
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 249 |
+
|
| 250 |
+
return ids
|
| 251 |
+
|
| 252 |
+
def __call__(self, *args, **kwargs):
|
| 253 |
+
return self.encode(*args, **kwargs)
|
| 254 |
+
|
| 255 |
+
def decode(self, ids):
|
| 256 |
+
return self.enc.decode(ids)
|
| 257 |
+
|
| 258 |
+
def save(self, tokenizer_dir):
|
| 259 |
+
# save the encoding object to disk
|
| 260 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 261 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 262 |
+
with open(pickle_path, "wb") as f:
|
| 263 |
+
pickle.dump(self.enc, f)
|
| 264 |
+
print(f"Saved tokenizer encoding to {pickle_path}")
|
| 265 |
+
|
| 266 |
+
def render_conversation(self, conversation, max_tokens=2048):
|
| 267 |
+
"""
|
| 268 |
+
Tokenize a single Chat conversation (which we call a "doc" or "document" here).
|
| 269 |
+
Returns:
|
| 270 |
+
- ids: list[int] is a list of token ids of this rendered conversation
|
| 271 |
+
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
|
| 272 |
+
"""
|
| 273 |
+
# ids, masks that we will return and a helper function to help build them up.
|
| 274 |
+
ids, mask = [], []
|
| 275 |
+
def add_tokens(token_ids, mask_val):
|
| 276 |
+
if isinstance(token_ids, int):
|
| 277 |
+
token_ids = [token_ids]
|
| 278 |
+
ids.extend(token_ids)
|
| 279 |
+
mask.extend([mask_val] * len(token_ids))
|
| 280 |
+
|
| 281 |
+
# sometimes the first message is a system message...
|
| 282 |
+
# => just merge it with the second (user) message
|
| 283 |
+
if conversation["messages"][0]["role"] == "system":
|
| 284 |
+
# some conversation surgery is necessary here for now...
|
| 285 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 286 |
+
messages = conversation["messages"]
|
| 287 |
+
assert messages[1]["role"] == "user", "System message must be followed by a user message"
|
| 288 |
+
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
|
| 289 |
+
messages = messages[1:]
|
| 290 |
+
else:
|
| 291 |
+
messages = conversation["messages"]
|
| 292 |
+
assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
|
| 293 |
+
|
| 294 |
+
# fetch all the special tokens we need
|
| 295 |
+
bos = self.get_bos_token_id()
|
| 296 |
+
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
|
| 297 |
+
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
|
| 298 |
+
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
|
| 299 |
+
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
|
| 300 |
+
|
| 301 |
+
# now we can tokenize the conversation
|
| 302 |
+
add_tokens(bos, 0)
|
| 303 |
+
for i, message in enumerate(messages):
|
| 304 |
+
|
| 305 |
+
# some sanity checking here around assumptions, to prevent footguns
|
| 306 |
+
must_be_from = "user" if i % 2 == 0 else "assistant"
|
| 307 |
+
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
| 308 |
+
|
| 309 |
+
# content can be either a simple string or a list of parts (e.g. containing tool calls)
|
| 310 |
+
content = message["content"]
|
| 311 |
+
|
| 312 |
+
if message["role"] == "user":
|
| 313 |
+
assert isinstance(content, str), "User messages are simply expected to be strings"
|
| 314 |
+
value_ids = self.encode(content)
|
| 315 |
+
add_tokens(user_start, 0)
|
| 316 |
+
add_tokens(value_ids, 0)
|
| 317 |
+
add_tokens(user_end, 0)
|
| 318 |
+
elif message["role"] == "assistant":
|
| 319 |
+
add_tokens(assistant_start, 0)
|
| 320 |
+
if isinstance(content, str):
|
| 321 |
+
# simple string => simply add the tokens
|
| 322 |
+
value_ids = self.encode(content)
|
| 323 |
+
add_tokens(value_ids, 1)
|
| 324 |
+
elif isinstance(content, list):
|
| 325 |
+
for part in content:
|
| 326 |
+
value_ids = self.encode(part["text"])
|
| 327 |
+
if part["type"] == "text":
|
| 328 |
+
# string part => simply add the tokens
|
| 329 |
+
add_tokens(value_ids, 1)
|
| 330 |
+
elif part["type"] == "python":
|
| 331 |
+
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
|
| 332 |
+
add_tokens(python_start, 1)
|
| 333 |
+
add_tokens(value_ids, 1)
|
| 334 |
+
add_tokens(python_end, 1)
|
| 335 |
+
elif part["type"] == "python_output":
|
| 336 |
+
# python output => add the tokens inside <|output_start|> and <|output_end|>
|
| 337 |
+
# none of these tokens are supervised because the tokens come from Python at test time
|
| 338 |
+
add_tokens(output_start, 0)
|
| 339 |
+
add_tokens(value_ids, 0)
|
| 340 |
+
add_tokens(output_end, 0)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError(f"Unknown part type: {part['type']}")
|
| 343 |
+
else:
|
| 344 |
+
raise ValueError(f"Unknown content type: {type(content)}")
|
| 345 |
+
add_tokens(assistant_end, 1)
|
| 346 |
+
|
| 347 |
+
# truncate to max_tokens tokens MAX (helps prevent OOMs)
|
| 348 |
+
ids = ids[:max_tokens]
|
| 349 |
+
mask = mask[:max_tokens]
|
| 350 |
+
return ids, mask
|
| 351 |
+
|
| 352 |
+
def visualize_tokenization(self, ids, mask, with_token_id=False):
|
| 353 |
+
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
| 354 |
+
RED = '\033[91m'
|
| 355 |
+
GREEN = '\033[92m'
|
| 356 |
+
RESET = '\033[0m'
|
| 357 |
+
GRAY = '\033[90m'
|
| 358 |
+
tokens = []
|
| 359 |
+
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
| 360 |
+
token_str = self.decode([token_id])
|
| 361 |
+
color = GREEN if mask_val == 1 else RED
|
| 362 |
+
tokens.append(f"{color}{token_str}{RESET}")
|
| 363 |
+
if with_token_id:
|
| 364 |
+
tokens.append(f"{GRAY}({token_id}){RESET}")
|
| 365 |
+
return '|'.join(tokens)
|
| 366 |
+
|
| 367 |
+
def render_for_completion(self, conversation):
|
| 368 |
+
"""
|
| 369 |
+
Used during Reinforcement Learning. In that setting, we want to
|
| 370 |
+
render the conversation priming the Assistant for a completion.
|
| 371 |
+
Unlike the Chat SFT case, we don't need to return the mask.
|
| 372 |
+
"""
|
| 373 |
+
# We have some surgery to do: we need to pop the last message (of the Assistant)
|
| 374 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 375 |
+
messages = conversation["messages"]
|
| 376 |
+
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
|
| 377 |
+
messages.pop() # remove the last message (of the Assistant) inplace
|
| 378 |
+
|
| 379 |
+
# Now tokenize the conversation
|
| 380 |
+
ids, mask = self.render_conversation(conversation)
|
| 381 |
+
|
| 382 |
+
# Finally, to prime the Assistant for a completion, append the Assistant start token
|
| 383 |
+
assistant_start = self.encode_special("<|assistant_start|>")
|
| 384 |
+
ids.append(assistant_start)
|
| 385 |
+
return ids
|
| 386 |
+
|
| 387 |
+
# -----------------------------------------------------------------------------
|
| 388 |
+
# nanochat-specific convenience functions
|
| 389 |
+
|
| 390 |
+
def get_tokenizer():
|
| 391 |
+
from nanochat.common import get_base_dir
|
| 392 |
+
base_dir = get_base_dir()
|
| 393 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 394 |
+
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
|
| 395 |
+
return RustBPETokenizer.from_directory(tokenizer_dir)
|
| 396 |
+
|
| 397 |
+
def get_token_bytes(device="cpu"):
|
| 398 |
+
import torch
|
| 399 |
+
from nanochat.common import get_base_dir
|
| 400 |
+
base_dir = get_base_dir()
|
| 401 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 402 |
+
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
| 403 |
+
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
| 404 |
+
with open(token_bytes_path, "rb") as f:
|
| 405 |
+
token_bytes = torch.load(f, map_location=device)
|
| 406 |
+
return token_bytes
|
nanochat/ui.html
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
| 6 |
+
<title>NanoChat</title>
|
| 7 |
+
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
color-scheme: light;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
* {
|
| 14 |
+
box-sizing: border-box;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
html, body{
|
| 18 |
+
height: 100%;
|
| 19 |
+
margin: 0;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
body {
|
| 23 |
+
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
| 24 |
+
background-color: #ffffff;
|
| 25 |
+
color: #111827;
|
| 26 |
+
min-height: 100dvh;
|
| 27 |
+
margin: 0;
|
| 28 |
+
display: flex;
|
| 29 |
+
flex-direction: column;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.header {
|
| 33 |
+
background-color: #ffffff;
|
| 34 |
+
padding: 1.25rem 1.5rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.header-left {
|
| 38 |
+
display: flex;
|
| 39 |
+
align-items: center;
|
| 40 |
+
gap: 0.75rem;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.header-logo {
|
| 44 |
+
height: 32px;
|
| 45 |
+
width: auto;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.header h1 {
|
| 49 |
+
font-size: 1.25rem;
|
| 50 |
+
font-weight: 600;
|
| 51 |
+
margin: 0;
|
| 52 |
+
color: #111827;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.new-conversation-btn {
|
| 56 |
+
width: 32px;
|
| 57 |
+
height: 32px;
|
| 58 |
+
padding: 0;
|
| 59 |
+
border: 1px solid #e5e7eb;
|
| 60 |
+
border-radius: 0.5rem;
|
| 61 |
+
background-color: #ffffff;
|
| 62 |
+
color: #6b7280;
|
| 63 |
+
cursor: pointer;
|
| 64 |
+
display: flex;
|
| 65 |
+
align-items: center;
|
| 66 |
+
justify-content: center;
|
| 67 |
+
transition: all 0.2s ease;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.new-conversation-btn:hover {
|
| 71 |
+
background-color: #f3f4f6;
|
| 72 |
+
border-color: #d1d5db;
|
| 73 |
+
color: #374151;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.chat-container {
|
| 77 |
+
flex: 1;
|
| 78 |
+
overflow-y: auto;
|
| 79 |
+
background-color: #ffffff;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.chat-wrapper {
|
| 83 |
+
max-width: 48rem;
|
| 84 |
+
margin: 0 auto;
|
| 85 |
+
padding: 2rem 1.5rem 3rem;
|
| 86 |
+
display: flex;
|
| 87 |
+
flex-direction: column;
|
| 88 |
+
gap: 0.75rem;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.message {
|
| 92 |
+
display: flex;
|
| 93 |
+
justify-content: flex-start;
|
| 94 |
+
margin-bottom: 0.5rem;
|
| 95 |
+
color: #0d0d0d;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.message.assistant {
|
| 99 |
+
justify-content: flex-start;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.message.user {
|
| 103 |
+
justify-content: flex-end;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.message-content {
|
| 107 |
+
white-space: pre-wrap;
|
| 108 |
+
line-height: 1.6;
|
| 109 |
+
max-width: 100%;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.message.assistant .message-content {
|
| 113 |
+
background: transparent;
|
| 114 |
+
border: none;
|
| 115 |
+
cursor: pointer;
|
| 116 |
+
border-radius: 0.5rem;
|
| 117 |
+
padding: 0.5rem;
|
| 118 |
+
margin-left: -0.5rem;
|
| 119 |
+
transition: background-color 0.2s ease;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.message.assistant .message-content:hover {
|
| 123 |
+
background-color: #f9fafb;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.message.user .message-content {
|
| 127 |
+
background-color: #f3f4f6;
|
| 128 |
+
border-radius: 1.25rem;
|
| 129 |
+
padding: 0.8rem 1rem;
|
| 130 |
+
max-width: 65%;
|
| 131 |
+
cursor: pointer;
|
| 132 |
+
transition: background-color 0.2s ease;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.message.user .message-content:hover {
|
| 136 |
+
background-color: #e5e7eb;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.message.console .message-content {
|
| 140 |
+
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
|
| 141 |
+
font-size: 0.875rem;
|
| 142 |
+
background-color: #fafafa;
|
| 143 |
+
padding: 0.75rem 1rem;
|
| 144 |
+
color: #374151;
|
| 145 |
+
max-width: 80%;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.input-container {
|
| 149 |
+
background-color: #ffffff;
|
| 150 |
+
padding: 1rem;
|
| 151 |
+
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.input-wrapper {
|
| 155 |
+
max-width: 48rem;
|
| 156 |
+
margin: 0 auto;
|
| 157 |
+
display: flex;
|
| 158 |
+
gap: 0.75rem;
|
| 159 |
+
align-items: flex-end;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.chat-input {
|
| 163 |
+
flex: 1;
|
| 164 |
+
padding: 0.8rem 1rem;
|
| 165 |
+
border: 1px solid #d1d5db;
|
| 166 |
+
border-radius: 0.75rem;
|
| 167 |
+
background-color: #ffffff;
|
| 168 |
+
color: #111827;
|
| 169 |
+
font-size: 1rem;
|
| 170 |
+
line-height: 1.5;
|
| 171 |
+
resize: none;
|
| 172 |
+
outline: none;
|
| 173 |
+
min-height: 54px;
|
| 174 |
+
max-height: 200px;
|
| 175 |
+
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.chat-input::placeholder {
|
| 179 |
+
color: #9ca3af;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.chat-input:focus {
|
| 183 |
+
border-color: #2563eb;
|
| 184 |
+
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.send-button {
|
| 188 |
+
flex-shrink: 0;
|
| 189 |
+
padding: 0;
|
| 190 |
+
width: 54px;
|
| 191 |
+
height: 54px;
|
| 192 |
+
border: 1px solid #111827;
|
| 193 |
+
border-radius: 0.75rem;
|
| 194 |
+
background-color: #111827;
|
| 195 |
+
color: #ffffff;
|
| 196 |
+
display: flex;
|
| 197 |
+
align-items: center;
|
| 198 |
+
justify-content: center;
|
| 199 |
+
cursor: pointer;
|
| 200 |
+
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.send-button:hover:not(:disabled) {
|
| 204 |
+
background-color: #2563eb;
|
| 205 |
+
border-color: #2563eb;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.send-button:disabled {
|
| 209 |
+
cursor: not-allowed;
|
| 210 |
+
border-color: #d1d5db;
|
| 211 |
+
background-color: #e5e7eb;
|
| 212 |
+
color: #9ca3af;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.typing-indicator {
|
| 216 |
+
display: inline-block;
|
| 217 |
+
color: #6b7280;
|
| 218 |
+
letter-spacing: 0.15em;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.typing-indicator::after {
|
| 222 |
+
content: '···';
|
| 223 |
+
animation: typing 1.4s infinite;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
@keyframes typing {
|
| 227 |
+
0%, 60%, 100% { opacity: 0.2; }
|
| 228 |
+
30% { opacity: 1; }
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
.error-message {
|
| 232 |
+
background-color: #fee2e2;
|
| 233 |
+
border: 1px solid #fecaca;
|
| 234 |
+
color: #b91c1c;
|
| 235 |
+
padding: 0.75rem 1rem;
|
| 236 |
+
border-radius: 0.75rem;
|
| 237 |
+
margin-top: 0.5rem;
|
| 238 |
+
}
|
| 239 |
+
</style>
|
| 240 |
+
</head>
|
| 241 |
+
<body>
|
| 242 |
+
<div class="header">
|
| 243 |
+
<div class="header-left">
|
| 244 |
+
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
|
| 245 |
+
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 246 |
+
<path d="M12 5v14"></path>
|
| 247 |
+
<path d="M5 12h14"></path>
|
| 248 |
+
</svg>
|
| 249 |
+
</button>
|
| 250 |
+
<h1>nanochat</h1>
|
| 251 |
+
</div>
|
| 252 |
+
</div>
|
| 253 |
+
|
| 254 |
+
<div class="chat-container" id="chatContainer">
|
| 255 |
+
<div class="chat-wrapper" id="chatWrapper">
|
| 256 |
+
<!-- Messages will be added here -->
|
| 257 |
+
</div>
|
| 258 |
+
</div>
|
| 259 |
+
|
| 260 |
+
<div class="input-container">
|
| 261 |
+
<div class="input-wrapper">
|
| 262 |
+
<textarea
|
| 263 |
+
id="chatInput"
|
| 264 |
+
class="chat-input"
|
| 265 |
+
placeholder="Ask anything"
|
| 266 |
+
rows="1"
|
| 267 |
+
onkeydown="handleKeyDown(event)"
|
| 268 |
+
></textarea>
|
| 269 |
+
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
|
| 270 |
+
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 271 |
+
<path d="M22 2L11 13"></path>
|
| 272 |
+
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
|
| 273 |
+
</svg>
|
| 274 |
+
</button>
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
|
| 278 |
+
<script>
|
| 279 |
+
const API_URL = '';
|
| 280 |
+
const chatContainer = document.getElementById('chatContainer');
|
| 281 |
+
const chatWrapper = document.getElementById('chatWrapper');
|
| 282 |
+
const chatInput = document.getElementById('chatInput');
|
| 283 |
+
const sendButton = document.getElementById('sendButton');
|
| 284 |
+
|
| 285 |
+
let messages = [];
|
| 286 |
+
let isGenerating = false;
|
| 287 |
+
let currentTemperature = 0.8;
|
| 288 |
+
let currentTopK = 50;
|
| 289 |
+
|
| 290 |
+
chatInput.addEventListener('input', function() {
|
| 291 |
+
this.style.height = 'auto';
|
| 292 |
+
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
|
| 293 |
+
sendButton.disabled = !this.value.trim() || isGenerating;
|
| 294 |
+
});
|
| 295 |
+
|
| 296 |
+
function handleKeyDown(event) {
|
| 297 |
+
if (event.key === 'Enter' && !event.shiftKey) {
|
| 298 |
+
event.preventDefault();
|
| 299 |
+
sendMessage();
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
document.addEventListener('keydown', function(event) {
|
| 304 |
+
// Ctrl+Shift+N for new conversation
|
| 305 |
+
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
|
| 306 |
+
event.preventDefault();
|
| 307 |
+
if (!isGenerating) {
|
| 308 |
+
newConversation();
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
});
|
| 312 |
+
|
| 313 |
+
function newConversation() {
|
| 314 |
+
messages = [];
|
| 315 |
+
chatWrapper.innerHTML = '';
|
| 316 |
+
chatInput.value = '';
|
| 317 |
+
chatInput.style.height = 'auto';
|
| 318 |
+
sendButton.disabled = false;
|
| 319 |
+
isGenerating = false;
|
| 320 |
+
chatInput.focus();
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
function addMessage(role, content, messageIndex = null) {
|
| 324 |
+
const messageDiv = document.createElement('div');
|
| 325 |
+
messageDiv.className = `message ${role}`;
|
| 326 |
+
|
| 327 |
+
const contentDiv = document.createElement('div');
|
| 328 |
+
contentDiv.className = 'message-content';
|
| 329 |
+
contentDiv.textContent = content;
|
| 330 |
+
|
| 331 |
+
// Add click handler for user messages to enable editing
|
| 332 |
+
if (role === 'user' && messageIndex !== null) {
|
| 333 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 334 |
+
contentDiv.setAttribute('title', 'Click to edit and restart from here');
|
| 335 |
+
contentDiv.addEventListener('click', function() {
|
| 336 |
+
if (!isGenerating) {
|
| 337 |
+
editMessage(messageIndex);
|
| 338 |
+
}
|
| 339 |
+
});
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Add click handler for assistant messages to enable regeneration
|
| 343 |
+
if (role === 'assistant' && messageIndex !== null) {
|
| 344 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 345 |
+
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
| 346 |
+
contentDiv.addEventListener('click', function() {
|
| 347 |
+
if (!isGenerating) {
|
| 348 |
+
regenerateMessage(messageIndex);
|
| 349 |
+
}
|
| 350 |
+
});
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
messageDiv.appendChild(contentDiv);
|
| 354 |
+
chatWrapper.appendChild(messageDiv);
|
| 355 |
+
|
| 356 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 357 |
+
return contentDiv;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
function editMessage(messageIndex) {
|
| 361 |
+
// Find the message in the messages array
|
| 362 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 363 |
+
|
| 364 |
+
const messageToEdit = messages[messageIndex];
|
| 365 |
+
if (messageToEdit.role !== 'user') return;
|
| 366 |
+
|
| 367 |
+
// Copy message content to input
|
| 368 |
+
chatInput.value = messageToEdit.content;
|
| 369 |
+
chatInput.style.height = 'auto';
|
| 370 |
+
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
|
| 371 |
+
|
| 372 |
+
// Remove this message and all subsequent messages from the array
|
| 373 |
+
messages = messages.slice(0, messageIndex);
|
| 374 |
+
|
| 375 |
+
// Remove message elements from DOM starting from messageIndex
|
| 376 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 377 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 378 |
+
allMessages[i].remove();
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
// Enable send button and focus input
|
| 382 |
+
sendButton.disabled = false;
|
| 383 |
+
chatInput.focus();
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
async function generateAssistantResponse() {
|
| 387 |
+
isGenerating = true;
|
| 388 |
+
sendButton.disabled = true;
|
| 389 |
+
|
| 390 |
+
const assistantContent = addMessage('assistant', '');
|
| 391 |
+
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
| 392 |
+
|
| 393 |
+
try {
|
| 394 |
+
const response = await fetch(`${API_URL}/chat/completions`, {
|
| 395 |
+
method: 'POST',
|
| 396 |
+
headers: {
|
| 397 |
+
'Content-Type': 'application/json',
|
| 398 |
+
},
|
| 399 |
+
body: JSON.stringify({
|
| 400 |
+
messages: messages,
|
| 401 |
+
temperature: currentTemperature,
|
| 402 |
+
top_k: currentTopK,
|
| 403 |
+
max_tokens: 512
|
| 404 |
+
}),
|
| 405 |
+
});
|
| 406 |
+
|
| 407 |
+
if (!response.ok) {
|
| 408 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
const reader = response.body.getReader();
|
| 412 |
+
const decoder = new TextDecoder();
|
| 413 |
+
let fullResponse = '';
|
| 414 |
+
assistantContent.textContent = '';
|
| 415 |
+
|
| 416 |
+
while (true) {
|
| 417 |
+
const { done, value } = await reader.read();
|
| 418 |
+
if (done) break;
|
| 419 |
+
|
| 420 |
+
const chunk = decoder.decode(value);
|
| 421 |
+
const lines = chunk.split('\n');
|
| 422 |
+
|
| 423 |
+
for (const line of lines) {
|
| 424 |
+
if (line.startsWith('data: ')) {
|
| 425 |
+
try {
|
| 426 |
+
const data = JSON.parse(line.slice(6));
|
| 427 |
+
if (data.token) {
|
| 428 |
+
fullResponse += data.token;
|
| 429 |
+
assistantContent.textContent = fullResponse;
|
| 430 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 431 |
+
}
|
| 432 |
+
} catch (e) {
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
const assistantMessageIndex = messages.length;
|
| 439 |
+
messages.push({ role: 'assistant', content: fullResponse });
|
| 440 |
+
|
| 441 |
+
// Add click handler to regenerate this assistant message
|
| 442 |
+
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
| 443 |
+
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
| 444 |
+
assistantContent.addEventListener('click', function() {
|
| 445 |
+
if (!isGenerating) {
|
| 446 |
+
regenerateMessage(assistantMessageIndex);
|
| 447 |
+
}
|
| 448 |
+
});
|
| 449 |
+
|
| 450 |
+
} catch (error) {
|
| 451 |
+
console.error('Error:', error);
|
| 452 |
+
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
| 453 |
+
} finally {
|
| 454 |
+
isGenerating = false;
|
| 455 |
+
sendButton.disabled = !chatInput.value.trim();
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
async function regenerateMessage(messageIndex) {
|
| 460 |
+
// Find the message in the messages array
|
| 461 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 462 |
+
|
| 463 |
+
const messageToRegenerate = messages[messageIndex];
|
| 464 |
+
if (messageToRegenerate.role !== 'assistant') return;
|
| 465 |
+
|
| 466 |
+
// Remove this message and all subsequent messages from the array
|
| 467 |
+
messages = messages.slice(0, messageIndex);
|
| 468 |
+
|
| 469 |
+
// Remove message elements from DOM starting from messageIndex
|
| 470 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 471 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 472 |
+
allMessages[i].remove();
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
// Regenerate the assistant response
|
| 476 |
+
await generateAssistantResponse();
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
function handleSlashCommand(command) {
|
| 480 |
+
const parts = command.trim().split(/\s+/);
|
| 481 |
+
const cmd = parts[0].toLowerCase();
|
| 482 |
+
const arg = parts[1];
|
| 483 |
+
|
| 484 |
+
if (cmd === '/temperature') {
|
| 485 |
+
if (arg === undefined) {
|
| 486 |
+
addMessage('console', `Current temperature: ${currentTemperature}`);
|
| 487 |
+
} else {
|
| 488 |
+
const temp = parseFloat(arg);
|
| 489 |
+
if (isNaN(temp) || temp < 0 || temp > 2) {
|
| 490 |
+
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
|
| 491 |
+
} else {
|
| 492 |
+
currentTemperature = temp;
|
| 493 |
+
addMessage('console', `Temperature set to ${currentTemperature}`);
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
return true;
|
| 497 |
+
} else if (cmd === '/topk') {
|
| 498 |
+
if (arg === undefined) {
|
| 499 |
+
addMessage('console', `Current top-k: ${currentTopK}`);
|
| 500 |
+
} else {
|
| 501 |
+
const topk = parseInt(arg);
|
| 502 |
+
if (isNaN(topk) || topk < 1 || topk > 200) {
|
| 503 |
+
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
|
| 504 |
+
} else {
|
| 505 |
+
currentTopK = topk;
|
| 506 |
+
addMessage('console', `Top-k set to ${currentTopK}`);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
return true;
|
| 510 |
+
} else if (cmd === '/clear') {
|
| 511 |
+
newConversation();
|
| 512 |
+
return true;
|
| 513 |
+
} else if (cmd === '/help') {
|
| 514 |
+
addMessage('console',
|
| 515 |
+
'Available commands:\n' +
|
| 516 |
+
'/temperature - Show current temperature\n' +
|
| 517 |
+
'/temperature <value> - Set temperature (0.0-2.0)\n' +
|
| 518 |
+
'/topk - Show current top-k\n' +
|
| 519 |
+
'/topk <value> - Set top-k (1-200)\n' +
|
| 520 |
+
'/clear - Clear conversation\n' +
|
| 521 |
+
'/help - Show this help message'
|
| 522 |
+
);
|
| 523 |
+
return true;
|
| 524 |
+
}
|
| 525 |
+
return false;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
async function sendMessage() {
|
| 529 |
+
const message = chatInput.value.trim();
|
| 530 |
+
if (!message || isGenerating) return;
|
| 531 |
+
|
| 532 |
+
// Handle slash commands
|
| 533 |
+
if (message.startsWith('/')) {
|
| 534 |
+
chatInput.value = '';
|
| 535 |
+
chatInput.style.height = 'auto';
|
| 536 |
+
handleSlashCommand(message);
|
| 537 |
+
return;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
chatInput.value = '';
|
| 541 |
+
chatInput.style.height = 'auto';
|
| 542 |
+
|
| 543 |
+
const userMessageIndex = messages.length;
|
| 544 |
+
messages.push({ role: 'user', content: message });
|
| 545 |
+
addMessage('user', message, userMessageIndex);
|
| 546 |
+
|
| 547 |
+
await generateAssistantResponse();
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
sendButton.disabled = false;
|
| 551 |
+
|
| 552 |
+
// Autofocus the chat input on page load
|
| 553 |
+
chatInput.focus();
|
| 554 |
+
|
| 555 |
+
fetch(`${API_URL}/health`)
|
| 556 |
+
.then(response => response.json())
|
| 557 |
+
.then(data => {
|
| 558 |
+
console.log('Engine status:', data);
|
| 559 |
+
})
|
| 560 |
+
.catch(error => {
|
| 561 |
+
console.error('Engine not available:', error);
|
| 562 |
+
chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
|
| 563 |
+
});
|
| 564 |
+
</script>
|
| 565 |
+
</body>
|
| 566 |
+
</html>
|