""" Export a slim, release-ready weight file from a training checkpoint. The training checkpoints saved by `save_checkpoint_pinn` bundle optimizer state, scaler state, and loss history. For public release you only want the model weights. Usage: python export_weights.py --ckpt path/to/train_ckpt.pt --out pytorch_model.bin Optionally also write safetensors (recommended for distribution): python export_weights.py --ckpt path/to/train_ckpt.pt --out pytorch_model.bin --safetensors """ import argparse import torch def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", required=True, help="training checkpoint (.pt) path") ap.add_argument("--out", default="pytorch_model.bin", help="output weight file") ap.add_argument("--safetensors", action="store_true", help="also write a .safetensors file next to --out") args = ap.parse_args() ckpt = torch.load(args.ckpt, map_location="cpu") # Accept either a full training checkpoint dict or a raw state_dict. if isinstance(ckpt, dict) and "model" in ckpt: state = ckpt["model"] else: state = ckpt # Keep only tensors; drop any non-tensor entries. state = {k: v for k, v in state.items() if torch.is_tensor(v)} torch.save(state, args.out) print(f"[saved] {args.out} ({len(state)} tensors)") if args.safetensors: from safetensors.torch import save_file st_path = args.out.rsplit(".", 1)[0] + ".safetensors" # safetensors requires contiguous tensors and no shared storage. clean = {k: v.contiguous().clone() for k, v in state.items()} save_file(clean, st_path) print(f"[saved] {st_path}") if __name__ == "__main__": main()