| |
| """Run the naive baseline agent on LabEnv and report aggregate metrics.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from lab_env.env import LabEnv, INITIAL_BUDGET |
| from agents.naive_agent import NaiveAgent |
|
|
|
|
| def run_episode(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: |
| obs, info = env.reset(seed=seed) |
| agent.reset() |
|
|
| total_reward = 0.0 |
| steps = 0 |
|
|
| while True: |
| action = agent.select_action(obs) |
| obs, reward, terminated, truncated, info = env.step(action) |
| total_reward += reward |
| steps += 1 |
| if terminated or truncated: |
| break |
|
|
| return { |
| "reward": total_reward, |
| "success": info["best_result"] == "success", |
| "partial": info["best_result"] == "partial", |
| "minutes": info["elapsed_minutes"], |
| "cost": INITIAL_BUDGET - info["remaining_budget"], |
| "steps": steps, |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Naive baseline evaluation") |
| parser.add_argument("--episodes", type=int, default=200) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| env = LabEnv() |
| agent = NaiveAgent(num_trials=3, seed=args.seed) |
|
|
| results = [run_episode(env, agent, seed=args.seed + i) for i in range(args.episodes)] |
| env.close() |
|
|
| rewards = [r["reward"] for r in results] |
| successes = sum(r["success"] for r in results) |
| partials = sum(r["partial"] for r in results) |
| minutes = [r["minutes"] for r in results] |
| costs = [r["cost"] for r in results] |
| steps = [r["steps"] for r in results] |
| n = len(results) |
|
|
| print("=" * 50) |
| print(" Naive Baseline Results") |
| print("=" * 50) |
| print(f" Episodes: {n}") |
| print(f" Avg reward: {sum(rewards) / n:8.2f}") |
| print(f" Success rate: {successes / n:8.2%}") |
| print(f" Partial rate: {partials / n:8.2%}") |
| print(f" Avg time (min): {sum(minutes) / n:8.1f}") |
| print(f" Avg cost ($): {sum(costs) / n:8.1f}") |
| print(f" Avg steps: {sum(steps) / n:8.1f}") |
| print("=" * 50) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|