biosim / scripts /visualize.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
#!/usr/bin/env python3
"""Train, evaluate, and visualize REINFORCE vs Naive agent on LabEnv.
Produces a 2x2 figure:
Top-left: Training reward curve (smoothed)
Top-right: Training success-rate curve (smoothed)
Bottom-left: Final comparison bar chart (reward, success%, partial%)
Bottom-right: Single-episode trace showing the RL agent's actions
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from lab_env.env import (
LabEnv,
INITIAL_BUDGET,
ACTION_SETUP_START,
ACTION_SETUP_END,
ACTION_RUN_ASSAY,
ACTION_ORDER_TIPS,
ACTION_ORDER_BUFFER,
ACTION_ORDER_POLYMERASE,
ACTION_WAIT,
ACTION_FINISH,
PRESETS,
)
from agents.naive_agent import NaiveAgent
from agents.rl_agent import ReinforceAgent
def smooth(values: list[float], window: int = 50) -> np.ndarray:
if len(values) < window:
return np.array(values)
kernel = np.ones(window) / window
return np.convolve(values, kernel, mode="valid")
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
obs, info = env.reset(seed=seed)
agent.reset()
total_reward = 0.0
steps = 0
while True:
action = agent.select_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
if terminated or truncated:
break
return {
"reward": total_reward,
"success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": INITIAL_BUDGET - info["remaining_budget"],
"steps": steps,
}
def trace_rl_episode(env: LabEnv, agent: ReinforceAgent, seed: int) -> list[dict]:
"""Run one episode and return a step-by-step trace for visualization."""
obs, info = env.reset(seed=seed)
agent.reset()
trace: list[dict] = []
for trial in range(agent.max_trials):
if agent._inventory_low(obs):
for act in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
obs, rew, done, trunc, info = env.step(act)
trace.append({"action": "order", "label": "Order", "result": "", "reward": rew, "minutes": info["elapsed_minutes"]})
if done or trunc:
return trace
preset = agent._select_preset(obs, deterministic=True)
p = PRESETS[preset]
label = f"Setup {p['temp']}C/{p['cycles']}cy/{p['ratio'][:4]}"
obs, rew, done, trunc, info = env.step(ACTION_SETUP_START + preset)
trace.append({"action": "setup", "label": label, "result": "", "reward": rew, "minutes": info["elapsed_minutes"]})
if done or trunc:
return trace
obs, rew, done, trunc, info = env.step(ACTION_RUN_ASSAY)
trace.append({"action": "run", "label": "Run assay", "result": info["last_result"], "reward": rew, "minutes": info["elapsed_minutes"]})
if done or trunc:
return trace
if info.get("best_result") == "success":
obs, rew, _, _, info = env.step(ACTION_FINISH)
trace.append({"action": "finish", "label": "Finish", "result": "success", "reward": rew, "minutes": info["elapsed_minutes"]})
return trace
if not (done or trunc):
obs, rew, _, _, info = env.step(ACTION_FINISH)
trace.append({"action": "finish", "label": "Finish", "result": info["best_result"], "reward": rew, "minutes": info["elapsed_minutes"]})
return trace
def main() -> None:
parser = argparse.ArgumentParser(description="Visualize training & evaluation")
parser.add_argument("--train-episodes", type=int, default=2000)
parser.add_argument("--eval-episodes", type=int, default=200)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--save", type=str, default="", help="Save figure to path instead of showing")
args = parser.parse_args()
env = LabEnv()
rl_agent = ReinforceAgent(max_trials=4)
# ---- Training with metric collection ----
print(f"Training REINFORCE for {args.train_episodes} episodes...")
train_rewards: list[float] = []
train_successes: list[float] = []
for ep in range(1, args.train_episodes + 1):
result = rl_agent.run_episode(env, seed=args.seed + ep, train=True)
train_rewards.append(result["reward"])
train_successes.append(float(result["success"]))
if ep % 500 == 0:
print(f" ...episode {ep}/{args.train_episodes}")
# ---- Evaluation ----
print(f"Evaluating both agents for {args.eval_episodes} episodes...")
eval_seed = 999_999
naive_agent = NaiveAgent(num_trials=3, seed=0)
rl_eval = [rl_agent.run_episode(env, seed=eval_seed + i, train=False) for i in range(args.eval_episodes)]
naive_eval = [run_episode_naive(env, naive_agent, seed=eval_seed + i) for i in range(args.eval_episodes)]
# ---- Episode trace ----
trace = trace_rl_episode(env, rl_agent, seed=12345)
env.close()
# ---- Aggregate ----
def agg(results):
n = len(results)
return {
"reward": sum(r["reward"] for r in results) / n,
"success": sum(r["success"] for r in results) / n,
"partial": sum(r["partial"] for r in results) / n,
"minutes": sum(r["minutes"] for r in results) / n,
}
rl_stats = agg(rl_eval)
naive_stats = agg(naive_eval)
# ==================================================================
# Plot
# ==================================================================
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("SimLab — Lab Automation RL Environment", fontsize=16, fontweight="bold")
# -- Top-left: reward curve --
ax = axes[0, 0]
smoothed_r = smooth(train_rewards, window=50)
ax.plot(range(len(smoothed_r)), smoothed_r, color="#2563eb", linewidth=1.5)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
ax.set_title("Training Reward (smoothed, window=50)")
ax.set_xlabel("Episode")
ax.set_ylabel("Total Episode Reward")
ax.grid(True, alpha=0.3)
# -- Top-right: success rate curve --
ax = axes[0, 1]
smoothed_s = smooth(train_successes, window=100) * 100
ax.plot(range(len(smoothed_s)), smoothed_s, color="#16a34a", linewidth=1.5)
ax.set_title("Training Success Rate (smoothed, window=100)")
ax.set_xlabel("Episode")
ax.set_ylabel("Success %")
ax.yaxis.set_major_formatter(mticker.PercentFormatter())
ax.set_ylim(0, 100)
ax.grid(True, alpha=0.3)
# -- Bottom-left: comparison bars --
ax = axes[1, 0]
metrics = ["Avg Reward", "Success %", "Partial %", "Avg Time (min)"]
rl_vals = [rl_stats["reward"], rl_stats["success"] * 100, rl_stats["partial"] * 100, rl_stats["minutes"]]
naive_vals = [naive_stats["reward"], naive_stats["success"] * 100, naive_stats["partial"] * 100, naive_stats["minutes"]]
x = np.arange(len(metrics))
w = 0.35
bars_rl = ax.bar(x - w / 2, rl_vals, w, label="REINFORCE", color="#2563eb", edgecolor="white")
bars_naive = ax.bar(x + w / 2, naive_vals, w, label="Naive", color="#f97316", edgecolor="white")
ax.set_xticks(x)
ax.set_xticklabels(metrics, fontsize=9)
ax.set_title("Evaluation Comparison")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")
for bar_group in (bars_rl, bars_naive):
for bar in bar_group:
h = bar.get_height()
ax.annotate(f"{h:.1f}", xy=(bar.get_x() + bar.get_width() / 2, h),
xytext=(0, 4), textcoords="offset points",
ha="center", va="bottom", fontsize=8)
# -- Bottom-right: episode trace --
ax = axes[1, 1]
if trace:
y_labels = []
colors = []
for i, step in enumerate(trace):
y_labels.append(step["label"])
if step["result"] == "success":
colors.append("#16a34a")
elif step["result"] == "partial":
colors.append("#eab308")
elif step["result"] == "fail":
colors.append("#dc2626")
else:
colors.append("#6b7280")
y_pos = np.arange(len(trace))
minutes = [s["minutes"] for s in trace]
ax.barh(y_pos, minutes, color=colors, edgecolor="white", height=0.6)
ax.set_yticks(y_pos)
ax.set_yticklabels(y_labels, fontsize=8)
ax.invert_yaxis()
ax.set_xlabel("Elapsed Minutes")
ax.set_title("Single Episode Trace (RL Agent)")
for i, step in enumerate(trace):
if step["result"] in ("success", "partial", "fail"):
ax.annotate(step["result"], xy=(minutes[i], i),
xytext=(5, 0), textcoords="offset points",
va="center", fontsize=8, fontweight="bold",
color=colors[i])
else:
ax.text(0.5, 0.5, "No trace data", ha="center", va="center", transform=ax.transAxes)
ax.set_title("Single Episode Trace (RL Agent)")
plt.tight_layout(rect=[0, 0, 1, 0.95])
if args.save:
fig.savefig(args.save, dpi=150, bbox_inches="tight")
print(f"Saved to {args.save}")
else:
plt.show()
# Print summary
print()
print(f" REINFORCE: reward={rl_stats['reward']:.1f} success={rl_stats['success']:.1%} time={rl_stats['minutes']:.0f}m")
print(f" Naive: reward={naive_stats['reward']:.1f} success={naive_stats['success']:.1%} time={naive_stats['minutes']:.0f}m")
if __name__ == "__main__":
main()