#!/usr/bin/env python3 """Train, evaluate, and visualize REINFORCE vs Naive agent on LabEnv. Produces a 2x2 figure: Top-left: Training reward curve (smoothed) Top-right: Training success-rate curve (smoothed) Bottom-left: Final comparison bar chart (reward, success%, partial%) Bottom-right: Single-episode trace showing the RL agent's actions """ from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import numpy as np import matplotlib.pyplot as plt import matplotlib.ticker as mticker from lab_env.env import ( LabEnv, INITIAL_BUDGET, ACTION_SETUP_START, ACTION_SETUP_END, ACTION_RUN_ASSAY, ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE, ACTION_WAIT, ACTION_FINISH, PRESETS, ) from agents.naive_agent import NaiveAgent from agents.rl_agent import ReinforceAgent def smooth(values: list[float], window: int = 50) -> np.ndarray: if len(values) < window: return np.array(values) kernel = np.ones(window) / window return np.convolve(values, kernel, mode="valid") def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: obs, info = env.reset(seed=seed) agent.reset() total_reward = 0.0 steps = 0 while True: action = agent.select_action(obs) obs, reward, terminated, truncated, info = env.step(action) total_reward += reward steps += 1 if terminated or truncated: break return { "reward": total_reward, "success": info["best_result"] == "success", "partial": info["best_result"] == "partial", "minutes": info["elapsed_minutes"], "cost": INITIAL_BUDGET - info["remaining_budget"], "steps": steps, } def trace_rl_episode(env: LabEnv, agent: ReinforceAgent, seed: int) -> list[dict]: """Run one episode and return a step-by-step trace for visualization.""" obs, info = env.reset(seed=seed) agent.reset() trace: list[dict] = [] for trial in range(agent.max_trials): if agent._inventory_low(obs): for act in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE): obs, rew, done, trunc, info = env.step(act) trace.append({"action": "order", "label": "Order", "result": "", "reward": rew, "minutes": info["elapsed_minutes"]}) if done or trunc: return trace preset = agent._select_preset(obs, deterministic=True) p = PRESETS[preset] label = f"Setup {p['temp']}C/{p['cycles']}cy/{p['ratio'][:4]}" obs, rew, done, trunc, info = env.step(ACTION_SETUP_START + preset) trace.append({"action": "setup", "label": label, "result": "", "reward": rew, "minutes": info["elapsed_minutes"]}) if done or trunc: return trace obs, rew, done, trunc, info = env.step(ACTION_RUN_ASSAY) trace.append({"action": "run", "label": "Run assay", "result": info["last_result"], "reward": rew, "minutes": info["elapsed_minutes"]}) if done or trunc: return trace if info.get("best_result") == "success": obs, rew, _, _, info = env.step(ACTION_FINISH) trace.append({"action": "finish", "label": "Finish", "result": "success", "reward": rew, "minutes": info["elapsed_minutes"]}) return trace if not (done or trunc): obs, rew, _, _, info = env.step(ACTION_FINISH) trace.append({"action": "finish", "label": "Finish", "result": info["best_result"], "reward": rew, "minutes": info["elapsed_minutes"]}) return trace def main() -> None: parser = argparse.ArgumentParser(description="Visualize training & evaluation") parser.add_argument("--train-episodes", type=int, default=2000) parser.add_argument("--eval-episodes", type=int, default=200) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--save", type=str, default="", help="Save figure to path instead of showing") args = parser.parse_args() env = LabEnv() rl_agent = ReinforceAgent(max_trials=4) # ---- Training with metric collection ---- print(f"Training REINFORCE for {args.train_episodes} episodes...") train_rewards: list[float] = [] train_successes: list[float] = [] for ep in range(1, args.train_episodes + 1): result = rl_agent.run_episode(env, seed=args.seed + ep, train=True) train_rewards.append(result["reward"]) train_successes.append(float(result["success"])) if ep % 500 == 0: print(f" ...episode {ep}/{args.train_episodes}") # ---- Evaluation ---- print(f"Evaluating both agents for {args.eval_episodes} episodes...") eval_seed = 999_999 naive_agent = NaiveAgent(num_trials=3, seed=0) rl_eval = [rl_agent.run_episode(env, seed=eval_seed + i, train=False) for i in range(args.eval_episodes)] naive_eval = [run_episode_naive(env, naive_agent, seed=eval_seed + i) for i in range(args.eval_episodes)] # ---- Episode trace ---- trace = trace_rl_episode(env, rl_agent, seed=12345) env.close() # ---- Aggregate ---- def agg(results): n = len(results) return { "reward": sum(r["reward"] for r in results) / n, "success": sum(r["success"] for r in results) / n, "partial": sum(r["partial"] for r in results) / n, "minutes": sum(r["minutes"] for r in results) / n, } rl_stats = agg(rl_eval) naive_stats = agg(naive_eval) # ================================================================== # Plot # ================================================================== fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig.suptitle("SimLab — Lab Automation RL Environment", fontsize=16, fontweight="bold") # -- Top-left: reward curve -- ax = axes[0, 0] smoothed_r = smooth(train_rewards, window=50) ax.plot(range(len(smoothed_r)), smoothed_r, color="#2563eb", linewidth=1.5) ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5) ax.set_title("Training Reward (smoothed, window=50)") ax.set_xlabel("Episode") ax.set_ylabel("Total Episode Reward") ax.grid(True, alpha=0.3) # -- Top-right: success rate curve -- ax = axes[0, 1] smoothed_s = smooth(train_successes, window=100) * 100 ax.plot(range(len(smoothed_s)), smoothed_s, color="#16a34a", linewidth=1.5) ax.set_title("Training Success Rate (smoothed, window=100)") ax.set_xlabel("Episode") ax.set_ylabel("Success %") ax.yaxis.set_major_formatter(mticker.PercentFormatter()) ax.set_ylim(0, 100) ax.grid(True, alpha=0.3) # -- Bottom-left: comparison bars -- ax = axes[1, 0] metrics = ["Avg Reward", "Success %", "Partial %", "Avg Time (min)"] rl_vals = [rl_stats["reward"], rl_stats["success"] * 100, rl_stats["partial"] * 100, rl_stats["minutes"]] naive_vals = [naive_stats["reward"], naive_stats["success"] * 100, naive_stats["partial"] * 100, naive_stats["minutes"]] x = np.arange(len(metrics)) w = 0.35 bars_rl = ax.bar(x - w / 2, rl_vals, w, label="REINFORCE", color="#2563eb", edgecolor="white") bars_naive = ax.bar(x + w / 2, naive_vals, w, label="Naive", color="#f97316", edgecolor="white") ax.set_xticks(x) ax.set_xticklabels(metrics, fontsize=9) ax.set_title("Evaluation Comparison") ax.legend() ax.grid(True, alpha=0.3, axis="y") for bar_group in (bars_rl, bars_naive): for bar in bar_group: h = bar.get_height() ax.annotate(f"{h:.1f}", xy=(bar.get_x() + bar.get_width() / 2, h), xytext=(0, 4), textcoords="offset points", ha="center", va="bottom", fontsize=8) # -- Bottom-right: episode trace -- ax = axes[1, 1] if trace: y_labels = [] colors = [] for i, step in enumerate(trace): y_labels.append(step["label"]) if step["result"] == "success": colors.append("#16a34a") elif step["result"] == "partial": colors.append("#eab308") elif step["result"] == "fail": colors.append("#dc2626") else: colors.append("#6b7280") y_pos = np.arange(len(trace)) minutes = [s["minutes"] for s in trace] ax.barh(y_pos, minutes, color=colors, edgecolor="white", height=0.6) ax.set_yticks(y_pos) ax.set_yticklabels(y_labels, fontsize=8) ax.invert_yaxis() ax.set_xlabel("Elapsed Minutes") ax.set_title("Single Episode Trace (RL Agent)") for i, step in enumerate(trace): if step["result"] in ("success", "partial", "fail"): ax.annotate(step["result"], xy=(minutes[i], i), xytext=(5, 0), textcoords="offset points", va="center", fontsize=8, fontweight="bold", color=colors[i]) else: ax.text(0.5, 0.5, "No trace data", ha="center", va="center", transform=ax.transAxes) ax.set_title("Single Episode Trace (RL Agent)") plt.tight_layout(rect=[0, 0, 1, 0.95]) if args.save: fig.savefig(args.save, dpi=150, bbox_inches="tight") print(f"Saved to {args.save}") else: plt.show() # Print summary print() print(f" REINFORCE: reward={rl_stats['reward']:.1f} success={rl_stats['success']:.1%} time={rl_stats['minutes']:.0f}m") print(f" Naive: reward={naive_stats['reward']:.1f} success={naive_stats['success']:.1%} time={naive_stats['minutes']:.0f}m") if __name__ == "__main__": main()