File size: 1,158 Bytes
0f7408a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Strip optimizer_state_dict from a checkpoint, keeping only model weights

and config metadata.



Reason: resuming training.py's standard path restores the optimizer state,

which (in our 6GB / Muon-compile / bf16 setup) reproducibly produces a

NaN/>100-loss on the first forward after load. Reloading model weights

only and letting the optimizer initialize fresh sidesteps the issue.



Output checkpoint also clears `step`, `train_seconds`, `epoch` so the LR

schedule and warmup restart from zero — useful when we want to fine-tune

the trained weights at a new schedule length.

"""
import sys, torch

src, dst = sys.argv[1], sys.argv[2]
ckpt = torch.load(src, map_location="cpu", weights_only=False)
keep = {
    "model_state_dict": ckpt.get("model_state_dict", ckpt),
    "config": ckpt.get("config"),
    # Reset training progress markers so LR schedule warmups cleanly.
    "step": 0,
    "train_seconds": 0.0,
    "smoothed_loss": 0.0,
    "bpt_ema": 0.0,
    "epoch": 0,
}
# Explicitly do NOT copy optimizer_state_dict.
torch.save(keep, dst)
print(f"Stripped -> {dst} (orig {sum(1 for _ in ckpt)} keys, kept {len(keep)})")