MRiabov commited on
Commit
a0ad699
·
1 Parent(s): 8223f36

(minor util) strip checkpoint util

Browse files
Files changed (1) hide show
  1. scripts/strip_checkpoint.py +46 -0
scripts/strip_checkpoint.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ import torch
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(
10
+ description="Strip training checkpoint to inference-only weights (FP32)."
11
+ )
12
+ parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
13
+ parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt")
14
+ args = parser.parse_args()
15
+
16
+ in_path = Path(args.inp)
17
+ out_path = Path(args.out)
18
+
19
+ assert in_path.is_file(), f"Input file does not exist: {in_path}"
20
+ out_path.parent.mkdir(parents=True, exist_ok=True)
21
+
22
+ ckpt = torch.load(str(in_path), map_location="cpu")
23
+
24
+ # Primary (project) format: {'step', 'model', 'optim', 'scaler', 'best_f1'}
25
+ if isinstance(ckpt, dict) and "model" in ckpt:
26
+ state_dict = ckpt["model"]
27
+ # Secondary common format: {'state_dict': model.state_dict(), ...}
28
+ elif isinstance(ckpt, dict) and "state_dict" in ckpt:
29
+ state_dict = ckpt["state_dict"]
30
+ else:
31
+ # Fallback: checkpoint is already a pure state_dict
32
+ assert isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()), (
33
+ "Checkpoint is not a recognized format: expected keys 'model' or 'state_dict', "
34
+ "or a pure state_dict (name->Tensor)."
35
+ )
36
+ state_dict = ckpt
37
+
38
+ # Ensure FP32 tensors (no casting to bf16/fp16 per request)
39
+ state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
40
+
41
+ torch.save(state_dict, str(out_path))
42
+ print(f"[strip_checkpoint] Saved weights-only to: {out_path}")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()