pi0_so101_config / sync_checkpoints.py
abdul004's picture
Upload sync_checkpoints.py with huggingface_hub
a597fbe verified
#!/usr/bin/env python3
"""
Checkpoint sync for Pi0.5 training on Vast.ai spot instances.
Run this in a second tmux window BEFORE starting training!
Syncs OpenPi checkpoints to HuggingFace Hub to prevent data loss.
Usage:
tmux new -s sync
python sync_checkpoints.py
# Ctrl+B, D to detach
"""
import time
import shutil
from pathlib import Path
from huggingface_hub import HfApi
# ============ CONFIGURATION ============
# Update these paths to match your training setup
# OpenPi checkpoint directory (matches train.py output)
CHECKPOINT_DIR = Path("/root/openpi/checkpoints/pi05_so101/ball_in_cup")
# HuggingFace repo for checkpoints
HF_REPO = "abdul004/pi05_so101_checkpoint"
# Track what we've already pushed
PUSHED = Path("/tmp/pushed_pi05.txt")
# =======================================
api = HfApi()
def create_repo_if_not_exists():
"""Create HF repo if it doesn't exist."""
try:
api.create_repo(HF_REPO, repo_type="model", exist_ok=True)
print(f"โœ… Repo {HF_REPO} ready")
except Exception as e:
print(f"โš ๏ธ Repo creation: {e}")
def upload_with_retry(ckpt_path: Path, max_retries=5):
"""Upload checkpoint with exponential backoff retry."""
for attempt in range(max_retries):
try:
# OpenPi saves checkpoints in numbered directories
# Each contains: params/, train_state/, assets/
api.upload_folder(
folder_path=str(ckpt_path),
repo_id=HF_REPO,
path_in_repo=f"checkpoints/{ckpt_path.name}",
repo_type="model"
)
return True
except Exception as e:
wait = 2 ** attempt * 10 # 10s, 20s, 40s, 80s, 160s
print(f"โš ๏ธ Attempt {attempt+1}/{max_retries} failed for {ckpt_path.name}: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {wait}s...")
time.sleep(wait)
return False
def get_checkpoints():
"""Get list of checkpoint directories, sorted by step number."""
if not CHECKPOINT_DIR.exists():
return []
ckpts = []
for d in CHECKPOINT_DIR.iterdir():
if d.is_dir() and not d.is_symlink():
# OpenPi uses numeric directory names (1000, 2000, etc.)
try:
step = int(d.name)
ckpts.append((step, d))
except ValueError:
continue
return sorted(ckpts, key=lambda x: x[0])
def main():
print("=" * 60)
print("Pi0.5 Checkpoint Sync for Spot Instances")
print("=" * 60)
print(f"\n๐Ÿ“ Watching: {CHECKPOINT_DIR}")
print(f"โ˜๏ธ Pushing to: {HF_REPO}")
print(f"\nโฐ Checking every 30 seconds...")
print(" Expected checkpoints at: 1000, 2000, 3000, 4000, 5000")
print("\n๐Ÿš€ Start training in another tmux window:")
print(" XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup")
print("\n" + "=" * 60)
create_repo_if_not_exists()
while True:
# Load already pushed checkpoints
pushed = set()
if PUSHED.exists():
pushed = set(PUSHED.read_text().strip().split('\n'))
# Get all checkpoints
checkpoints = get_checkpoints()
if checkpoints:
print(f"\n[{time.strftime('%H:%M:%S')}] Found {len(checkpoints)} checkpoints")
for step, ckpt_path in checkpoints:
if ckpt_path.name in pushed:
continue
# Check if checkpoint is complete (has params directory)
params_dir = ckpt_path / "params"
if not params_dir.exists():
print(f" โณ {ckpt_path.name}: incomplete (waiting for params/)")
continue
print(f" ๐Ÿ“ค Uploading checkpoint {ckpt_path.name}...")
if upload_with_retry(ckpt_path):
# Mark as pushed
with open(PUSHED, 'a') as f:
f.write(f"{ckpt_path.name}\n")
print(f" โœ… Synced checkpoint {ckpt_path.name} to HF Hub")
else:
print(f" โŒ Failed to sync {ckpt_path.name} - will retry next loop")
# Cleanup old checkpoints (keep last 2 locally to save disk)
if len(checkpoints) > 2:
pushed_set = set(PUSHED.read_text().strip().split('\n')) if PUSHED.exists() else set()
for step, ckpt_path in checkpoints[:-2]:
if ckpt_path.name in pushed_set and ckpt_path.exists():
shutil.rmtree(ckpt_path)
print(f" ๐Ÿ—‘๏ธ Cleaned up local {ckpt_path.name} (already on HF)")
time.sleep(30)
if __name__ == "__main__":
main()