Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| DEFAULT_SOURCE = Path("/fs/nexus-scratch/psando/nanotts-05-10/gpt2/ckpt_025000.pt") | |
| stem = DEFAULT_SOURCE.stem | |
| DEFAULT_OUTPUT = Path(f"checkpoints/{stem}_inference.pt") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Export a slim nanoTTS inference checkpoint.") | |
| parser.add_argument("--source", type=Path, default=DEFAULT_SOURCE) | |
| parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) | |
| parser.add_argument("--overwrite", action="store_true") | |
| args = parser.parse_args() | |
| if not args.source.exists(): | |
| raise FileNotFoundError(f"Source checkpoint not found: {args.source}") | |
| if args.output.exists() and not args.overwrite: | |
| raise FileExistsError(f"Output already exists: {args.output}. Pass --overwrite to replace it.") | |
| checkpoint = torch.load(args.source, map_location="cpu") | |
| inference_checkpoint = { | |
| "model": checkpoint["model"], | |
| "model_args": checkpoint["model_args"], | |
| "iter_num": checkpoint.get("iter_num"), | |
| "train_loss": checkpoint.get("train_loss"), | |
| "val_loss": checkpoint.get("val_loss"), | |
| "source_checkpoint": str(args.source), | |
| } | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save(inference_checkpoint, args.output) | |
| size_mib = args.output.stat().st_size / 1024 / 1024 | |
| print(f"Wrote {args.output} ({size_mib:.1f} MiB)") | |
| print(f"Keys: {sorted(inference_checkpoint.keys())}") | |
| print(f"Val loss: {inference_checkpoint.get('val_loss')}") | |
| if __name__ == "__main__": | |
| main() | |