feather-runtime / overlay /scripts /strip_optimizer_state.py
Jackoatmon's picture
Update Feather training runtime image
0f7408a verified
"""Strip optimizer_state_dict from a checkpoint, keeping only model weights
and config metadata.
Reason: resuming training.py's standard path restores the optimizer state,
which (in our 6GB / Muon-compile / bf16 setup) reproducibly produces a
NaN/>100-loss on the first forward after load. Reloading model weights
only and letting the optimizer initialize fresh sidesteps the issue.
Output checkpoint also clears `step`, `train_seconds`, `epoch` so the LR
schedule and warmup restart from zero — useful when we want to fine-tune
the trained weights at a new schedule length.
"""
import sys, torch
src, dst = sys.argv[1], sys.argv[2]
ckpt = torch.load(src, map_location="cpu", weights_only=False)
keep = {
"model_state_dict": ckpt.get("model_state_dict", ckpt),
"config": ckpt.get("config"),
# Reset training progress markers so LR schedule warmups cleanly.
"step": 0,
"train_seconds": 0.0,
"smoothed_loss": 0.0,
"bpt_ema": 0.0,
"epoch": 0,
}
# Explicitly do NOT copy optimizer_state_dict.
torch.save(keep, dst)
print(f"Stripped -> {dst} (orig {sum(1 for _ in ckpt)} keys, kept {len(keep)})")