|
|
|
|
|
""" |
|
|
Checkpoint sync for Pi0.5 training on Vast.ai spot instances. |
|
|
|
|
|
Run this in a second tmux window BEFORE starting training! |
|
|
Syncs OpenPi checkpoints to HuggingFace Hub to prevent data loss. |
|
|
|
|
|
Usage: |
|
|
tmux new -s sync |
|
|
python sync_checkpoints.py |
|
|
# Ctrl+B, D to detach |
|
|
""" |
|
|
|
|
|
import time |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHECKPOINT_DIR = Path("/root/openpi/checkpoints/pi05_so101/ball_in_cup") |
|
|
|
|
|
|
|
|
HF_REPO = "abdul004/pi05_so101_checkpoint" |
|
|
|
|
|
|
|
|
PUSHED = Path("/tmp/pushed_pi05.txt") |
|
|
|
|
|
|
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
def create_repo_if_not_exists(): |
|
|
"""Create HF repo if it doesn't exist.""" |
|
|
try: |
|
|
api.create_repo(HF_REPO, repo_type="model", exist_ok=True) |
|
|
print(f"โ
Repo {HF_REPO} ready") |
|
|
except Exception as e: |
|
|
print(f"โ ๏ธ Repo creation: {e}") |
|
|
|
|
|
def upload_with_retry(ckpt_path: Path, max_retries=5): |
|
|
"""Upload checkpoint with exponential backoff retry.""" |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
|
|
|
api.upload_folder( |
|
|
folder_path=str(ckpt_path), |
|
|
repo_id=HF_REPO, |
|
|
path_in_repo=f"checkpoints/{ckpt_path.name}", |
|
|
repo_type="model" |
|
|
) |
|
|
return True |
|
|
except Exception as e: |
|
|
wait = 2 ** attempt * 10 |
|
|
print(f"โ ๏ธ Attempt {attempt+1}/{max_retries} failed for {ckpt_path.name}: {e}") |
|
|
if attempt < max_retries - 1: |
|
|
print(f" Retrying in {wait}s...") |
|
|
time.sleep(wait) |
|
|
return False |
|
|
|
|
|
def get_checkpoints(): |
|
|
"""Get list of checkpoint directories, sorted by step number.""" |
|
|
if not CHECKPOINT_DIR.exists(): |
|
|
return [] |
|
|
|
|
|
ckpts = [] |
|
|
for d in CHECKPOINT_DIR.iterdir(): |
|
|
if d.is_dir() and not d.is_symlink(): |
|
|
|
|
|
try: |
|
|
step = int(d.name) |
|
|
ckpts.append((step, d)) |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
return sorted(ckpts, key=lambda x: x[0]) |
|
|
|
|
|
def main(): |
|
|
print("=" * 60) |
|
|
print("Pi0.5 Checkpoint Sync for Spot Instances") |
|
|
print("=" * 60) |
|
|
print(f"\n๐ Watching: {CHECKPOINT_DIR}") |
|
|
print(f"โ๏ธ Pushing to: {HF_REPO}") |
|
|
print(f"\nโฐ Checking every 30 seconds...") |
|
|
print(" Expected checkpoints at: 1000, 2000, 3000, 4000, 5000") |
|
|
print("\n๐ Start training in another tmux window:") |
|
|
print(" XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup") |
|
|
print("\n" + "=" * 60) |
|
|
|
|
|
create_repo_if_not_exists() |
|
|
|
|
|
while True: |
|
|
|
|
|
pushed = set() |
|
|
if PUSHED.exists(): |
|
|
pushed = set(PUSHED.read_text().strip().split('\n')) |
|
|
|
|
|
|
|
|
checkpoints = get_checkpoints() |
|
|
|
|
|
if checkpoints: |
|
|
print(f"\n[{time.strftime('%H:%M:%S')}] Found {len(checkpoints)} checkpoints") |
|
|
|
|
|
for step, ckpt_path in checkpoints: |
|
|
if ckpt_path.name in pushed: |
|
|
continue |
|
|
|
|
|
|
|
|
params_dir = ckpt_path / "params" |
|
|
if not params_dir.exists(): |
|
|
print(f" โณ {ckpt_path.name}: incomplete (waiting for params/)") |
|
|
continue |
|
|
|
|
|
print(f" ๐ค Uploading checkpoint {ckpt_path.name}...") |
|
|
|
|
|
if upload_with_retry(ckpt_path): |
|
|
|
|
|
with open(PUSHED, 'a') as f: |
|
|
f.write(f"{ckpt_path.name}\n") |
|
|
print(f" โ
Synced checkpoint {ckpt_path.name} to HF Hub") |
|
|
else: |
|
|
print(f" โ Failed to sync {ckpt_path.name} - will retry next loop") |
|
|
|
|
|
|
|
|
if len(checkpoints) > 2: |
|
|
pushed_set = set(PUSHED.read_text().strip().split('\n')) if PUSHED.exists() else set() |
|
|
for step, ckpt_path in checkpoints[:-2]: |
|
|
if ckpt_path.name in pushed_set and ckpt_path.exists(): |
|
|
shutil.rmtree(ckpt_path) |
|
|
print(f" ๐๏ธ Cleaned up local {ckpt_path.name} (already on HF)") |
|
|
|
|
|
time.sleep(30) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|