File size: 4,807 Bytes
a597fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
"""
Checkpoint sync for Pi0.5 training on Vast.ai spot instances.

Run this in a second tmux window BEFORE starting training!
Syncs OpenPi checkpoints to HuggingFace Hub to prevent data loss.

Usage:
    tmux new -s sync
    python sync_checkpoints.py
    # Ctrl+B, D to detach
"""

import time
import shutil
from pathlib import Path
from huggingface_hub import HfApi

# ============ CONFIGURATION ============
# Update these paths to match your training setup

# OpenPi checkpoint directory (matches train.py output)
CHECKPOINT_DIR = Path("/root/openpi/checkpoints/pi05_so101/ball_in_cup")

# HuggingFace repo for checkpoints
HF_REPO = "abdul004/pi05_so101_checkpoint"

# Track what we've already pushed
PUSHED = Path("/tmp/pushed_pi05.txt")

# =======================================

api = HfApi()

def create_repo_if_not_exists():
    """Create HF repo if it doesn't exist."""
    try:
        api.create_repo(HF_REPO, repo_type="model", exist_ok=True)
        print(f"✅ Repo {HF_REPO} ready")
    except Exception as e:
        print(f"⚠️ Repo creation: {e}")

def upload_with_retry(ckpt_path: Path, max_retries=5):
    """Upload checkpoint with exponential backoff retry."""
    for attempt in range(max_retries):
        try:
            # OpenPi saves checkpoints in numbered directories
            # Each contains: params/, train_state/, assets/
            api.upload_folder(
                folder_path=str(ckpt_path),
                repo_id=HF_REPO,
                path_in_repo=f"checkpoints/{ckpt_path.name}",
                repo_type="model"
            )
            return True
        except Exception as e:
            wait = 2 ** attempt * 10  # 10s, 20s, 40s, 80s, 160s
            print(f"⚠️ Attempt {attempt+1}/{max_retries} failed for {ckpt_path.name}: {e}")
            if attempt < max_retries - 1:
                print(f"   Retrying in {wait}s...")
                time.sleep(wait)
    return False

def get_checkpoints():
    """Get list of checkpoint directories, sorted by step number."""
    if not CHECKPOINT_DIR.exists():
        return []
    
    ckpts = []
    for d in CHECKPOINT_DIR.iterdir():
        if d.is_dir() and not d.is_symlink():
            # OpenPi uses numeric directory names (1000, 2000, etc.)
            try:
                step = int(d.name)
                ckpts.append((step, d))
            except ValueError:
                continue
    
    return sorted(ckpts, key=lambda x: x[0])

def main():
    print("=" * 60)
    print("Pi0.5 Checkpoint Sync for Spot Instances")
    print("=" * 60)
    print(f"\n📁 Watching: {CHECKPOINT_DIR}")
    print(f"☁️  Pushing to: {HF_REPO}")
    print(f"\n⏰ Checking every 30 seconds...")
    print("   Expected checkpoints at: 1000, 2000, 3000, 4000, 5000")
    print("\n🚀 Start training in another tmux window:")
    print("   XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup")
    print("\n" + "=" * 60)
    
    create_repo_if_not_exists()
    
    while True:
        # Load already pushed checkpoints
        pushed = set()
        if PUSHED.exists():
            pushed = set(PUSHED.read_text().strip().split('\n'))
        
        # Get all checkpoints
        checkpoints = get_checkpoints()
        
        if checkpoints:
            print(f"\n[{time.strftime('%H:%M:%S')}] Found {len(checkpoints)} checkpoints")
        
        for step, ckpt_path in checkpoints:
            if ckpt_path.name in pushed:
                continue
            
            # Check if checkpoint is complete (has params directory)
            params_dir = ckpt_path / "params"
            if not params_dir.exists():
                print(f"   ⏳ {ckpt_path.name}: incomplete (waiting for params/)")
                continue
            
            print(f"   📤 Uploading checkpoint {ckpt_path.name}...")
            
            if upload_with_retry(ckpt_path):
                # Mark as pushed
                with open(PUSHED, 'a') as f:
                    f.write(f"{ckpt_path.name}\n")
                print(f"   ✅ Synced checkpoint {ckpt_path.name} to HF Hub")
            else:
                print(f"   ❌ Failed to sync {ckpt_path.name} - will retry next loop")
        
        # Cleanup old checkpoints (keep last 2 locally to save disk)
        if len(checkpoints) > 2:
            pushed_set = set(PUSHED.read_text().strip().split('\n')) if PUSHED.exists() else set()
            for step, ckpt_path in checkpoints[:-2]:
                if ckpt_path.name in pushed_set and ckpt_path.exists():
                    shutil.rmtree(ckpt_path)
                    print(f"   🗑️ Cleaned up local {ckpt_path.name} (already on HF)")
        
        time.sleep(30)

if __name__ == "__main__":
    main()