Spaces:
Runtime error
Runtime error
| """Strip optimizer_state_dict from a checkpoint, keeping only model weights | |
| and config metadata. | |
| Reason: resuming training.py's standard path restores the optimizer state, | |
| which (in our 6GB / Muon-compile / bf16 setup) reproducibly produces a | |
| NaN/>100-loss on the first forward after load. Reloading model weights | |
| only and letting the optimizer initialize fresh sidesteps the issue. | |
| Output checkpoint also clears `step`, `train_seconds`, `epoch` so the LR | |
| schedule and warmup restart from zero — useful when we want to fine-tune | |
| the trained weights at a new schedule length. | |
| """ | |
| import sys, torch | |
| src, dst = sys.argv[1], sys.argv[2] | |
| ckpt = torch.load(src, map_location="cpu", weights_only=False) | |
| keep = { | |
| "model_state_dict": ckpt.get("model_state_dict", ckpt), | |
| "config": ckpt.get("config"), | |
| # Reset training progress markers so LR schedule warmups cleanly. | |
| "step": 0, | |
| "train_seconds": 0.0, | |
| "smoothed_loss": 0.0, | |
| "bpt_ema": 0.0, | |
| "epoch": 0, | |
| } | |
| # Explicitly do NOT copy optimizer_state_dict. | |
| torch.save(keep, dst) | |
| print(f"Stripped -> {dst} (orig {sum(1 for _ in ckpt)} keys, kept {len(keep)})") | |