#!/usr/bin/env python3 """ Checkpoint sync for Pi0.5 training on Vast.ai spot instances. Run this in a second tmux window BEFORE starting training! Syncs OpenPi checkpoints to HuggingFace Hub to prevent data loss. Usage: tmux new -s sync python sync_checkpoints.py # Ctrl+B, D to detach """ import time import shutil from pathlib import Path from huggingface_hub import HfApi # ============ CONFIGURATION ============ # Update these paths to match your training setup # OpenPi checkpoint directory (matches train.py output) CHECKPOINT_DIR = Path("/root/openpi/checkpoints/pi05_so101/ball_in_cup") # HuggingFace repo for checkpoints HF_REPO = "abdul004/pi05_so101_checkpoint" # Track what we've already pushed PUSHED = Path("/tmp/pushed_pi05.txt") # ======================================= api = HfApi() def create_repo_if_not_exists(): """Create HF repo if it doesn't exist.""" try: api.create_repo(HF_REPO, repo_type="model", exist_ok=True) print(f"✅ Repo {HF_REPO} ready") except Exception as e: print(f"⚠️ Repo creation: {e}") def upload_with_retry(ckpt_path: Path, max_retries=5): """Upload checkpoint with exponential backoff retry.""" for attempt in range(max_retries): try: # OpenPi saves checkpoints in numbered directories # Each contains: params/, train_state/, assets/ api.upload_folder( folder_path=str(ckpt_path), repo_id=HF_REPO, path_in_repo=f"checkpoints/{ckpt_path.name}", repo_type="model" ) return True except Exception as e: wait = 2 ** attempt * 10 # 10s, 20s, 40s, 80s, 160s print(f"⚠️ Attempt {attempt+1}/{max_retries} failed for {ckpt_path.name}: {e}") if attempt < max_retries - 1: print(f" Retrying in {wait}s...") time.sleep(wait) return False def get_checkpoints(): """Get list of checkpoint directories, sorted by step number.""" if not CHECKPOINT_DIR.exists(): return [] ckpts = [] for d in CHECKPOINT_DIR.iterdir(): if d.is_dir() and not d.is_symlink(): # OpenPi uses numeric directory names (1000, 2000, etc.) try: step = int(d.name) ckpts.append((step, d)) except ValueError: continue return sorted(ckpts, key=lambda x: x[0]) def main(): print("=" * 60) print("Pi0.5 Checkpoint Sync for Spot Instances") print("=" * 60) print(f"\n📁 Watching: {CHECKPOINT_DIR}") print(f"☁️ Pushing to: {HF_REPO}") print(f"\n⏰ Checking every 30 seconds...") print(" Expected checkpoints at: 1000, 2000, 3000, 4000, 5000") print("\n🚀 Start training in another tmux window:") print(" XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup") print("\n" + "=" * 60) create_repo_if_not_exists() while True: # Load already pushed checkpoints pushed = set() if PUSHED.exists(): pushed = set(PUSHED.read_text().strip().split('\n')) # Get all checkpoints checkpoints = get_checkpoints() if checkpoints: print(f"\n[{time.strftime('%H:%M:%S')}] Found {len(checkpoints)} checkpoints") for step, ckpt_path in checkpoints: if ckpt_path.name in pushed: continue # Check if checkpoint is complete (has params directory) params_dir = ckpt_path / "params" if not params_dir.exists(): print(f" ⏳ {ckpt_path.name}: incomplete (waiting for params/)") continue print(f" 📤 Uploading checkpoint {ckpt_path.name}...") if upload_with_retry(ckpt_path): # Mark as pushed with open(PUSHED, 'a') as f: f.write(f"{ckpt_path.name}\n") print(f" ✅ Synced checkpoint {ckpt_path.name} to HF Hub") else: print(f" ❌ Failed to sync {ckpt_path.name} - will retry next loop") # Cleanup old checkpoints (keep last 2 locally to save disk) if len(checkpoints) > 2: pushed_set = set(PUSHED.read_text().strip().split('\n')) if PUSHED.exists() else set() for step, ckpt_path in checkpoints[:-2]: if ckpt_path.name in pushed_set and ckpt_path.exists(): shutil.rmtree(ckpt_path) print(f" 🗑️ Cleaned up local {ckpt_path.name} (already on HF)") time.sleep(30) if __name__ == "__main__": main()