"""Strip optimizer_state_dict from a checkpoint, keeping only model weights and config metadata. Reason: resuming training.py's standard path restores the optimizer state, which (in our 6GB / Muon-compile / bf16 setup) reproducibly produces a NaN/>100-loss on the first forward after load. Reloading model weights only and letting the optimizer initialize fresh sidesteps the issue. Output checkpoint also clears `step`, `train_seconds`, `epoch` so the LR schedule and warmup restart from zero — useful when we want to fine-tune the trained weights at a new schedule length. """ import sys, torch src, dst = sys.argv[1], sys.argv[2] ckpt = torch.load(src, map_location="cpu", weights_only=False) keep = { "model_state_dict": ckpt.get("model_state_dict", ckpt), "config": ckpt.get("config"), # Reset training progress markers so LR schedule warmups cleanly. "step": 0, "train_seconds": 0.0, "smoothed_loss": 0.0, "bpt_ema": 0.0, "epoch": 0, } # Explicitly do NOT copy optimizer_state_dict. torch.save(keep, dst) print(f"Stripped -> {dst} (orig {sum(1 for _ in ckpt)} keys, kept {len(keep)})")