""" Save TD checkpoints to HuggingFace. Usage: python3 save_checkpoint.py # saves latest checkpoint python3 save_checkpoint.py after_mimo # saves specific checkpoint python3 save_checkpoint.py all # saves all checkpoints """ import sys import os from pathlib import Path from huggingface_hub import HfApi, login TOKEN = os.environ.get("HF_TOKEN", "") REPO = "td-builder/td-qwen3vl-v1" CKPT_DIR = Path("td_fuse_checkpoints") def upload_checkpoint(api, name): ckpt_path = CKPT_DIR / name if not ckpt_path.exists(): print(f" ERROR: {ckpt_path} doesn't exist") return False safetensors = ckpt_path / "model.safetensors" if not safetensors.exists(): print(f" ERROR: No model.safetensors in {ckpt_path}") return False size_gb = sum(f.stat().st_size for f in ckpt_path.rglob("*") if f.is_file()) / 1e9 print(f" Uploading {name} ({size_gb:.1f} GB) to {REPO}/{name}/...") api.upload_folder( folder_path=str(ckpt_path), path_in_repo=name, repo_id=REPO, commit_message=f"Checkpoint: {name}", ) print(f" Done: {name}") return True def main(): login(token=TOKEN) api = HfApi() target = sys.argv[1] if len(sys.argv) > 1 else None if not CKPT_DIR.exists(): print(f"No checkpoint directory found at {CKPT_DIR}") sys.exit(1) # List available checkpoints checkpoints = sorted([d.name for d in CKPT_DIR.iterdir() if d.is_dir() and (d / "model.safetensors").exists()]) if not checkpoints: print("No checkpoints found (need model.safetensors in each folder)") sys.exit(1) print(f"Available checkpoints: {', '.join(checkpoints)}") if target == "all": # Upload everything for name in checkpoints: upload_checkpoint(api, name) elif target: # Upload specific one if target not in checkpoints: print(f"Checkpoint '{target}' not found. Available: {', '.join(checkpoints)}") sys.exit(1) upload_checkpoint(api, target) else: # Upload the latest (most recently modified) latest = max(checkpoints, key=lambda n: (CKPT_DIR / n).stat().st_mtime) print(f"Uploading latest: {latest}") upload_checkpoint(api, latest) # Also upload perm_cache if it exists (tiny files, saves 12 min per re-run) perm_cache = CKPT_DIR / "perm_cache" if perm_cache.exists() and any(perm_cache.glob("*.npz")): try: size_kb = sum(f.stat().st_size for f in perm_cache.rglob("*") if f.is_file()) / 1024 print(f" Uploading perm_cache ({size_kb:.0f} KB) to {REPO}/perm_cache/...") api.upload_folder( folder_path=str(perm_cache), path_in_repo="perm_cache", repo_id=REPO, commit_message="Permutation cache (saves 12 min Sinkhorn)", ) print(f" Done: perm_cache") except Exception as e: print(f" WARNING: perm_cache upload failed ({e})") print("\nAll done! Checkpoints saved to HuggingFace.") if __name__ == "__main__": main()