pinn / export_weights.py
SeongvinJu's picture
Upload folder using huggingface_hub
5cb4913 verified
Raw
History Blame Contribute Delete
1.74 kB
"""
Export a slim, release-ready weight file from a training checkpoint.
The training checkpoints saved by `save_checkpoint_pinn` bundle optimizer state,
scaler state, and loss history. For public release you only want the model weights.
Usage:
python export_weights.py --ckpt path/to/train_ckpt.pt --out pytorch_model.bin
Optionally also write safetensors (recommended for distribution):
python export_weights.py --ckpt path/to/train_ckpt.pt --out pytorch_model.bin --safetensors
"""
import argparse
import torch
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True, help="training checkpoint (.pt) path")
ap.add_argument("--out", default="pytorch_model.bin", help="output weight file")
ap.add_argument("--safetensors", action="store_true",
help="also write a .safetensors file next to --out")
args = ap.parse_args()
ckpt = torch.load(args.ckpt, map_location="cpu")
# Accept either a full training checkpoint dict or a raw state_dict.
if isinstance(ckpt, dict) and "model" in ckpt:
state = ckpt["model"]
else:
state = ckpt
# Keep only tensors; drop any non-tensor entries.
state = {k: v for k, v in state.items() if torch.is_tensor(v)}
torch.save(state, args.out)
print(f"[saved] {args.out} ({len(state)} tensors)")
if args.safetensors:
from safetensors.torch import save_file
st_path = args.out.rsplit(".", 1)[0] + ".safetensors"
# safetensors requires contiguous tensors and no shared storage.
clean = {k: v.contiguous().clone() for k, v in state.items()}
save_file(clean, st_path)
print(f"[saved] {st_path}")
if __name__ == "__main__":
main()