moe100m-physics-tinybpe / train_phys.py
AlexWortega's picture
Upload train_phys.py with huggingface_hub
d0edc76 verified
"""Single-GPU (V100, fp16) from-scratch trainer for the tiny-vocab physics MoE.
Adapted from scaffold/train/train_200m.py — drops DDP / EMA / WSD-resume /
HF-push machinery, keeps the load-bearing bits:
- Muon (matrix) + AdamW (rest) via optim.make_param_groups
- cosine LR schedule with warmup
- fp16 autocast forward, fp32 router math (handled inside model.py), dynamic
loss-scale, NaN-guard (skip step + halve scale; abort after nan_cap)
- router aux/z loss added; router bias controller stepped each good step
- chunked / Liger fused CE (from model.py)
Logs: train.log (per-step), eval.log (periodic val loss). Checkpoints to ckpts/.
"""
from __future__ import annotations
import argparse, json, math, os, sys, time
import torch
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(_HERE, "..", "scaffold"))
from model import MoEModel # noqa: E402
from optim import Muon, make_param_groups # noqa: E402
from config_100m import make_config # noqa: E402
import data_physics as dp # noqa: E402
def cosine_lr(step, peak, warmup, total, min_lr):
if step < warmup:
return peak * (step + 1) / warmup
p = (step - warmup) / max(1, total - warmup)
p = min(1.0, p)
return min_lr + 0.5 * (peak - min_lr) * (1 + math.cos(math.pi * p))
@torch.no_grad()
def eval_loss(model, tok_path, seq_len, batch_size, n_batches, device):
model.eval()
it = dp.batch_iterator(tok_path, seq_len, batch_size, split="val",
types=dp.TRAIN_TYPES, device=device, infinite=False,
shuffle_buffer=0, seed=123)
tot, n = 0.0, 0
for ids, lbl in it:
with torch.cuda.amp.autocast(dtype=torch.float16):
_, loss, _ = model(ids, labels=lbl)
if loss is not None and torch.isfinite(loss):
tot += float(loss.item()); n += 1
if n >= n_batches:
break
model.train()
return tot / max(1, n)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--tokenizer", default="tokenizer.json")
ap.add_argument("--vocab", type=int, default=512)
ap.add_argument("--seq-len", type=int, default=1024)
ap.add_argument("--batch-size", type=int, default=8)
ap.add_argument("--grad-accum", type=int, default=1)
ap.add_argument("--peak-lr", type=float, default=6e-4)
ap.add_argument("--min-lr", type=float, default=3e-5)
ap.add_argument("--warmup", type=int, default=500)
ap.add_argument("--token-budget", type=float, default=2.5e9)
ap.add_argument("--max-steps", type=int, default=0) # 0 = derive from budget
ap.add_argument("--eval-every", type=int, default=1000)
ap.add_argument("--eval-batches", type=int, default=30)
ap.add_argument("--ckpt-every", type=int, default=2000)
ap.add_argument("--shuffle-buffer", type=int, default=200)
ap.add_argument("--nan-cap", type=int, default=50)
ap.add_argument("--out", default="ckpts")
ap.add_argument("--max-wall-hours", type=float, default=23.0)
ap.add_argument("--smoke", action="store_true")
ap.add_argument("--resume", default="")
ap.add_argument("--data-seed", type=int, default=0)
args = ap.parse_args()
device = "cuda"
torch.manual_seed(0)
os.makedirs(args.out, exist_ok=True)
tokens_per_step = args.batch_size * args.seq_len * args.grad_accum
total_steps = args.max_steps or int(args.token_budget / tokens_per_step)
cfg = make_config(args.vocab, max_seq_len=args.seq_len)
model = MoEModel(cfg).to(device)
start_step = 0
if args.resume and os.path.exists(args.resume):
ck = torch.load(args.resume, map_location="cpu", weights_only=False)
model.load_state_dict(ck["model"])
start_step = int(ck.get("step", 0))
print(f"[init] RESUMED weights from {args.resume} @ step {start_step}", flush=True)
act = model.num_parameters(only_active=True) / 1e6
tot = model.num_parameters() / 1e6
print(f"[init] ACTIVE={act:.2f}M TOTAL={tot:.2f}M vocab={cfg.vocab_size} "
f"total_steps={total_steps} tokens/step={tokens_per_step} start_step={start_step}", flush=True)
matrix, non_matrix = make_param_groups(model)
opt = Muon(matrix, non_matrix, lr=args.peak_lr, momentum=0.95,
ns_mode="fp32", weight_decay=0.01, betas=(0.9, 0.95),
foreach=True)
loss_scale = 2.0 ** 14
# cap at 2^16: physics grads occasionally overflow above that, causing a
# benign-but-frequent NaN-skip oscillation. Lower ceiling = far fewer skips.
loss_scale_min, loss_scale_max = 2.0 ** 0, 2.0 ** 16
grow_every = 200
n_good = 0
nan_count = 0
consec_nan = 0
data = dp.batch_iterator(args.tokenizer, args.seq_len, args.batch_size,
split="train", types=dp.TRAIN_TYPES, device=device,
infinite=True, shuffle_buffer=args.shuffle_buffer,
seed=args.data_seed)
train_log = open(os.path.join(args.out, "..", "train.log"), "a")
eval_log = open(os.path.join(args.out, "..", "eval.log"), "a")
def logln(f, s):
f.write(s + "\n"); f.flush(); print(s, flush=True)
t_start = time.time()
tokens_seen = start_step * tokens_per_step
walls = []
params = list(model.parameters())
best_eval = float("inf")
for step in range(start_step, total_steps):
lr = cosine_lr(step, args.peak_lr, args.warmup, total_steps, args.min_lr)
opt.set_lr(lr)
opt.zero_grad()
accum_loss = 0.0
aux_last = None
ok = True
forward_bad = False
for _ in range(args.grad_accum):
ids, lbl = next(data)
t0 = time.perf_counter()
with torch.cuda.amp.autocast(dtype=torch.float16):
_, lm_loss, aux = model(ids, labels=lbl)
# Catch activation/loss explosion AT THE SOURCE: if the forward loss
# is already non-finite (a pathological high-velocity batch overflowed
# fp16 inside CE/router), skip this batch entirely instead of letting
# the NaN propagate into the weights via backward.
if not torch.isfinite(lm_loss):
forward_bad = True
break
loss = lm_loss
if aux is not None:
loss = loss + cfg.router_z_coef * aux["z_loss"] + \
cfg.router_aux_coef * aux["aux_loss"]
(loss * loss_scale / args.grad_accum).backward()
accum_loss += float(lm_loss.item())
aux_last = aux
if forward_bad:
nan_count += 1
consec_nan += 1
opt.zero_grad()
logln(train_log, f"step {step} non-finite FORWARD loss -> skip batch "
f"(consec={consec_nan} total={nan_count})")
if consec_nan > args.nan_cap:
logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE bad -> ABORT")
break
continue
# unscale + NaN guard
inv = 1.0 / loss_scale
nan_seen = False
for p in params:
if p.grad is None:
continue
p.grad.data.mul_(inv)
if not torch.isfinite(p.grad.data).all():
nan_seen = True
break
if nan_seen:
nan_count += 1
consec_nan += 1
loss_scale = max(loss_scale_min, loss_scale * 0.5)
n_good = 0
logln(train_log, f"step {step} NaN/Inf grad -> skip; scale={loss_scale:.0f} "
f"(consec={consec_nan} total={nan_count})")
# abort only on SUSTAINED divergence (consecutive), not cumulative —
# occasional fp16 overflow at high loss-scale is benign and recovers.
if consec_nan > args.nan_cap:
logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE NaN -> ABORT")
break
continue
consec_nan = 0
torch.nn.utils.clip_grad_norm_(params, 1.0)
opt.step()
n_good += 1
if n_good >= grow_every:
loss_scale = min(loss_scale_max, loss_scale * 2.0)
n_good = 0
if aux_last is not None:
model.step_router_biases(aux_last["counts_per_layer"])
torch.cuda.synchronize()
walls.append(time.perf_counter() - t0)
tokens_seen += tokens_per_step
avg_loss = accum_loss / args.grad_accum
if step % 20 == 0 or step < 5:
cv = float(aux_last["router_cv"].item()) if aux_last is not None else 0.0
tps = tokens_per_step / (sum(walls[-20:]) / len(walls[-20:]))
logln(train_log, f"step {step} loss={avg_loss:.4f} lr={lr:.2e} "
f"scale={loss_scale:.0f} cv={cv:.3f} tok={tokens_seen} "
f"tok/s={tps:.0f} elapsed={(time.time()-t_start)/3600:.2f}h")
if step > 0 and step % args.eval_every == 0:
ev = eval_loss(model, args.tokenizer, args.seq_len, args.batch_size,
args.eval_batches, device)
logln(eval_log, f"step {step} eval_loss={ev:.4f} train_loss={avg_loss:.4f} tok={tokens_seen}")
if ev < best_eval:
best_eval = ev
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step, "eval_loss": ev},
os.path.join(args.out, "best.pt"))
if step > 0 and step % args.ckpt_every == 0:
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step}, os.path.join(args.out, "last.pt"))
if (time.time() - t_start) / 3600.0 > args.max_wall_hours:
logln(train_log, f"step {step} wall-cap {args.max_wall_hours}h reached -> stop")
break
if args.smoke and step >= 1:
logln(train_log, f"[smoke] completed {step+1} real steps, loss={avg_loss:.4f}")
break
# final save
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step, "final": True}, os.path.join(args.out, "last.pt"))
if best_eval == float("inf"):
torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(),
"step": step}, os.path.join(args.out, "best.pt"))
summary = {"final_train_loss": avg_loss, "best_eval_loss": best_eval,
"steps": step + 1, "tokens_seen": tokens_seen,
"active_M": act, "total_M": tot,
"wall_hours": (time.time() - t_start) / 3600.0,
"planned_tokens": args.token_budget, "total_steps": total_steps}
with open(os.path.join(args.out, "..", "train_summary.json"), "w") as f:
json.dump(summary, f, indent=2)
logln(train_log, f"DONE {json.dumps(summary)}")
if __name__ == "__main__":
main()