robomme-eval-docker / eval_models.py
pepijn223's picture
pepijn223 HF Staff
Upload eval_models.py
4ef3086 verified
"""Evaluate baseline and MEM models on RoboMME PickHighlight."""
import argparse
import json
import subprocess
import sys
from pathlib import Path
def run_eval(model_path: str, task: str, n_episodes: int, output_dir: Path, job_name: str):
"""Run lerobot-eval for a single model on a RoboMME task."""
out = output_dir / job_name
out.mkdir(parents=True, exist_ok=True)
cmd = [
sys.executable, "-m", "lerobot.scripts.lerobot_eval",
f"--policy.path={model_path}",
"--env.type=robomme",
f"--env.task={task}",
"--env.dataset_split=test",
"--eval.batch_size=1",
f"--eval.n_episodes={n_episodes}",
'--rename_map={"observation.images.image":"observation.images.camera1","observation.images.wrist_image":"observation.images.camera2"}',
"--policy.device=cuda",
"--policy.use_amp=false",
f"--output_dir={out}",
f"--job_name={job_name}",
]
print(f"\n{'='*70}", flush=True)
print(f"Running eval for: {job_name}", flush=True)
print(f"Model: {model_path}", flush=True)
print(f"Command: {' '.join(cmd)}", flush=True)
print(f"{'='*70}\n", flush=True)
result = subprocess.run(cmd, capture_output=True, text=True)
(out / "stdout.log").write_text(result.stdout)
(out / "stderr.log").write_text(result.stderr)
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)
if result.returncode != 0:
print(f"ERROR: Eval failed with code {result.returncode}", file=sys.stderr)
return None
eval_info_path = out / "eval_info.json"
if eval_info_path.exists():
with open(eval_info_path) as f:
info = json.load(f)
return info
return {"stdout": result.stdout[-5000:]}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", default="pepijn223/smolvla_robomme_baseline_pickhighlight_5k_migrated")
parser.add_argument("--mem", default="pepijn223/smolvla_robomme_mem_k4_pickhighlight_5k_migrated")
parser.add_argument("--task", default="PickHighlight")
parser.add_argument("--n_episodes", type=int, default=50)
parser.add_argument("--output_dir", type=Path, default=Path("/app/outputs"))
args = parser.parse_args()
output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
results = {}
baseline_info = run_eval(
model_path=args.baseline,
task=args.task,
n_episodes=args.n_episodes,
output_dir=output_dir,
job_name="baseline",
)
results["baseline"] = baseline_info
mem_info = run_eval(
model_path=args.mem,
task=args.task,
n_episodes=args.n_episodes,
output_dir=output_dir,
job_name="mem_k4",
)
results["mem_k4"] = mem_info
with open(output_dir / "eval_comparison.json", "w") as f:
json.dump(results, f, indent=2)
print(f"\n{'='*70}", flush=True)
print("EVALUATION SUMMARY", flush=True)
print(f"{'='*70}", flush=True)
for name, info in [("Baseline", baseline_info), ("MEM K=4", mem_info)]:
if info and "overall" in info:
overall = info["overall"]
sr = overall.get("pc_success", "N/A")
avg_reward = overall.get("avg_sum_reward", "N/A")
print(f"{name}: success_rate={sr}, avg_reward={avg_reward}")
elif info and "stdout" in info:
print(f"{name}: Raw output saved (no parsed results)")
else:
print(f"{name}: FAILED")
print(f"\nFull results saved to: {output_dir / 'eval_comparison.json'}")
if __name__ == "__main__":
main()