td-toolkit / save_checkpoint.py
td-builder's picture
Fixed code: vocab mismatch fix for cross-arch merging (Llama/Falcon)
5d61448 verified
"""
Save TD checkpoints to HuggingFace.
Usage:
python3 save_checkpoint.py # saves latest checkpoint
python3 save_checkpoint.py after_mimo # saves specific checkpoint
python3 save_checkpoint.py all # saves all checkpoints
"""
import sys
import os
from pathlib import Path
from huggingface_hub import HfApi, login
TOKEN = os.environ.get("HF_TOKEN", "")
REPO = "td-builder/td-qwen3vl-v1"
CKPT_DIR = Path("td_fuse_checkpoints")
def upload_checkpoint(api, name):
ckpt_path = CKPT_DIR / name
if not ckpt_path.exists():
print(f" ERROR: {ckpt_path} doesn't exist")
return False
safetensors = ckpt_path / "model.safetensors"
if not safetensors.exists():
print(f" ERROR: No model.safetensors in {ckpt_path}")
return False
size_gb = sum(f.stat().st_size for f in ckpt_path.rglob("*") if f.is_file()) / 1e9
print(f" Uploading {name} ({size_gb:.1f} GB) to {REPO}/{name}/...")
api.upload_folder(
folder_path=str(ckpt_path),
path_in_repo=name,
repo_id=REPO,
commit_message=f"Checkpoint: {name}",
)
print(f" Done: {name}")
return True
def main():
login(token=TOKEN)
api = HfApi()
target = sys.argv[1] if len(sys.argv) > 1 else None
if not CKPT_DIR.exists():
print(f"No checkpoint directory found at {CKPT_DIR}")
sys.exit(1)
# List available checkpoints
checkpoints = sorted([d.name for d in CKPT_DIR.iterdir() if d.is_dir() and (d / "model.safetensors").exists()])
if not checkpoints:
print("No checkpoints found (need model.safetensors in each folder)")
sys.exit(1)
print(f"Available checkpoints: {', '.join(checkpoints)}")
if target == "all":
# Upload everything
for name in checkpoints:
upload_checkpoint(api, name)
elif target:
# Upload specific one
if target not in checkpoints:
print(f"Checkpoint '{target}' not found. Available: {', '.join(checkpoints)}")
sys.exit(1)
upload_checkpoint(api, target)
else:
# Upload the latest (most recently modified)
latest = max(checkpoints, key=lambda n: (CKPT_DIR / n).stat().st_mtime)
print(f"Uploading latest: {latest}")
upload_checkpoint(api, latest)
# Also upload perm_cache if it exists (tiny files, saves 12 min per re-run)
perm_cache = CKPT_DIR / "perm_cache"
if perm_cache.exists() and any(perm_cache.glob("*.npz")):
try:
size_kb = sum(f.stat().st_size for f in perm_cache.rglob("*") if f.is_file()) / 1024
print(f" Uploading perm_cache ({size_kb:.0f} KB) to {REPO}/perm_cache/...")
api.upload_folder(
folder_path=str(perm_cache),
path_in_repo="perm_cache",
repo_id=REPO,
commit_message="Permutation cache (saves 12 min Sinkhorn)",
)
print(f" Done: perm_cache")
except Exception as e:
print(f" WARNING: perm_cache upload failed ({e})")
print("\nAll done! Checkpoints saved to HuggingFace.")
if __name__ == "__main__":
main()