biosim / server /app.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
"""
FastAPI server bridging the LabEnv Python backend to the Next.js frontend.
Endpoints:
POST /api/training/start — train the agent (SSE stream)
POST /api/run/ai — run one AI-agent episode
POST /api/run/naive — run one naive-agent episode
POST /api/env/reset — reset environment
POST /api/env/step — take one step
GET /api/stats — dashboard aggregate stats
"""
from __future__ import annotations
import json
import sys
import time
from pathlib import Path
from typing import Any
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import (
LabEnv,
INITIAL_BUDGET,
ACTION_SETUP_START,
ACTION_RUN_ASSAY,
ACTION_ORDER_TIPS,
ACTION_ORDER_BUFFER,
ACTION_ORDER_POLYMERASE,
ACTION_WAIT,
ACTION_FINISH,
)
from lab_env.spec import pcr_experiment_spec, get_spec_for_workflow
from agents.naive_agent import NaiveAgent
from agents.rl_agent import ReinforceAgent
# Per-workflow envs (created on first use). RL agent is shared and trained on PCR.
_envs: dict[str, LabEnv] = {}
try:
from agents.research_llm_agent import ResearchLLMAgent
HAS_RESEARCH_AGENT = True
except ImportError:
ResearchLLMAgent = None
HAS_RESEARCH_AGENT = False
try:
from agents.research_generate_agent import ResearchGenerateAgent
HAS_RESEARCH_GENERATE_AGENT = True
except ImportError:
ResearchGenerateAgent = None
HAS_RESEARCH_GENERATE_AGENT = False
app = FastAPI(title="SimLab API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
rl_agent: ReinforceAgent | None = None
_trained_agents: dict[str, ReinforceAgent] = {} # workflow_id -> agent (for UI per-protocol training)
run_history: list[dict] = []
def _get_env(workflow_id: str) -> LabEnv:
"""Get or create LabEnv for this workflow. Uses spec from get_spec_for_workflow(workflow_id)."""
if workflow_id not in _envs:
spec = get_spec_for_workflow(workflow_id)
_envs[workflow_id] = LabEnv(spec=spec)
return _envs[workflow_id]
# ──────────────────────────────────────────────
# Request / response models
# ──────────────────────────────────────────────
class TrainRequest(BaseModel):
episodes: int = 2000
lr: float = 3e-3
max_trials: int = 4
eval_episodes: int = 100
workflow_id: str = "pcr-amplification"
class StepRequest(BaseModel):
action: int
workflow_id: str = "pcr-amplification"
class RunRequest(BaseModel):
seed: int = 42
workflow_id: str = "pcr-amplification"
# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────
def _env_state_dict(env: LabEnv) -> dict[str, Any]:
info = env._info()
return {
"step_index": info["step_index"],
"elapsed_minutes": info["elapsed_minutes"],
"remaining_budget": info["remaining_budget"],
"inventory": info["inventory"],
"last_result": info["last_result"],
"best_result": info["best_result"],
"max_time": 240,
"max_budget": 500,
}
def _trace_episode(env: LabEnv, agent: ReinforceAgent, seed: int) -> dict:
"""Run an AI episode and produce a step-by-step timeline."""
presets = env.spec.presets
obs, info = env.reset(seed=seed)
agent.reset()
timeline: list[dict] = []
presets_tried: dict[int, str] = {}
for trial in range(agent.max_trials):
if agent._inventory_low(obs):
for act in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
obs, rew, done, trunc, info = env.step(act)
timeline.append({
"title": "Order Reagents",
"description": _order_label(act),
"time": f"{info['elapsed_minutes']:.0f} min",
"status": "action",
"icon": "order",
})
if done or trunc:
return _build_run_result(env, info, timeline, presets_tried)
preset = agent._select_preset(obs, deterministic=True)
p = presets[preset]
label = _preset_label(p)
obs, rew, done, trunc, info = env.step(ACTION_SETUP_START + preset)
timeline.append({
"title": "Setup",
"description": label,
"time": f"{info['elapsed_minutes']:.0f} min",
"status": "pending",
"icon": "setup",
})
if done or trunc:
return _build_run_result(env, info, timeline, presets_tried)
obs, rew, done, trunc, info = env.step(ACTION_RUN_ASSAY)
result = info["last_result"]
presets_tried[preset] = result
timeline.append({
"title": "Run Assay",
"description": _result_description(result),
"time": f"{info['elapsed_minutes']:.0f} min",
"status": result,
"icon": "run",
})
if done or trunc:
return _build_run_result(env, info, timeline, presets_tried)
if info.get("best_result") == "success":
obs, rew, _, _, info = env.step(ACTION_FINISH)
timeline.append({
"title": "Finish",
"description": "Experiment complete — success!",
"time": f"{info['elapsed_minutes']:.0f} min",
"status": "success",
"icon": "finish",
})
return _build_run_result(env, info, timeline, presets_tried)
obs, rew, _, _, info = env.step(ACTION_FINISH)
timeline.append({
"title": "Finish",
"description": f"Experiment complete — best: {info['best_result']}",
"time": f"{info['elapsed_minutes']:.0f} min",
"status": info["best_result"] if info["best_result"] in ("success", "partial") else "fail",
"icon": "finish",
})
return _build_run_result(env, info, timeline, presets_tried)
def _trace_naive_episode(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
presets = env.spec.presets
num_presets = len(presets)
obs, info = env.reset(seed=seed)
agent.reset()
timeline: list[dict] = []
presets_tried: dict[int, str] = {}
total_reward = 0.0
while True:
action = agent.select_action(obs)
obs, reward, done, trunc, info = env.step(action)
total_reward += reward
if ACTION_SETUP_START <= action < ACTION_SETUP_START + num_presets:
p = presets[action - ACTION_SETUP_START]
timeline.append({
"title": "Setup",
"description": _preset_label(p),
"time": f"{info['elapsed_minutes']:.0f} min",
"status": "pending",
"icon": "setup",
})
elif action == ACTION_RUN_ASSAY:
result = info["last_result"]
timeline.append({
"title": "Run Assay",
"description": _result_description(result),
"time": f"{info['elapsed_minutes']:.0f} min",
"status": result,
"icon": "run",
})
elif action in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
timeline.append({
"title": "Order Reagents",
"description": _order_label(action),
"time": f"{info['elapsed_minutes']:.0f} min",
"status": "action",
"icon": "order",
})
elif action == ACTION_FINISH:
timeline.append({
"title": "Finish",
"description": f"Experiment complete — best: {info['best_result']}",
"time": f"{info['elapsed_minutes']:.0f} min",
"status": info["best_result"] if info["best_result"] in ("success", "partial") else "fail",
"icon": "finish",
})
if done or trunc:
break
return _build_run_result(env, info, timeline, presets_tried)
def _build_run_result(env: LabEnv, info: dict, timeline: list[dict], presets_tried: dict[int, str]) -> dict:
presets = env.spec.presets
spec = env.spec
preset_statuses = []
for i, p in enumerate(presets):
row: dict[str, Any] = {
"id": str(i),
"status": presets_tried.get(i, "untried"),
"label": _preset_label(p),
}
if "temp" in p:
row["temp"] = p["temp"]
row["cycles"] = p["cycles"]
row["ratio"] = p["ratio"]
if "coating_hr" in p:
row["coating_hr"] = p["coating_hr"]
row["block"] = p.get("block", "")
preset_statuses.append(row)
return {
"state": {
"elapsed_minutes": info["elapsed_minutes"],
"remaining_budget": info["remaining_budget"],
"inventory": info["inventory"],
"best_result": info["best_result"],
"max_time": getattr(spec, "max_minutes", 240),
"max_budget": getattr(spec, "initial_budget", 500),
},
"timeline": timeline,
"presets": preset_statuses,
"reward": float(INITIAL_BUDGET - info["remaining_budget"]),
"best_result": info["best_result"],
}
def _result_description(result: str) -> str:
return {"success": "Success!", "partial": "Partial — low yield", "fail": "Failed — no amplification"}.get(result, result)
def _order_label(action: int) -> str:
return {ACTION_ORDER_TIPS: "+5 tips", ACTION_ORDER_BUFFER: "+5 buffer", ACTION_ORDER_POLYMERASE: "+3 polymerase"}.get(action, "reagents")
def _preset_label(preset: dict) -> str:
"""Human-readable preset description for timeline/UI (PCR or ELISA)."""
if "coating_hr" in preset:
return f"{preset['coating_hr']}hr coat / {preset['temp']}°C / {preset.get('block', '')}"
return f"{preset.get('temp', '?')}°C / {preset.get('cycles', '?')} cyc / {preset.get('ratio', '?')}"
def _trace_research_episode(env: LabEnv, seed: int, max_trials: int = 5) -> dict:
"""Run Research LLM agent episode and build timeline (Research → Hypothesis → Experiment → Learn). PCR only."""
presets = env.spec.presets
if not HAS_RESEARCH_AGENT:
return _build_run_result(env, env._info(), [{"title": "Research agent unavailable", "description": "Install openai and set OPENAI_API_KEY", "time": "0 min", "status": "fail", "icon": "run"}], {})
if env.spec.name != "pcr":
return _build_run_result(env, env._info(), [{"title": "Research agent", "description": "Research agent is only supported for PCR workflow.", "time": "0 min", "status": "fail", "icon": "run"}], {})
agent = ResearchLLMAgent(max_trials=max_trials)
callback: list[dict] = []
result = agent.run_episode(env, seed=seed, episode_callback=callback)
info = env._info()
timeline: list[dict] = []
presets_tried: dict[int, str] = {}
for step in callback:
research = (step.get("research") or "")[:200]
if len(step.get("research") or "") > 200:
research += "..."
timeline.append({
"title": "Research",
"description": research or "Literature search for PCR protocol",
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": "action",
"icon": "research",
})
hyp = step.get("hypothesis") or {}
timeline.append({
"title": "Hypothesis",
"description": f"temp={hyp.get('temp', '?')}°C, cycles={hyp.get('cycles', '?')}, ratio={hyp.get('ratio', '?')}",
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": "pending",
"icon": "hypothesis",
})
params = step.get("params_used") or {}
res = step.get("result", "fail")
timeline.append({
"title": "Run Assay",
"description": _result_description(res),
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": res,
"icon": "run",
})
for i, p in enumerate(presets):
if p.get("temp") == params.get("temp") and p.get("cycles") == params.get("cycles") and p.get("ratio") == params.get("ratio"):
presets_tried[i] = res
break
timeline.append({
"title": "Learn",
"description": f"temp_range={agent.knowledge.get('temp_range', [])}, cycle_range={agent.knowledge.get('cycle_range', [])}",
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": "action",
"icon": "learn",
})
return _build_run_result(env, info, timeline, presets_tried)
def _protocol_dict_label(protocol: dict) -> str:
"""Human-readable label for a protocol dict (PCR or ELISA)."""
if "coating_hr" in protocol:
return f"{protocol.get('coating_hr', '?')}hr / {protocol.get('temp', '?')}°C / {protocol.get('block', '?')}"
return f"{protocol.get('temp', '?')}°C / {protocol.get('cycles', '?')} cyc / {protocol.get('ratio', '?')}"
def _trace_research_generate_episode(env: LabEnv, seed: int, max_trials: int = 6) -> dict:
"""Run Research & Generate agent (research → generate any protocol → run → learn). Works for PCR, ELISA, etc."""
if not HAS_RESEARCH_GENERATE_AGENT:
return _build_run_result(
env, env._info(),
[{"title": "Research & Generate agent unavailable", "description": "Install openai and set OPENAI_API_KEY", "time": "0 min", "status": "fail", "icon": "run"}],
{},
)
if env.spec.evaluate_custom_protocol is None:
return _build_run_result(
env, env._info(),
[{"title": "Research & Generate", "description": "This workflow does not support custom protocols.", "time": "0 min", "status": "fail", "icon": "run"}],
{},
)
agent = ResearchGenerateAgent(max_trials=max_trials)
agent.run_episode(env, seed=seed, verbose=False)
info = env._info()
timeline: list[dict] = []
preset_statuses: list[dict[str, Any]] = []
for i, entry in enumerate(agent.feedback_history):
protocol = entry.get("protocol", {})
result = entry.get("result", "fail")
label = _protocol_dict_label(protocol)
timeline.append({
"title": "Research & Generate",
"description": f"Generated: {label}",
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": "pending",
"icon": "research",
})
timeline.append({
"title": "Run Assay",
"description": _result_description(result),
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": result,
"icon": "run",
})
row: dict[str, Any] = {"id": str(i), "status": result, "label": label}
if "temp" in protocol:
row["temp"] = protocol.get("temp")
row["cycles"] = protocol.get("cycles")
row["ratio"] = protocol.get("ratio", "")
if "coating_hr" in protocol:
row["coating_hr"] = protocol.get("coating_hr")
row["block"] = protocol.get("block", "")
preset_statuses.append(row)
timeline.append({
"title": "Finish",
"description": f"Best result: {info.get('best_result', 'none')}",
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
"status": info["best_result"] if info["best_result"] in ("success", "partial") else "fail",
"icon": "finish",
})
return {
"state": {
"elapsed_minutes": info["elapsed_minutes"],
"remaining_budget": info["remaining_budget"],
"inventory": info["inventory"],
"best_result": info["best_result"],
"max_time": getattr(env.spec, "max_minutes", 240),
"max_budget": getattr(env.spec, "initial_budget", 500),
},
"timeline": timeline,
"presets": preset_statuses,
"reward": float(INITIAL_BUDGET - info["remaining_budget"]),
"best_result": info["best_result"],
}
# ──────────────────────────────────────────────
# Training endpoint (SSE stream)
# ──────────────────────────────────────────────
@app.post("/api/training/start")
async def training_start(req: TrainRequest):
global rl_agent, _trained_agents
def generate():
global rl_agent, _trained_agents
spec = get_spec_for_workflow(req.workflow_id)
agent = ReinforceAgent(lr=req.lr, max_trials=req.max_trials, spec=spec)
train_env = LabEnv(spec=spec)
window_rewards: list[float] = []
window_successes: list[float] = []
chart_data: list[dict] = []
log_interval = max(req.episodes // 40, 10)
for ep in range(1, req.episodes + 1):
result = agent.run_episode(train_env, seed=42 + ep, train=True)
window_rewards.append(result["reward"])
window_successes.append(float(result["success"]))
if ep % log_interval == 0 or ep == req.episodes:
avg_reward = sum(window_rewards) / len(window_rewards)
avg_success = sum(window_successes) / len(window_successes) * 100
chart_data.append({
"episode": ep,
"reward": round(avg_reward, 2),
"successRate": round(avg_success, 1),
})
progress = round(ep / req.episodes * 100)
event = {
"type": "progress",
"episode": ep,
"total": req.episodes,
"progress": progress,
"reward": round(avg_reward, 2),
"successRate": round(avg_success, 1),
"chartData": chart_data,
}
yield f"data: {json.dumps(event)}\n\n"
window_rewards.clear()
window_successes.clear()
rl_agent = agent
_trained_agents[req.workflow_id] = agent
eval_seed = 999_999
rl_results = [agent.run_episode(train_env, seed=eval_seed + i, train=False) for i in range(req.eval_episodes)]
naive = NaiveAgent(num_trials=3, seed=0)
naive_results = []
for i in range(req.eval_episodes):
obs, info = train_env.reset(seed=eval_seed + i)
naive.reset()
total_r = 0.0
while True:
a = naive.select_action(obs)
obs, r, d, t, info = train_env.step(a)
total_r += r
if d or t:
break
naive_results.append({"reward": total_r, "success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": 500.0 - info["remaining_budget"]})
train_env.close()
n_rl = len(rl_results)
n_nv = len(naive_results)
def agg(res, n):
return {
"reward": round(sum(r["reward"] for r in res) / n, 1),
"success": round(sum(r["success"] for r in res) / n * 100, 1),
"partial": round(sum(r["partial"] for r in res) / n * 100, 1),
"minutes": round(sum(r["minutes"] for r in res) / n, 0),
"cost": round(sum(r["cost"] for r in res) / n, 1),
}
rl_s = agg(rl_results, n_rl)
nv_s = agg(naive_results, n_nv)
def imp(rl_v, nv_v):
if nv_v == 0:
return None
return round((rl_v - nv_v) / abs(nv_v) * 100)
comparison = [
{"metric": "Avg Reward", "reinforce": rl_s["reward"], "baseline": nv_s["reward"], "improvement": imp(rl_s["reward"], nv_s["reward"]), "unit": ""},
{"metric": "Success Rate", "reinforce": rl_s["success"], "baseline": nv_s["success"], "improvement": imp(rl_s["success"], nv_s["success"]), "unit": "%"},
{"metric": "Partial Rate", "reinforce": rl_s["partial"], "baseline": nv_s["partial"], "improvement": imp(rl_s["partial"], nv_s["partial"]), "unit": "%"},
{"metric": "Avg Time", "reinforce": rl_s["minutes"], "baseline": nv_s["minutes"], "improvement": imp(nv_s["minutes"], rl_s["minutes"]), "unit": "min"},
{"metric": "Avg Cost", "reinforce": rl_s["cost"], "baseline": nv_s["cost"], "improvement": imp(nv_s["cost"], rl_s["cost"]), "unit": "$"},
]
final_event = {
"type": "done",
"chartData": chart_data,
"comparison": comparison,
}
yield f"data: {json.dumps(final_event)}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# ──────────────────────────────────────────────
# Run endpoints
# ──────────────────────────────────────────────
@app.post("/api/run/ai")
async def run_ai(req: RunRequest):
global rl_agent, _trained_agents
env = _get_env(req.workflow_id)
agent = _trained_agents.get(req.workflow_id) or rl_agent
if agent is None:
spec = get_spec_for_workflow(req.workflow_id)
agent = ReinforceAgent(max_trials=4, spec=spec)
rl_agent = agent
_trained_agents[req.workflow_id] = agent
return _trace_episode(env, agent, seed=req.seed)
@app.post("/api/run/naive")
async def run_naive(req: RunRequest):
env = _get_env(req.workflow_id)
agent = NaiveAgent(num_trials=3, seed=req.seed)
return _trace_naive_episode(env, agent, seed=req.seed)
@app.post("/api/run/research")
async def run_research(req: RunRequest):
"""Run Research LLM agent (research → hypothesize → experiment → learn). PCR workflow only."""
env = _get_env(req.workflow_id)
return _trace_research_episode(env, seed=req.seed, max_trials=5)
@app.post("/api/run/research-generate")
async def run_research_generate(req: RunRequest):
"""Run Research & Generate agent (research → generate any protocol → run → learn). PCR, ELISA, any spec with evaluate_custom_protocol."""
env = _get_env(req.workflow_id)
return _trace_research_generate_episode(env, seed=req.seed, max_trials=6)
# ──────────────────────────────────────────────
# Step-by-step endpoint
# ──────────────────────────────────────────────
@app.post("/api/env/reset")
async def env_reset(req: RunRequest):
env = _get_env(req.workflow_id)
obs, info = env.reset(seed=req.seed)
return _env_state_dict(env)
@app.post("/api/env/step")
async def env_step(req: StepRequest):
env = _get_env(req.workflow_id)
obs, reward, terminated, truncated, info = env.step(req.action)
return {
**_env_state_dict(env),
"reward": float(reward),
"terminated": terminated,
"truncated": truncated,
}
# ──────────────────────────────────────────────
# Stats endpoint
# ──────────────────────────────────────────────
@app.get("/api/stats")
async def get_stats():
n_runs = len(run_history)
if n_runs == 0:
return {
"active_workflows": 1,
"total_experiments": 0,
"success_rate": "—",
"budget_spent": "$0",
}
successes = sum(1 for r in run_history if r.get("best_result") == "success")
return {
"active_workflows": 1,
"total_experiments": n_runs,
"success_rate": f"{successes / n_runs:.0%}",
"budget_spent": f"${sum(r.get('cost', 0) for r in run_history):.0f}",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)