Instructions to use AlexWortega/moe100m-physics-tinybpe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AlexWortega/moe100m-physics-tinybpe with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("AlexWortega/moe100m-physics-tinybpe", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Single-GPU (V100, fp16) from-scratch trainer for the tiny-vocab physics MoE. | |
| Adapted from scaffold/train/train_200m.py — drops DDP / EMA / WSD-resume / | |
| HF-push machinery, keeps the load-bearing bits: | |
| - Muon (matrix) + AdamW (rest) via optim.make_param_groups | |
| - cosine LR schedule with warmup | |
| - fp16 autocast forward, fp32 router math (handled inside model.py), dynamic | |
| loss-scale, NaN-guard (skip step + halve scale; abort after nan_cap) | |
| - router aux/z loss added; router bias controller stepped each good step | |
| - chunked / Liger fused CE (from model.py) | |
| Logs: train.log (per-step), eval.log (periodic val loss). Checkpoints to ckpts/. | |
| """ | |
| from __future__ import annotations | |
| import argparse, json, math, os, sys, time | |
| import torch | |
| _HERE = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.join(_HERE, "..", "scaffold")) | |
| from model import MoEModel # noqa: E402 | |
| from optim import Muon, make_param_groups # noqa: E402 | |
| from config_100m import make_config # noqa: E402 | |
| import data_physics as dp # noqa: E402 | |
| def cosine_lr(step, peak, warmup, total, min_lr): | |
| if step < warmup: | |
| return peak * (step + 1) / warmup | |
| p = (step - warmup) / max(1, total - warmup) | |
| p = min(1.0, p) | |
| return min_lr + 0.5 * (peak - min_lr) * (1 + math.cos(math.pi * p)) | |
| def eval_loss(model, tok_path, seq_len, batch_size, n_batches, device): | |
| model.eval() | |
| it = dp.batch_iterator(tok_path, seq_len, batch_size, split="val", | |
| types=dp.TRAIN_TYPES, device=device, infinite=False, | |
| shuffle_buffer=0, seed=123) | |
| tot, n = 0.0, 0 | |
| for ids, lbl in it: | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| _, loss, _ = model(ids, labels=lbl) | |
| if loss is not None and torch.isfinite(loss): | |
| tot += float(loss.item()); n += 1 | |
| if n >= n_batches: | |
| break | |
| model.train() | |
| return tot / max(1, n) | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--tokenizer", default="tokenizer.json") | |
| ap.add_argument("--vocab", type=int, default=512) | |
| ap.add_argument("--seq-len", type=int, default=1024) | |
| ap.add_argument("--batch-size", type=int, default=8) | |
| ap.add_argument("--grad-accum", type=int, default=1) | |
| ap.add_argument("--peak-lr", type=float, default=6e-4) | |
| ap.add_argument("--min-lr", type=float, default=3e-5) | |
| ap.add_argument("--warmup", type=int, default=500) | |
| ap.add_argument("--token-budget", type=float, default=2.5e9) | |
| ap.add_argument("--max-steps", type=int, default=0) # 0 = derive from budget | |
| ap.add_argument("--eval-every", type=int, default=1000) | |
| ap.add_argument("--eval-batches", type=int, default=30) | |
| ap.add_argument("--ckpt-every", type=int, default=2000) | |
| ap.add_argument("--shuffle-buffer", type=int, default=200) | |
| ap.add_argument("--nan-cap", type=int, default=50) | |
| ap.add_argument("--out", default="ckpts") | |
| ap.add_argument("--max-wall-hours", type=float, default=23.0) | |
| ap.add_argument("--smoke", action="store_true") | |
| ap.add_argument("--resume", default="") | |
| ap.add_argument("--data-seed", type=int, default=0) | |
| args = ap.parse_args() | |
| device = "cuda" | |
| torch.manual_seed(0) | |
| os.makedirs(args.out, exist_ok=True) | |
| tokens_per_step = args.batch_size * args.seq_len * args.grad_accum | |
| total_steps = args.max_steps or int(args.token_budget / tokens_per_step) | |
| cfg = make_config(args.vocab, max_seq_len=args.seq_len) | |
| model = MoEModel(cfg).to(device) | |
| start_step = 0 | |
| if args.resume and os.path.exists(args.resume): | |
| ck = torch.load(args.resume, map_location="cpu", weights_only=False) | |
| model.load_state_dict(ck["model"]) | |
| start_step = int(ck.get("step", 0)) | |
| print(f"[init] RESUMED weights from {args.resume} @ step {start_step}", flush=True) | |
| act = model.num_parameters(only_active=True) / 1e6 | |
| tot = model.num_parameters() / 1e6 | |
| print(f"[init] ACTIVE={act:.2f}M TOTAL={tot:.2f}M vocab={cfg.vocab_size} " | |
| f"total_steps={total_steps} tokens/step={tokens_per_step} start_step={start_step}", flush=True) | |
| matrix, non_matrix = make_param_groups(model) | |
| opt = Muon(matrix, non_matrix, lr=args.peak_lr, momentum=0.95, | |
| ns_mode="fp32", weight_decay=0.01, betas=(0.9, 0.95), | |
| foreach=True) | |
| loss_scale = 2.0 ** 14 | |
| # cap at 2^16: physics grads occasionally overflow above that, causing a | |
| # benign-but-frequent NaN-skip oscillation. Lower ceiling = far fewer skips. | |
| loss_scale_min, loss_scale_max = 2.0 ** 0, 2.0 ** 16 | |
| grow_every = 200 | |
| n_good = 0 | |
| nan_count = 0 | |
| consec_nan = 0 | |
| data = dp.batch_iterator(args.tokenizer, args.seq_len, args.batch_size, | |
| split="train", types=dp.TRAIN_TYPES, device=device, | |
| infinite=True, shuffle_buffer=args.shuffle_buffer, | |
| seed=args.data_seed) | |
| train_log = open(os.path.join(args.out, "..", "train.log"), "a") | |
| eval_log = open(os.path.join(args.out, "..", "eval.log"), "a") | |
| def logln(f, s): | |
| f.write(s + "\n"); f.flush(); print(s, flush=True) | |
| t_start = time.time() | |
| tokens_seen = start_step * tokens_per_step | |
| walls = [] | |
| params = list(model.parameters()) | |
| best_eval = float("inf") | |
| for step in range(start_step, total_steps): | |
| lr = cosine_lr(step, args.peak_lr, args.warmup, total_steps, args.min_lr) | |
| opt.set_lr(lr) | |
| opt.zero_grad() | |
| accum_loss = 0.0 | |
| aux_last = None | |
| ok = True | |
| forward_bad = False | |
| for _ in range(args.grad_accum): | |
| ids, lbl = next(data) | |
| t0 = time.perf_counter() | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| _, lm_loss, aux = model(ids, labels=lbl) | |
| # Catch activation/loss explosion AT THE SOURCE: if the forward loss | |
| # is already non-finite (a pathological high-velocity batch overflowed | |
| # fp16 inside CE/router), skip this batch entirely instead of letting | |
| # the NaN propagate into the weights via backward. | |
| if not torch.isfinite(lm_loss): | |
| forward_bad = True | |
| break | |
| loss = lm_loss | |
| if aux is not None: | |
| loss = loss + cfg.router_z_coef * aux["z_loss"] + \ | |
| cfg.router_aux_coef * aux["aux_loss"] | |
| (loss * loss_scale / args.grad_accum).backward() | |
| accum_loss += float(lm_loss.item()) | |
| aux_last = aux | |
| if forward_bad: | |
| nan_count += 1 | |
| consec_nan += 1 | |
| opt.zero_grad() | |
| logln(train_log, f"step {step} non-finite FORWARD loss -> skip batch " | |
| f"(consec={consec_nan} total={nan_count})") | |
| if consec_nan > args.nan_cap: | |
| logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE bad -> ABORT") | |
| break | |
| continue | |
| # unscale + NaN guard | |
| inv = 1.0 / loss_scale | |
| nan_seen = False | |
| for p in params: | |
| if p.grad is None: | |
| continue | |
| p.grad.data.mul_(inv) | |
| if not torch.isfinite(p.grad.data).all(): | |
| nan_seen = True | |
| break | |
| if nan_seen: | |
| nan_count += 1 | |
| consec_nan += 1 | |
| loss_scale = max(loss_scale_min, loss_scale * 0.5) | |
| n_good = 0 | |
| logln(train_log, f"step {step} NaN/Inf grad -> skip; scale={loss_scale:.0f} " | |
| f"(consec={consec_nan} total={nan_count})") | |
| # abort only on SUSTAINED divergence (consecutive), not cumulative — | |
| # occasional fp16 overflow at high loss-scale is benign and recovers. | |
| if consec_nan > args.nan_cap: | |
| logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE NaN -> ABORT") | |
| break | |
| continue | |
| consec_nan = 0 | |
| torch.nn.utils.clip_grad_norm_(params, 1.0) | |
| opt.step() | |
| n_good += 1 | |
| if n_good >= grow_every: | |
| loss_scale = min(loss_scale_max, loss_scale * 2.0) | |
| n_good = 0 | |
| if aux_last is not None: | |
| model.step_router_biases(aux_last["counts_per_layer"]) | |
| torch.cuda.synchronize() | |
| walls.append(time.perf_counter() - t0) | |
| tokens_seen += tokens_per_step | |
| avg_loss = accum_loss / args.grad_accum | |
| if step % 20 == 0 or step < 5: | |
| cv = float(aux_last["router_cv"].item()) if aux_last is not None else 0.0 | |
| tps = tokens_per_step / (sum(walls[-20:]) / len(walls[-20:])) | |
| logln(train_log, f"step {step} loss={avg_loss:.4f} lr={lr:.2e} " | |
| f"scale={loss_scale:.0f} cv={cv:.3f} tok={tokens_seen} " | |
| f"tok/s={tps:.0f} elapsed={(time.time()-t_start)/3600:.2f}h") | |
| if step > 0 and step % args.eval_every == 0: | |
| ev = eval_loss(model, args.tokenizer, args.seq_len, args.batch_size, | |
| args.eval_batches, device) | |
| logln(eval_log, f"step {step} eval_loss={ev:.4f} train_loss={avg_loss:.4f} tok={tokens_seen}") | |
| if ev < best_eval: | |
| best_eval = ev | |
| torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), | |
| "step": step, "eval_loss": ev}, | |
| os.path.join(args.out, "best.pt")) | |
| if step > 0 and step % args.ckpt_every == 0: | |
| torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), | |
| "step": step}, os.path.join(args.out, "last.pt")) | |
| if (time.time() - t_start) / 3600.0 > args.max_wall_hours: | |
| logln(train_log, f"step {step} wall-cap {args.max_wall_hours}h reached -> stop") | |
| break | |
| if args.smoke and step >= 1: | |
| logln(train_log, f"[smoke] completed {step+1} real steps, loss={avg_loss:.4f}") | |
| break | |
| # final save | |
| torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), | |
| "step": step, "final": True}, os.path.join(args.out, "last.pt")) | |
| if best_eval == float("inf"): | |
| torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), | |
| "step": step}, os.path.join(args.out, "best.pt")) | |
| summary = {"final_train_loss": avg_loss, "best_eval_loss": best_eval, | |
| "steps": step + 1, "tokens_seen": tokens_seen, | |
| "active_M": act, "total_M": tot, | |
| "wall_hours": (time.time() - t_start) / 3600.0, | |
| "planned_tokens": args.token_budget, "total_steps": total_steps} | |
| with open(os.path.join(args.out, "..", "train_summary.json"), "w") as f: | |
| json.dump(summary, f, indent=2) | |
| logln(train_log, f"DONE {json.dumps(summary)}") | |
| if __name__ == "__main__": | |
| main() | |