VLAdaptorBench / code /scripts /run_oven_pregrasp_batch.py
lsnu's picture
Add iter22 pregrasp repair reruns and updated benchmark code
7f173cd verified
raw
history blame
3.23 kB
from pathlib import Path
import argparse
import json
import pickle
import sys
import numpy as np
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from rr_label_study.oven_study import (
BimanualTakeTrayOutOfOven,
ReplayCache,
_launch_replay_env,
_load_demo,
_pregrasp_progress_and_distance,
_pregrasp_score_and_success,
)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--episode-dir", required=True)
parser.add_argument("--templates-pkl", required=True)
parser.add_argument("--frame-indices", nargs="+", type=int, required=True)
parser.add_argument("--checkpoint-stride", type=int, default=16)
parser.add_argument("--output-dir", required=True)
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
frame_indices = sorted(set(args.frame_indices))
pending_frame_indices = [
frame_index
for frame_index in frame_indices
if not output_dir.joinpath(f"frame_{frame_index:04d}.json").exists()
]
if not pending_frame_indices:
return 0
episode_dir = Path(args.episode_dir)
with Path(args.templates_pkl).open("rb") as handle:
templates = pickle.load(handle)
demo = _load_demo(episode_dir)
env = _launch_replay_env()
try:
task = env.get_task(BimanualTakeTrayOutOfOven)
cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride)
cache.reset()
total = len(pending_frame_indices)
for completed, frame_index in enumerate(pending_frame_indices, start=1):
cache.step_to(frame_index)
state = cache.current_state()
pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance(
np.asarray(state.left_gripper_pose, dtype=np.float64),
np.asarray(state.tray_pose, dtype=np.float64),
templates,
)
p_pre, y_pre = _pregrasp_score_and_success(task, templates)
row = {
"frame_index": int(frame_index),
"pregrasp_progress": float(pregrasp_progress),
"pregrasp_distance": float(pregrasp_distance),
"p_pre": float(p_pre),
"y_pre_raw": float(bool(y_pre)),
"y_pre": float(bool(y_pre)),
}
row_path = output_dir.joinpath(f"frame_{frame_index:04d}.json")
tmp_path = row_path.with_suffix(".json.tmp")
with tmp_path.open("w", encoding="utf-8") as handle:
json.dump(row, handle)
tmp_path.replace(row_path)
if completed == total or completed % 8 == 0:
print(
json.dumps(
{
"done": completed,
"total": total,
"frame_index": int(frame_index),
}
),
flush=True,
)
finally:
env.shutdown()
return 0
if __name__ == "__main__":
raise SystemExit(main())