VLAdaptorBench / code /scripts /launch_parallel_oven_label_study.py
lsnu's picture
Add files using upload-large-folder tool
150d02a verified
raw
history blame
7.35 kB
import argparse
import json
import math
import os
import signal
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from rr_label_study.oven_study import _aggregate_summary, _episode_dirs
def _chunk_specs(
total_episodes: int,
episode_offset: int,
max_episodes: Optional[int],
num_workers: int,
) -> List[Tuple[int, int]]:
remaining = max(0, total_episodes - episode_offset)
if max_episodes is not None:
remaining = min(remaining, max_episodes)
if remaining <= 0:
return []
worker_count = min(num_workers, remaining)
chunk_size = math.ceil(remaining / worker_count)
specs: List[Tuple[int, int]] = []
for worker_index in range(worker_count):
start = episode_offset + worker_index * chunk_size
count = min(chunk_size, episode_offset + remaining - start)
if count > 0:
specs.append((start, count))
return specs
def _launch_xvfb(display_num: int, log_path: Path) -> subprocess.Popen:
log_handle = log_path.open("w", encoding="utf-8")
return subprocess.Popen(
[
"Xvfb",
f":{display_num}",
"-screen",
"0",
"1280x1024x24",
"+extension",
"GLX",
"+render",
"-noreset",
],
stdout=log_handle,
stderr=subprocess.STDOUT,
start_new_session=True,
)
def _launch_worker(
worker_dir: Path,
display_num: int,
dataset_root: str,
episode_offset: int,
max_episodes: int,
checkpoint_stride: int,
template_episode_index: int,
max_frames: Optional[int],
) -> Tuple[subprocess.Popen, subprocess.Popen]:
worker_dir.mkdir(parents=True, exist_ok=True)
xvfb = _launch_xvfb(display_num, worker_dir.joinpath("xvfb.log"))
time.sleep(1.0)
runtime_dir = Path(f"/tmp/rr_label_study_display_{display_num}")
runtime_dir.mkdir(parents=True, exist_ok=True)
command = [
sys.executable,
str(PROJECT_ROOT.joinpath("scripts", "run_oven_label_study.py")),
"--dataset-root",
dataset_root,
"--result-dir",
str(worker_dir),
"--episode-offset",
str(episode_offset),
"--max-episodes",
str(max_episodes),
"--checkpoint-stride",
str(checkpoint_stride),
"--template-episode-index",
str(template_episode_index),
]
if max_frames is not None:
command.extend(["--max-frames", str(max_frames)])
env = os.environ.copy()
env["DISPLAY"] = f":{display_num}"
env["XDG_RUNTIME_DIR"] = str(runtime_dir)
worker_log = worker_dir.joinpath("worker.log").open("w", encoding="utf-8")
process = subprocess.Popen(
command,
stdout=worker_log,
stderr=subprocess.STDOUT,
env=env,
cwd=str(PROJECT_ROOT),
start_new_session=True,
)
return xvfb, process
def _stop_process(process: subprocess.Popen) -> None:
if process.poll() is not None:
return
try:
os.killpg(process.pid, signal.SIGTERM)
except ProcessLookupError:
return
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
try:
os.killpg(process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
def _collect_metrics(base_result_dir: Path) -> List[Dict[str, object]]:
metrics: List[Dict[str, object]] = []
for metrics_path in sorted(base_result_dir.glob("worker_*/episode*.metrics.json")):
with metrics_path.open("r", encoding="utf-8") as handle:
metrics.append(json.load(handle))
return metrics
def main(argv: Optional[List[str]] = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset-root",
default="/workspace/data/bimanual_take_tray_out_of_oven_train_128",
)
parser.add_argument(
"--result-dir",
default="/workspace/reveal_retrieve_label_study/results/oven_parallel",
)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--episode-offset", type=int, default=0)
parser.add_argument("--max-episodes", type=int)
parser.add_argument("--checkpoint-stride", type=int, default=16)
parser.add_argument("--template-episode-index", type=int, default=0)
parser.add_argument("--base-display", type=int, default=110)
parser.add_argument("--max-frames", type=int)
args = parser.parse_args(argv)
dataset_root = Path(args.dataset_root)
all_episodes = _episode_dirs(dataset_root)
chunk_specs = _chunk_specs(
total_episodes=len(all_episodes),
episode_offset=args.episode_offset,
max_episodes=args.max_episodes,
num_workers=args.num_workers,
)
if not chunk_specs:
raise RuntimeError("no episodes selected for parallel run")
result_dir = Path(args.result_dir)
result_dir.mkdir(parents=True, exist_ok=True)
workers: List[Tuple[subprocess.Popen, subprocess.Popen]] = []
worker_meta: List[Dict[str, object]] = []
try:
for worker_index, (episode_offset, episode_count) in enumerate(chunk_specs):
display_num = args.base_display + worker_index
worker_dir = result_dir.joinpath(f"worker_{worker_index:02d}")
xvfb, process = _launch_worker(
worker_dir=worker_dir,
display_num=display_num,
dataset_root=args.dataset_root,
episode_offset=episode_offset,
max_episodes=episode_count,
checkpoint_stride=args.checkpoint_stride,
template_episode_index=args.template_episode_index,
max_frames=args.max_frames,
)
workers.append((xvfb, process))
worker_meta.append(
{
"worker_index": worker_index,
"display_num": display_num,
"episode_offset": episode_offset,
"episode_count": episode_count,
}
)
for meta, (_, process) in zip(worker_meta, workers):
return_code = process.wait()
meta["return_code"] = return_code
if return_code != 0:
worker_index = int(meta["worker_index"])
worker_log = result_dir.joinpath(f"worker_{worker_index:02d}", "worker.log")
raise RuntimeError(
f"worker {worker_index} failed with code {return_code}; see {worker_log}"
)
finally:
for xvfb, process in workers:
_stop_process(process)
_stop_process(xvfb)
episode_metrics = _collect_metrics(result_dir)
summary = _aggregate_summary(episode_metrics)
with result_dir.joinpath("parallel_workers.json").open("w", encoding="utf-8") as handle:
json.dump(worker_meta, handle, indent=2)
with result_dir.joinpath("parallel_summary.json").open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
print(json.dumps(summary, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())