File size: 3,735 Bytes
4ef3086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Evaluate baseline and MEM models on RoboMME PickHighlight."""
import argparse
import json
import subprocess
import sys
from pathlib import Path


def run_eval(model_path: str, task: str, n_episodes: int, output_dir: Path, job_name: str):
    """Run lerobot-eval for a single model on a RoboMME task."""
    
    out = output_dir / job_name
    out.mkdir(parents=True, exist_ok=True)
    
    cmd = [
        sys.executable, "-m", "lerobot.scripts.lerobot_eval",
        f"--policy.path={model_path}",
        "--env.type=robomme",
        f"--env.task={task}",
        "--env.dataset_split=test",
        "--eval.batch_size=1",
        f"--eval.n_episodes={n_episodes}",
        '--rename_map={"observation.images.image":"observation.images.camera1","observation.images.wrist_image":"observation.images.camera2"}',
        "--policy.device=cuda",
        "--policy.use_amp=false",
        f"--output_dir={out}",
        f"--job_name={job_name}",
    ]
    
    print(f"\n{'='*70}", flush=True)
    print(f"Running eval for: {job_name}", flush=True)
    print(f"Model: {model_path}", flush=True)
    print(f"Command: {' '.join(cmd)}", flush=True)
    print(f"{'='*70}\n", flush=True)
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    (out / "stdout.log").write_text(result.stdout)
    (out / "stderr.log").write_text(result.stderr)
    
    print(result.stdout)
    if result.stderr:
        print(result.stderr, file=sys.stderr)
    
    if result.returncode != 0:
        print(f"ERROR: Eval failed with code {result.returncode}", file=sys.stderr)
        return None
    
    eval_info_path = out / "eval_info.json"
    if eval_info_path.exists():
        with open(eval_info_path) as f:
            info = json.load(f)
        return info
    
    return {"stdout": result.stdout[-5000:]}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--baseline", default="pepijn223/smolvla_robomme_baseline_pickhighlight_5k_migrated")
    parser.add_argument("--mem", default="pepijn223/smolvla_robomme_mem_k4_pickhighlight_5k_migrated")
    parser.add_argument("--task", default="PickHighlight")
    parser.add_argument("--n_episodes", type=int, default=50)
    parser.add_argument("--output_dir", type=Path, default=Path("/app/outputs"))
    args = parser.parse_args()
    
    output_dir = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)
    
    results = {}
    
    baseline_info = run_eval(
        model_path=args.baseline,
        task=args.task,
        n_episodes=args.n_episodes,
        output_dir=output_dir,
        job_name="baseline",
    )
    results["baseline"] = baseline_info
    
    mem_info = run_eval(
        model_path=args.mem,
        task=args.task,
        n_episodes=args.n_episodes,
        output_dir=output_dir,
        job_name="mem_k4",
    )
    results["mem_k4"] = mem_info
    
    with open(output_dir / "eval_comparison.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\n{'='*70}", flush=True)
    print("EVALUATION SUMMARY", flush=True)
    print(f"{'='*70}", flush=True)
    
    for name, info in [("Baseline", baseline_info), ("MEM K=4", mem_info)]:
        if info and "overall" in info:
            overall = info["overall"]
            sr = overall.get("pc_success", "N/A")
            avg_reward = overall.get("avg_sum_reward", "N/A")
            print(f"{name}: success_rate={sr}, avg_reward={avg_reward}")
        elif info and "stdout" in info:
            print(f"{name}: Raw output saved (no parsed results)")
        else:
            print(f"{name}: FAILED")
    
    print(f"\nFull results saved to: {output_dir / 'eval_comparison.json'}")


if __name__ == "__main__":
    main()