| """ |
| Export a slim, release-ready weight file from a training checkpoint. |
| |
| The training checkpoints saved by `save_checkpoint_pinn` bundle optimizer state, |
| scaler state, and loss history. For public release you only want the model weights. |
| |
| Usage: |
| python export_weights.py --ckpt path/to/train_ckpt.pt --out pytorch_model.bin |
| |
| Optionally also write safetensors (recommended for distribution): |
| python export_weights.py --ckpt path/to/train_ckpt.pt --out pytorch_model.bin --safetensors |
| """ |
| import argparse |
| import torch |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", required=True, help="training checkpoint (.pt) path") |
| ap.add_argument("--out", default="pytorch_model.bin", help="output weight file") |
| ap.add_argument("--safetensors", action="store_true", |
| help="also write a .safetensors file next to --out") |
| args = ap.parse_args() |
|
|
| ckpt = torch.load(args.ckpt, map_location="cpu") |
|
|
| |
| if isinstance(ckpt, dict) and "model" in ckpt: |
| state = ckpt["model"] |
| else: |
| state = ckpt |
|
|
| |
| state = {k: v for k, v in state.items() if torch.is_tensor(v)} |
|
|
| torch.save(state, args.out) |
| print(f"[saved] {args.out} ({len(state)} tensors)") |
|
|
| if args.safetensors: |
| from safetensors.torch import save_file |
| st_path = args.out.rsplit(".", 1)[0] + ".safetensors" |
| |
| clean = {k: v.contiguous().clone() for k, v in state.items()} |
| save_file(clean, st_path) |
| print(f"[saved] {st_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|