File size: 3,194 Bytes
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Save TD checkpoints to HuggingFace.

Usage:
    python3 save_checkpoint.py                  # saves latest checkpoint
    python3 save_checkpoint.py after_mimo       # saves specific checkpoint
    python3 save_checkpoint.py all              # saves all checkpoints
"""

import sys
import os
from pathlib import Path
from huggingface_hub import HfApi, login

TOKEN = os.environ.get("HF_TOKEN", "")
REPO = "td-builder/td-qwen3vl-v1"
CKPT_DIR = Path("td_fuse_checkpoints")

def upload_checkpoint(api, name):
    ckpt_path = CKPT_DIR / name
    if not ckpt_path.exists():
        print(f"  ERROR: {ckpt_path} doesn't exist")
        return False

    safetensors = ckpt_path / "model.safetensors"
    if not safetensors.exists():
        print(f"  ERROR: No model.safetensors in {ckpt_path}")
        return False

    size_gb = sum(f.stat().st_size for f in ckpt_path.rglob("*") if f.is_file()) / 1e9
    print(f"  Uploading {name} ({size_gb:.1f} GB) to {REPO}/{name}/...")

    api.upload_folder(
        folder_path=str(ckpt_path),
        path_in_repo=name,
        repo_id=REPO,
        commit_message=f"Checkpoint: {name}",
    )
    print(f"  Done: {name}")
    return True


def main():
    login(token=TOKEN)
    api = HfApi()

    target = sys.argv[1] if len(sys.argv) > 1 else None

    if not CKPT_DIR.exists():
        print(f"No checkpoint directory found at {CKPT_DIR}")
        sys.exit(1)

    # List available checkpoints
    checkpoints = sorted([d.name for d in CKPT_DIR.iterdir() if d.is_dir() and (d / "model.safetensors").exists()])

    if not checkpoints:
        print("No checkpoints found (need model.safetensors in each folder)")
        sys.exit(1)

    print(f"Available checkpoints: {', '.join(checkpoints)}")

    if target == "all":
        # Upload everything
        for name in checkpoints:
            upload_checkpoint(api, name)
    elif target:
        # Upload specific one
        if target not in checkpoints:
            print(f"Checkpoint '{target}' not found. Available: {', '.join(checkpoints)}")
            sys.exit(1)
        upload_checkpoint(api, target)
    else:
        # Upload the latest (most recently modified)
        latest = max(checkpoints, key=lambda n: (CKPT_DIR / n).stat().st_mtime)
        print(f"Uploading latest: {latest}")
        upload_checkpoint(api, latest)

    # Also upload perm_cache if it exists (tiny files, saves 12 min per re-run)
    perm_cache = CKPT_DIR / "perm_cache"
    if perm_cache.exists() and any(perm_cache.glob("*.npz")):
        try:
            size_kb = sum(f.stat().st_size for f in perm_cache.rglob("*") if f.is_file()) / 1024
            print(f"  Uploading perm_cache ({size_kb:.0f} KB) to {REPO}/perm_cache/...")
            api.upload_folder(
                folder_path=str(perm_cache),
                path_in_repo="perm_cache",
                repo_id=REPO,
                commit_message="Permutation cache (saves 12 min Sinkhorn)",
            )
            print(f"  Done: perm_cache")
        except Exception as e:
            print(f"  WARNING: perm_cache upload failed ({e})")

    print("\nAll done! Checkpoints saved to HuggingFace.")


if __name__ == "__main__":
    main()