Instructions to use AlexWortega/moe100m-physics-tinybpe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AlexWortega/moe100m-physics-tinybpe with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("AlexWortega/moe100m-physics-tinybpe", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 10,913 Bytes
d0edc76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 | """Single-GPU (V100, fp16) from-scratch trainer for the tiny-vocab physics MoE.
Adapted from scaffold/train/train_200m.py — drops DDP / EMA / WSD-resume /
HF-push machinery, keeps the load-bearing bits:
- Muon (matrix) + AdamW (rest) via optim.make_param_groups
- cosine LR schedule with warmup
- fp16 autocast forward, fp32 router math (handled inside model.py), dynamic
loss-scale, NaN-guard (skip step + halve scale; abort after nan_cap)
- router aux/z loss added; router bias controller stepped each good step
- chunked / Liger fused CE (from model.py)
Logs: train.log (per-step), eval.log (periodic val loss). Checkpoints to ckpts/.
"""
from __future__ import annotations
import argparse, json, math, os, sys, time
import torch
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(_HERE, "..", "scaffold"))
from model import MoEModel # noqa: E402
from optim import Muon, make_param_groups # noqa: E402
from config_100m import make_config # noqa: E402
import data_physics as dp # noqa: E402
def cosine_lr(step, peak, warmup, total, min_lr):
if step < warmup:
return peak * (step + 1) / warmup
p = (step - warmup) / max(1, total - warmup)
p = min(1.0, p)
return min_lr + 0.5 * (peak - min_lr) * (1 + math.cos(math.pi * p))
@torch.no_grad()
def eval_loss(model, tok_path, seq_len, batch_size, n_batches, device):
model.eval()
it = dp.batch_iterator(tok_path, seq_len, batch_size, split="val",
types=dp.TRAIN_TYPES, device=device, infinite=False,
shuffle_buffer=0, seed=123)
tot, n = 0.0, 0
for ids, lbl in it:
with torch.cuda.amp.autocast(dtype=torch.float16):
_, loss, _ = model(ids, labels=lbl)
if loss is not None and torch.isfinite(loss):
tot += float(loss.item()); n += 1
if n >= n_batches:
break
model.train()
return tot / max(1, n)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--tokenizer", default="tokenizer.json")
ap.add_argument("--vocab", type=int, default=512)
ap.add_argument("--seq-len", type=int, default=1024)
ap.add_argument("--batch-size", type=int, default=8)
ap.add_argument("--grad-accum", type=int, default=1)
ap.add_argument("--peak-lr", type=float, default=6e-4)
ap.add_argument("--min-lr", type=float, default=3e-5)
ap.add_argument("--warmup", type=int, default=500)
ap.add_argument("--token-budget", type=float, default=2.5e9)
ap.add_argument("--max-steps", type=int, default=0) # 0 = derive from budget
ap.add_argument("--eval-every", type=int, default=1000)
ap.add_argument("--eval-batches", type=int, default=30)
ap.add_argument("--ckpt-every", type=int, default=2000)
ap.add_argument("--shuffle-buffer", type=int, default=200)
ap.add_argument("--nan-cap", type=int, default=50)
ap.add_argument("--out", default="ckpts")
ap.add_argument("--max-wall-hours", type=float, default=23.0)
ap.add_argument("--smoke", action="store_true")
ap.add_argument("--resume", default="")
ap.add_argument("--data-seed", type=int, default=0)
args = ap.parse_args()
device = "cuda"
torch.manual_seed(0)
os.makedirs(args.out, exist_ok=True)
tokens_per_step = args.batch_size * args.seq_len * args.grad_accum
total_steps = args.max_steps or int(args.token_budget / tokens_per_step)
cfg = make_config(args.vocab, max_seq_len=args.seq_len)
model = MoEModel(cfg).to(device)
start_step = 0
if args.resume and os.path.exists(args.resume):
ck = torch.load(args.resume, map_location="cpu", weights_only=False)
model.load_state_dict(ck["model"])
start_step = int(ck.get("step", 0))
print(f"[init] RESUMED weights from {args.resume} @ step {start_step}", flush=True)
act = model.num_parameters(only_active=True) / 1e6
tot = model.num_parameters() / 1e6
print(f"[init] ACTIVE={act:.2f}M TOTAL={tot:.2f}M vocab={cfg.vocab_size} "
f"total_steps={total_steps} tokens/step={tokens_per_step} start_step={start_step}", flush=True)
matrix, non_matrix = make_param_groups(model)
opt = Muon(matrix, non_matrix, lr=args.peak_lr, momentum=0.95,
ns_mode="fp32", weight_decay=0.01, betas=(0.9, 0.95),
foreach=True)
loss_scale = 2.0 ** 14
# cap at 2^16: physics grads occasionally overflow above that, causing a
# benign-but-frequent NaN-skip oscillation. Lower ceiling = far fewer skips.
loss_scale_min, loss_scale_max = 2.0 ** 0, 2.0 ** 16
grow_every = 200
n_good = 0
nan_count = 0
consec_nan = 0
data = dp.batch_iterator(args.tokenizer, args.seq_len, args.batch_size,
split="train", types=dp.TRAIN_TYPES, device=device,
infinite=True, shuffle_buffer=args.shuffle_buffer,
seed=args.data_seed)
train_log = open(os.path.join(args.out, "..", "train.log"), "a")
eval_log = open(os.path.join(args.out, "..", "eval.log"), "a")
def logln(f, s):
f.write(s + "\n"); f.flush(); print(s, flush=True)
t_start = time.time()
tokens_seen = start_step * tokens_per_step
walls = []
params = list(model.parameters())
best_eval = float("inf")
for step in range(start_step, total_steps):
lr = cosine_lr(step, args.peak_lr, args.warmup, total_steps, args.min_lr)
opt.set_lr(lr)
opt.zero_grad()
accum_loss = 0.0
aux_last = None
ok = True
forward_bad = False
for _ in range(args.grad_accum):
ids, lbl = next(data)
t0 = time.perf_counter()
with torch.cuda.amp.autocast(dtype=torch.float16):
_, lm_loss, aux = model(ids, labels=lbl)
# Catch activation/loss explosion AT THE SOURCE: if the forward loss
# is already non-finite (a pathological high-velocity batch overflowed
# fp16 inside CE/router), skip this batch entirely instead of letting
# the NaN propagate into the weights via backward.
if not torch.isfinite(lm_loss):
forward_bad = True
break
loss = lm_loss
if aux is not None:
loss = loss + cfg.router_z_coef * aux["z_loss"] + \
cfg.router_aux_coef * aux["aux_loss"]
(loss * loss_scale / args.grad_accum).backward()
accum_loss += float(lm_loss.item())
aux_last = aux
if forward_bad:
nan_count += 1
consec_nan += 1
opt.zero_grad()
logln(train_log, f"step {step} non-finite FORWARD loss -> skip batch "
f"(consec={consec_nan} total={nan_count})")
if consec_nan > args.nan_cap:
logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE bad -> ABORT")
break
continue
# unscale + NaN guard
inv = 1.0 / loss_scale
nan_seen = False
for p in params:
if p.grad is None:
continue
p.grad.data.mul_(inv)
if not torch.isfinite(p.grad.data).all():
nan_seen = True
break
if nan_seen:
nan_count += 1
consec_nan += 1
loss_scale = max(loss_scale_min, loss_scale * 0.5)
n_good = 0
logln(train_log, f"step {step} NaN/Inf grad -> skip; scale={loss_scale:.0f} "
f"(consec={consec_nan} total={nan_count})")
# abort only on SUSTAINED divergence (consecutive), not cumulative —
# occasional fp16 overflow at high loss-scale is benign and recovers.
if consec_nan > args.nan_cap:
logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE NaN -> ABORT")
break
continue
consec_nan = 0
torch.nn.utils.clip_grad_norm_(params, 1.0)
opt.step()
n_good += 1
if n_good >= grow_every:
loss_scale = min(loss_scale_max, loss_scale * 2.0)
n_good = 0
if aux_last is not None:
model.step_router_biases(aux_last["counts_per_layer"])
torch.cuda.synchronize()
walls.append(time.perf_counter() - t0)
tokens_seen += tokens_per_step
avg_loss = accum_loss / args.grad_accum
if step % 20 == 0 or step < 5:
cv = float(aux_last["router_cv"].item()) if aux_last is not None else 0.0
tps = tokens_per_step / (sum(walls[-20:]) / len(walls[-20:]))
logln(train_log, f"step {step} loss={avg_loss:.4f} lr={lr:.2e} "
f"scale={loss_scale:.0f} cv={cv:.3f} tok={tokens_seen} "
f"tok/s={tps:.0f} elapsed={(time.time()-t_start)/3600:.2f}h")
if step > 0 and step % args.eval_every == 0:
ev = eval_loss(model, args.tokenizer, args.seq_len, args.batch_size,
args.eval_batches, device)
logln(eval_log, f"step {step} eval_loss={ev:.4f} train_loss={avg_loss:.4f} tok={tokens_seen}")
if ev < best_eval:
best_eval = ev
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step, "eval_loss": ev},
os.path.join(args.out, "best.pt"))
if step > 0 and step % args.ckpt_every == 0:
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step}, os.path.join(args.out, "last.pt"))
if (time.time() - t_start) / 3600.0 > args.max_wall_hours:
logln(train_log, f"step {step} wall-cap {args.max_wall_hours}h reached -> stop")
break
if args.smoke and step >= 1:
logln(train_log, f"[smoke] completed {step+1} real steps, loss={avg_loss:.4f}")
break
# final save
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step, "final": True}, os.path.join(args.out, "last.pt"))
if best_eval == float("inf"):
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step}, os.path.join(args.out, "best.pt"))
summary = {"final_train_loss": avg_loss, "best_eval_loss": best_eval,
"steps": step + 1, "tokens_seen": tokens_seen,
"active_M": act, "total_M": tot,
"wall_hours": (time.time() - t_start) / 3600.0,
"planned_tokens": args.token_budget, "total_steps": total_steps}
with open(os.path.join(args.out, "..", "train_summary.json"), "w") as f:
json.dump(summary, f, indent=2)
logln(train_log, f"DONE {json.dumps(summary)}")
if __name__ == "__main__":
main()
|