File size: 4,807 Bytes
a597fbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python3
"""
Checkpoint sync for Pi0.5 training on Vast.ai spot instances.
Run this in a second tmux window BEFORE starting training!
Syncs OpenPi checkpoints to HuggingFace Hub to prevent data loss.
Usage:
tmux new -s sync
python sync_checkpoints.py
# Ctrl+B, D to detach
"""
import time
import shutil
from pathlib import Path
from huggingface_hub import HfApi
# ============ CONFIGURATION ============
# Update these paths to match your training setup
# OpenPi checkpoint directory (matches train.py output)
CHECKPOINT_DIR = Path("/root/openpi/checkpoints/pi05_so101/ball_in_cup")
# HuggingFace repo for checkpoints
HF_REPO = "abdul004/pi05_so101_checkpoint"
# Track what we've already pushed
PUSHED = Path("/tmp/pushed_pi05.txt")
# =======================================
api = HfApi()
def create_repo_if_not_exists():
"""Create HF repo if it doesn't exist."""
try:
api.create_repo(HF_REPO, repo_type="model", exist_ok=True)
print(f"✅ Repo {HF_REPO} ready")
except Exception as e:
print(f"⚠️ Repo creation: {e}")
def upload_with_retry(ckpt_path: Path, max_retries=5):
"""Upload checkpoint with exponential backoff retry."""
for attempt in range(max_retries):
try:
# OpenPi saves checkpoints in numbered directories
# Each contains: params/, train_state/, assets/
api.upload_folder(
folder_path=str(ckpt_path),
repo_id=HF_REPO,
path_in_repo=f"checkpoints/{ckpt_path.name}",
repo_type="model"
)
return True
except Exception as e:
wait = 2 ** attempt * 10 # 10s, 20s, 40s, 80s, 160s
print(f"⚠️ Attempt {attempt+1}/{max_retries} failed for {ckpt_path.name}: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {wait}s...")
time.sleep(wait)
return False
def get_checkpoints():
"""Get list of checkpoint directories, sorted by step number."""
if not CHECKPOINT_DIR.exists():
return []
ckpts = []
for d in CHECKPOINT_DIR.iterdir():
if d.is_dir() and not d.is_symlink():
# OpenPi uses numeric directory names (1000, 2000, etc.)
try:
step = int(d.name)
ckpts.append((step, d))
except ValueError:
continue
return sorted(ckpts, key=lambda x: x[0])
def main():
print("=" * 60)
print("Pi0.5 Checkpoint Sync for Spot Instances")
print("=" * 60)
print(f"\n📁 Watching: {CHECKPOINT_DIR}")
print(f"☁️ Pushing to: {HF_REPO}")
print(f"\n⏰ Checking every 30 seconds...")
print(" Expected checkpoints at: 1000, 2000, 3000, 4000, 5000")
print("\n🚀 Start training in another tmux window:")
print(" XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup")
print("\n" + "=" * 60)
create_repo_if_not_exists()
while True:
# Load already pushed checkpoints
pushed = set()
if PUSHED.exists():
pushed = set(PUSHED.read_text().strip().split('\n'))
# Get all checkpoints
checkpoints = get_checkpoints()
if checkpoints:
print(f"\n[{time.strftime('%H:%M:%S')}] Found {len(checkpoints)} checkpoints")
for step, ckpt_path in checkpoints:
if ckpt_path.name in pushed:
continue
# Check if checkpoint is complete (has params directory)
params_dir = ckpt_path / "params"
if not params_dir.exists():
print(f" ⏳ {ckpt_path.name}: incomplete (waiting for params/)")
continue
print(f" 📤 Uploading checkpoint {ckpt_path.name}...")
if upload_with_retry(ckpt_path):
# Mark as pushed
with open(PUSHED, 'a') as f:
f.write(f"{ckpt_path.name}\n")
print(f" ✅ Synced checkpoint {ckpt_path.name} to HF Hub")
else:
print(f" ❌ Failed to sync {ckpt_path.name} - will retry next loop")
# Cleanup old checkpoints (keep last 2 locally to save disk)
if len(checkpoints) > 2:
pushed_set = set(PUSHED.read_text().strip().split('\n')) if PUSHED.exists() else set()
for step, ckpt_path in checkpoints[:-2]:
if ckpt_path.name in pushed_set and ckpt_path.exists():
shutil.rmtree(ckpt_path)
print(f" 🗑️ Cleaned up local {ckpt_path.name} (already on HF)")
time.sleep(30)
if __name__ == "__main__":
main()
|