from __future__ import annotations import argparse from pathlib import Path import torch DEFAULT_SOURCE = Path("/fs/nexus-scratch/psando/nanotts-05-10/gpt2/ckpt_025000.pt") stem = DEFAULT_SOURCE.stem DEFAULT_OUTPUT = Path(f"checkpoints/{stem}_inference.pt") def main() -> None: parser = argparse.ArgumentParser(description="Export a slim nanoTTS inference checkpoint.") parser.add_argument("--source", type=Path, default=DEFAULT_SOURCE) parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) parser.add_argument("--overwrite", action="store_true") args = parser.parse_args() if not args.source.exists(): raise FileNotFoundError(f"Source checkpoint not found: {args.source}") if args.output.exists() and not args.overwrite: raise FileExistsError(f"Output already exists: {args.output}. Pass --overwrite to replace it.") checkpoint = torch.load(args.source, map_location="cpu") inference_checkpoint = { "model": checkpoint["model"], "model_args": checkpoint["model_args"], "iter_num": checkpoint.get("iter_num"), "train_loss": checkpoint.get("train_loss"), "val_loss": checkpoint.get("val_loss"), "source_checkpoint": str(args.source), } args.output.parent.mkdir(parents=True, exist_ok=True) torch.save(inference_checkpoint, args.output) size_mib = args.output.stat().st_size / 1024 / 1024 print(f"Wrote {args.output} ({size_mib:.1f} MiB)") print(f"Keys: {sorted(inference_checkpoint.keys())}") print(f"Val loss: {inference_checkpoint.get('val_loss')}") if __name__ == "__main__": main()