biosim / scripts /run_naive_baseline.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
#!/usr/bin/env python3
"""Run the naive baseline agent on LabEnv and report aggregate metrics."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv, INITIAL_BUDGET
from agents.naive_agent import NaiveAgent
def run_episode(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
obs, info = env.reset(seed=seed)
agent.reset()
total_reward = 0.0
steps = 0
while True:
action = agent.select_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
if terminated or truncated:
break
return {
"reward": total_reward,
"success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": INITIAL_BUDGET - info["remaining_budget"],
"steps": steps,
}
def main() -> None:
parser = argparse.ArgumentParser(description="Naive baseline evaluation")
parser.add_argument("--episodes", type=int, default=200)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
env = LabEnv()
agent = NaiveAgent(num_trials=3, seed=args.seed)
results = [run_episode(env, agent, seed=args.seed + i) for i in range(args.episodes)]
env.close()
rewards = [r["reward"] for r in results]
successes = sum(r["success"] for r in results)
partials = sum(r["partial"] for r in results)
minutes = [r["minutes"] for r in results]
costs = [r["cost"] for r in results]
steps = [r["steps"] for r in results]
n = len(results)
print("=" * 50)
print(" Naive Baseline Results")
print("=" * 50)
print(f" Episodes: {n}")
print(f" Avg reward: {sum(rewards) / n:8.2f}")
print(f" Success rate: {successes / n:8.2%}")
print(f" Partial rate: {partials / n:8.2%}")
print(f" Avg time (min): {sum(minutes) / n:8.1f}")
print(f" Avg cost ($): {sum(costs) / n:8.1f}")
print(f" Avg steps: {sum(steps) / n:8.1f}")
print("=" * 50)
if __name__ == "__main__":
main()