| |
| """Train, evaluate, and visualize REINFORCE vs Naive agent on LabEnv. |
| |
| Produces a 2x2 figure: |
| Top-left: Training reward curve (smoothed) |
| Top-right: Training success-rate curve (smoothed) |
| Bottom-left: Final comparison bar chart (reward, success%, partial%) |
| Bottom-right: Single-episode trace showing the RL agent's actions |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib.ticker as mticker |
|
|
| from lab_env.env import ( |
| LabEnv, |
| INITIAL_BUDGET, |
| ACTION_SETUP_START, |
| ACTION_SETUP_END, |
| ACTION_RUN_ASSAY, |
| ACTION_ORDER_TIPS, |
| ACTION_ORDER_BUFFER, |
| ACTION_ORDER_POLYMERASE, |
| ACTION_WAIT, |
| ACTION_FINISH, |
| PRESETS, |
| ) |
| from agents.naive_agent import NaiveAgent |
| from agents.rl_agent import ReinforceAgent |
|
|
|
|
| def smooth(values: list[float], window: int = 50) -> np.ndarray: |
| if len(values) < window: |
| return np.array(values) |
| kernel = np.ones(window) / window |
| return np.convolve(values, kernel, mode="valid") |
|
|
|
|
| def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: |
| obs, info = env.reset(seed=seed) |
| agent.reset() |
| total_reward = 0.0 |
| steps = 0 |
| while True: |
| action = agent.select_action(obs) |
| obs, reward, terminated, truncated, info = env.step(action) |
| total_reward += reward |
| steps += 1 |
| if terminated or truncated: |
| break |
| return { |
| "reward": total_reward, |
| "success": info["best_result"] == "success", |
| "partial": info["best_result"] == "partial", |
| "minutes": info["elapsed_minutes"], |
| "cost": INITIAL_BUDGET - info["remaining_budget"], |
| "steps": steps, |
| } |
|
|
|
|
| def trace_rl_episode(env: LabEnv, agent: ReinforceAgent, seed: int) -> list[dict]: |
| """Run one episode and return a step-by-step trace for visualization.""" |
| obs, info = env.reset(seed=seed) |
| agent.reset() |
| trace: list[dict] = [] |
|
|
| for trial in range(agent.max_trials): |
| if agent._inventory_low(obs): |
| for act in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE): |
| obs, rew, done, trunc, info = env.step(act) |
| trace.append({"action": "order", "label": "Order", "result": "", "reward": rew, "minutes": info["elapsed_minutes"]}) |
| if done or trunc: |
| return trace |
|
|
| preset = agent._select_preset(obs, deterministic=True) |
| p = PRESETS[preset] |
| label = f"Setup {p['temp']}C/{p['cycles']}cy/{p['ratio'][:4]}" |
|
|
| obs, rew, done, trunc, info = env.step(ACTION_SETUP_START + preset) |
| trace.append({"action": "setup", "label": label, "result": "", "reward": rew, "minutes": info["elapsed_minutes"]}) |
| if done or trunc: |
| return trace |
|
|
| obs, rew, done, trunc, info = env.step(ACTION_RUN_ASSAY) |
| trace.append({"action": "run", "label": "Run assay", "result": info["last_result"], "reward": rew, "minutes": info["elapsed_minutes"]}) |
| if done or trunc: |
| return trace |
|
|
| if info.get("best_result") == "success": |
| obs, rew, _, _, info = env.step(ACTION_FINISH) |
| trace.append({"action": "finish", "label": "Finish", "result": "success", "reward": rew, "minutes": info["elapsed_minutes"]}) |
| return trace |
|
|
| if not (done or trunc): |
| obs, rew, _, _, info = env.step(ACTION_FINISH) |
| trace.append({"action": "finish", "label": "Finish", "result": info["best_result"], "reward": rew, "minutes": info["elapsed_minutes"]}) |
|
|
| return trace |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Visualize training & evaluation") |
| parser.add_argument("--train-episodes", type=int, default=2000) |
| parser.add_argument("--eval-episodes", type=int, default=200) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--save", type=str, default="", help="Save figure to path instead of showing") |
| args = parser.parse_args() |
|
|
| env = LabEnv() |
| rl_agent = ReinforceAgent(max_trials=4) |
|
|
| |
| print(f"Training REINFORCE for {args.train_episodes} episodes...") |
| train_rewards: list[float] = [] |
| train_successes: list[float] = [] |
|
|
| for ep in range(1, args.train_episodes + 1): |
| result = rl_agent.run_episode(env, seed=args.seed + ep, train=True) |
| train_rewards.append(result["reward"]) |
| train_successes.append(float(result["success"])) |
| if ep % 500 == 0: |
| print(f" ...episode {ep}/{args.train_episodes}") |
|
|
| |
| print(f"Evaluating both agents for {args.eval_episodes} episodes...") |
| eval_seed = 999_999 |
| naive_agent = NaiveAgent(num_trials=3, seed=0) |
|
|
| rl_eval = [rl_agent.run_episode(env, seed=eval_seed + i, train=False) for i in range(args.eval_episodes)] |
| naive_eval = [run_episode_naive(env, naive_agent, seed=eval_seed + i) for i in range(args.eval_episodes)] |
|
|
| |
| trace = trace_rl_episode(env, rl_agent, seed=12345) |
|
|
| env.close() |
|
|
| |
| def agg(results): |
| n = len(results) |
| return { |
| "reward": sum(r["reward"] for r in results) / n, |
| "success": sum(r["success"] for r in results) / n, |
| "partial": sum(r["partial"] for r in results) / n, |
| "minutes": sum(r["minutes"] for r in results) / n, |
| } |
|
|
| rl_stats = agg(rl_eval) |
| naive_stats = agg(naive_eval) |
|
|
| |
| |
| |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
| fig.suptitle("SimLab — Lab Automation RL Environment", fontsize=16, fontweight="bold") |
|
|
| |
| ax = axes[0, 0] |
| smoothed_r = smooth(train_rewards, window=50) |
| ax.plot(range(len(smoothed_r)), smoothed_r, color="#2563eb", linewidth=1.5) |
| ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5) |
| ax.set_title("Training Reward (smoothed, window=50)") |
| ax.set_xlabel("Episode") |
| ax.set_ylabel("Total Episode Reward") |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[0, 1] |
| smoothed_s = smooth(train_successes, window=100) * 100 |
| ax.plot(range(len(smoothed_s)), smoothed_s, color="#16a34a", linewidth=1.5) |
| ax.set_title("Training Success Rate (smoothed, window=100)") |
| ax.set_xlabel("Episode") |
| ax.set_ylabel("Success %") |
| ax.yaxis.set_major_formatter(mticker.PercentFormatter()) |
| ax.set_ylim(0, 100) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 0] |
| metrics = ["Avg Reward", "Success %", "Partial %", "Avg Time (min)"] |
| rl_vals = [rl_stats["reward"], rl_stats["success"] * 100, rl_stats["partial"] * 100, rl_stats["minutes"]] |
| naive_vals = [naive_stats["reward"], naive_stats["success"] * 100, naive_stats["partial"] * 100, naive_stats["minutes"]] |
|
|
| x = np.arange(len(metrics)) |
| w = 0.35 |
| bars_rl = ax.bar(x - w / 2, rl_vals, w, label="REINFORCE", color="#2563eb", edgecolor="white") |
| bars_naive = ax.bar(x + w / 2, naive_vals, w, label="Naive", color="#f97316", edgecolor="white") |
| ax.set_xticks(x) |
| ax.set_xticklabels(metrics, fontsize=9) |
| ax.set_title("Evaluation Comparison") |
| ax.legend() |
| ax.grid(True, alpha=0.3, axis="y") |
|
|
| for bar_group in (bars_rl, bars_naive): |
| for bar in bar_group: |
| h = bar.get_height() |
| ax.annotate(f"{h:.1f}", xy=(bar.get_x() + bar.get_width() / 2, h), |
| xytext=(0, 4), textcoords="offset points", |
| ha="center", va="bottom", fontsize=8) |
|
|
| |
| ax = axes[1, 1] |
| if trace: |
| y_labels = [] |
| colors = [] |
| for i, step in enumerate(trace): |
| y_labels.append(step["label"]) |
| if step["result"] == "success": |
| colors.append("#16a34a") |
| elif step["result"] == "partial": |
| colors.append("#eab308") |
| elif step["result"] == "fail": |
| colors.append("#dc2626") |
| else: |
| colors.append("#6b7280") |
|
|
| y_pos = np.arange(len(trace)) |
| minutes = [s["minutes"] for s in trace] |
| ax.barh(y_pos, minutes, color=colors, edgecolor="white", height=0.6) |
| ax.set_yticks(y_pos) |
| ax.set_yticklabels(y_labels, fontsize=8) |
| ax.invert_yaxis() |
| ax.set_xlabel("Elapsed Minutes") |
| ax.set_title("Single Episode Trace (RL Agent)") |
|
|
| for i, step in enumerate(trace): |
| if step["result"] in ("success", "partial", "fail"): |
| ax.annotate(step["result"], xy=(minutes[i], i), |
| xytext=(5, 0), textcoords="offset points", |
| va="center", fontsize=8, fontweight="bold", |
| color=colors[i]) |
| else: |
| ax.text(0.5, 0.5, "No trace data", ha="center", va="center", transform=ax.transAxes) |
| ax.set_title("Single Episode Trace (RL Agent)") |
|
|
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
|
|
| if args.save: |
| fig.savefig(args.save, dpi=150, bbox_inches="tight") |
| print(f"Saved to {args.save}") |
| else: |
| plt.show() |
|
|
| |
| print() |
| print(f" REINFORCE: reward={rl_stats['reward']:.1f} success={rl_stats['success']:.1%} time={rl_stats['minutes']:.0f}m") |
| print(f" Naive: reward={naive_stats['reward']:.1f} success={naive_stats['success']:.1%} time={naive_stats['minutes']:.0f}m") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|