Spaces:
Build error
Build error
| """Evaluate baseline and MEM models on RoboMME PickHighlight.""" | |
| import argparse | |
| import json | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| def run_eval(model_path: str, task: str, n_episodes: int, output_dir: Path, job_name: str): | |
| """Run lerobot-eval for a single model on a RoboMME task.""" | |
| out = output_dir / job_name | |
| out.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| sys.executable, "-m", "lerobot.scripts.lerobot_eval", | |
| f"--policy.path={model_path}", | |
| "--env.type=robomme", | |
| f"--env.task={task}", | |
| "--env.dataset_split=test", | |
| "--eval.batch_size=1", | |
| f"--eval.n_episodes={n_episodes}", | |
| '--rename_map={"observation.images.image":"observation.images.camera1","observation.images.wrist_image":"observation.images.camera2"}', | |
| "--policy.device=cuda", | |
| "--policy.use_amp=false", | |
| f"--output_dir={out}", | |
| f"--job_name={job_name}", | |
| ] | |
| print(f"\n{'='*70}", flush=True) | |
| print(f"Running eval for: {job_name}", flush=True) | |
| print(f"Model: {model_path}", flush=True) | |
| print(f"Command: {' '.join(cmd)}", flush=True) | |
| print(f"{'='*70}\n", flush=True) | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| (out / "stdout.log").write_text(result.stdout) | |
| (out / "stderr.log").write_text(result.stderr) | |
| print(result.stdout) | |
| if result.stderr: | |
| print(result.stderr, file=sys.stderr) | |
| if result.returncode != 0: | |
| print(f"ERROR: Eval failed with code {result.returncode}", file=sys.stderr) | |
| return None | |
| eval_info_path = out / "eval_info.json" | |
| if eval_info_path.exists(): | |
| with open(eval_info_path) as f: | |
| info = json.load(f) | |
| return info | |
| return {"stdout": result.stdout[-5000:]} | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--baseline", default="pepijn223/smolvla_robomme_baseline_pickhighlight_5k_migrated") | |
| parser.add_argument("--mem", default="pepijn223/smolvla_robomme_mem_k4_pickhighlight_5k_migrated") | |
| parser.add_argument("--task", default="PickHighlight") | |
| parser.add_argument("--n_episodes", type=int, default=50) | |
| parser.add_argument("--output_dir", type=Path, default=Path("/app/outputs")) | |
| args = parser.parse_args() | |
| output_dir = args.output_dir | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| results = {} | |
| baseline_info = run_eval( | |
| model_path=args.baseline, | |
| task=args.task, | |
| n_episodes=args.n_episodes, | |
| output_dir=output_dir, | |
| job_name="baseline", | |
| ) | |
| results["baseline"] = baseline_info | |
| mem_info = run_eval( | |
| model_path=args.mem, | |
| task=args.task, | |
| n_episodes=args.n_episodes, | |
| output_dir=output_dir, | |
| job_name="mem_k4", | |
| ) | |
| results["mem_k4"] = mem_info | |
| with open(output_dir / "eval_comparison.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\n{'='*70}", flush=True) | |
| print("EVALUATION SUMMARY", flush=True) | |
| print(f"{'='*70}", flush=True) | |
| for name, info in [("Baseline", baseline_info), ("MEM K=4", mem_info)]: | |
| if info and "overall" in info: | |
| overall = info["overall"] | |
| sr = overall.get("pc_success", "N/A") | |
| avg_reward = overall.get("avg_sum_reward", "N/A") | |
| print(f"{name}: success_rate={sr}, avg_reward={avg_reward}") | |
| elif info and "stdout" in info: | |
| print(f"{name}: Raw output saved (no parsed results)") | |
| else: | |
| print(f"{name}: FAILED") | |
| print(f"\nFull results saved to: {output_dir / 'eval_comparison.json'}") | |
| if __name__ == "__main__": | |
| main() | |