import argparse import json import math import os import signal import subprocess import sys import time from pathlib import Path from typing import Dict, List, Optional, Tuple PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from rr_label_study.oven_study import _aggregate_summary, _episode_dirs def _chunk_specs( total_episodes: int, episode_offset: int, max_episodes: Optional[int], num_workers: int, ) -> List[Tuple[int, int]]: remaining = max(0, total_episodes - episode_offset) if max_episodes is not None: remaining = min(remaining, max_episodes) if remaining <= 0: return [] worker_count = min(num_workers, remaining) chunk_size = math.ceil(remaining / worker_count) specs: List[Tuple[int, int]] = [] for worker_index in range(worker_count): start = episode_offset + worker_index * chunk_size count = min(chunk_size, episode_offset + remaining - start) if count > 0: specs.append((start, count)) return specs def _launch_xvfb(display_num: int, log_path: Path) -> subprocess.Popen: log_handle = log_path.open("w", encoding="utf-8") return subprocess.Popen( [ "Xvfb", f":{display_num}", "-screen", "0", "1280x1024x24", "+extension", "GLX", "+render", "-noreset", ], stdout=log_handle, stderr=subprocess.STDOUT, start_new_session=True, ) def _launch_worker( worker_dir: Path, display_num: int, dataset_root: str, episode_offset: int, max_episodes: int, checkpoint_stride: int, template_episode_index: int, max_frames: Optional[int], ) -> Tuple[subprocess.Popen, subprocess.Popen]: worker_dir.mkdir(parents=True, exist_ok=True) xvfb = _launch_xvfb(display_num, worker_dir.joinpath("xvfb.log")) time.sleep(1.0) runtime_dir = Path(f"/tmp/rr_label_study_display_{display_num}") runtime_dir.mkdir(parents=True, exist_ok=True) command = [ sys.executable, str(PROJECT_ROOT.joinpath("scripts", "run_oven_label_study.py")), "--dataset-root", dataset_root, "--result-dir", str(worker_dir), "--episode-offset", str(episode_offset), "--max-episodes", str(max_episodes), "--checkpoint-stride", str(checkpoint_stride), "--template-episode-index", str(template_episode_index), ] if max_frames is not None: command.extend(["--max-frames", str(max_frames)]) env = os.environ.copy() env["DISPLAY"] = f":{display_num}" env["XDG_RUNTIME_DIR"] = str(runtime_dir) worker_log = worker_dir.joinpath("worker.log").open("w", encoding="utf-8") process = subprocess.Popen( command, stdout=worker_log, stderr=subprocess.STDOUT, env=env, cwd=str(PROJECT_ROOT), start_new_session=True, ) return xvfb, process def _stop_process(process: subprocess.Popen) -> None: if process.poll() is not None: return try: os.killpg(process.pid, signal.SIGTERM) except ProcessLookupError: return try: process.wait(timeout=10) except subprocess.TimeoutExpired: try: os.killpg(process.pid, signal.SIGKILL) except ProcessLookupError: pass def _collect_metrics(base_result_dir: Path) -> List[Dict[str, object]]: metrics: List[Dict[str, object]] = [] for metrics_path in sorted(base_result_dir.glob("worker_*/episode*.metrics.json")): with metrics_path.open("r", encoding="utf-8") as handle: metrics.append(json.load(handle)) return metrics def main(argv: Optional[List[str]] = None) -> int: parser = argparse.ArgumentParser() parser.add_argument( "--dataset-root", default="/workspace/data/bimanual_take_tray_out_of_oven_train_128", ) parser.add_argument( "--result-dir", default="/workspace/reveal_retrieve_label_study/results/oven_parallel", ) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--episode-offset", type=int, default=0) parser.add_argument("--max-episodes", type=int) parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--template-episode-index", type=int, default=0) parser.add_argument("--base-display", type=int, default=110) parser.add_argument("--max-frames", type=int) args = parser.parse_args(argv) dataset_root = Path(args.dataset_root) all_episodes = _episode_dirs(dataset_root) chunk_specs = _chunk_specs( total_episodes=len(all_episodes), episode_offset=args.episode_offset, max_episodes=args.max_episodes, num_workers=args.num_workers, ) if not chunk_specs: raise RuntimeError("no episodes selected for parallel run") result_dir = Path(args.result_dir) result_dir.mkdir(parents=True, exist_ok=True) workers: List[Tuple[subprocess.Popen, subprocess.Popen]] = [] worker_meta: List[Dict[str, object]] = [] try: for worker_index, (episode_offset, episode_count) in enumerate(chunk_specs): display_num = args.base_display + worker_index worker_dir = result_dir.joinpath(f"worker_{worker_index:02d}") xvfb, process = _launch_worker( worker_dir=worker_dir, display_num=display_num, dataset_root=args.dataset_root, episode_offset=episode_offset, max_episodes=episode_count, checkpoint_stride=args.checkpoint_stride, template_episode_index=args.template_episode_index, max_frames=args.max_frames, ) workers.append((xvfb, process)) worker_meta.append( { "worker_index": worker_index, "display_num": display_num, "episode_offset": episode_offset, "episode_count": episode_count, } ) for meta, (_, process) in zip(worker_meta, workers): return_code = process.wait() meta["return_code"] = return_code if return_code != 0: worker_index = int(meta["worker_index"]) worker_log = result_dir.joinpath(f"worker_{worker_index:02d}", "worker.log") raise RuntimeError( f"worker {worker_index} failed with code {return_code}; see {worker_log}" ) finally: for xvfb, process in workers: _stop_process(process) _stop_process(xvfb) episode_metrics = _collect_metrics(result_dir) summary = _aggregate_summary(episode_metrics) with result_dir.joinpath("parallel_workers.json").open("w", encoding="utf-8") as handle: json.dump(worker_meta, handle, indent=2) with result_dir.joinpath("parallel_summary.json").open("w", encoding="utf-8") as handle: json.dump(summary, handle, indent=2) print(json.dumps(summary, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())