PhysioJEPA / scripts /runpod_launch.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Launch N RunPod A40 pods, deploy the codebase, kick off training.
Usage:
python scripts/runpod_launch.py --models A B C F --gpu A40 \
--image runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04
For each model letter:
1. create pod
2. wait for SSH
3. rsync repo + .env via scp
4. run pod_bootstrap.sh on the pod (in tmux/nohup)
5. record pod id + run name in runs/launch_manifest.json
Polling/log retrieval is left to scripts/runpod_status.py.
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
RUNPOD_API_KEY = os.environ["RUNPOD_API_KEY"]
GPU_IDS = {
"A40": "NVIDIA A40",
"A6000": "NVIDIA RTX A6000",
"A100": "NVIDIA A100-SXM4-80GB",
"H100": "NVIDIA H100 80GB HBM3",
}
DEFAULT_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
def runpodctl(args: list[str], capture: bool = True) -> str:
env = {**os.environ, "RUNPOD_API_KEY": RUNPOD_API_KEY}
res = subprocess.run(
["runpodctl", *args], env=env, capture_output=capture, text=True
)
if res.returncode != 0:
raise RuntimeError(f"runpodctl {' '.join(args)} failed: {res.stderr}\n{res.stdout}")
return res.stdout
def create_pod(name: str, gpu_id: str, image: str, container_disk: int = 50,
volume_gb: int = 100) -> dict:
out = runpodctl([
"pod", "create",
"--name", name,
"--gpu-id", gpu_id,
"--gpu-count", "1",
"--image", image,
"--cloud-type", "COMMUNITY",
"--container-disk-in-gb", str(container_disk),
"--volume-in-gb", str(volume_gb),
"--volume-mount-path", "/workspace",
"--ports", "22/tcp",
"--ssh",
])
pod = json.loads(out)
return pod
def wait_for_ssh(pod_id: str, timeout: int = 600) -> tuple[str, int]:
start = time.time()
last_err = ""
while time.time() - start < timeout:
try:
info = json.loads(runpodctl(["ssh", "info", pod_id]))
host = info.get("publicIp") or info.get("ip")
port = info.get("port") or info.get("sshPort")
if host and port:
return host, int(port)
except Exception as e:
last_err = str(e)
time.sleep(15)
raise TimeoutError(f"SSH not ready for {pod_id}: {last_err}")
def ssh(host: str, port: int, cmd: str, user: str = "root", timeout: int = 60) -> str:
res = subprocess.run([
"ssh", "-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=15",
"-p", str(port),
f"{user}@{host}", cmd,
], capture_output=True, text=True, timeout=timeout)
if res.returncode != 0:
raise RuntimeError(f"ssh {host}:{port} {cmd!r} failed: {res.stderr}")
return res.stdout
def scp(host: str, port: int, local_path: Path, remote_path: str, user: str = "root") -> None:
cmd = ["scp", "-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-P", str(port)]
if local_path.is_dir():
cmd.append("-r")
cmd.extend([str(local_path), f"{user}@{host}:{remote_path}"])
res = subprocess.run(cmd, capture_output=True, text=True, timeout=900)
if res.returncode != 0:
raise RuntimeError(f"scp {local_path} -> {host}:{remote_path} failed: {res.stderr}")
def deploy_and_launch(host: str, port: int, model: str, run_name: str, repo_root: Path) -> None:
# build a tarball excluding bulky dirs
with tempfile.TemporaryDirectory() as td:
tar = Path(td) / "physiojepa.tar.gz"
excludes = [".venv", ".git", "__pycache__", "runs", "cache", "docs/figures",
"docs/paperes"]
excl_args = []
for e in excludes:
excl_args.extend(["--exclude", e])
subprocess.run(
["tar", "-czf", str(tar), *excl_args, "-C", str(repo_root.parent),
repo_root.name],
check=True,
)
scp(host, port, tar, "/workspace/physiojepa.tar.gz")
# also send .env
env_file = repo_root / ".env"
scp(host, port, env_file, "/workspace/.env")
ssh(host, port, "set -e; cd /workspace && rm -rf physiojepa && "
"tar -xzf physiojepa.tar.gz && rm physiojepa.tar.gz")
# background the bootstrap with nohup so SSH disconnect doesn't kill it
bootstrap = (
f"set -e; mkdir -p /workspace/runs; "
f"cd /workspace/physiojepa && chmod +x scripts/pod_bootstrap.sh && "
f"nohup bash scripts/pod_bootstrap.sh {model} {run_name} "
f"> /workspace/runs/{run_name}.bootstrap.log 2>&1 &"
f" disown; echo started; sleep 1"
)
ssh(host, port, bootstrap)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--models", nargs="+", default=["A", "B", "C", "F"])
ap.add_argument("--gpu", default="A40", choices=list(GPU_IDS.keys()))
ap.add_argument("--image", default=DEFAULT_IMAGE)
ap.add_argument("--repo_root", default=str(Path(__file__).resolve().parents[1]))
ap.add_argument("--manifest", default="runs/launch_manifest.json")
args = ap.parse_args()
repo_root = Path(args.repo_root)
Path(args.manifest).parent.mkdir(parents=True, exist_ok=True)
gpu_id = GPU_IDS[args.gpu]
manifest = []
for model in args.models:
run_name = f"e2_{model}_a40"
pod_name = f"pj-{model.lower()}-{int(time.time()) % 100000:05d}"
print(f"[launch] creating pod {pod_name} (model={model}, gpu={args.gpu})")
pod = create_pod(pod_name, gpu_id, args.image)
pod_id = pod.get("id") or pod.get("podId")
print(f"[launch] pod_id={pod_id}, waiting for SSH...")
try:
host, port = wait_for_ssh(pod_id)
except TimeoutError as e:
print(f"[launch] WARN: {e}; deleting pod and continuing")
try:
runpodctl(["pod", "delete", pod_id])
except Exception:
pass
continue
print(f"[launch] SSH up @ {host}:{port}, deploying code")
deploy_and_launch(host, port, model, run_name, repo_root)
manifest.append({"pod_id": pod_id, "pod_name": pod_name, "host": host,
"port": port, "model": model, "run_name": run_name,
"started_at": time.time()})
Path(args.manifest).write_text(json.dumps(manifest, indent=2))
print(f"[launch] {model} kicked off; manifest -> {args.manifest}")
print(f"[launch] all done. manifest:\n{Path(args.manifest).read_text()}")
if __name__ == "__main__":
main()