| """ |
| Save TD checkpoints to HuggingFace. |
| |
| Usage: |
| python3 save_checkpoint.py # saves latest checkpoint |
| python3 save_checkpoint.py after_mimo # saves specific checkpoint |
| python3 save_checkpoint.py all # saves all checkpoints |
| """ |
|
|
| import sys |
| import os |
| from pathlib import Path |
| from huggingface_hub import HfApi, login |
|
|
| TOKEN = os.environ.get("HF_TOKEN", "") |
| REPO = "td-builder/td-qwen3vl-v1" |
| CKPT_DIR = Path("td_fuse_checkpoints") |
|
|
| def upload_checkpoint(api, name): |
| ckpt_path = CKPT_DIR / name |
| if not ckpt_path.exists(): |
| print(f" ERROR: {ckpt_path} doesn't exist") |
| return False |
|
|
| safetensors = ckpt_path / "model.safetensors" |
| if not safetensors.exists(): |
| print(f" ERROR: No model.safetensors in {ckpt_path}") |
| return False |
|
|
| size_gb = sum(f.stat().st_size for f in ckpt_path.rglob("*") if f.is_file()) / 1e9 |
| print(f" Uploading {name} ({size_gb:.1f} GB) to {REPO}/{name}/...") |
|
|
| api.upload_folder( |
| folder_path=str(ckpt_path), |
| path_in_repo=name, |
| repo_id=REPO, |
| commit_message=f"Checkpoint: {name}", |
| ) |
| print(f" Done: {name}") |
| return True |
|
|
|
|
| def main(): |
| login(token=TOKEN) |
| api = HfApi() |
|
|
| target = sys.argv[1] if len(sys.argv) > 1 else None |
|
|
| if not CKPT_DIR.exists(): |
| print(f"No checkpoint directory found at {CKPT_DIR}") |
| sys.exit(1) |
|
|
| |
| checkpoints = sorted([d.name for d in CKPT_DIR.iterdir() if d.is_dir() and (d / "model.safetensors").exists()]) |
|
|
| if not checkpoints: |
| print("No checkpoints found (need model.safetensors in each folder)") |
| sys.exit(1) |
|
|
| print(f"Available checkpoints: {', '.join(checkpoints)}") |
|
|
| if target == "all": |
| |
| for name in checkpoints: |
| upload_checkpoint(api, name) |
| elif target: |
| |
| if target not in checkpoints: |
| print(f"Checkpoint '{target}' not found. Available: {', '.join(checkpoints)}") |
| sys.exit(1) |
| upload_checkpoint(api, target) |
| else: |
| |
| latest = max(checkpoints, key=lambda n: (CKPT_DIR / n).stat().st_mtime) |
| print(f"Uploading latest: {latest}") |
| upload_checkpoint(api, latest) |
|
|
| |
| perm_cache = CKPT_DIR / "perm_cache" |
| if perm_cache.exists() and any(perm_cache.glob("*.npz")): |
| try: |
| size_kb = sum(f.stat().st_size for f in perm_cache.rglob("*") if f.is_file()) / 1024 |
| print(f" Uploading perm_cache ({size_kb:.0f} KB) to {REPO}/perm_cache/...") |
| api.upload_folder( |
| folder_path=str(perm_cache), |
| path_in_repo="perm_cache", |
| repo_id=REPO, |
| commit_message="Permutation cache (saves 12 min Sinkhorn)", |
| ) |
| print(f" Done: perm_cache") |
| except Exception as e: |
| print(f" WARNING: perm_cache upload failed ({e})") |
|
|
| print("\nAll done! Checkpoints saved to HuggingFace.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|