Commit ·
da63ca8
0
Parent(s):
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +26 -0
- .space/DEPLOY.md +39 -0
- Dockerfile +24 -0
- README.md +407 -0
- agents/__init__.py +19 -0
- agents/naive_agent.py +75 -0
- agents/research_generate_agent.py +323 -0
- agents/research_llm_agent.py +406 -0
- agents/rl_agent.py +307 -0
- demo/streamlit_app.py +185 -0
- knowledge/pcr_protocols.json +32 -0
- lab_env/__init__.py +4 -0
- lab_env/env.py +369 -0
- lab_env/openenv_adapter.py +231 -0
- lab_env/spec.py +367 -0
- pyproject.toml +24 -0
- scripts/compare_all_agents.py +139 -0
- scripts/demo_hackathon.sh +24 -0
- scripts/demo_research_agent.py +45 -0
- scripts/run_naive_baseline.py +75 -0
- scripts/run_research_generate_agent.py +91 -0
- scripts/train_and_eval_agent.py +148 -0
- scripts/train_per_protocol.py +82 -0
- scripts/visualize.py +258 -0
- server/app.py +621 -0
- v0ap/.gitignore +10 -0
- v0ap/app/docs/page.tsx +87 -0
- v0ap/app/globals.css +137 -0
- v0ap/app/layout.tsx +68 -0
- v0ap/app/page.tsx +24 -0
- v0ap/app/training/page.tsx +114 -0
- v0ap/app/workflows/[id]/page.tsx +180 -0
- v0ap/app/workflows/page.tsx +24 -0
- v0ap/components.json +21 -0
- v0ap/components/app-sidebar.tsx +91 -0
- v0ap/components/dashboard/performance-chart.tsx +88 -0
- v0ap/components/dashboard/recent-experiments.tsx +79 -0
- v0ap/components/dashboard/stats-cards.tsx +62 -0
- v0ap/components/theme-provider.tsx +14 -0
- v0ap/components/training/comparison-table.tsx +116 -0
- v0ap/components/training/training-chart.tsx +119 -0
- v0ap/components/training/training-controls.tsx +141 -0
- v0ap/components/ui/accordion.tsx +66 -0
- v0ap/components/ui/alert-dialog.tsx +157 -0
- v0ap/components/ui/alert.tsx +66 -0
- v0ap/components/ui/aspect-ratio.tsx +11 -0
- v0ap/components/ui/avatar.tsx +53 -0
- v0ap/components/ui/badge.tsx +46 -0
- v0ap/components/ui/breadcrumb.tsx +109 -0
- v0ap/components/ui/button-group.tsx +83 -0
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.egg-info/
|
| 5 |
+
.eggs/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
.env
|
| 9 |
+
.venv/
|
| 10 |
+
venv/
|
| 11 |
+
|
| 12 |
+
# Node / Next
|
| 13 |
+
node_modules/
|
| 14 |
+
.next/
|
| 15 |
+
.v0ap/
|
| 16 |
+
|
| 17 |
+
# IDE / OS
|
| 18 |
+
.idea/
|
| 19 |
+
.vscode/
|
| 20 |
+
.DS_Store
|
| 21 |
+
*.log
|
| 22 |
+
|
| 23 |
+
# v0 runtime (if present)
|
| 24 |
+
__v0_runtime_loader.js
|
| 25 |
+
__v0_devtools.tsx
|
| 26 |
+
__v0_jsx-dev-runtime.ts
|
.space/DEPLOY.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deploy SimLab to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
Use this to get a **Hugging Face Spaces link** (e.g. `https://YOUR_USERNAME-simlab-env.hf.space`).
|
| 4 |
+
|
| 5 |
+
## Option 1: Docker Space (OpenEnv API only)
|
| 6 |
+
|
| 7 |
+
1. Go to [https://huggingface.co/spaces](https://huggingface.co/spaces) and click **Create new Space**.
|
| 8 |
+
2. Choose:
|
| 9 |
+
- **Name:** e.g. `simlab-env`
|
| 10 |
+
- **SDK:** **Docker**
|
| 11 |
+
- **Visibility:** Public (or Private)
|
| 12 |
+
3. Push this repo (or the contents) to the Space repo, or copy the `Dockerfile` from the repo root into the Space.
|
| 13 |
+
4. In the Space repo, the **Dockerfile** must be at the root. If your Space is a clone of simlab, the root already has the Dockerfile. If you created an empty Space, add a Dockerfile with:
|
| 14 |
+
|
| 15 |
+
```dockerfile
|
| 16 |
+
FROM python:3.11-slim
|
| 17 |
+
WORKDIR /app
|
| 18 |
+
RUN apt-get update && apt-get install -y --no-install-recommends build-essential && rm -rf /var/lib/apt/lists/*
|
| 19 |
+
COPY pyproject.toml ./
|
| 20 |
+
COPY lab_env ./lab_env/
|
| 21 |
+
RUN pip install --no-cache-dir -e .
|
| 22 |
+
ENV PORT=7860
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
CMD uvicorn lab_env.openenv_adapter:app --host 0.0.0.0 --port ${PORT}
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
5. Build and run. Your Space link will be:
|
| 28 |
+
- **`https://huggingface.co/spaces/YOUR_USERNAME/simlab-env`**
|
| 29 |
+
- or **`https://YOUR_USERNAME-simlab-env.hf.space`**
|
| 30 |
+
|
| 31 |
+
That Space serves the **OpenEnv API** (POST /reset, POST /step, GET /state, GET /metadata). It does **not** serve the full Next.js Training/Workflow UI; for that you run the app locally or host it elsewhere.
|
| 32 |
+
|
| 33 |
+
## Option 2: Link to an existing Space
|
| 34 |
+
|
| 35 |
+
If the OpenEnv org or someone else has already deployed SimLab, the link might be:
|
| 36 |
+
|
| 37 |
+
- **OpenEnv org:** [https://huggingface.co/openenv](https://huggingface.co/openenv) (list of envs; SimLab may be listed there if published)
|
| 38 |
+
|
| 39 |
+
Once your Space is live, use **`https://huggingface.co/spaces/YOUR_USERNAME/simlab-env`** as the Hugging Face Spaces link.
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space — SimLab OpenEnv API
|
| 2 |
+
# Exposes POST /reset, POST /step, GET /state, GET /metadata on port 7860
|
| 3 |
+
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install system deps if needed
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
build-essential \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy package files
|
| 14 |
+
COPY pyproject.toml ./
|
| 15 |
+
COPY lab_env ./lab_env/
|
| 16 |
+
|
| 17 |
+
# Install simlab (and openenv-core, torch, gymnasium, numpy)
|
| 18 |
+
RUN pip install --no-cache-dir -e .
|
| 19 |
+
|
| 20 |
+
# HF Spaces expect the app on port 7860
|
| 21 |
+
ENV PORT=7860
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
CMD ["uvicorn", "lab_env.openenv_adapter:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SimLab — Lab Automation RL Environment
|
| 2 |
+
|
| 3 |
+
A self-contained Gymnasium-style reinforcement learning environment that
|
| 4 |
+
simulates **any** wet-lab experiment workflow. The experiment type is defined by
|
| 5 |
+
an **ExperimentSpec** (protocol presets, inventory, rewards, outcome model). The
|
| 6 |
+
default spec is PCR (Polymerase Chain Reaction); you can plug in ELISA, custom
|
| 7 |
+
assays, or any protocol-discovery task under real-world constraints: limited
|
| 8 |
+
time, budget, and finite reagent inventory.
|
| 9 |
+
|
| 10 |
+
Built for the **OpenEnv** ecosystem so it can be wrapped as an HTTP-served,
|
| 11 |
+
sandboxed environment and uploaded to the OpenEnv hub on Hugging Face.
|
| 12 |
+
|
| 13 |
+
**Integrations:** [OpenEnv](https://meta-pytorch.github.io/OpenEnv/) · [Hugging Face](https://huggingface.co/openenv)
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## What the Environment Simulates
|
| 18 |
+
|
| 19 |
+
Each episode represents a scientist at the bench trying to get a successful
|
| 20 |
+
result. The environment:
|
| 21 |
+
|
| 22 |
+
- **Samples a hidden optimal protocol** on every `reset()` — the agent never
|
| 23 |
+
sees it directly.
|
| 24 |
+
- Offers **protocol presets** (defined in the spec) the agent can choose from.
|
| 25 |
+
- Lets the agent **run assays** that consume reagents and time, returning
|
| 26 |
+
outcomes (e.g. success / partial / fail) from the spec’s outcome model.
|
| 27 |
+
- **Custom protocols:** Specs with `evaluate_custom_protocol` (PCR, ELISA) allow
|
| 28 |
+
**arbitrary** protocol parameters via `env.run_assay_with_protocol(protocol_dict)` — agents can generate and try any valid params, not just presets.
|
| 29 |
+
- Allows **ordering more reagents** (costs money and time) and **waiting**.
|
| 30 |
+
- Terminates when the agent calls **finish**, runs out of time/budget, or
|
| 31 |
+
exhausts inventory with no way to reorder.
|
| 32 |
+
|
| 33 |
+
**Default (PCR):** 12 presets (3 temps × 2 cycle counts × 2 reagent ratios);
|
| 34 |
+
probabilistic success based on distance to hidden optimum. Other experiments
|
| 35 |
+
use their own presets and outcome logic via a custom `ExperimentSpec`.
|
| 36 |
+
|
| 37 |
+
### Reward structure (default PCR)
|
| 38 |
+
|
| 39 |
+
The reward encodes real lab trade-offs (all configurable per spec):
|
| 40 |
+
|
| 41 |
+
| Signal | Value |
|
| 42 |
+
|---|---|
|
| 43 |
+
| Immediate assay result: success | +15 |
|
| 44 |
+
| Immediate assay result: partial | +5 |
|
| 45 |
+
| Per-assay cost penalty | -3 |
|
| 46 |
+
| Terminal bonus (best = success) | +60 |
|
| 47 |
+
| Terminal bonus (best = partial) | +25 |
|
| 48 |
+
| Terminal penalty (no success/partial) | -20 |
|
| 49 |
+
| Time penalty | -0.25 per minute elapsed |
|
| 50 |
+
|
| 51 |
+
A good agent learns to explore efficiently — try a few presets, read the
|
| 52 |
+
signals from partial/success outcomes, and converge on the best protocol before
|
| 53 |
+
finishing.
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Architecture
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
simlab/
|
| 61 |
+
├── pyproject.toml # Package metadata & dependencies
|
| 62 |
+
├── README.md
|
| 63 |
+
├── lab_env/
|
| 64 |
+
│ ├── __init__.py
|
| 65 |
+
│ ├── spec.py # ExperimentSpec, pcr_experiment_spec()
|
| 66 |
+
│ ├── env.py # LabEnv (Gymnasium interface, any experiment)
|
| 67 |
+
│ └── openenv_adapter.py # OpenEnv types, LabEnvironment, HTTP app
|
| 68 |
+
├── agents/
|
| 69 |
+
│ ├── __init__.py
|
| 70 |
+
│ ├── naive_agent.py # Random-preset baseline
|
| 71 |
+
│ ├── rl_agent.py # REINFORCE policy-gradient agent (PyTorch)
|
| 72 |
+
│ ├── research_llm_agent.py # LLM researcher: presets + research
|
| 73 |
+
│ └── research_generate_agent.py # Research → generate any protocol → run → learn from feedback
|
| 74 |
+
├── knowledge/
|
| 75 |
+
│ └── pcr_protocols.json # Fake “papers” for web_search tool (demo)
|
| 76 |
+
├── demo/
|
| 77 |
+
│ └── streamlit_app.py # Live research dashboard + 3-agent comparison
|
| 78 |
+
└── scripts/
|
| 79 |
+
├── run_naive_baseline.py # Evaluate the naive agent
|
| 80 |
+
├── train_and_eval_agent.py # Train REINFORCE & compare both agents
|
| 81 |
+
├── compare_all_agents.py # Benchmark Naive vs RL vs Research LLM
|
| 82 |
+
├── run_research_generate_agent.py # Research → generate protocol → run → learn (any protocol)
|
| 83 |
+
└── demo_research_agent.py # Terminal demo of research agent
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Defining a new experiment
|
| 87 |
+
|
| 88 |
+
Implement an `ExperimentSpec` in `lab_env/spec.py` (or your own module) with:
|
| 89 |
+
|
| 90 |
+
- **presets** — list of protocol dicts (e.g. temperature, cycles, ratio for PCR).
|
| 91 |
+
- **inventory_items** / **orderable_items** — what the lab tracks and can reorder.
|
| 92 |
+
- **initial_inventory**, **order_costs**, **result_labels**.
|
| 93 |
+
- **sample_hidden_optimum(rng)** — returns hidden optimal state (e.g. ideal temp/cycles).
|
| 94 |
+
- **sample_assay_result(hidden, preset_idx, presets, rng)** — returns outcome label.
|
| 95 |
+
- **evaluate_custom_protocol(hidden, protocol_dict, rng)** (optional) — score an arbitrary protocol dict so agents can run any params via `env.run_assay_with_protocol(protocol_dict)`.
|
| 96 |
+
- **protocol_param_schema** (optional) — dict describing params for codegen/LLM (e.g. `{"temp": {"type": "number"}, "cycles": {"type": "integer"}, ...}`).
|
| 97 |
+
|
| 98 |
+
Then use `LabEnv(spec=my_spec)` or pass `spec` into the OpenEnv `LabEnvironment(spec=my_spec)`.
|
| 99 |
+
|
| 100 |
+
### Agent design
|
| 101 |
+
|
| 102 |
+
The **REINFORCE agent** decomposes the problem into a learned and a scripted
|
| 103 |
+
part:
|
| 104 |
+
|
| 105 |
+
- **Learned** — a 2-layer MLP (14 → 64 → 64 → 12) maps the observation to a
|
| 106 |
+
distribution over the 12 protocol presets. Trained with REINFORCE + entropy
|
| 107 |
+
bonus + running-mean baseline.
|
| 108 |
+
- **Scripted** — the episode loop (setup → run assay → check result → order
|
| 109 |
+
if needed → finish on success) is fixed so the agent focuses on the hard
|
| 110 |
+
decision: *which* preset to try.
|
| 111 |
+
|
| 112 |
+
This decomposition lets training converge in ~2000 episodes (a few seconds on
|
| 113 |
+
CPU) while clearly beating the random-preset naive baseline.
|
| 114 |
+
|
| 115 |
+
The **Research LLM agent** adds a self-improving lab scientist: it researches
|
| 116 |
+
protocols (via a `web_search` tool over a local knowledge base), hypothesizes
|
| 117 |
+
new parameter combinations (mapped to presets), runs experiments in LabEnv, and
|
| 118 |
+
updates internal knowledge from results.
|
| 119 |
+
|
| 120 |
+
The **Research & Generate agent** (`research_generate_agent.py`) goes further: it
|
| 121 |
+
**researches** (web_search), **generates** protocol parameters for **any** valid
|
| 122 |
+
values (not limited to presets), **runs** them via `env.run_assay_with_protocol(protocol_dict)`,
|
| 123 |
+
and **learns from feedback** — each run's (protocol, result, reward) is passed
|
| 124 |
+
into the next trial so the agent improves over the episode. Works with any spec
|
| 125 |
+
that has `evaluate_custom_protocol` (PCR, ELISA). Run it with:
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
export OPENAI_API_KEY=your_key
|
| 129 |
+
python scripts/run_research_generate_agent.py --episodes 5 --verbose
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
Use `--workflow elisa-readout` for ELISA. Add `knowledge/{name}_protocols.json`
|
| 133 |
+
for more experiment types so research has literature to search.
|
| 134 |
+
|
| 135 |
+
### Training on different protocol sets
|
| 136 |
+
|
| 137 |
+
Each **protocol** (PCR, ELISA, or a custom spec) has its own **presets** and outcome model. The RL agent can train on any of them so you get one policy per protocol set.
|
| 138 |
+
|
| 139 |
+
- **One agent per protocol:** Create an agent with that spec and train it on an env with the same spec. The policy’s input/output sizes come from the spec (e.g. 14-dim obs → 12 presets for PCR; same for ELISA).
|
| 140 |
+
- **Script:** `scripts/train_per_protocol.py` trains a separate REINFORCE agent for each workflow and saves checkpoints (e.g. `checkpoints/pcr-amplification.pt`, `checkpoints/elisa-readout.pt`):
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout --train-episodes 1500
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
- **Using agents to create different protocol sets:** You can define new protocol sets in two ways:
|
| 147 |
+
1. **In code:** Add a new `ExperimentSpec` in `lab_env/spec.py` (or your own module): define `presets`, `sample_hidden_optimum`, `sample_assay_result`, and optionally `evaluate_custom_protocol` + `protocol_param_schema`. Register it in `get_spec_for_workflow()` and run `train_per_protocol.py --workflows your-workflow-id`.
|
| 148 |
+
2. **Generated presets:** Use an LLM or script to produce a list of protocol dicts (e.g. different temps/cycles) and a simple outcome rule; wrap them in an `ExperimentSpec` and train an agent with `ReinforceAgent(spec=my_spec)` on `LabEnv(spec=my_spec)`. The Research & Generate agent already “creates” protocols at run time (arbitrary params); to **train** on a generated set, you’d turn that set into fixed presets in a new spec and train REINFORCE on it.
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
## Quick Start
|
| 153 |
+
|
| 154 |
+
### Install
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
pip install -e .
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
Or just ensure `numpy`, `torch`, and `gymnasium` are installed.
|
| 161 |
+
|
| 162 |
+
### Run the naive baseline
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
python scripts/run_naive_baseline.py --episodes 200
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Train the REINFORCE agent and compare
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
python scripts/train_and_eval_agent.py --train-episodes 2000 --eval-episodes 100
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### Next.js UI + API server (general UI)
|
| 175 |
+
|
| 176 |
+
Run the FastAPI backend, then the Next.js frontend (with API proxy to the backend):
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
# Terminal 1: Python API (agents + LabEnv)
|
| 180 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 181 |
+
|
| 182 |
+
# Terminal 2: Next.js frontend (v0ap)
|
| 183 |
+
cd v0ap && pnpm dev
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
Then open the workflow run page (e.g. `/workflows/pcr-amplification`). The UI shows **Run with AI Agent**, **Run Research Agent** (research → hypothesize → experiment → learn), and **Run Naive Baseline**. The timeline displays which agent was used and each step (Research, Hypothesis, Run Assay, Learn for the research agent). Set `OPENAI_API_KEY` if you use the Research agent.
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## Hackathon / live demo — how to show the RL
|
| 191 |
+
|
| 192 |
+
**Pitch in one line:** *“We simulate a lab where an agent has to discover the right protocol; you see it learn with RL and compare to baselines.”*
|
| 193 |
+
|
| 194 |
+
### Setup (do this before going on stage)
|
| 195 |
+
|
| 196 |
+
1. **Start both servers** (two terminals):
|
| 197 |
+
```bash
|
| 198 |
+
# Terminal 1 — API (agents + LabEnv)
|
| 199 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 200 |
+
|
| 201 |
+
# Terminal 2 — UI
|
| 202 |
+
cd v0ap && pnpm dev
|
| 203 |
+
```
|
| 204 |
+
2. Open **http://localhost:3000** (or the URL Next.js prints).
|
| 205 |
+
3. Optional: set `OPENAI_API_KEY` if you want to demo Research / Research & Generate.
|
| 206 |
+
|
| 207 |
+
### Demo flow A — “Watch the RL agent learn” (~2 min)
|
| 208 |
+
|
| 209 |
+
1. Go to **Training** (`/training`).
|
| 210 |
+
2. Say: *“This is our wet-lab sim. The agent doesn’t know the optimal protocol; it has to learn from trial and error.”*
|
| 211 |
+
3. Set **episodes to 500** (slider) for a short run — training finishes in under a minute on a laptop.
|
| 212 |
+
4. Click **Start Training**. Point at:
|
| 213 |
+
- **Progress** and “Episode X of 500”.
|
| 214 |
+
- **Chart**: reward and success rate climbing over episodes.
|
| 215 |
+
5. When it finishes: *“Here’s the comparison: REINFORCE vs random baseline.”* Show the table (success rate, reward, time).
|
| 216 |
+
|
| 217 |
+
### Demo flow B — “Compare agents in the lab” (~1–2 min)
|
| 218 |
+
|
| 219 |
+
1. Go to **PCR Amplification** (`/workflows/pcr-amplification`).
|
| 220 |
+
2. Say: *“Each run is one scientist trying to get a successful experiment under time and budget.”*
|
| 221 |
+
3. Click **Run Naive Baseline** — timeline fills with random preset choices and results.
|
| 222 |
+
4. Then click **Run with AI Agent** (uses the policy you trained in flow A, or a default). Point at the timeline: *“The learned agent picks protocols more purposefully and often gets success sooner.”*
|
| 223 |
+
5. If you have an API key: click **Research & Generate (any protocol)** — *“This one researches, proposes parameters, runs them, and learns from feedback.”*
|
| 224 |
+
|
| 225 |
+
### Tips
|
| 226 |
+
|
| 227 |
+
- **Keep training short on stage:** 500 episodes is enough to show learning; 1000 if you have time.
|
| 228 |
+
- **If the UI is slow:** Run a quick train in the background before the demo, then only show “Run with AI Agent” and the comparison table.
|
| 229 |
+
- **Backup:** Pre-record a 1‑minute screen capture of training + one workflow run; use it if WiFi or live run fails.
|
| 230 |
+
- **Talking points:** Hidden optimal protocol, limited time/budget, REINFORCE policy over presets, Research & Generate for “any protocol” + learning from feedback.
|
| 231 |
+
|
| 232 |
+
### Demo script (optional)
|
| 233 |
+
|
| 234 |
+
From repo root, run `./scripts/demo_hackathon.sh` for a short checklist and the option to start the API in that terminal. Or start both manually:
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
# Terminal 1
|
| 238 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 239 |
+
|
| 240 |
+
# Terminal 2
|
| 241 |
+
cd v0ap && pnpm dev
|
| 242 |
+
# Open http://localhost:3000 → /training or /workflows/pcr-amplification
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
### Research LLM agent (optional, Streamlit)
|
| 248 |
+
|
| 249 |
+
Install demo dependencies (`openai`, `streamlit`) and set `OPENAI_API_KEY`:
|
| 250 |
+
|
| 251 |
+
```bash
|
| 252 |
+
pip install -e ".[demo]"
|
| 253 |
+
export OPENAI_API_KEY=your_key
|
| 254 |
+
streamlit run demo/streamlit_app.py
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
The Streamlit app shows the research flow (research → hypothesize → experiment → learn) and a 3-agent comparison table. To benchmark all agents from the terminal:
|
| 258 |
+
|
| 259 |
+
```bash
|
| 260 |
+
python scripts/compare_all_agents.py --eval-episodes 50
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
### Sample output (train & eval)
|
| 264 |
+
|
| 265 |
+
```
|
| 266 |
+
Metric REINFORCE Naive
|
| 267 |
+
----------------------------------------------
|
| 268 |
+
Avg reward 15.7 5.0
|
| 269 |
+
Success rate 53.0% 43.0%
|
| 270 |
+
Partial rate 19.0% 15.0%
|
| 271 |
+
Avg time 62.8m 63.0m
|
| 272 |
+
Avg cost $0.0 $0.0
|
| 273 |
+
Avg steps 7.0 7.0
|
| 274 |
+
----------------------------------------------
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
## OpenEnv & Hugging Face — How to show and use
|
| 280 |
+
|
| 281 |
+
SimLab is built for the **OpenEnv** ecosystem and can be served over HTTP and deployed to **Hugging Face** as a standardized agentic environment.
|
| 282 |
+
|
| 283 |
+
### How SimLab uses OpenEnv
|
| 284 |
+
|
| 285 |
+
- **`openenv-core`** is a required dependency (`pyproject.toml`).
|
| 286 |
+
- **`lab_env/openenv_adapter.py`** wraps `LabEnv` in the OpenEnv `Environment` interface:
|
| 287 |
+
- **Types:** `LabAction`, `LabObservation`, `LabState`, `LabEnvironment`
|
| 288 |
+
- **`create_app(LabEnvironment, LabAction, LabObservation, ...)`** — FastAPI app with OpenEnv endpoints
|
| 289 |
+
|
| 290 |
+
### Run the OpenEnv HTTP server
|
| 291 |
+
|
| 292 |
+
```bash
|
| 293 |
+
uvicorn lab_env.openenv_adapter:app --host 0.0.0.0 --port 8000
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
This exposes standard OpenEnv endpoints:
|
| 297 |
+
|
| 298 |
+
| Endpoint | Description |
|
| 299 |
+
|----------------|--------------------------------|
|
| 300 |
+
| `POST /reset` | Reset environment, get initial observation |
|
| 301 |
+
| `POST /step` | Send action, get next observation & reward |
|
| 302 |
+
| `GET /state` | Current state snapshot |
|
| 303 |
+
| `GET /metadata`| Environment name, version, docs |
|
| 304 |
+
| WebSocket `/ws`| Persistent session (optional) |
|
| 305 |
+
|
| 306 |
+
Up to `max_concurrent_envs=4` sessions are supported.
|
| 307 |
+
|
| 308 |
+
### Call the OpenEnv server (show usage)
|
| 309 |
+
|
| 310 |
+
From another process or machine, you can drive SimLab over HTTP:
|
| 311 |
+
|
| 312 |
+
```bash
|
| 313 |
+
# Reset (start new episode)
|
| 314 |
+
curl -s -X POST http://localhost:8000/reset -H "Content-Type: application/json" -d '{"seed": 42}' | jq .
|
| 315 |
+
|
| 316 |
+
# Step (e.g. action 0 = setup preset 0)
|
| 317 |
+
curl -s -X POST http://localhost:8000/step -H "Content-Type: application/json" -d '{"action": 0}' | jq .
|
| 318 |
+
|
| 319 |
+
# Get current state
|
| 320 |
+
curl -s http://localhost:8000/state | jq .
|
| 321 |
+
```
|
| 322 |
+
|
| 323 |
+
From Python (e.g. for demos or integration):
|
| 324 |
+
|
| 325 |
+
```python
|
| 326 |
+
import requests
|
| 327 |
+
|
| 328 |
+
BASE = "http://localhost:8000"
|
| 329 |
+
|
| 330 |
+
# Reset
|
| 331 |
+
r = requests.post(f"{BASE}/reset", json={"seed": 42})
|
| 332 |
+
obs = r.json() # observation with metadata (obs_vector, info, etc.)
|
| 333 |
+
|
| 334 |
+
# Step: setup preset 0, then run assay (action 12 for PCR)
|
| 335 |
+
requests.post(f"{BASE}/step", json={"action": 0})
|
| 336 |
+
r = requests.post(f"{BASE}/step", json={"action": 12})
|
| 337 |
+
print(r.json()) # observation, reward, done
|
| 338 |
+
|
| 339 |
+
# State
|
| 340 |
+
state = requests.get(f"{BASE}/state").json()
|
| 341 |
+
print(state["step_count"], state["best_result"])
|
| 342 |
+
```
|
| 343 |
+
|
| 344 |
+
### Deploy to Hugging Face
|
| 345 |
+
|
| 346 |
+
To **show SimLab on the Hugging Face Hub** as an OpenEnv environment:
|
| 347 |
+
|
| 348 |
+
1. **Option A — Hugging Face Space (Docker)**
|
| 349 |
+
Create a Space with **Docker** as the SDK. Use a `Dockerfile` that installs SimLab and runs:
|
| 350 |
+
```dockerfile
|
| 351 |
+
CMD uvicorn lab_env.openenv_adapter:app --host 0.0.0.0 --port 7860
|
| 352 |
+
```
|
| 353 |
+
Point the Space to your repo and set the port to 7860 (or the port HF expects). Your Space URL (e.g. `https://huggingface.co/spaces/your-username/simlab-env`) is then the public OpenEnv endpoint.
|
| 354 |
+
|
| 355 |
+
2. **Option B — OpenEnv CLI (if you adopt the full OpenEnv layout)**
|
| 356 |
+
The [OpenEnv Packaging & Deploying](https://meta-pytorch.github.io/OpenEnv/auto_getting_started/environment-builder.html) guide uses `openenv init`, `openenv build`, and **`openenv push`** to deploy to the Hub. SimLab currently uses `openenv-core` and a custom adapter; to use `openenv push`, you would add the expected layout (e.g. `openenv.yaml`, `server/` with Dockerfile) and wire the existing `LabEnvironment` + `create_app` into that structure.
|
| 357 |
+
|
| 358 |
+
3. **Link your repo on the Hub**
|
| 359 |
+
In your SimLab repo or any Hugging Face model/Space card, set the **Repository** and **Documentation** URLs to your GitHub repo and add a tag or short description such as: *"OpenEnv-compatible lab automation environment; run with `uvicorn lab_env.openenv_adapter:app` and connect via POST /reset, POST /step."*
|
| 360 |
+
|
| 361 |
+
### References
|
| 362 |
+
|
| 363 |
+
- [OpenEnv documentation](https://meta-pytorch.github.io/OpenEnv/) — framework overview and APIs
|
| 364 |
+
- [OpenEnv on Hugging Face](https://huggingface.co/openenv) — OpenEnv org and environments
|
| 365 |
+
- [Packaging & Deploying (OpenEnv)](https://meta-pytorch.github.io/OpenEnv/auto_getting_started/environment-builder.html) — build, validate, push to Hub
|
| 366 |
+
|
| 367 |
+
---
|
| 368 |
+
|
| 369 |
+
## Environment API Reference
|
| 370 |
+
|
| 371 |
+
```python
|
| 372 |
+
from lab_env import LabEnv, ExperimentSpec, pcr_experiment_spec
|
| 373 |
+
|
| 374 |
+
# Default: PCR experiment (same as before)
|
| 375 |
+
env = LabEnv()
|
| 376 |
+
# Or any experiment from a spec:
|
| 377 |
+
# env = LabEnv(spec=my_experiment_spec)
|
| 378 |
+
|
| 379 |
+
obs, info = env.reset(seed=42)
|
| 380 |
+
|
| 381 |
+
# obs shape and action count come from env.spec (e.g. PCR: 14-dim obs, 18 actions)
|
| 382 |
+
# [0] step_index (normalised)
|
| 383 |
+
# [1] elapsed_minutes (normalised)
|
| 384 |
+
# [2] remaining_budget (normalised)
|
| 385 |
+
# [3..] inventory (one per spec.inventory_items, normalised)
|
| 386 |
+
# [...] last_result one-hot (len(spec.result_labels))
|
| 387 |
+
# [...] has_setup, current_preset_idx (norm), best_result_score
|
| 388 |
+
|
| 389 |
+
# Actions (Discrete, from spec):
|
| 390 |
+
# 0 .. num_presets-1 setup_reaction(preset_index)
|
| 391 |
+
# num_presets run_assay
|
| 392 |
+
# num_presets+1 .. order_reagents (one per orderable_items)
|
| 393 |
+
# ... wait, finish
|
| 394 |
+
|
| 395 |
+
obs, reward, terminated, truncated, info = env.step(0) # setup preset 0
|
| 396 |
+
obs, reward, terminated, truncated, info = env.step(12) # run assay (PCR)
|
| 397 |
+
obs, reward, terminated, truncated, info = env.step(17) # finish (PCR)
|
| 398 |
+
|
| 399 |
+
# Custom protocol (any params; spec must have evaluate_custom_protocol)
|
| 400 |
+
obs, reward, term, trunc, info = env.run_assay_with_protocol({"temp": 57.5, "cycles": 32, "ratio": "conservative"})
|
| 401 |
+
```
|
| 402 |
+
|
| 403 |
+
---
|
| 404 |
+
|
| 405 |
+
## License
|
| 406 |
+
|
| 407 |
+
MIT
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agents.naive_agent import NaiveAgent
|
| 2 |
+
from agents.rl_agent import ReinforceAgent
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from agents.research_llm_agent import ResearchLLMAgent
|
| 6 |
+
except ImportError:
|
| 7 |
+
ResearchLLMAgent = None # type: ignore[misc, assignment]
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from agents.research_generate_agent import ResearchGenerateAgent
|
| 11 |
+
except ImportError:
|
| 12 |
+
ResearchGenerateAgent = None # type: ignore[misc, assignment]
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"NaiveAgent",
|
| 16 |
+
"ReinforceAgent",
|
| 17 |
+
"ResearchLLMAgent",
|
| 18 |
+
"ResearchGenerateAgent",
|
| 19 |
+
]
|
agents/naive_agent.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Naive baseline agent for LabEnv.
|
| 3 |
+
|
| 4 |
+
Follows a fixed strategy:
|
| 5 |
+
1. Pick a random protocol preset.
|
| 6 |
+
2. Run the assay.
|
| 7 |
+
3. Repeat for a fixed number of trials.
|
| 8 |
+
4. Finish.
|
| 9 |
+
If inventory is low, order reagents before running.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from lab_env.env import (
|
| 17 |
+
ACTION_FINISH,
|
| 18 |
+
ACTION_ORDER_BUFFER,
|
| 19 |
+
ACTION_ORDER_POLYMERASE,
|
| 20 |
+
ACTION_ORDER_TIPS,
|
| 21 |
+
ACTION_RUN_ASSAY,
|
| 22 |
+
ACTION_SETUP_START,
|
| 23 |
+
NUM_PRESETS,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class NaiveAgent:
|
| 28 |
+
"""Baseline agent: random preset selection, fixed trial count, no learning."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, num_trials: int = 3, seed: int | None = None) -> None:
|
| 31 |
+
self.num_trials = num_trials
|
| 32 |
+
self._rng = np.random.default_rng(seed)
|
| 33 |
+
self._trial: int = 0
|
| 34 |
+
self._phase: str = "setup"
|
| 35 |
+
|
| 36 |
+
def reset(self) -> None:
|
| 37 |
+
self._trial = 0
|
| 38 |
+
self._phase = "setup"
|
| 39 |
+
|
| 40 |
+
def select_action(self, obs: np.ndarray) -> int:
|
| 41 |
+
"""Choose the next action based on a scripted strategy."""
|
| 42 |
+
tips = obs[3]
|
| 43 |
+
buffer = obs[4]
|
| 44 |
+
poly = obs[5]
|
| 45 |
+
samples = obs[6]
|
| 46 |
+
inventory_low = min(tips, buffer, poly, samples) < 0.05 # ~1 unit
|
| 47 |
+
|
| 48 |
+
if self._trial >= self.num_trials:
|
| 49 |
+
return ACTION_FINISH
|
| 50 |
+
|
| 51 |
+
if self._phase == "setup":
|
| 52 |
+
if inventory_low:
|
| 53 |
+
return self._order_cheapest(obs)
|
| 54 |
+
preset = int(self._rng.integers(0, NUM_PRESETS))
|
| 55 |
+
self._phase = "run"
|
| 56 |
+
return ACTION_SETUP_START + preset
|
| 57 |
+
|
| 58 |
+
if self._phase == "run":
|
| 59 |
+
self._phase = "setup"
|
| 60 |
+
self._trial += 1
|
| 61 |
+
return ACTION_RUN_ASSAY
|
| 62 |
+
|
| 63 |
+
return ACTION_FINISH
|
| 64 |
+
|
| 65 |
+
def update(self, *_args: object, **_kwargs: object) -> None:
|
| 66 |
+
"""No-op — the naive agent does not learn."""
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def _order_cheapest(obs: np.ndarray) -> int:
|
| 70 |
+
tips, buffer, poly = obs[3], obs[4], obs[5]
|
| 71 |
+
if tips <= buffer and tips <= poly:
|
| 72 |
+
return ACTION_ORDER_TIPS
|
| 73 |
+
if buffer <= poly:
|
| 74 |
+
return ACTION_ORDER_BUFFER
|
| 75 |
+
return ACTION_ORDER_POLYMERASE
|
agents/research_generate_agent.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Research & Generate agent: research → generate protocol (any params) → run → learn from feedback.
|
| 3 |
+
|
| 4 |
+
Uses the spec's protocol_param_schema so it works for any experiment type (PCR, ELISA, etc.).
|
| 5 |
+
Generates arbitrary protocol dicts (not limited to presets), runs them via
|
| 6 |
+
env.run_assay_with_protocol(), and learns from (protocol, result, reward) history.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
except ImportError:
|
| 19 |
+
OpenAI = None # type: ignore[misc, assignment]
|
| 20 |
+
|
| 21 |
+
from lab_env.env import LabEnv
|
| 22 |
+
from lab_env.spec import ExperimentSpec
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Knowledge base (per-experiment)
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
KNOWLEDGE_DIR = Path(__file__).resolve().parent.parent / "knowledge"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_protocols_knowledge(experiment_name: str) -> list[dict[str, Any]]:
|
| 33 |
+
"""Load literature/knowledge for an experiment. Tries {name}_protocols.json then pcr_protocols.json."""
|
| 34 |
+
for name in (experiment_name, "pcr"):
|
| 35 |
+
path = KNOWLEDGE_DIR / f"{name}_protocols.json"
|
| 36 |
+
if path.exists():
|
| 37 |
+
with open(path, encoding="utf-8") as f:
|
| 38 |
+
return json.load(f)
|
| 39 |
+
return []
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def web_search(query: str, experiment_name: str = "pcr", top_k: int = 4) -> str:
|
| 43 |
+
"""Search literature for the given experiment type; return relevant snippets."""
|
| 44 |
+
papers = load_protocols_knowledge(experiment_name)
|
| 45 |
+
query_lower = query.lower()
|
| 46 |
+
scored = []
|
| 47 |
+
for p in papers:
|
| 48 |
+
text = (
|
| 49 |
+
f"{p.get('title','')} {p.get('abstract','')} "
|
| 50 |
+
f"{p.get('keywords','')} {p.get('recommendations','')}"
|
| 51 |
+
).lower()
|
| 52 |
+
score = sum(1 for w in query_lower.split() if len(w) > 2 and w in text)
|
| 53 |
+
if score > 0:
|
| 54 |
+
scored.append((score, p))
|
| 55 |
+
scored.sort(key=lambda x: -x[0])
|
| 56 |
+
if not scored:
|
| 57 |
+
return (
|
| 58 |
+
"No relevant literature found. Try general terms for this experiment type, "
|
| 59 |
+
"e.g. temperature, cycles, protocol parameters."
|
| 60 |
+
)
|
| 61 |
+
out = []
|
| 62 |
+
for _, p in scored[:top_k]:
|
| 63 |
+
out.append(f"[{p.get('title','')}] {p.get('recommendations','')}")
|
| 64 |
+
return "\n".join(out)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# Build tool schemas from spec (any protocol)
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
def build_tool_schemas(spec: ExperimentSpec) -> list[dict[str, Any]]:
|
| 72 |
+
"""Build OpenAI tool schemas from spec: web_search + run_experiment with protocol_param_schema."""
|
| 73 |
+
schema = spec.protocol_param_schema
|
| 74 |
+
if not schema:
|
| 75 |
+
# Fallback for specs without protocol_param_schema (e.g. custom)
|
| 76 |
+
run_params = {
|
| 77 |
+
"type": "object",
|
| 78 |
+
"properties": {"protocol": {"type": "object", "description": "Protocol parameters as key-value dict"}},
|
| 79 |
+
"required": ["protocol"],
|
| 80 |
+
}
|
| 81 |
+
else:
|
| 82 |
+
run_params = {
|
| 83 |
+
"type": "object",
|
| 84 |
+
"properties": {
|
| 85 |
+
k: {
|
| 86 |
+
"type": v.get("type", "string"),
|
| 87 |
+
"description": v.get("description", ""),
|
| 88 |
+
}
|
| 89 |
+
| ({"enum": v["enum"]} if "enum" in v else {})
|
| 90 |
+
for k, v in schema.items()
|
| 91 |
+
},
|
| 92 |
+
"required": list(schema.keys()),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
return [
|
| 96 |
+
{
|
| 97 |
+
"type": "function",
|
| 98 |
+
"function": {
|
| 99 |
+
"name": "web_search",
|
| 100 |
+
"description": f"Search scientific literature for {spec.name} protocols and parameter recommendations.",
|
| 101 |
+
"parameters": {
|
| 102 |
+
"type": "object",
|
| 103 |
+
"properties": {
|
| 104 |
+
"query": {
|
| 105 |
+
"type": "string",
|
| 106 |
+
"description": "Search query, e.g. 'optimal temperature and cycles'",
|
| 107 |
+
},
|
| 108 |
+
},
|
| 109 |
+
"required": ["query"],
|
| 110 |
+
},
|
| 111 |
+
},
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"type": "function",
|
| 115 |
+
"function": {
|
| 116 |
+
"name": "run_experiment",
|
| 117 |
+
"description": (
|
| 118 |
+
f"Run one {spec.name} experiment with the given protocol parameters. "
|
| 119 |
+
"You can use any valid values (not limited to presets). The lab returns success/partial/fail."
|
| 120 |
+
),
|
| 121 |
+
"parameters": run_params,
|
| 122 |
+
},
|
| 123 |
+
},
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# Research & Generate Agent
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
class ResearchGenerateAgent:
|
| 132 |
+
"""
|
| 133 |
+
Agentic flow: research → generate protocol (any params) → run in env → get feedback → learn.
|
| 134 |
+
|
| 135 |
+
Works with any ExperimentSpec that has evaluate_custom_protocol and protocol_param_schema.
|
| 136 |
+
Maintains history of (protocol, result, reward) so the LLM learns from feedback.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
model: str = "gpt-4o-mini",
|
| 142 |
+
max_trials: int = 6,
|
| 143 |
+
) -> None:
|
| 144 |
+
if OpenAI is None:
|
| 145 |
+
raise ImportError("Optional dependency 'openai' is required. Install with: pip install openai")
|
| 146 |
+
self.model = model
|
| 147 |
+
self.max_trials = max_trials
|
| 148 |
+
self._client: OpenAI | None = None
|
| 149 |
+
self.feedback_history: list[dict[str, Any]] = []
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def client(self) -> OpenAI:
|
| 153 |
+
if self._client is None:
|
| 154 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 155 |
+
if not api_key:
|
| 156 |
+
raise RuntimeError(
|
| 157 |
+
"OPENAI_API_KEY environment variable is required for ResearchGenerateAgent"
|
| 158 |
+
)
|
| 159 |
+
self._client = OpenAI(api_key=api_key)
|
| 160 |
+
return self._client
|
| 161 |
+
|
| 162 |
+
def _run_tool(
|
| 163 |
+
self,
|
| 164 |
+
name: str,
|
| 165 |
+
arguments: dict[str, Any],
|
| 166 |
+
env: LabEnv,
|
| 167 |
+
) -> tuple[str, float | None]:
|
| 168 |
+
"""Execute one tool. Returns (result_string, reward_if_run_experiment else None)."""
|
| 169 |
+
spec = env.spec
|
| 170 |
+
if name == "web_search":
|
| 171 |
+
q = arguments.get("query", "")
|
| 172 |
+
result = web_search(q, experiment_name=spec.name)
|
| 173 |
+
return result, None
|
| 174 |
+
|
| 175 |
+
if name == "run_experiment":
|
| 176 |
+
# Protocol dict: use protocol_param_schema keys (e.g. temp, cycles, ratio for PCR)
|
| 177 |
+
protocol = dict(arguments)
|
| 178 |
+
if "protocol" in protocol and isinstance(protocol["protocol"], dict):
|
| 179 |
+
protocol = protocol["protocol"]
|
| 180 |
+
try:
|
| 181 |
+
obs, reward, term, trunc, info = env.run_assay_with_protocol(protocol)
|
| 182 |
+
except ValueError as e:
|
| 183 |
+
return f"Error: {e}", None
|
| 184 |
+
result = info.get("last_result", "fail")
|
| 185 |
+
return (
|
| 186 |
+
f"Ran protocol {protocol}. Result: {result}. Reward: {reward:.1f}. "
|
| 187 |
+
f"Best so far: {info.get('best_result', 'none')}.",
|
| 188 |
+
reward,
|
| 189 |
+
)
|
| 190 |
+
return "Unknown tool.", None
|
| 191 |
+
|
| 192 |
+
def _order_reagents_if_low(self, env: LabEnv, obs: Any, info: dict) -> tuple[Any, dict, float]:
|
| 193 |
+
"""Order reagents when inventory is low; return (obs, info, total_reward)."""
|
| 194 |
+
spec = env.spec
|
| 195 |
+
inv = info.get("inventory", {})
|
| 196 |
+
budget = info.get("remaining_budget", 0)
|
| 197 |
+
order_start = spec.action_order_start()
|
| 198 |
+
order_actions = list(range(order_start, spec.action_order_end()))
|
| 199 |
+
total_rew = 0.0
|
| 200 |
+
for idx, item in enumerate(spec.orderable_items):
|
| 201 |
+
if inv.get(item, 0) < 2:
|
| 202 |
+
cost = spec.order_costs.get(item, (0, float("inf")))[1]
|
| 203 |
+
if budget >= cost:
|
| 204 |
+
action = order_start + idx
|
| 205 |
+
obs, rew, term, trunc, info = env.step(action)
|
| 206 |
+
total_rew += rew
|
| 207 |
+
inv = info.get("inventory", inv)
|
| 208 |
+
budget = info.get("remaining_budget", budget)
|
| 209 |
+
if term or trunc:
|
| 210 |
+
break
|
| 211 |
+
return env._obs(), env._info(), total_rew
|
| 212 |
+
|
| 213 |
+
def run_episode(
|
| 214 |
+
self,
|
| 215 |
+
env: LabEnv,
|
| 216 |
+
seed: int,
|
| 217 |
+
*,
|
| 218 |
+
verbose: bool = False,
|
| 219 |
+
) -> dict[str, Any]:
|
| 220 |
+
"""
|
| 221 |
+
Run one episode: each trial = research (optional) → generate protocol → run → record feedback.
|
| 222 |
+
Feedback history is passed into the next trial so the agent learns.
|
| 223 |
+
"""
|
| 224 |
+
if env.spec.evaluate_custom_protocol is None:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"This environment spec does not support custom protocols. "
|
| 227 |
+
"Use a spec with evaluate_custom_protocol (e.g. PCR, ELISA)."
|
| 228 |
+
)
|
| 229 |
+
obs, info = env.reset(seed=seed)
|
| 230 |
+
spec = env.spec
|
| 231 |
+
tools = build_tool_schemas(spec)
|
| 232 |
+
self.feedback_history = []
|
| 233 |
+
total_reward = 0.0
|
| 234 |
+
steps = 0
|
| 235 |
+
|
| 236 |
+
for trial in range(self.max_trials):
|
| 237 |
+
if info.get("best_result") == "success":
|
| 238 |
+
obs, rew, _, _, info = env.step(spec.action_finish())
|
| 239 |
+
total_reward += rew
|
| 240 |
+
steps += 1
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
# Order reagents if needed
|
| 244 |
+
obs, info, order_rew = self._order_reagents_if_low(env, obs, info)
|
| 245 |
+
total_reward += order_rew
|
| 246 |
+
if getattr(env, "_terminated", False) or getattr(env, "_truncated", False):
|
| 247 |
+
break
|
| 248 |
+
|
| 249 |
+
# Build prompt with feedback from previous runs (learning)
|
| 250 |
+
feedback_text = ""
|
| 251 |
+
if self.feedback_history:
|
| 252 |
+
feedback_text = "Previous runs this episode (learn from these):\n"
|
| 253 |
+
for i, entry in enumerate(self.feedback_history[-6:], 1):
|
| 254 |
+
feedback_text += (
|
| 255 |
+
f" {i}. Protocol: {entry['protocol']} → "
|
| 256 |
+
f"Result: {entry['result']}, Reward: {entry['reward']:.1f}\n"
|
| 257 |
+
)
|
| 258 |
+
feedback_text += "\n"
|
| 259 |
+
|
| 260 |
+
param_desc = json.dumps(spec.protocol_param_schema, indent=2) if spec.protocol_param_schema else "protocol params (key-value dict)"
|
| 261 |
+
|
| 262 |
+
system_msg = (
|
| 263 |
+
f"You are a lab scientist running a {spec.name} experiment. "
|
| 264 |
+
"You have tools: web_search (research literature), run_experiment (run one assay with your chosen protocol). "
|
| 265 |
+
"Generate a protocol using the parameter schema below. You can use ANY valid values (not just presets). "
|
| 266 |
+
"Use research and feedback from previous runs to improve. Output exactly one run_experiment call per turn.\n\n"
|
| 267 |
+
f"Parameter schema for run_experiment:\n{param_desc}"
|
| 268 |
+
)
|
| 269 |
+
user_msg = (
|
| 270 |
+
f"{feedback_text}"
|
| 271 |
+
f"Current state: best_result={info.get('best_result')}, last_result={info.get('last_result')}, "
|
| 272 |
+
f"inventory={info.get('inventory')}, remaining_budget=${info.get('remaining_budget', 0):.0f}. "
|
| 273 |
+
f"Trial {trial + 1}/{self.max_trials}. Call run_experiment with your next protocol (one call only)."
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
messages = [
|
| 277 |
+
{"role": "system", "content": system_msg},
|
| 278 |
+
{"role": "user", "content": user_msg},
|
| 279 |
+
]
|
| 280 |
+
response = self.client.chat.completions.create(
|
| 281 |
+
model=self.model,
|
| 282 |
+
messages=messages,
|
| 283 |
+
tools=tools,
|
| 284 |
+
tool_choice="required",
|
| 285 |
+
)
|
| 286 |
+
choice = response.choices[0]
|
| 287 |
+
if not choice.message.tool_calls:
|
| 288 |
+
break
|
| 289 |
+
tc = choice.message.tool_calls[0]
|
| 290 |
+
name = tc.function.name
|
| 291 |
+
args = json.loads(tc.function.arguments or "{}")
|
| 292 |
+
result_str, run_reward = self._run_tool(name, args, env)
|
| 293 |
+
if run_reward is not None:
|
| 294 |
+
total_reward += run_reward
|
| 295 |
+
steps += 1
|
| 296 |
+
protocol = dict(args)
|
| 297 |
+
if "protocol" in protocol and isinstance(protocol["protocol"], dict):
|
| 298 |
+
protocol = protocol["protocol"]
|
| 299 |
+
self.feedback_history.append({
|
| 300 |
+
"protocol": protocol,
|
| 301 |
+
"result": env._info().get("last_result", "fail"),
|
| 302 |
+
"reward": run_reward,
|
| 303 |
+
})
|
| 304 |
+
if verbose:
|
| 305 |
+
print(f" Trial {trial + 1}: {protocol} → {self.feedback_history[-1]['result']} (reward {run_reward:.1f})")
|
| 306 |
+
|
| 307 |
+
if getattr(env, "_terminated", False) or getattr(env, "_truncated", False):
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
if not (getattr(env, "_terminated", False) or getattr(env, "_truncated", False)) and info.get("best_result") != "success":
|
| 311 |
+
obs, rew, _, _, info = env.step(spec.action_finish())
|
| 312 |
+
total_reward += rew
|
| 313 |
+
steps += 1
|
| 314 |
+
|
| 315 |
+
return {
|
| 316 |
+
"reward": total_reward,
|
| 317 |
+
"success": info.get("best_result") == "success",
|
| 318 |
+
"partial": info.get("best_result") == "partial",
|
| 319 |
+
"minutes": info.get("elapsed_minutes", 0.0),
|
| 320 |
+
"cost": spec.initial_budget - info.get("remaining_budget", spec.initial_budget),
|
| 321 |
+
"steps": steps,
|
| 322 |
+
"num_protocols_tried": len(self.feedback_history),
|
| 323 |
+
}
|
agents/research_llm_agent.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Research LLM agent for LabEnv.
|
| 3 |
+
|
| 4 |
+
ReAct-style agent that: researches protocols (web_search), hypothesizes params,
|
| 5 |
+
runs experiments in LabEnv (run_experiment), and learns from results (analyze + update_knowledge).
|
| 6 |
+
Uses the same LabEnv action space; continuous hypotheses are mapped to nearest preset.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
except ImportError:
|
| 19 |
+
OpenAI = None # type: ignore[misc, assignment]
|
| 20 |
+
|
| 21 |
+
from lab_env.env import (
|
| 22 |
+
ACTION_FINISH,
|
| 23 |
+
ACTION_ORDER_BUFFER,
|
| 24 |
+
ACTION_ORDER_POLYMERASE,
|
| 25 |
+
ACTION_ORDER_TIPS,
|
| 26 |
+
ACTION_RUN_ASSAY,
|
| 27 |
+
ACTION_SETUP_START,
|
| 28 |
+
INITIAL_BUDGET,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Knowledge base (web_search tool)
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
def _load_protocols_knowledge() -> list[dict[str, Any]]:
|
| 37 |
+
path = Path(__file__).resolve().parent.parent / "knowledge" / "pcr_protocols.json"
|
| 38 |
+
if not path.exists():
|
| 39 |
+
return []
|
| 40 |
+
with open(path, encoding="utf-8") as f:
|
| 41 |
+
return json.load(f)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _web_search_impl(query: str, top_k: int = 3) -> str:
|
| 45 |
+
"""Search fake literature; return top_k relevant snippets."""
|
| 46 |
+
papers = _load_protocols_knowledge()
|
| 47 |
+
query_lower = query.lower()
|
| 48 |
+
scored = []
|
| 49 |
+
for p in papers:
|
| 50 |
+
text = f"{p.get('title','')} {p.get('abstract','')} {p.get('keywords','')} {p.get('recommendations','')}".lower()
|
| 51 |
+
score = sum(1 for w in query_lower.split() if len(w) > 2 and w in text)
|
| 52 |
+
if score > 0:
|
| 53 |
+
scored.append((score, p))
|
| 54 |
+
scored.sort(key=lambda x: -x[0])
|
| 55 |
+
if not scored:
|
| 56 |
+
return "No relevant literature found. Try general terms: annealing temperature, cycles, PCR protocol."
|
| 57 |
+
out = []
|
| 58 |
+
for _, p in scored[:top_k]:
|
| 59 |
+
out.append(f"[{p.get('title','')}] {p.get('recommendations','')}")
|
| 60 |
+
return "\n".join(out)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# Preset mapping: continuous (temp, cycles, ratio) -> nearest preset index
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
def _params_to_preset_index(
|
| 68 |
+
presets: list[dict[str, Any]],
|
| 69 |
+
temp: float,
|
| 70 |
+
cycles: int,
|
| 71 |
+
ratio: str,
|
| 72 |
+
) -> int:
|
| 73 |
+
"""Map (temp, cycles, ratio) to nearest preset index."""
|
| 74 |
+
ratio_clean = ratio.strip().lower()
|
| 75 |
+
if "conservative" in ratio_clean or ratio_clean == "conservative":
|
| 76 |
+
ratio_clean = "conservative"
|
| 77 |
+
else:
|
| 78 |
+
ratio_clean = "aggressive"
|
| 79 |
+
|
| 80 |
+
best_idx = 0
|
| 81 |
+
best_dist = float("inf")
|
| 82 |
+
for i, p in enumerate(presets):
|
| 83 |
+
dt = abs(float(p["temp"]) - temp)
|
| 84 |
+
dc = abs(int(p["cycles"]) - cycles)
|
| 85 |
+
dr = 0 if str(p.get("ratio", "")).lower() == ratio_clean else 10
|
| 86 |
+
dist = dt + dc * 0.1 + dr
|
| 87 |
+
if dist < best_dist:
|
| 88 |
+
best_dist = dist
|
| 89 |
+
best_idx = i
|
| 90 |
+
return best_idx
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Tool schemas (OpenEnv-style JSON for LLM)
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
TOOL_SCHEMAS = [
|
| 98 |
+
{
|
| 99 |
+
"type": "function",
|
| 100 |
+
"function": {
|
| 101 |
+
"name": "web_search",
|
| 102 |
+
"description": "Search scientific literature for PCR protocols, annealing temperature, cycle number, or reagent ratio recommendations.",
|
| 103 |
+
"parameters": {
|
| 104 |
+
"type": "object",
|
| 105 |
+
"properties": {
|
| 106 |
+
"query": {
|
| 107 |
+
"type": "string",
|
| 108 |
+
"description": "Search query, e.g. 'optimal annealing temp for AT-rich primers'",
|
| 109 |
+
},
|
| 110 |
+
},
|
| 111 |
+
"required": ["query"],
|
| 112 |
+
},
|
| 113 |
+
},
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"type": "function",
|
| 117 |
+
"function": {
|
| 118 |
+
"name": "run_experiment",
|
| 119 |
+
"description": "Run a PCR experiment with the given parameters. Temperature will be mapped to nearest preset (55, 65, or 72°C); cycles to 25 or 35; ratio to conservative or aggressive.",
|
| 120 |
+
"parameters": {
|
| 121 |
+
"type": "object",
|
| 122 |
+
"properties": {
|
| 123 |
+
"temp": {"type": "number", "description": "Annealing temperature in °C (e.g. 57.5)"},
|
| 124 |
+
"cycles": {"type": "integer", "description": "Number of PCR cycles (e.g. 32)"},
|
| 125 |
+
"ratio": {"type": "string", "description": "Reagent ratio: 'conservative' or 'aggressive'"},
|
| 126 |
+
},
|
| 127 |
+
"required": ["temp", "cycles", "ratio"],
|
| 128 |
+
},
|
| 129 |
+
},
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"type": "function",
|
| 133 |
+
"function": {
|
| 134 |
+
"name": "analyze_result",
|
| 135 |
+
"description": "Compare the current result to previous experiments and summarize what we learned.",
|
| 136 |
+
"parameters": {
|
| 137 |
+
"type": "object",
|
| 138 |
+
"properties": {
|
| 139 |
+
"current_result": {"type": "string", "description": "Last assay result: success, partial, or fail"},
|
| 140 |
+
"summary": {"type": "string", "description": "Brief analysis: what this result suggests for next parameters"},
|
| 141 |
+
},
|
| 142 |
+
"required": ["current_result", "summary"],
|
| 143 |
+
},
|
| 144 |
+
},
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"type": "function",
|
| 148 |
+
"function": {
|
| 149 |
+
"name": "update_knowledge",
|
| 150 |
+
"description": "Record what we learned about optimal parameters (temperature range, cycle range, or notes).",
|
| 151 |
+
"parameters": {
|
| 152 |
+
"type": "object",
|
| 153 |
+
"properties": {
|
| 154 |
+
"temp_range": {
|
| 155 |
+
"type": "array",
|
| 156 |
+
"items": {"type": "number"},
|
| 157 |
+
"description": "Optional [low, high] °C range for optimal annealing",
|
| 158 |
+
},
|
| 159 |
+
"cycle_range": {
|
| 160 |
+
"type": "array",
|
| 161 |
+
"items": {"type": "integer"},
|
| 162 |
+
"description": "Optional [low, high] cycle range",
|
| 163 |
+
},
|
| 164 |
+
"notes": {"type": "string", "description": "Optional text note about what we learned"},
|
| 165 |
+
},
|
| 166 |
+
},
|
| 167 |
+
},
|
| 168 |
+
},
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
# Research LLM Agent
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
class ResearchLLMAgent:
|
| 177 |
+
"""LLM agent that researches, hypothesizes, runs experiments, and learns."""
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
model: str = "gpt-4o-mini",
|
| 182 |
+
max_trials: int = 5,
|
| 183 |
+
knowledge_path: str | None = None,
|
| 184 |
+
) -> None:
|
| 185 |
+
if OpenAI is None:
|
| 186 |
+
raise ImportError("Optional dependency 'openai' is required. Install with: pip install openai")
|
| 187 |
+
self.model = model
|
| 188 |
+
self.max_trials = max_trials
|
| 189 |
+
self._client: OpenAI | None = None
|
| 190 |
+
self.knowledge_path = knowledge_path
|
| 191 |
+
|
| 192 |
+
self.knowledge: dict[str, Any] = {
|
| 193 |
+
"temp_range": [50.0, 70.0],
|
| 194 |
+
"cycle_range": [20, 40],
|
| 195 |
+
"past_experiments": [],
|
| 196 |
+
"notes": "",
|
| 197 |
+
}
|
| 198 |
+
self._episode_buffer: list[dict[str, Any]] = [] # last 5 episodes for context
|
| 199 |
+
self._last_research: str = ""
|
| 200 |
+
self._last_hypothesis: dict[str, Any] = {}
|
| 201 |
+
self._last_result: str = ""
|
| 202 |
+
self._last_params: dict[str, Any] = {}
|
| 203 |
+
self._last_run_reward: float = 0.0
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def client(self) -> OpenAI:
|
| 207 |
+
if self._client is None:
|
| 208 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 209 |
+
if not api_key:
|
| 210 |
+
raise RuntimeError("OPENAI_API_KEY environment variable is required for ResearchLLMAgent")
|
| 211 |
+
self._client = OpenAI(api_key=api_key)
|
| 212 |
+
return self._client
|
| 213 |
+
|
| 214 |
+
def _run_tool(self, name: str, arguments: dict[str, Any], env: Any) -> str:
|
| 215 |
+
"""Execute one tool and return result string."""
|
| 216 |
+
if name == "web_search":
|
| 217 |
+
q = arguments.get("query", "")
|
| 218 |
+
result = _web_search_impl(q)
|
| 219 |
+
self._last_research = result
|
| 220 |
+
return result
|
| 221 |
+
|
| 222 |
+
if name == "run_experiment":
|
| 223 |
+
temp = float(arguments.get("temp", 60))
|
| 224 |
+
cycles = int(arguments.get("cycles", 30))
|
| 225 |
+
ratio = str(arguments.get("ratio", "conservative"))
|
| 226 |
+
presets = env.spec.presets
|
| 227 |
+
idx = _params_to_preset_index(presets, temp, cycles, ratio)
|
| 228 |
+
preset = presets[idx]
|
| 229 |
+
self._last_params = {"temp": preset["temp"], "cycles": preset["cycles"], "ratio": preset["ratio"]}
|
| 230 |
+
self._last_hypothesis = {"temp": temp, "cycles": cycles, "ratio": ratio}
|
| 231 |
+
|
| 232 |
+
setup_action = ACTION_SETUP_START + idx
|
| 233 |
+
obs, r1, term, trunc, info = env.step(setup_action)
|
| 234 |
+
if term or trunc:
|
| 235 |
+
self._last_run_reward = r1
|
| 236 |
+
return f"Environment ended after setup. Result: {info.get('last_result', 'none')}"
|
| 237 |
+
obs, r2, term, trunc, info = env.step(ACTION_RUN_ASSAY)
|
| 238 |
+
result = info.get("last_result", "fail")
|
| 239 |
+
self._last_result = result
|
| 240 |
+
self._last_run_reward = r1 + r2
|
| 241 |
+
return f"Ran preset {preset}. Result: {result}. Reward: {r1 + r2:.1f}"
|
| 242 |
+
|
| 243 |
+
if name == "analyze_result":
|
| 244 |
+
current = arguments.get("current_result", "")
|
| 245 |
+
summary = arguments.get("summary", "")
|
| 246 |
+
return f"Analysis: {summary}"
|
| 247 |
+
|
| 248 |
+
if name == "update_knowledge":
|
| 249 |
+
if "temp_range" in arguments:
|
| 250 |
+
self.knowledge["temp_range"] = arguments["temp_range"]
|
| 251 |
+
if "cycle_range" in arguments:
|
| 252 |
+
self.knowledge["cycle_range"] = arguments["cycle_range"]
|
| 253 |
+
if "notes" in arguments:
|
| 254 |
+
self.knowledge["notes"] = arguments["notes"]
|
| 255 |
+
return "Knowledge updated."
|
| 256 |
+
|
| 257 |
+
return "Unknown tool."
|
| 258 |
+
|
| 259 |
+
def _inventory_low(self, obs: Any) -> bool:
|
| 260 |
+
return float(min(obs[3], obs[4], obs[5], obs[6])) < 0.08
|
| 261 |
+
|
| 262 |
+
def _order_reagents(self, env: Any, obs: Any, info: dict, steps: int) -> tuple[Any, float, dict, int]:
|
| 263 |
+
total_rew = 0.0
|
| 264 |
+
for action in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
|
| 265 |
+
obs, rew, term, trunc, info = env.step(action)
|
| 266 |
+
total_rew += rew
|
| 267 |
+
steps += 1
|
| 268 |
+
if term or trunc:
|
| 269 |
+
break
|
| 270 |
+
return obs, total_rew, info, steps
|
| 271 |
+
|
| 272 |
+
def run_episode(
|
| 273 |
+
self,
|
| 274 |
+
env: Any,
|
| 275 |
+
seed: int,
|
| 276 |
+
*,
|
| 277 |
+
verbose: bool = False,
|
| 278 |
+
episode_callback: list[dict[str, Any]] | None = None,
|
| 279 |
+
) -> dict[str, Any]:
|
| 280 |
+
"""Run one episode: Research -> Hypothesize -> Execute -> Learn per trial."""
|
| 281 |
+
obs, info = env.reset(seed=seed)
|
| 282 |
+
total_reward = 0.0
|
| 283 |
+
steps = 0
|
| 284 |
+
presets = env.spec.presets
|
| 285 |
+
last_params_used: dict[str, Any] = {}
|
| 286 |
+
|
| 287 |
+
for trial in range(self.max_trials):
|
| 288 |
+
if info.get("best_result") == "success":
|
| 289 |
+
obs, rew, _, _, info = env.step(ACTION_FINISH)
|
| 290 |
+
total_reward += rew
|
| 291 |
+
steps += 1
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
if self._inventory_low(obs):
|
| 295 |
+
obs, rew, info, steps = self._order_reagents(env, obs, info, steps)
|
| 296 |
+
total_reward += rew
|
| 297 |
+
if getattr(env, "_terminated", False) or getattr(env, "_truncated", False):
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
# Stage 1: Research
|
| 301 |
+
research_query = (
|
| 302 |
+
"optimal annealing temperature and cycle number for PCR"
|
| 303 |
+
if trial == 0
|
| 304 |
+
else f"PCR protocol improvement: last result was {info.get('last_result','none')} with params {last_params_used}"
|
| 305 |
+
)
|
| 306 |
+
research_text = _web_search_impl(research_query)
|
| 307 |
+
self._last_research = research_text
|
| 308 |
+
|
| 309 |
+
# Stage 2: Hypothesize (LLM chooses temp, cycles, ratio)
|
| 310 |
+
state_desc = (
|
| 311 |
+
f"Last result: {info.get('last_result','none')}. "
|
| 312 |
+
f"Best so far: {info.get('best_result','none')}. "
|
| 313 |
+
f"Inventory: {info.get('inventory',{})}. "
|
| 314 |
+
f"Budget: ${info.get('remaining_budget',0):.0f}. "
|
| 315 |
+
f"Knowledge: temp_range={self.knowledge['temp_range']}, cycle_range={self.knowledge['cycle_range']}. "
|
| 316 |
+
f"Past experiments this episode: {self.knowledge['past_experiments'][-5:]}."
|
| 317 |
+
)
|
| 318 |
+
if last_params_used:
|
| 319 |
+
state_desc += f" Last params used: {last_params_used}."
|
| 320 |
+
|
| 321 |
+
sys_msg = (
|
| 322 |
+
"You are a lab scientist optimizing a PCR protocol. You have access to tools: "
|
| 323 |
+
"web_search (already done: use the research below), run_experiment (use this to try one protocol). "
|
| 324 |
+
"Output exactly one run_experiment call with temp (number °C), cycles (integer), ratio ('conservative' or 'aggressive'). "
|
| 325 |
+
"Use the research and past results to pick the best next parameters. "
|
| 326 |
+
"Available presets in the lab are temp in [55, 65, 72], cycles in [25, 35], ratio conservative or aggressive; "
|
| 327 |
+
"your values will be mapped to the nearest preset."
|
| 328 |
+
)
|
| 329 |
+
user_msg = (
|
| 330 |
+
f"Research:\n{research_text}\n\n"
|
| 331 |
+
f"Current state: {state_desc}\n\n"
|
| 332 |
+
"Call run_experiment with your chosen temp, cycles, and ratio (one call only)."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
messages = [
|
| 336 |
+
{"role": "system", "content": sys_msg},
|
| 337 |
+
{"role": "user", "content": user_msg},
|
| 338 |
+
]
|
| 339 |
+
response = self.client.chat.completions.create(
|
| 340 |
+
model=self.model,
|
| 341 |
+
messages=messages,
|
| 342 |
+
tools=TOOL_SCHEMAS,
|
| 343 |
+
tool_choice={"type": "function", "function": {"name": "run_experiment"}},
|
| 344 |
+
)
|
| 345 |
+
choice = response.choices[0]
|
| 346 |
+
if choice.message.tool_calls:
|
| 347 |
+
tc = choice.message.tool_calls[0]
|
| 348 |
+
name = tc.function.name
|
| 349 |
+
args = json.loads(tc.function.arguments or "{}")
|
| 350 |
+
result_str = self._run_tool(name, args, env)
|
| 351 |
+
total_reward += getattr(self, "_last_run_reward", 0.0)
|
| 352 |
+
last_params_used = dict(self._last_params)
|
| 353 |
+
steps += 2 # setup + run_assay
|
| 354 |
+
obs = env._obs()
|
| 355 |
+
info = env._info()
|
| 356 |
+
|
| 357 |
+
if verbose:
|
| 358 |
+
print(f" Trial {trial+1}: hypothesis {self._last_hypothesis} -> preset {self._last_params} -> {self._last_result}")
|
| 359 |
+
|
| 360 |
+
# Stage 4: Learn (update knowledge from this result)
|
| 361 |
+
self.knowledge["past_experiments"].append(
|
| 362 |
+
(dict(self._last_params), self._last_result, 1.0 if self._last_result == "success" else (0.5 if self._last_result == "partial" else 0.0))
|
| 363 |
+
)
|
| 364 |
+
if len(self.knowledge["past_experiments"]) > 20:
|
| 365 |
+
self.knowledge["past_experiments"] = self.knowledge["past_experiments"][-20:]
|
| 366 |
+
|
| 367 |
+
# Narrow knowledge range by heuristic
|
| 368 |
+
if self._last_result == "success":
|
| 369 |
+
t = self._last_params.get("temp", 60)
|
| 370 |
+
self.knowledge["temp_range"] = [t - 2, t + 2]
|
| 371 |
+
c = self._last_params.get("cycles", 30)
|
| 372 |
+
self.knowledge["cycle_range"] = [max(20, c - 2), min(40, c + 2)]
|
| 373 |
+
elif self._last_result == "partial":
|
| 374 |
+
t = self._last_params.get("temp", 60)
|
| 375 |
+
self.knowledge["temp_range"] = [
|
| 376 |
+
min(self.knowledge["temp_range"][0], t - 1),
|
| 377 |
+
max(self.knowledge["temp_range"][1], t + 1),
|
| 378 |
+
]
|
| 379 |
+
|
| 380 |
+
if episode_callback is not None:
|
| 381 |
+
episode_callback.append({
|
| 382 |
+
"trial": trial + 1,
|
| 383 |
+
"research": self._last_research,
|
| 384 |
+
"hypothesis": self._last_hypothesis,
|
| 385 |
+
"params_used": self._last_params,
|
| 386 |
+
"result": self._last_result,
|
| 387 |
+
})
|
| 388 |
+
else:
|
| 389 |
+
break
|
| 390 |
+
|
| 391 |
+
if getattr(env, "_terminated", False) or getattr(env, "_truncated", False):
|
| 392 |
+
break
|
| 393 |
+
|
| 394 |
+
if not (getattr(env, "_terminated", False) or getattr(env, "_truncated", False)) and info.get("best_result") != "success":
|
| 395 |
+
obs, rew, _, _, info = env.step(ACTION_FINISH)
|
| 396 |
+
total_reward += rew
|
| 397 |
+
steps += 1
|
| 398 |
+
|
| 399 |
+
return {
|
| 400 |
+
"reward": total_reward,
|
| 401 |
+
"success": info.get("best_result") == "success",
|
| 402 |
+
"partial": info.get("best_result") == "partial",
|
| 403 |
+
"minutes": info.get("elapsed_minutes", 0.0),
|
| 404 |
+
"cost": INITIAL_BUDGET - info.get("remaining_budget", 500.0),
|
| 405 |
+
"steps": steps,
|
| 406 |
+
}
|
agents/rl_agent.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
REINFORCE policy-gradient agent for LabEnv.
|
| 3 |
+
|
| 4 |
+
Rather than trying to learn the full 18-action sequential policy (which
|
| 5 |
+
requires discovering the setup->run->finish sequence from scratch), this
|
| 6 |
+
agent decomposes the problem:
|
| 7 |
+
|
| 8 |
+
- **Learned part** — a small MLP policy that maps the current observation to
|
| 9 |
+
a distribution over the 12 protocol presets. Trained with REINFORCE.
|
| 10 |
+
- **Scripted part** — episode logic that executes setup(preset) -> run_assay,
|
| 11 |
+
checks inventory, orders reagents when needed, and finishes after a
|
| 12 |
+
configurable number of trials or when a success is achieved.
|
| 13 |
+
|
| 14 |
+
This decomposition makes training tractable in minutes on a CPU while still
|
| 15 |
+
demonstrating clear improvement over the random-preset naive baseline.
|
| 16 |
+
|
| 17 |
+
The policy network is pure PyTorch — directly compatible with Hugging Face TRL,
|
| 18 |
+
Lightning AI, or any custom training loop.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from torch.distributions import Categorical
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
from lab_env.env import (
|
| 31 |
+
ACTION_FINISH,
|
| 32 |
+
ACTION_ORDER_BUFFER,
|
| 33 |
+
ACTION_ORDER_POLYMERASE,
|
| 34 |
+
ACTION_ORDER_TIPS,
|
| 35 |
+
ACTION_RUN_ASSAY,
|
| 36 |
+
ACTION_SETUP_START,
|
| 37 |
+
NUM_PRESETS,
|
| 38 |
+
OBS_DIM,
|
| 39 |
+
)
|
| 40 |
+
from lab_env.spec import ExperimentSpec
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PolicyNetwork(nn.Module):
|
| 44 |
+
"""Two-hidden-layer MLP: obs -> preset logits (12 outputs)."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
obs_dim: int = OBS_DIM,
|
| 49 |
+
hidden: int = 64,
|
| 50 |
+
n_presets: int = NUM_PRESETS,
|
| 51 |
+
) -> None:
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.fc1 = nn.Linear(obs_dim, hidden)
|
| 54 |
+
self.fc2 = nn.Linear(hidden, hidden)
|
| 55 |
+
self.fc3 = nn.Linear(hidden, n_presets)
|
| 56 |
+
|
| 57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
x = F.relu(self.fc1(x))
|
| 59 |
+
x = F.relu(self.fc2(x))
|
| 60 |
+
return self.fc3(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ReinforceAgent:
|
| 64 |
+
"""REINFORCE agent that learns which preset to pick each trial.
|
| 65 |
+
|
| 66 |
+
The episode loop (setup -> run -> order-if-needed -> maybe-finish) is
|
| 67 |
+
scripted. Only the preset selection is learned.
|
| 68 |
+
|
| 69 |
+
Pass spec=... to train on a different protocol set (e.g. ELISA or a
|
| 70 |
+
custom spec). Otherwise uses default PCR (12 presets, 14-dim obs).
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
lr: float = 3e-3,
|
| 76 |
+
gamma: float = 0.99,
|
| 77 |
+
entropy_coef: float = 0.02,
|
| 78 |
+
max_trials: int = 4,
|
| 79 |
+
device: str = "cpu",
|
| 80 |
+
spec: ExperimentSpec | None = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
self.gamma = gamma
|
| 83 |
+
self.entropy_coef = entropy_coef
|
| 84 |
+
self.max_trials = max_trials
|
| 85 |
+
self.device = torch.device(device)
|
| 86 |
+
self.spec = spec
|
| 87 |
+
|
| 88 |
+
obs_dim = (spec.obs_dim if spec else OBS_DIM)
|
| 89 |
+
n_presets = (spec.num_presets if spec else NUM_PRESETS)
|
| 90 |
+
self.policy = PolicyNetwork(obs_dim=obs_dim, n_presets=n_presets).to(self.device)
|
| 91 |
+
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
|
| 92 |
+
|
| 93 |
+
self._log_probs: list[torch.Tensor] = []
|
| 94 |
+
self._entropies: list[torch.Tensor] = []
|
| 95 |
+
self._rewards: list[float] = []
|
| 96 |
+
|
| 97 |
+
self._baseline: float = 0.0
|
| 98 |
+
self._baseline_count: int = 0
|
| 99 |
+
|
| 100 |
+
# ------------------------------------------------------------------
|
| 101 |
+
# Preset selection (the learned part)
|
| 102 |
+
# ------------------------------------------------------------------
|
| 103 |
+
|
| 104 |
+
def _select_preset(
|
| 105 |
+
self, obs: np.ndarray, *, deterministic: bool = False
|
| 106 |
+
) -> int:
|
| 107 |
+
obs_t = torch.as_tensor(
|
| 108 |
+
obs, dtype=torch.float32, device=self.device
|
| 109 |
+
).unsqueeze(0)
|
| 110 |
+
logits = self.policy(obs_t)
|
| 111 |
+
|
| 112 |
+
if deterministic:
|
| 113 |
+
logits = logits * 5.0
|
| 114 |
+
|
| 115 |
+
dist = Categorical(logits=logits)
|
| 116 |
+
action = dist.sample()
|
| 117 |
+
|
| 118 |
+
self._log_probs.append(dist.log_prob(action))
|
| 119 |
+
self._entropies.append(dist.entropy())
|
| 120 |
+
return int(action.item())
|
| 121 |
+
|
| 122 |
+
# ------------------------------------------------------------------
|
| 123 |
+
# Full episode runner (scripted loop + learned preset choice)
|
| 124 |
+
# ------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
def run_episode(
|
| 127 |
+
self,
|
| 128 |
+
env: object,
|
| 129 |
+
seed: int,
|
| 130 |
+
*,
|
| 131 |
+
train: bool = True,
|
| 132 |
+
) -> dict[str, object]:
|
| 133 |
+
"""Run a complete episode, returning metrics dict.
|
| 134 |
+
|
| 135 |
+
*env* must be a :class:`LabEnv` (or anything with the same
|
| 136 |
+
``reset`` / ``step`` interface).
|
| 137 |
+
"""
|
| 138 |
+
obs, info = env.reset(seed=seed) # type: ignore[union-attr]
|
| 139 |
+
total_reward = 0.0
|
| 140 |
+
steps = 0
|
| 141 |
+
trial_rewards: list[float] = []
|
| 142 |
+
finish_action = getattr(env.spec, "action_finish", lambda: ACTION_FINISH)()
|
| 143 |
+
|
| 144 |
+
for trial in range(self.max_trials):
|
| 145 |
+
if self._episode_done(info):
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
if self._inventory_low(obs, env):
|
| 149 |
+
obs, rew, info, steps = self._order_reagents(env, obs, info, steps)
|
| 150 |
+
total_reward += rew
|
| 151 |
+
if self._episode_done(info):
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
preset = self._select_preset(obs, deterministic=not train)
|
| 155 |
+
|
| 156 |
+
setup_start = getattr(env.spec, "action_setup_start", lambda: ACTION_SETUP_START)()
|
| 157 |
+
obs, rew_setup, term, trunc, info = env.step(setup_start + preset) # type: ignore[union-attr]
|
| 158 |
+
total_reward += rew_setup
|
| 159 |
+
steps += 1
|
| 160 |
+
if term or trunc:
|
| 161 |
+
trial_rewards.append(rew_setup)
|
| 162 |
+
self._rewards.append(rew_setup)
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
run_assay = getattr(env.spec, "action_run_assay", lambda: ACTION_RUN_ASSAY)()
|
| 166 |
+
obs, rew_run, done, truncated, info = env.step(run_assay) # type: ignore[union-attr]
|
| 167 |
+
total_reward += rew_run
|
| 168 |
+
steps += 1
|
| 169 |
+
trial_rewards.append(rew_run)
|
| 170 |
+
self._rewards.append(rew_run)
|
| 171 |
+
|
| 172 |
+
if done or truncated:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
if info.get("best_result") == "success":
|
| 176 |
+
obs, rew_finish, _, _, info = env.step(finish_action) # type: ignore[union-attr]
|
| 177 |
+
total_reward += rew_finish
|
| 178 |
+
steps += 1
|
| 179 |
+
self._rewards[-1] += rew_finish
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
else:
|
| 183 |
+
if not self._episode_done(info):
|
| 184 |
+
obs, rew_finish, _, _, info = env.step(finish_action) # type: ignore[union-attr]
|
| 185 |
+
total_reward += rew_finish
|
| 186 |
+
steps += 1
|
| 187 |
+
if self._rewards:
|
| 188 |
+
self._rewards[-1] += rew_finish
|
| 189 |
+
|
| 190 |
+
loss = self.update() if train else 0.0
|
| 191 |
+
|
| 192 |
+
return {
|
| 193 |
+
"reward": total_reward,
|
| 194 |
+
"success": info.get("best_result") == "success",
|
| 195 |
+
"partial": info.get("best_result") == "partial",
|
| 196 |
+
"minutes": info.get("elapsed_minutes", 0.0),
|
| 197 |
+
"cost": 500.0 - info.get("remaining_budget", 500.0),
|
| 198 |
+
"steps": steps,
|
| 199 |
+
"loss": loss,
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
# ------------------------------------------------------------------
|
| 203 |
+
# Helpers for the scripted loop
|
| 204 |
+
# ------------------------------------------------------------------
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def _episode_done(info: dict) -> bool:
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def _inventory_low(obs: np.ndarray, env: object | None = None) -> bool:
|
| 212 |
+
n_inv = 4
|
| 213 |
+
if env is not None and getattr(env, "spec", None) is not None:
|
| 214 |
+
n_inv = len(env.spec.inventory_items)
|
| 215 |
+
inv_slice = obs[3 : 3 + n_inv]
|
| 216 |
+
return float(min(inv_slice)) < 0.08 if len(inv_slice) else False
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def _order_reagents(
|
| 220 |
+
env: object, obs: np.ndarray, info: dict, steps: int
|
| 221 |
+
) -> tuple[np.ndarray, float, dict, int]:
|
| 222 |
+
total_rew = 0.0
|
| 223 |
+
spec = getattr(env, "spec", None)
|
| 224 |
+
if spec is not None:
|
| 225 |
+
order_actions = range(spec.action_order_start(), spec.action_order_end())
|
| 226 |
+
else:
|
| 227 |
+
order_actions = (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE)
|
| 228 |
+
for action in order_actions:
|
| 229 |
+
obs, rew, done, truncated, info = env.step(action) # type: ignore[union-attr]
|
| 230 |
+
total_rew += rew
|
| 231 |
+
steps += 1
|
| 232 |
+
if done or truncated:
|
| 233 |
+
break
|
| 234 |
+
return obs, total_rew, info, steps
|
| 235 |
+
|
| 236 |
+
# ------------------------------------------------------------------
|
| 237 |
+
# Learning
|
| 238 |
+
# ------------------------------------------------------------------
|
| 239 |
+
|
| 240 |
+
def reset(self) -> None:
|
| 241 |
+
self._log_probs.clear()
|
| 242 |
+
self._entropies.clear()
|
| 243 |
+
self._rewards.clear()
|
| 244 |
+
|
| 245 |
+
def update(self) -> float:
|
| 246 |
+
"""REINFORCE update over the collected episode. Returns loss."""
|
| 247 |
+
if not self._rewards or not self._log_probs:
|
| 248 |
+
return 0.0
|
| 249 |
+
|
| 250 |
+
n = min(len(self._rewards), len(self._log_probs))
|
| 251 |
+
rewards = self._rewards[:n]
|
| 252 |
+
log_probs = self._log_probs[:n]
|
| 253 |
+
entropies = self._entropies[:n]
|
| 254 |
+
|
| 255 |
+
returns = self._compute_returns(rewards)
|
| 256 |
+
self._update_baseline(returns)
|
| 257 |
+
|
| 258 |
+
returns_t = torch.as_tensor(returns, dtype=torch.float32, device=self.device)
|
| 259 |
+
advantages = returns_t - self._baseline
|
| 260 |
+
|
| 261 |
+
policy_loss = torch.zeros(1, device=self.device)
|
| 262 |
+
entropy_bonus = torch.zeros(1, device=self.device)
|
| 263 |
+
|
| 264 |
+
for lp, adv, ent in zip(log_probs, advantages, entropies):
|
| 265 |
+
policy_loss -= lp * adv.detach()
|
| 266 |
+
entropy_bonus += ent
|
| 267 |
+
|
| 268 |
+
loss = policy_loss - self.entropy_coef * entropy_bonus
|
| 269 |
+
|
| 270 |
+
self.optimizer.zero_grad()
|
| 271 |
+
loss.backward()
|
| 272 |
+
nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=1.0)
|
| 273 |
+
self.optimizer.step()
|
| 274 |
+
|
| 275 |
+
self.reset()
|
| 276 |
+
return float(loss.item())
|
| 277 |
+
|
| 278 |
+
def _compute_returns(self, rewards: list[float]) -> list[float]:
|
| 279 |
+
returns: list[float] = []
|
| 280 |
+
g = 0.0
|
| 281 |
+
for r in reversed(rewards):
|
| 282 |
+
g = r + self.gamma * g
|
| 283 |
+
returns.insert(0, g)
|
| 284 |
+
return returns
|
| 285 |
+
|
| 286 |
+
def _update_baseline(self, returns: list[float]) -> None:
|
| 287 |
+
episode_return = returns[0] if returns else 0.0
|
| 288 |
+
self._baseline_count += 1
|
| 289 |
+
self._baseline += (episode_return - self._baseline) / self._baseline_count
|
| 290 |
+
|
| 291 |
+
def save(self, path: str) -> None:
|
| 292 |
+
torch.save(
|
| 293 |
+
{
|
| 294 |
+
"policy_state_dict": self.policy.state_dict(),
|
| 295 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 296 |
+
"baseline": self._baseline,
|
| 297 |
+
"baseline_count": self._baseline_count,
|
| 298 |
+
},
|
| 299 |
+
path,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def load(self, path: str) -> None:
|
| 303 |
+
checkpoint = torch.load(path, map_location=self.device, weights_only=True)
|
| 304 |
+
self.policy.load_state_dict(checkpoint["policy_state_dict"])
|
| 305 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 306 |
+
self._baseline = checkpoint["baseline"]
|
| 307 |
+
self._baseline_count = checkpoint["baseline_count"]
|
demo/streamlit_app.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit demo: Self-Improving Lab Scientist — Research flow and 3-agent comparison.
|
| 3 |
+
Run with: streamlit run demo/streamlit_app.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 12 |
+
|
| 13 |
+
import streamlit as st
|
| 14 |
+
|
| 15 |
+
from lab_env.env import LabEnv, INITIAL_BUDGET
|
| 16 |
+
from agents.naive_agent import NaiveAgent
|
| 17 |
+
from agents.rl_agent import ReinforceAgent
|
| 18 |
+
from agents.research_llm_agent import ResearchLLMAgent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
st.set_page_config(page_title="SimLab Research Agent", layout="wide")
|
| 22 |
+
|
| 23 |
+
st.title("Self-Improving Lab Scientist")
|
| 24 |
+
st.markdown("**Research Scientist Agent** — Research → Hypothesize → Experiment → Learn")
|
| 25 |
+
|
| 26 |
+
# Session state
|
| 27 |
+
if "episode_history" not in st.session_state:
|
| 28 |
+
st.session_state.episode_history = []
|
| 29 |
+
if "comparison_table" not in st.session_state:
|
| 30 |
+
st.session_state.comparison_table = None
|
| 31 |
+
if "current_run_steps" not in st.session_state:
|
| 32 |
+
st.session_state.current_run_steps = []
|
| 33 |
+
if "last_knowledge" not in st.session_state:
|
| 34 |
+
st.session_state.last_knowledge = None
|
| 35 |
+
|
| 36 |
+
# Sidebar controls
|
| 37 |
+
with st.sidebar:
|
| 38 |
+
st.header("Controls")
|
| 39 |
+
num_episodes = st.slider("Episodes to run", 1, 10, 3)
|
| 40 |
+
seed = st.number_input("Seed", value=42, min_value=0, step=1)
|
| 41 |
+
max_trials = st.slider("Max trials per episode", 2, 8, 5)
|
| 42 |
+
run_research = st.button("Run research agent episodes")
|
| 43 |
+
st.divider()
|
| 44 |
+
st.header("Benchmark")
|
| 45 |
+
compare_episodes = st.slider("Eval episodes for comparison", 10, 100, 30)
|
| 46 |
+
run_compare = st.button("Run 3-agent comparison")
|
| 47 |
+
|
| 48 |
+
# Main: Research flow
|
| 49 |
+
st.header("Research flow")
|
| 50 |
+
|
| 51 |
+
if run_research:
|
| 52 |
+
try:
|
| 53 |
+
env = LabEnv()
|
| 54 |
+
agent = ResearchLLMAgent(max_trials=max_trials)
|
| 55 |
+
progress = st.progress(0, text="Running episodes...")
|
| 56 |
+
for ep in range(1, num_episodes + 1):
|
| 57 |
+
progress.progress(ep / num_episodes, text=f"Episode {ep}/{num_episodes}...")
|
| 58 |
+
callback: list[dict] = []
|
| 59 |
+
result = agent.run_episode(env, seed=seed + ep, episode_callback=callback)
|
| 60 |
+
st.session_state.last_knowledge = dict(agent.knowledge)
|
| 61 |
+
st.session_state.episode_history.append({
|
| 62 |
+
"episode": ep,
|
| 63 |
+
"success": result["success"],
|
| 64 |
+
"partial": result["partial"],
|
| 65 |
+
"reward": result["reward"],
|
| 66 |
+
"cost": result["cost"],
|
| 67 |
+
"steps": result["steps"],
|
| 68 |
+
"callback": callback,
|
| 69 |
+
})
|
| 70 |
+
env.close()
|
| 71 |
+
progress.empty()
|
| 72 |
+
except Exception as e:
|
| 73 |
+
st.error(f"Research agent failed: {e}. Set OPENAI_API_KEY for LLM agent.")
|
| 74 |
+
else:
|
| 75 |
+
if not st.session_state.episode_history:
|
| 76 |
+
st.info("Click **Run research agent episodes** in the sidebar to start.")
|
| 77 |
+
|
| 78 |
+
# Show last run / history
|
| 79 |
+
if st.session_state.episode_history:
|
| 80 |
+
st.subheader("Learning progress")
|
| 81 |
+
cols = st.columns(min(10, len(st.session_state.episode_history)))
|
| 82 |
+
for i, rec in enumerate(st.session_state.episode_history[-10:]):
|
| 83 |
+
with cols[i % len(cols)]:
|
| 84 |
+
label = "SUCCESS" if rec["success"] else "partial" if rec["partial"] else "fail"
|
| 85 |
+
pct = "94%" if rec["success"] else "73%" if rec["partial"] else "12%"
|
| 86 |
+
st.metric(f"Ep{rec['episode']}", label, pct)
|
| 87 |
+
st.divider()
|
| 88 |
+
|
| 89 |
+
# Show latest episode detail (stage cards)
|
| 90 |
+
latest = st.session_state.episode_history[-1]
|
| 91 |
+
st.subheader(f"Episode {latest['episode']} — Research → Experiment → Learn")
|
| 92 |
+
if latest.get("callback"):
|
| 93 |
+
for step in latest["callback"]:
|
| 94 |
+
with st.expander(f"Trial {step['trial']}: {step['result']}", expanded=(step['trial'] == latest['callback'][-1]['trial'])):
|
| 95 |
+
st.markdown("**Research**")
|
| 96 |
+
st.caption(step.get("research", "")[:400] + "..." if len(step.get("research", "")) > 400 else step.get("research", ""))
|
| 97 |
+
st.markdown("**Hypothesis**")
|
| 98 |
+
st.code(step.get("hypothesis", {}))
|
| 99 |
+
st.markdown("**Experiment**")
|
| 100 |
+
st.write(f"Ran preset: {step.get('params_used', {})} → **{step.get('result', '')}**")
|
| 101 |
+
|
| 102 |
+
st.markdown("**Knowledge**")
|
| 103 |
+
if st.session_state.last_knowledge:
|
| 104 |
+
k = st.session_state.last_knowledge
|
| 105 |
+
st.write(f"temp_range = {k.get('temp_range', [])} °C, cycle_range = {k.get('cycle_range', [])}")
|
| 106 |
+
past = k.get("past_experiments", [])[-5:]
|
| 107 |
+
if past:
|
| 108 |
+
st.caption("Last experiments: " + ", ".join(f"{p[0]}→{p[1]}" for p in past))
|
| 109 |
+
else:
|
| 110 |
+
st.caption("Run research episodes to see updated knowledge.")
|
| 111 |
+
|
| 112 |
+
# 3-agent comparison
|
| 113 |
+
st.header("3-agent comparison")
|
| 114 |
+
if run_compare:
|
| 115 |
+
with st.status("Running Naive, RL, and Research LLM agents...", expanded=True) as status:
|
| 116 |
+
try:
|
| 117 |
+
env = LabEnv()
|
| 118 |
+
eval_seed_base = 100_000 + seed
|
| 119 |
+
|
| 120 |
+
st.write("Naive agent...")
|
| 121 |
+
naive_agent = NaiveAgent(num_trials=3, seed=seed)
|
| 122 |
+
naive_results = []
|
| 123 |
+
for i in range(compare_episodes):
|
| 124 |
+
obs, info = env.reset(seed=eval_seed_base + i)
|
| 125 |
+
naive_agent.reset()
|
| 126 |
+
steps = 0
|
| 127 |
+
while True:
|
| 128 |
+
action = naive_agent.select_action(obs)
|
| 129 |
+
obs, reward, term, trunc, info = env.step(action)
|
| 130 |
+
steps += 1
|
| 131 |
+
if term or trunc:
|
| 132 |
+
break
|
| 133 |
+
naive_results.append({
|
| 134 |
+
"success": info["best_result"] == "success",
|
| 135 |
+
"partial": info["best_result"] == "partial",
|
| 136 |
+
"cost": INITIAL_BUDGET - info["remaining_budget"],
|
| 137 |
+
"steps": steps,
|
| 138 |
+
})
|
| 139 |
+
|
| 140 |
+
st.write("Training and evaluating RL agent...")
|
| 141 |
+
rl_agent = ReinforceAgent(max_trials=max_trials)
|
| 142 |
+
for ep in range(500):
|
| 143 |
+
rl_agent.run_episode(env, seed=seed + ep, train=True)
|
| 144 |
+
rl_results = [rl_agent.run_episode(env, seed=eval_seed_base + i, train=False) for i in range(compare_episodes)]
|
| 145 |
+
|
| 146 |
+
st.write("Research LLM agent...")
|
| 147 |
+
llm_agent = ResearchLLMAgent(max_trials=max_trials)
|
| 148 |
+
llm_results = [llm_agent.run_episode(env, seed=eval_seed_base + i) for i in range(compare_episodes)]
|
| 149 |
+
|
| 150 |
+
env.close()
|
| 151 |
+
|
| 152 |
+
def agg(results: list[dict]) -> dict:
|
| 153 |
+
n = len(results)
|
| 154 |
+
succ = sum(r["success"] for r in results) / n
|
| 155 |
+
steps_succ = [r["steps"] for r in results if r["success"]]
|
| 156 |
+
exp_to_succ = sum(steps_succ) / len(steps_succ) if steps_succ else 0
|
| 157 |
+
cost = sum(r["cost"] for r in results) / n
|
| 158 |
+
return {"success_rate": succ, "experiments_to_success": exp_to_succ, "cost": cost}
|
| 159 |
+
|
| 160 |
+
st.session_state.comparison_table = {
|
| 161 |
+
"Naive (random)": agg(naive_results),
|
| 162 |
+
"RL (MLP)": agg(rl_results),
|
| 163 |
+
"LLM Researcher": agg(llm_results),
|
| 164 |
+
}
|
| 165 |
+
status.update(label="Done!", state="complete")
|
| 166 |
+
except Exception as e:
|
| 167 |
+
status.update(label="Error", state="error")
|
| 168 |
+
st.exception(e)
|
| 169 |
+
|
| 170 |
+
if st.session_state.comparison_table:
|
| 171 |
+
st.dataframe(
|
| 172 |
+
[
|
| 173 |
+
{
|
| 174 |
+
"Agent": name,
|
| 175 |
+
"Success rate": f"{data['success_rate']:.0%}",
|
| 176 |
+
"Experiments to success": f"{data['experiments_to_success']:.1f}",
|
| 177 |
+
"Cost/episode": f"${data['cost']:.1f}",
|
| 178 |
+
}
|
| 179 |
+
for name, data in st.session_state.comparison_table.items()
|
| 180 |
+
],
|
| 181 |
+
use_container_width=True,
|
| 182 |
+
hide_index=True,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
st.caption("Click **Run 3-agent comparison** to benchmark Naive, RL, and LLM agents.")
|
knowledge/pcr_protocols.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"title": "Optimization of Annealing Temperature for AT-Rich Primer Pairs",
|
| 4 |
+
"abstract": "We systematically evaluated annealing temperatures between 50 and 62°C for primers with high AT content. Best amplification and specificity were achieved in the 55–58°C range, with 57°C yielding optimal balance for most templates.",
|
| 5 |
+
"keywords": ["annealing", "AT-rich", "temperature", "primers", "PCR"],
|
| 6 |
+
"recommendations": "AT-rich primers: use annealing 55–58°C. Avoid >60°C to prevent poor yield."
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"title": "Cycle Number and Fidelity in Standard PCR",
|
| 10 |
+
"abstract": "Cycle counts from 25 to 40 were compared for amplicon yield and error rate. For most targets, 30–35 cycles gave high yield without excessive nonspecific product. Conservative protocols favor 28–32 cycles.",
|
| 11 |
+
"keywords": ["cycles", "fidelity", "yield", "PCR", "amplification"],
|
| 12 |
+
"recommendations": "High-fidelity applications: 30–35 cycles. Use 25–28 for long amplicons."
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"title": "Reagent Ratios and Primer Dimer Formation",
|
| 16 |
+
"abstract": "Conservative versus aggressive primer-to-template ratios were tested across 200 reactions. Conservative ratio reduced primer dimers and improved specificity in 78% of cases; aggressive ratio increased yield when template was limiting.",
|
| 17 |
+
"keywords": ["ratio", "conservative", "aggressive", "primer dimer", "specificity"],
|
| 18 |
+
"recommendations": "Conservative ratio for long amplicons and low template; aggressive when maximizing yield."
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"title": "Annealing Temperature Gradients for Multiplex PCR",
|
| 22 |
+
"abstract": "Gradient optimization showed that 65°C annealing worked well for GC-rich primers, while 55–58°C suited AT-rich primers. Middle range 60–62°C was a compromise with lower peak efficiency.",
|
| 23 |
+
"keywords": ["gradient", "GC-rich", "AT-rich", "multiplex", "annealing"],
|
| 24 |
+
"recommendations": "GC-rich primers: try 63–67°C. AT-rich: 55–59°C. Avoid 60–62°C as default."
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"title": "Extension Temperature and Cycle Count Interaction",
|
| 28 |
+
"abstract": "Extension at 72°C with 25 versus 35 cycles was compared. Fewer cycles (25–30) reduced background; 32–35 cycles improved sensitivity for low-copy targets. Combined with lower annealing (56–58°C), 32 cycles was optimal in our assay.",
|
| 29 |
+
"keywords": ["extension", "cycles", "sensitivity", "background", "72C"],
|
| 30 |
+
"recommendations": "For difficult templates: annealing 56–58°C, 32 cycles. Reduce to 28–30 if nonspecific bands appear."
|
| 31 |
+
}
|
| 32 |
+
]
|
lab_env/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lab_env.env import LabEnv
|
| 2 |
+
from lab_env.spec import ExperimentSpec, pcr_experiment_spec, elisa_experiment_spec, get_spec_for_workflow
|
| 3 |
+
|
| 4 |
+
__all__ = ["LabEnv", "ExperimentSpec", "pcr_experiment_spec", "elisa_experiment_spec", "get_spec_for_workflow"]
|
lab_env/env.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LabEnv — A Gymnasium-style simulated wet-lab environment for RL training.
|
| 3 |
+
|
| 4 |
+
Simulates a single experiment workflow (e.g. PCR, ELISA) where the agent must
|
| 5 |
+
discover a hidden optimal protocol under time and budget constraints. The
|
| 6 |
+
experiment type is defined by an ExperimentSpec so any protocol-discovery
|
| 7 |
+
experiment can be modelled.
|
| 8 |
+
|
| 9 |
+
Designed for compatibility with OpenEnv's sandboxed execution model:
|
| 10 |
+
the reset/step/close interface can be served over HTTP via the adapter in
|
| 11 |
+
``openenv_adapter.py`` and uploaded to the OpenEnv hub on Hugging Face as a
|
| 12 |
+
standardized agentic environment for lab-automation research.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import gymnasium as gym
|
| 20 |
+
import numpy as np
|
| 21 |
+
from gymnasium import spaces
|
| 22 |
+
|
| 23 |
+
from lab_env.spec import ExperimentSpec, pcr_experiment_spec
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Backward compatibility: expose constants for default (PCR) spec
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
_DEFAULT_SPEC = pcr_experiment_spec()
|
| 31 |
+
NUM_PRESETS: int = _DEFAULT_SPEC.num_presets
|
| 32 |
+
ACTION_SETUP_START: int = _DEFAULT_SPEC.action_setup_start()
|
| 33 |
+
ACTION_SETUP_END: int = _DEFAULT_SPEC.action_setup_end()
|
| 34 |
+
ACTION_RUN_ASSAY: int = _DEFAULT_SPEC.action_run_assay()
|
| 35 |
+
ACTION_ORDER_TIPS: int = _DEFAULT_SPEC.action_order_start() + 0
|
| 36 |
+
ACTION_ORDER_BUFFER: int = _DEFAULT_SPEC.action_order_start() + 1
|
| 37 |
+
ACTION_ORDER_POLYMERASE: int = _DEFAULT_SPEC.action_order_start() + 2
|
| 38 |
+
ACTION_WAIT: int = _DEFAULT_SPEC.action_wait()
|
| 39 |
+
ACTION_FINISH: int = _DEFAULT_SPEC.action_finish()
|
| 40 |
+
NUM_ACTIONS: int = _DEFAULT_SPEC.num_actions
|
| 41 |
+
OBS_DIM: int = _DEFAULT_SPEC.obs_dim
|
| 42 |
+
|
| 43 |
+
# Legacy constants used by scripts
|
| 44 |
+
INITIAL_BUDGET: float = _DEFAULT_SPEC.initial_budget
|
| 45 |
+
RESULT_LABELS = _DEFAULT_SPEC.result_labels
|
| 46 |
+
RESULT_TO_IDX = {label: i for i, label in enumerate(RESULT_LABELS)}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LabEnv(gym.Env):
|
| 50 |
+
"""Simulated wet-lab environment for any experiment type.
|
| 51 |
+
|
| 52 |
+
The experiment (protocol presets, inventory, rewards, outcome model) is
|
| 53 |
+
defined by an ExperimentSpec. Use LabEnv() for default PCR; use
|
| 54 |
+
LabEnv(spec=my_spec) for custom experiments.
|
| 55 |
+
|
| 56 |
+
Observation (Box, shape from spec):
|
| 57 |
+
[0] step_index (normalised)
|
| 58 |
+
[1] elapsed_minutes (normalised)
|
| 59 |
+
[2] remaining_budget (normalised)
|
| 60 |
+
[3..] inventory (one slot per inventory_items, normalised)
|
| 61 |
+
[...] last_result one-hot (len(result_labels))
|
| 62 |
+
[...] has_setup, current_preset_idx (norm), best_result_score
|
| 63 |
+
|
| 64 |
+
Actions (Discrete, from spec):
|
| 65 |
+
0 .. num_presets-1 setup_reaction(preset_index)
|
| 66 |
+
num_presets run_assay
|
| 67 |
+
num_presets+1 .. order_reagents(item) for each orderable item
|
| 68 |
+
... wait, finish
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
metadata = {"render_modes": []}
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
spec: ExperimentSpec | None = None,
|
| 76 |
+
render_mode: str | None = None,
|
| 77 |
+
) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.spec = spec if spec is not None else pcr_experiment_spec()
|
| 80 |
+
self.observation_space = spaces.Box(
|
| 81 |
+
low=0.0, high=1.0, shape=(self.spec.obs_dim,), dtype=np.float32
|
| 82 |
+
)
|
| 83 |
+
self.action_space = spaces.Discrete(self.spec.num_actions)
|
| 84 |
+
|
| 85 |
+
self._rng: np.random.Generator | None = None
|
| 86 |
+
self._current_protocol_override: dict[str, Any] | None = None
|
| 87 |
+
self._reset_state()
|
| 88 |
+
|
| 89 |
+
# ------------------------------------------------------------------
|
| 90 |
+
# Gymnasium API
|
| 91 |
+
# ------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
def reset(
|
| 94 |
+
self,
|
| 95 |
+
*,
|
| 96 |
+
seed: int | None = None,
|
| 97 |
+
options: dict[str, Any] | None = None,
|
| 98 |
+
) -> tuple[np.ndarray, dict[str, Any]]:
|
| 99 |
+
super().reset(seed=seed)
|
| 100 |
+
self._rng = np.random.default_rng(seed)
|
| 101 |
+
self._reset_state()
|
| 102 |
+
self._sample_hidden_optimum()
|
| 103 |
+
return self._obs(), self._info()
|
| 104 |
+
|
| 105 |
+
def step(
|
| 106 |
+
self, action: int
|
| 107 |
+
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
|
| 108 |
+
if self._terminated or self._truncated:
|
| 109 |
+
raise RuntimeError("Episode already done — call reset().")
|
| 110 |
+
|
| 111 |
+
reward = 0.0
|
| 112 |
+
self._step_index += 1
|
| 113 |
+
|
| 114 |
+
if self.spec.action_setup_start() <= action < self.spec.action_setup_end():
|
| 115 |
+
reward += self._do_setup(action)
|
| 116 |
+
elif action == self.spec.action_run_assay():
|
| 117 |
+
reward += self._do_run_assay()
|
| 118 |
+
elif self.spec.action_order_start() <= action < self.spec.action_order_end():
|
| 119 |
+
reward += self._do_order(action)
|
| 120 |
+
elif action == self.spec.action_wait():
|
| 121 |
+
reward += self._do_wait()
|
| 122 |
+
elif action == self.spec.action_finish():
|
| 123 |
+
reward += self._do_finish()
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Invalid action {action}")
|
| 126 |
+
|
| 127 |
+
self._check_forced_termination()
|
| 128 |
+
|
| 129 |
+
if self._terminated or self._truncated:
|
| 130 |
+
reward += self._terminal_reward()
|
| 131 |
+
|
| 132 |
+
return self._obs(), reward, self._terminated, self._truncated, self._info()
|
| 133 |
+
|
| 134 |
+
def run_assay_with_protocol(
|
| 135 |
+
self, protocol: dict[str, Any]
|
| 136 |
+
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
|
| 137 |
+
"""Run one assay with an arbitrary protocol dict (no preset).
|
| 138 |
+
|
| 139 |
+
The spec must have evaluate_custom_protocol set (e.g. PCR/ELISA). Consumes
|
| 140 |
+
inventory and time like a normal assay; outcome is from the spec's outcome
|
| 141 |
+
model. Use this for agent-generated protocols.
|
| 142 |
+
"""
|
| 143 |
+
if self._terminated or self._truncated:
|
| 144 |
+
raise RuntimeError("Episode already done — call reset().")
|
| 145 |
+
if self.spec.evaluate_custom_protocol is None:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"This spec does not support custom protocols; evaluate_custom_protocol is not set."
|
| 148 |
+
)
|
| 149 |
+
self._step_index += 1
|
| 150 |
+
self._current_protocol_override = dict(protocol)
|
| 151 |
+
self._has_setup = True
|
| 152 |
+
reward = self._do_run_assay()
|
| 153 |
+
self._check_forced_termination()
|
| 154 |
+
if self._terminated or self._truncated:
|
| 155 |
+
reward += self._terminal_reward()
|
| 156 |
+
return self._obs(), reward, self._terminated, self._truncated, self._info()
|
| 157 |
+
|
| 158 |
+
def close(self) -> None:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
# Action implementations
|
| 163 |
+
# ------------------------------------------------------------------
|
| 164 |
+
|
| 165 |
+
def _do_setup(self, action: int) -> float:
|
| 166 |
+
preset_idx = action - self.spec.action_setup_start()
|
| 167 |
+
self._current_preset_idx = preset_idx
|
| 168 |
+
self._has_setup = True
|
| 169 |
+
self._elapsed_minutes += 1.0
|
| 170 |
+
return 0.0
|
| 171 |
+
|
| 172 |
+
def _fail_result_label(self) -> str:
|
| 173 |
+
if "fail" in self.spec.result_labels:
|
| 174 |
+
return "fail"
|
| 175 |
+
return self.spec.result_labels[-1] if self.spec.result_labels else "fail"
|
| 176 |
+
|
| 177 |
+
def _do_run_assay(self) -> float:
|
| 178 |
+
if not self._has_setup:
|
| 179 |
+
self._last_result = self._fail_result_label()
|
| 180 |
+
self._elapsed_minutes += self.spec.assay_time_minutes
|
| 181 |
+
return self.spec.assay_penalty
|
| 182 |
+
|
| 183 |
+
inv = self._inventory
|
| 184 |
+
for item in self.spec.inventory_items:
|
| 185 |
+
if inv.get(item, 0) < 1:
|
| 186 |
+
self._last_result = self._fail_result_label()
|
| 187 |
+
return self.spec.assay_penalty
|
| 188 |
+
|
| 189 |
+
for item in self.spec.inventory_items:
|
| 190 |
+
inv[item] = inv.get(item, 0) - 1
|
| 191 |
+
inv[item] = max(0, inv[item])
|
| 192 |
+
|
| 193 |
+
self._elapsed_minutes += self.spec.assay_time_minutes
|
| 194 |
+
|
| 195 |
+
result = self._sample_assay_result()
|
| 196 |
+
self._last_result = result
|
| 197 |
+
self._update_best(result)
|
| 198 |
+
|
| 199 |
+
imm = self.spec.immediate_result_reward.get(result, 0.0)
|
| 200 |
+
return self.spec.assay_penalty + imm
|
| 201 |
+
|
| 202 |
+
def _do_order(self, action: int) -> float:
|
| 203 |
+
idx = action - self.spec.action_order_start()
|
| 204 |
+
if idx < 0 or idx >= len(self.spec.orderable_items):
|
| 205 |
+
return 0.0
|
| 206 |
+
item = self.spec.orderable_items[idx]
|
| 207 |
+
if item not in self.spec.order_costs:
|
| 208 |
+
return 0.0
|
| 209 |
+
qty, cost = self.spec.order_costs[item]
|
| 210 |
+
|
| 211 |
+
if self._remaining_budget < cost:
|
| 212 |
+
return 0.0
|
| 213 |
+
|
| 214 |
+
self._remaining_budget -= cost
|
| 215 |
+
self._inventory[item] = min(
|
| 216 |
+
self._inventory.get(item, 0) + qty, self.spec.max_inventory
|
| 217 |
+
)
|
| 218 |
+
self._elapsed_minutes += self.spec.order_time_minutes
|
| 219 |
+
return 0.0
|
| 220 |
+
|
| 221 |
+
def _do_wait(self) -> float:
|
| 222 |
+
self._elapsed_minutes += self.spec.wait_minutes
|
| 223 |
+
return 0.0
|
| 224 |
+
|
| 225 |
+
def _do_finish(self) -> float:
|
| 226 |
+
self._terminated = True
|
| 227 |
+
return 0.0
|
| 228 |
+
|
| 229 |
+
# ------------------------------------------------------------------
|
| 230 |
+
# Termination
|
| 231 |
+
# ------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
def _check_forced_termination(self) -> None:
|
| 234 |
+
if self._terminated:
|
| 235 |
+
return
|
| 236 |
+
if self._elapsed_minutes >= self.spec.max_minutes:
|
| 237 |
+
self._truncated = True
|
| 238 |
+
return
|
| 239 |
+
if self._remaining_budget <= 0:
|
| 240 |
+
self._truncated = True
|
| 241 |
+
return
|
| 242 |
+
if self._step_index >= self.spec.max_steps:
|
| 243 |
+
self._truncated = True
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
inv = self._inventory
|
| 247 |
+
can_run = all(inv.get(item, 0) >= 1 for item in self.spec.inventory_items)
|
| 248 |
+
can_order = any(
|
| 249 |
+
self._remaining_budget >= self.spec.order_costs.get(k, (0, float("inf")))[1]
|
| 250 |
+
for k in self.spec.orderable_items
|
| 251 |
+
)
|
| 252 |
+
if not can_run and not can_order:
|
| 253 |
+
self._truncated = True
|
| 254 |
+
|
| 255 |
+
def _terminal_reward(self) -> float:
|
| 256 |
+
bonus = self.spec.terminal_bonus.get(self._best_result, 0.0)
|
| 257 |
+
time_penalty = self.spec.time_penalty_per_min * self._elapsed_minutes
|
| 258 |
+
no_success = (
|
| 259 |
+
self.spec.no_success_penalty
|
| 260 |
+
if self._best_result in ("none", "fail") or self._best_result not in self.spec.terminal_bonus
|
| 261 |
+
else 0.0
|
| 262 |
+
)
|
| 263 |
+
return bonus + time_penalty + no_success
|
| 264 |
+
|
| 265 |
+
# ------------------------------------------------------------------
|
| 266 |
+
# Outcome model (delegated to spec)
|
| 267 |
+
# ------------------------------------------------------------------
|
| 268 |
+
|
| 269 |
+
def _sample_hidden_optimum(self) -> None:
|
| 270 |
+
if self._rng is None:
|
| 271 |
+
return
|
| 272 |
+
if self.spec.sample_hidden_optimum is not None:
|
| 273 |
+
self._hidden_optimum = self.spec.sample_hidden_optimum(self._rng)
|
| 274 |
+
else:
|
| 275 |
+
self._hidden_optimum = {}
|
| 276 |
+
|
| 277 |
+
def _sample_assay_result(self) -> str:
|
| 278 |
+
if self._rng is None:
|
| 279 |
+
return self.spec.result_labels[1] if len(self.spec.result_labels) > 1 else "fail"
|
| 280 |
+
if self._current_protocol_override is not None and self.spec.evaluate_custom_protocol is not None:
|
| 281 |
+
result = self.spec.evaluate_custom_protocol(
|
| 282 |
+
self._hidden_optimum,
|
| 283 |
+
self._current_protocol_override,
|
| 284 |
+
self._rng,
|
| 285 |
+
)
|
| 286 |
+
self._current_protocol_override = None
|
| 287 |
+
return result
|
| 288 |
+
if self.spec.sample_assay_result is not None:
|
| 289 |
+
return self.spec.sample_assay_result(
|
| 290 |
+
self._hidden_optimum,
|
| 291 |
+
self._current_preset_idx,
|
| 292 |
+
self.spec.presets,
|
| 293 |
+
self._rng,
|
| 294 |
+
)
|
| 295 |
+
# Default: random non-none result
|
| 296 |
+
choices = [r for r in self.spec.result_labels if r != "none"]
|
| 297 |
+
if not choices:
|
| 298 |
+
return "fail"
|
| 299 |
+
return str(self._rng.choice(choices))
|
| 300 |
+
|
| 301 |
+
def _update_best(self, result: str) -> None:
|
| 302 |
+
rank = {"fail": 0, "none": 0, "partial": 1, "success": 2}
|
| 303 |
+
if rank.get(result, 0) > rank.get(self._best_result, 0):
|
| 304 |
+
self._best_result = result
|
| 305 |
+
|
| 306 |
+
# ------------------------------------------------------------------
|
| 307 |
+
# Observation helpers
|
| 308 |
+
# ------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
def _result_to_onehot(self, result: str) -> list[float]:
|
| 311 |
+
out = [0.0] * len(self.spec.result_labels)
|
| 312 |
+
for i, label in enumerate(self.spec.result_labels):
|
| 313 |
+
if label == result:
|
| 314 |
+
out[i] = 1.0
|
| 315 |
+
break
|
| 316 |
+
return out
|
| 317 |
+
|
| 318 |
+
def _obs(self) -> np.ndarray:
|
| 319 |
+
inv = self._inventory
|
| 320 |
+
result_onehot = self._result_to_onehot(self._last_result)
|
| 321 |
+
best_score = {"none": 0.0, "fail": 0.0, "partial": 0.5, "success": 1.0}.get(
|
| 322 |
+
self._best_result, 0.0
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
inv_slice = [
|
| 326 |
+
inv.get(item, 0) / self.spec.max_inventory
|
| 327 |
+
for item in self.spec.inventory_items
|
| 328 |
+
]
|
| 329 |
+
obs = np.array(
|
| 330 |
+
[
|
| 331 |
+
self._step_index / self.spec.max_steps,
|
| 332 |
+
self._elapsed_minutes / self.spec.max_minutes,
|
| 333 |
+
self._remaining_budget / self.spec.initial_budget,
|
| 334 |
+
*inv_slice,
|
| 335 |
+
*result_onehot,
|
| 336 |
+
float(self._has_setup),
|
| 337 |
+
(self._current_preset_idx / self.spec.num_presets) if self._has_setup else 0.0,
|
| 338 |
+
best_score,
|
| 339 |
+
],
|
| 340 |
+
dtype=np.float32,
|
| 341 |
+
)
|
| 342 |
+
return obs
|
| 343 |
+
|
| 344 |
+
def _info(self) -> dict[str, Any]:
|
| 345 |
+
return {
|
| 346 |
+
"step_index": self._step_index,
|
| 347 |
+
"elapsed_minutes": self._elapsed_minutes,
|
| 348 |
+
"remaining_budget": self._remaining_budget,
|
| 349 |
+
"inventory": dict(self._inventory),
|
| 350 |
+
"last_result": self._last_result,
|
| 351 |
+
"best_result": self._best_result,
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
# ------------------------------------------------------------------
|
| 355 |
+
# Internal state management
|
| 356 |
+
# ------------------------------------------------------------------
|
| 357 |
+
|
| 358 |
+
def _reset_state(self) -> None:
|
| 359 |
+
self._step_index = 0
|
| 360 |
+
self._elapsed_minutes = 0.0
|
| 361 |
+
self._remaining_budget = self.spec.initial_budget
|
| 362 |
+
self._inventory = dict(self.spec.initial_inventory)
|
| 363 |
+
self._last_result = self.spec.result_labels[0] if self.spec.result_labels else "none"
|
| 364 |
+
self._best_result = self._last_result
|
| 365 |
+
self._has_setup = False
|
| 366 |
+
self._current_preset_idx = 0
|
| 367 |
+
self._terminated = False
|
| 368 |
+
self._truncated = False
|
| 369 |
+
self._hidden_optimum: dict[str, Any] = {}
|
lab_env/openenv_adapter.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv adapter for LabEnv.
|
| 3 |
+
|
| 4 |
+
Wraps :class:`LabEnv` in the OpenEnv Environment interface so it can be
|
| 5 |
+
served over HTTP/WebSocket via openenv-core and deployed to the OpenEnv hub
|
| 6 |
+
on Hugging Face.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# Run the OpenEnv HTTP server (POST /reset, POST /step, GET /state, WebSocket /ws)
|
| 10 |
+
uvicorn lab_env.openenv_adapter:app --host 0.0.0.0 --port 8000
|
| 11 |
+
|
| 12 |
+
# Or from Python
|
| 13 |
+
from lab_env.openenv_adapter import app
|
| 14 |
+
import uvicorn
|
| 15 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
from uuid import uuid4
|
| 22 |
+
|
| 23 |
+
from openenv.core.env_server import create_app
|
| 24 |
+
from openenv.core.env_server.interfaces import Environment
|
| 25 |
+
from openenv.core.env_server.types import (
|
| 26 |
+
Action,
|
| 27 |
+
EnvironmentMetadata,
|
| 28 |
+
Observation,
|
| 29 |
+
State,
|
| 30 |
+
)
|
| 31 |
+
from pydantic import Field
|
| 32 |
+
|
| 33 |
+
from lab_env.env import LabEnv
|
| 34 |
+
from lab_env.spec import ExperimentSpec, pcr_experiment_spec
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# OpenEnv types (Pydantic models for action, observation, state)
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LabAction(Action):
|
| 43 |
+
"""Discrete action index for LabEnv (range depends on experiment spec)."""
|
| 44 |
+
|
| 45 |
+
action: int = Field(..., ge=0, description="Action index (0 to num_actions-1)")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LabObservation(Observation):
|
| 49 |
+
"""Observation returned after reset / step. Vector and info live in metadata."""
|
| 50 |
+
|
| 51 |
+
metadata: dict[str, Any] = Field(
|
| 52 |
+
default_factory=dict,
|
| 53 |
+
description="Contains obs_vector, terminated, truncated, info",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LabState(State):
|
| 58 |
+
"""Full environment state snapshot for LabEnv."""
|
| 59 |
+
|
| 60 |
+
model_config = {"extra": "allow"}
|
| 61 |
+
|
| 62 |
+
episode_id: Optional[str] = Field(default=None, description="Episode identifier")
|
| 63 |
+
step_count: int = Field(default=0, ge=0, description="Steps taken")
|
| 64 |
+
elapsed_minutes: float = Field(default=0.0, description="Elapsed time (min)")
|
| 65 |
+
remaining_budget: float = Field(default=0.0, description="Remaining budget")
|
| 66 |
+
inventory: dict[str, int] = Field(default_factory=dict, description="Inventory counts")
|
| 67 |
+
last_result: str = Field(default="none", description="Last assay result")
|
| 68 |
+
best_result: str = Field(default="none", description="Best result so far")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# OpenEnv Environment implementation
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class LabEnvironment(Environment[LabAction, LabObservation, LabState]):
|
| 77 |
+
"""OpenEnv Environment that wraps a single LabEnv instance.
|
| 78 |
+
|
| 79 |
+
Each session gets its own LabEnv. Compatible with OpenEnv HTTP server
|
| 80 |
+
and WebSocket endpoints.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 84 |
+
|
| 85 |
+
def __init__(self, spec: Optional[ExperimentSpec] = None) -> None:
|
| 86 |
+
super().__init__()
|
| 87 |
+
self._env = LabEnv(spec=spec)
|
| 88 |
+
self._episode_id: Optional[str] = None
|
| 89 |
+
|
| 90 |
+
def reset(
|
| 91 |
+
self,
|
| 92 |
+
seed: Optional[int] = None,
|
| 93 |
+
episode_id: Optional[str] = None,
|
| 94 |
+
**kwargs: Any,
|
| 95 |
+
) -> LabObservation:
|
| 96 |
+
obs, info = self._env.reset(seed=seed)
|
| 97 |
+
self._episode_id = episode_id or str(uuid4())
|
| 98 |
+
return LabObservation(
|
| 99 |
+
done=False,
|
| 100 |
+
reward=0.0,
|
| 101 |
+
metadata={
|
| 102 |
+
"obs_vector": obs.tolist(),
|
| 103 |
+
"terminated": False,
|
| 104 |
+
"truncated": False,
|
| 105 |
+
"info": info,
|
| 106 |
+
},
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def step(
|
| 110 |
+
self,
|
| 111 |
+
action: LabAction,
|
| 112 |
+
timeout_s: Optional[float] = None,
|
| 113 |
+
**kwargs: Any,
|
| 114 |
+
) -> LabObservation:
|
| 115 |
+
obs, reward, terminated, truncated, info = self._env.step(action.action)
|
| 116 |
+
return LabObservation(
|
| 117 |
+
done=terminated or truncated,
|
| 118 |
+
reward=float(reward),
|
| 119 |
+
metadata={
|
| 120 |
+
"obs_vector": obs.tolist(),
|
| 121 |
+
"terminated": terminated,
|
| 122 |
+
"truncated": truncated,
|
| 123 |
+
"info": info,
|
| 124 |
+
},
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def state(self) -> LabState:
|
| 129 |
+
e = self._env
|
| 130 |
+
return LabState(
|
| 131 |
+
episode_id=self._episode_id,
|
| 132 |
+
step_count=e._step_index,
|
| 133 |
+
elapsed_minutes=e._elapsed_minutes,
|
| 134 |
+
remaining_budget=e._remaining_budget,
|
| 135 |
+
inventory=dict(e._inventory),
|
| 136 |
+
last_result=e._last_result,
|
| 137 |
+
best_result=e._best_result,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 141 |
+
exp_name = getattr(self._env.spec, "name", "pcr")
|
| 142 |
+
return EnvironmentMetadata(
|
| 143 |
+
name="SimLab",
|
| 144 |
+
description=f"Gymnasium-style simulated wet-lab for protocol discovery ({exp_name})",
|
| 145 |
+
version="0.1.0",
|
| 146 |
+
documentation_url="https://github.com/openrl/simlab",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def close(self) -> None:
|
| 150 |
+
self._env.close()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
# FastAPI app for OpenEnv HTTP/WebSocket server
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
app = create_app(
|
| 158 |
+
LabEnvironment,
|
| 159 |
+
LabAction,
|
| 160 |
+
LabObservation,
|
| 161 |
+
env_name="simlab",
|
| 162 |
+
max_concurrent_envs=4,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ---------------------------------------------------------------------------
|
| 167 |
+
# Legacy session-based adapter (for direct use without HTTP server)
|
| 168 |
+
# ---------------------------------------------------------------------------
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class LabEnvOpenEnvAdapter:
|
| 172 |
+
"""Manages multiple concurrent LabEnv sessions keyed by env_id.
|
| 173 |
+
|
| 174 |
+
Use this when you need to drive LabEnv by env_id (e.g. from another
|
| 175 |
+
service) without going through the OpenEnv HTTP server. For standard
|
| 176 |
+
OpenEnv deployment, use LabEnvironment + the `app` above instead.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self) -> None:
|
| 180 |
+
self._envs: dict[str, LabEnv] = {}
|
| 181 |
+
|
| 182 |
+
def env_reset(
|
| 183 |
+
self,
|
| 184 |
+
env_id: str,
|
| 185 |
+
seed: Optional[int] = None,
|
| 186 |
+
) -> dict[str, Any]:
|
| 187 |
+
"""Create or reset an environment instance; return initial observation."""
|
| 188 |
+
env = self._envs.get(env_id)
|
| 189 |
+
if env is None:
|
| 190 |
+
env = LabEnv()
|
| 191 |
+
self._envs[env_id] = env
|
| 192 |
+
obs, info = env.reset(seed=seed)
|
| 193 |
+
return {
|
| 194 |
+
"observation": obs.tolist(),
|
| 195 |
+
"info": info,
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
def env_step(
|
| 199 |
+
self,
|
| 200 |
+
env_id: str,
|
| 201 |
+
action: int,
|
| 202 |
+
) -> dict[str, Any]:
|
| 203 |
+
"""Advance the environment by one action; return transition."""
|
| 204 |
+
env = self._envs[env_id]
|
| 205 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 206 |
+
return {
|
| 207 |
+
"observation": obs.tolist(),
|
| 208 |
+
"reward": float(reward),
|
| 209 |
+
"terminated": terminated,
|
| 210 |
+
"truncated": truncated,
|
| 211 |
+
"info": info,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
def env_state(self, env_id: str) -> dict[str, Any]:
|
| 215 |
+
"""Return a JSON-serializable snapshot of the current state."""
|
| 216 |
+
env = self._envs[env_id]
|
| 217 |
+
return {
|
| 218 |
+
"episode_id": env_id,
|
| 219 |
+
"step_index": env._step_index,
|
| 220 |
+
"elapsed_minutes": env._elapsed_minutes,
|
| 221 |
+
"remaining_budget": env._remaining_budget,
|
| 222 |
+
"inventory": dict(env._inventory),
|
| 223 |
+
"last_result": env._last_result,
|
| 224 |
+
"best_result": env._best_result,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def env_close(self, env_id: str) -> None:
|
| 228 |
+
"""Tear down an environment instance."""
|
| 229 |
+
env = self._envs.pop(env_id, None)
|
| 230 |
+
if env is not None:
|
| 231 |
+
env.close()
|
lab_env/spec.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment specification for the generic lab environment.
|
| 3 |
+
|
| 4 |
+
Defines protocol presets, inventory, rewards, and outcome model so LabEnv can
|
| 5 |
+
simulate any experiment type (PCR, ELISA, etc.) from a single spec.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any, Callable
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ExperimentSpec:
|
| 18 |
+
"""Specification for a single experiment type (PCR, ELISA, etc.).
|
| 19 |
+
|
| 20 |
+
The environment uses this to build action/observation spaces and dynamics.
|
| 21 |
+
Outcome logic is pluggable via sample_hidden_optimum and sample_assay_result.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
name: str
|
| 25 |
+
"""Short name for this experiment (e.g. 'pcr', 'elisa')."""
|
| 26 |
+
|
| 27 |
+
presets: list[dict[str, Any]]
|
| 28 |
+
"""List of protocol presets the agent can choose (e.g. temp/cycles/ratio for PCR)."""
|
| 29 |
+
|
| 30 |
+
inventory_items: list[str]
|
| 31 |
+
"""Ordered list of inventory item names (tips, buffer, polymerase, samples, ...)."""
|
| 32 |
+
|
| 33 |
+
orderable_items: list[str]
|
| 34 |
+
"""Subset of inventory_items that can be reordered (each gets an order action)."""
|
| 35 |
+
|
| 36 |
+
initial_inventory: dict[str, int]
|
| 37 |
+
"""Starting count per inventory item."""
|
| 38 |
+
|
| 39 |
+
order_costs: dict[str, tuple[int, float]]
|
| 40 |
+
"""For each orderable item: (quantity_per_order, cost_per_order)."""
|
| 41 |
+
|
| 42 |
+
result_labels: list[str]
|
| 43 |
+
"""Possible assay outcomes, e.g. ['none', 'success', 'partial', 'fail']."""
|
| 44 |
+
|
| 45 |
+
# Limits
|
| 46 |
+
max_steps: int = 30
|
| 47 |
+
max_minutes: float = 240.0
|
| 48 |
+
initial_budget: float = 500.0
|
| 49 |
+
max_inventory: int = 20
|
| 50 |
+
|
| 51 |
+
# Time costs
|
| 52 |
+
assay_time_minutes: float = 20.0
|
| 53 |
+
order_time_minutes: float = 5.0
|
| 54 |
+
wait_minutes: float = 15.0
|
| 55 |
+
|
| 56 |
+
# Rewards
|
| 57 |
+
assay_penalty: float = -3.0
|
| 58 |
+
time_penalty_per_min: float = -0.25
|
| 59 |
+
no_success_penalty: float = -20.0
|
| 60 |
+
immediate_result_reward: dict[str, float] = field(default_factory=dict)
|
| 61 |
+
terminal_bonus: dict[str, float] = field(default_factory=dict)
|
| 62 |
+
|
| 63 |
+
# Outcome model: callables that take (rng) or (hidden_state, preset_idx, presets, rng)
|
| 64 |
+
sample_hidden_optimum: Callable[[np.random.Generator], dict[str, Any]] | None = None
|
| 65 |
+
sample_assay_result: (
|
| 66 |
+
Callable[
|
| 67 |
+
[dict[str, Any], int, list[dict[str, Any]], np.random.Generator],
|
| 68 |
+
str,
|
| 69 |
+
]
|
| 70 |
+
| None
|
| 71 |
+
) = None
|
| 72 |
+
|
| 73 |
+
# Custom protocol support: evaluate arbitrary protocol dict (for agent-generated protocols)
|
| 74 |
+
evaluate_custom_protocol: (
|
| 75 |
+
Callable[
|
| 76 |
+
[dict[str, Any], dict[str, Any], np.random.Generator],
|
| 77 |
+
str,
|
| 78 |
+
]
|
| 79 |
+
| None
|
| 80 |
+
) = None
|
| 81 |
+
"""If set, (hidden_optimum, protocol_dict, rng) -> result label. Enables run_assay_with_protocol()."""
|
| 82 |
+
|
| 83 |
+
protocol_param_schema: dict[str, Any] = field(default_factory=dict)
|
| 84 |
+
"""Schema describing protocol params for codegen/LLM: e.g. {"temp": {"type": "number", "description": "°C"}, ...}."""
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def num_presets(self) -> int:
|
| 88 |
+
return len(self.presets)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def num_actions(self) -> int:
|
| 92 |
+
return (
|
| 93 |
+
self.num_presets
|
| 94 |
+
+ 1 # run_assay
|
| 95 |
+
+ len(self.orderable_items)
|
| 96 |
+
+ 2 # wait, finish
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def obs_dim(self) -> int:
|
| 101 |
+
return (
|
| 102 |
+
3 # step_index, elapsed_minutes, remaining_budget
|
| 103 |
+
+ len(self.inventory_items)
|
| 104 |
+
+ len(self.result_labels)
|
| 105 |
+
+ 3 # has_setup, current_preset_idx (norm), best_result_score
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def action_setup_start(self) -> int:
|
| 109 |
+
return 0
|
| 110 |
+
|
| 111 |
+
def action_setup_end(self) -> int:
|
| 112 |
+
return self.num_presets
|
| 113 |
+
|
| 114 |
+
def action_run_assay(self) -> int:
|
| 115 |
+
return self.num_presets
|
| 116 |
+
|
| 117 |
+
def action_order_start(self) -> int:
|
| 118 |
+
return self.num_presets + 1
|
| 119 |
+
|
| 120 |
+
def action_order_end(self) -> int:
|
| 121 |
+
return self.num_presets + 1 + len(self.orderable_items)
|
| 122 |
+
|
| 123 |
+
def action_wait(self) -> int:
|
| 124 |
+
return self.num_presets + 1 + len(self.orderable_items)
|
| 125 |
+
|
| 126 |
+
def action_finish(self) -> int:
|
| 127 |
+
return self.num_presets + 2 + len(self.orderable_items)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
# PCR experiment spec (default / backward compatibility)
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
def _pcr_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]:
|
| 135 |
+
temps = [55.0, 65.0, 72.0]
|
| 136 |
+
cycles = [25, 35]
|
| 137 |
+
ratios = ["conservative", "aggressive"]
|
| 138 |
+
opt_temp = float(rng.choice(temps, p=[0.2, 0.5, 0.3])) + rng.uniform(-3.0, 3.0)
|
| 139 |
+
opt_cycles = float(rng.choice(cycles, p=[0.6, 0.4])) + rng.uniform(-2.0, 2.0)
|
| 140 |
+
opt_ratio = str(rng.choice(ratios, p=[0.6, 0.4]))
|
| 141 |
+
return {"temp": opt_temp, "cycles": opt_cycles, "ratio": opt_ratio}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _pcr_sample_assay_result(
|
| 145 |
+
hidden: dict[str, Any],
|
| 146 |
+
preset_idx: int,
|
| 147 |
+
presets: list[dict[str, Any]],
|
| 148 |
+
rng: np.random.Generator,
|
| 149 |
+
) -> str:
|
| 150 |
+
preset = presets[preset_idx]
|
| 151 |
+
chosen_temp = float(preset["temp"])
|
| 152 |
+
chosen_cycles = float(preset["cycles"])
|
| 153 |
+
chosen_ratio = str(preset["ratio"])
|
| 154 |
+
opt_temp = hidden["temp"]
|
| 155 |
+
opt_cycles = hidden["cycles"]
|
| 156 |
+
opt_ratio = hidden["ratio"]
|
| 157 |
+
|
| 158 |
+
temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0)
|
| 159 |
+
cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0)
|
| 160 |
+
ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3
|
| 161 |
+
closeness = temp_close * cycle_close * ratio_match
|
| 162 |
+
|
| 163 |
+
p_success = closeness ** 2
|
| 164 |
+
p_partial = closeness * (1.0 - closeness) * 0.8
|
| 165 |
+
p_fail = 1.0 - p_success - p_partial
|
| 166 |
+
return str(
|
| 167 |
+
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _pcr_evaluate_custom_protocol(
|
| 172 |
+
hidden: dict[str, Any],
|
| 173 |
+
protocol: dict[str, Any],
|
| 174 |
+
rng: np.random.Generator,
|
| 175 |
+
) -> str:
|
| 176 |
+
"""Evaluate any protocol dict (temp, cycles, ratio) against hidden optimum."""
|
| 177 |
+
chosen_temp = float(protocol.get("temp", 60.0))
|
| 178 |
+
chosen_cycles = float(protocol.get("cycles", 30))
|
| 179 |
+
r = str(protocol.get("ratio", "conservative")).strip().lower()
|
| 180 |
+
chosen_ratio = "conservative" if "conservative" in r else "aggressive"
|
| 181 |
+
opt_temp = hidden["temp"]
|
| 182 |
+
opt_cycles = hidden["cycles"]
|
| 183 |
+
opt_ratio = hidden["ratio"]
|
| 184 |
+
temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0)
|
| 185 |
+
cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0)
|
| 186 |
+
ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3
|
| 187 |
+
closeness = temp_close * cycle_close * ratio_match
|
| 188 |
+
p_success = closeness ** 2
|
| 189 |
+
p_partial = closeness * (1.0 - closeness) * 0.8
|
| 190 |
+
p_fail = 1.0 - p_success - p_partial
|
| 191 |
+
return str(
|
| 192 |
+
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
PCR_PROTOCOL_SCHEMA = {
|
| 197 |
+
"temp": {"type": "number", "description": "Annealing temperature in °C (e.g. 55–72)"},
|
| 198 |
+
"cycles": {"type": "integer", "description": "Number of PCR cycles (e.g. 25–40)"},
|
| 199 |
+
"ratio": {"type": "string", "enum": ["conservative", "aggressive"], "description": "Reagent ratio"},
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def pcr_experiment_spec() -> ExperimentSpec:
|
| 204 |
+
"""Build the default PCR experiment spec (same behaviour as original LabEnv)."""
|
| 205 |
+
from itertools import product
|
| 206 |
+
|
| 207 |
+
temps = [55.0, 65.0, 72.0]
|
| 208 |
+
cycles = [25, 35]
|
| 209 |
+
ratios = ["conservative", "aggressive"]
|
| 210 |
+
presets = [
|
| 211 |
+
{"temp": t, "cycles": c, "ratio": r}
|
| 212 |
+
for t, c, r in product(temps, cycles, ratios)
|
| 213 |
+
]
|
| 214 |
+
return ExperimentSpec(
|
| 215 |
+
name="pcr",
|
| 216 |
+
presets=presets,
|
| 217 |
+
inventory_items=["tips", "buffer", "polymerase", "samples"],
|
| 218 |
+
orderable_items=["tips", "buffer", "polymerase"],
|
| 219 |
+
initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8},
|
| 220 |
+
order_costs={
|
| 221 |
+
"tips": (5, 10.0),
|
| 222 |
+
"buffer": (5, 15.0),
|
| 223 |
+
"polymerase": (3, 30.0),
|
| 224 |
+
},
|
| 225 |
+
result_labels=["none", "success", "partial", "fail"],
|
| 226 |
+
max_steps=30,
|
| 227 |
+
max_minutes=240.0,
|
| 228 |
+
initial_budget=500.0,
|
| 229 |
+
max_inventory=20,
|
| 230 |
+
assay_time_minutes=20.0,
|
| 231 |
+
order_time_minutes=5.0,
|
| 232 |
+
wait_minutes=15.0,
|
| 233 |
+
assay_penalty=-3.0,
|
| 234 |
+
time_penalty_per_min=-0.25,
|
| 235 |
+
no_success_penalty=-20.0,
|
| 236 |
+
immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0},
|
| 237 |
+
terminal_bonus={"success": 60.0, "partial": 25.0},
|
| 238 |
+
sample_hidden_optimum=_pcr_sample_hidden_optimum,
|
| 239 |
+
sample_assay_result=_pcr_sample_assay_result,
|
| 240 |
+
evaluate_custom_protocol=_pcr_evaluate_custom_protocol,
|
| 241 |
+
protocol_param_schema=PCR_PROTOCOL_SCHEMA,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# ---------------------------------------------------------------------------
|
| 246 |
+
# ELISA experiment spec (same obs/action shape as PCR for agent compatibility)
|
| 247 |
+
# ---------------------------------------------------------------------------
|
| 248 |
+
|
| 249 |
+
def _elisa_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]:
|
| 250 |
+
coating_hrs = [1.0, 2.0, 3.0]
|
| 251 |
+
temps = [4.0, 37.0]
|
| 252 |
+
blocks = ["bsa", "casein"]
|
| 253 |
+
opt_coating = float(rng.choice(coating_hrs, p=[0.3, 0.5, 0.2])) + rng.uniform(-0.2, 0.2)
|
| 254 |
+
opt_temp = float(rng.choice(temps, p=[0.5, 0.5])) + rng.uniform(-2.0, 2.0)
|
| 255 |
+
opt_block = str(rng.choice(blocks, p=[0.6, 0.4]))
|
| 256 |
+
return {"coating_hr": opt_coating, "temp": opt_temp, "block": opt_block}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _elisa_sample_assay_result(
|
| 260 |
+
hidden: dict[str, Any],
|
| 261 |
+
preset_idx: int,
|
| 262 |
+
presets: list[dict[str, Any]],
|
| 263 |
+
rng: np.random.Generator,
|
| 264 |
+
) -> str:
|
| 265 |
+
preset = presets[preset_idx]
|
| 266 |
+
c = float(preset["coating_hr"])
|
| 267 |
+
t = float(preset["temp"])
|
| 268 |
+
b = str(preset["block"])
|
| 269 |
+
oc = hidden["coating_hr"]
|
| 270 |
+
ot = hidden["temp"]
|
| 271 |
+
ob = hidden["block"]
|
| 272 |
+
coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0)
|
| 273 |
+
temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0)
|
| 274 |
+
block_match = 1.0 if b == ob else 0.3
|
| 275 |
+
closeness = coat_close * temp_close * block_match
|
| 276 |
+
p_success = closeness ** 2
|
| 277 |
+
p_partial = closeness * (1.0 - closeness) * 0.8
|
| 278 |
+
p_fail = 1.0 - p_success - p_partial
|
| 279 |
+
return str(
|
| 280 |
+
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _elisa_evaluate_custom_protocol(
|
| 285 |
+
hidden: dict[str, Any],
|
| 286 |
+
protocol: dict[str, Any],
|
| 287 |
+
rng: np.random.Generator,
|
| 288 |
+
) -> str:
|
| 289 |
+
"""Evaluate any protocol dict (coating_hr, temp, block) against hidden optimum."""
|
| 290 |
+
c = float(protocol.get("coating_hr", 2.0))
|
| 291 |
+
t = float(protocol.get("temp", 25.0))
|
| 292 |
+
b = str(protocol.get("block", "bsa")).strip().lower()
|
| 293 |
+
block_clean = "bsa" if "bsa" in b else "casein"
|
| 294 |
+
oc, ot, ob = hidden["coating_hr"], hidden["temp"], hidden["block"]
|
| 295 |
+
coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0)
|
| 296 |
+
temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0)
|
| 297 |
+
block_match = 1.0 if block_clean == ob else 0.3
|
| 298 |
+
closeness = coat_close * temp_close * block_match
|
| 299 |
+
p_success = closeness ** 2
|
| 300 |
+
p_partial = closeness * (1.0 - closeness) * 0.8
|
| 301 |
+
p_fail = 1.0 - p_success - p_partial
|
| 302 |
+
return str(
|
| 303 |
+
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
ELISA_PROTOCOL_SCHEMA = {
|
| 308 |
+
"coating_hr": {"type": "number", "description": "Coating time in hours (e.g. 1–3)"},
|
| 309 |
+
"temp": {"type": "number", "description": "Incubation temperature °C (e.g. 4 or 37)"},
|
| 310 |
+
"block": {"type": "string", "enum": ["bsa", "casein"], "description": "Blocking agent"},
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def elisa_experiment_spec() -> ExperimentSpec:
|
| 315 |
+
"""ELISA readout: coating time (hr), temperature (°C), blocking type. Same obs/action dims as PCR."""
|
| 316 |
+
from itertools import product
|
| 317 |
+
|
| 318 |
+
coating_hrs = [1.0, 2.0, 3.0]
|
| 319 |
+
temps = [4.0, 37.0]
|
| 320 |
+
blocks = ["bsa", "casein"]
|
| 321 |
+
presets = [
|
| 322 |
+
{"coating_hr": ch, "temp": t, "block": bl}
|
| 323 |
+
for ch, t, bl in product(coating_hrs, temps, blocks)
|
| 324 |
+
]
|
| 325 |
+
return ExperimentSpec(
|
| 326 |
+
name="elisa",
|
| 327 |
+
presets=presets,
|
| 328 |
+
inventory_items=["tips", "buffer", "polymerase", "samples"],
|
| 329 |
+
orderable_items=["tips", "buffer", "polymerase"],
|
| 330 |
+
initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8},
|
| 331 |
+
order_costs={
|
| 332 |
+
"tips": (5, 10.0),
|
| 333 |
+
"buffer": (5, 15.0),
|
| 334 |
+
"polymerase": (3, 30.0),
|
| 335 |
+
},
|
| 336 |
+
result_labels=["none", "success", "partial", "fail"],
|
| 337 |
+
max_steps=30,
|
| 338 |
+
max_minutes=240.0,
|
| 339 |
+
initial_budget=500.0,
|
| 340 |
+
max_inventory=20,
|
| 341 |
+
assay_time_minutes=20.0,
|
| 342 |
+
order_time_minutes=5.0,
|
| 343 |
+
wait_minutes=15.0,
|
| 344 |
+
assay_penalty=-3.0,
|
| 345 |
+
time_penalty_per_min=-0.25,
|
| 346 |
+
no_success_penalty=-20.0,
|
| 347 |
+
immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0},
|
| 348 |
+
terminal_bonus={"success": 60.0, "partial": 25.0},
|
| 349 |
+
sample_hidden_optimum=_elisa_sample_hidden_optimum,
|
| 350 |
+
sample_assay_result=_elisa_sample_assay_result,
|
| 351 |
+
evaluate_custom_protocol=_elisa_evaluate_custom_protocol,
|
| 352 |
+
protocol_param_schema=ELISA_PROTOCOL_SCHEMA,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# ---------------------------------------------------------------------------
|
| 357 |
+
# Workflow ID -> spec registry (for UI / API)
|
| 358 |
+
# ---------------------------------------------------------------------------
|
| 359 |
+
|
| 360 |
+
def get_spec_for_workflow(workflow_id: str) -> ExperimentSpec:
|
| 361 |
+
"""Return the experiment spec for a given workflow ID. Unknown IDs default to PCR."""
|
| 362 |
+
_registry: dict[str, Callable[[], ExperimentSpec]] = {
|
| 363 |
+
"pcr-amplification": pcr_experiment_spec,
|
| 364 |
+
"elisa-readout": elisa_experiment_spec,
|
| 365 |
+
}
|
| 366 |
+
factory = _registry.get(workflow_id) or pcr_experiment_spec
|
| 367 |
+
return factory()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "simlab"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Lab Automation RL Environment — a Gymnasium-style simulated wet-lab for agentic RL training"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
dependencies = [
|
| 13 |
+
"numpy>=1.24",
|
| 14 |
+
"torch>=2.0",
|
| 15 |
+
"gymnasium>=0.29",
|
| 16 |
+
"openenv-core>=0.2.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
dev = ["pytest", "ruff"]
|
| 21 |
+
demo = ["openai", "streamlit"]
|
| 22 |
+
|
| 23 |
+
[tool.setuptools.packages.find]
|
| 24 |
+
include = ["lab_env*", "agents*"]
|
scripts/compare_all_agents.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Benchmark Naive, RL, and Research LLM agents on the same eval seeds."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 11 |
+
|
| 12 |
+
from lab_env.env import LabEnv, INITIAL_BUDGET
|
| 13 |
+
from agents.naive_agent import NaiveAgent
|
| 14 |
+
from agents.rl_agent import ReinforceAgent
|
| 15 |
+
from agents.research_llm_agent import ResearchLLMAgent
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
|
| 19 |
+
obs, info = env.reset(seed=seed)
|
| 20 |
+
agent.reset()
|
| 21 |
+
total_reward = 0.0
|
| 22 |
+
steps = 0
|
| 23 |
+
while True:
|
| 24 |
+
action = agent.select_action(obs)
|
| 25 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 26 |
+
total_reward += reward
|
| 27 |
+
steps += 1
|
| 28 |
+
if terminated or truncated:
|
| 29 |
+
break
|
| 30 |
+
return {
|
| 31 |
+
"reward": total_reward,
|
| 32 |
+
"success": info["best_result"] == "success",
|
| 33 |
+
"partial": info["best_result"] == "partial",
|
| 34 |
+
"minutes": info["elapsed_minutes"],
|
| 35 |
+
"cost": INITIAL_BUDGET - info["remaining_budget"],
|
| 36 |
+
"steps": steps,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def aggregate(results: list[dict]) -> dict:
|
| 41 |
+
n = len(results)
|
| 42 |
+
successes = [r["success"] for r in results]
|
| 43 |
+
steps_to_success = [r["steps"] for r in results if r["success"]] or [0]
|
| 44 |
+
return {
|
| 45 |
+
"n": n,
|
| 46 |
+
"avg_reward": sum(r["reward"] for r in results) / n,
|
| 47 |
+
"success_rate": sum(successes) / n,
|
| 48 |
+
"partial_rate": sum(r["partial"] for r in results) / n,
|
| 49 |
+
"avg_minutes": sum(r["minutes"] for r in results) / n,
|
| 50 |
+
"avg_cost": sum(r["cost"] for r in results) / n,
|
| 51 |
+
"avg_steps": sum(r["steps"] for r in results) / n,
|
| 52 |
+
"experiments_to_success": sum(steps_to_success) / len(steps_to_success) if steps_to_success else 0,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main() -> None:
|
| 57 |
+
parser = argparse.ArgumentParser(description="Compare Naive, RL, and Research LLM agents")
|
| 58 |
+
parser.add_argument("--eval-episodes", type=int, default=50, help="Episodes per agent (eval)")
|
| 59 |
+
parser.add_argument("--train-episodes", type=int, default=500, help="RL training episodes before eval")
|
| 60 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 61 |
+
parser.add_argument("--max-trials", type=int, default=5, help="Max trials per episode (RL and LLM)")
|
| 62 |
+
parser.add_argument("--no-llm", action="store_true", help="Skip LLM agent (no API key)")
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
eval_seed_base = 100_000 + args.seed
|
| 66 |
+
env = LabEnv()
|
| 67 |
+
|
| 68 |
+
# ---- Naive ----
|
| 69 |
+
print("Running Naive agent...")
|
| 70 |
+
naive_agent = NaiveAgent(num_trials=3, seed=args.seed)
|
| 71 |
+
naive_results = [
|
| 72 |
+
run_episode_naive(env, naive_agent, eval_seed_base + i)
|
| 73 |
+
for i in range(args.eval_episodes)
|
| 74 |
+
]
|
| 75 |
+
naive_stats = aggregate(naive_results)
|
| 76 |
+
|
| 77 |
+
# ---- RL (train then eval) ----
|
| 78 |
+
print("Training REINFORCE agent...")
|
| 79 |
+
rl_agent = ReinforceAgent(max_trials=args.max_trials)
|
| 80 |
+
for ep in range(1, args.train_episodes + 1):
|
| 81 |
+
rl_agent.run_episode(env, seed=args.seed + ep, train=True)
|
| 82 |
+
if ep % 100 == 0:
|
| 83 |
+
print(f" RL train episode {ep}/{args.train_episodes}")
|
| 84 |
+
print("Evaluating REINFORCE agent...")
|
| 85 |
+
rl_results = [
|
| 86 |
+
rl_agent.run_episode(env, seed=eval_seed_base + i, train=False)
|
| 87 |
+
for i in range(args.eval_episodes)
|
| 88 |
+
]
|
| 89 |
+
rl_stats = aggregate(rl_results)
|
| 90 |
+
|
| 91 |
+
# ---- Research LLM ----
|
| 92 |
+
llm_stats = None
|
| 93 |
+
if not args.no_llm:
|
| 94 |
+
print("Running Research LLM agent...")
|
| 95 |
+
try:
|
| 96 |
+
llm_agent = ResearchLLMAgent(max_trials=args.max_trials)
|
| 97 |
+
llm_results = [
|
| 98 |
+
llm_agent.run_episode(env, seed=eval_seed_base + i)
|
| 99 |
+
for i in range(args.eval_episodes)
|
| 100 |
+
]
|
| 101 |
+
llm_stats = aggregate(llm_results)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f" Skipping LLM agent: {e}")
|
| 104 |
+
|
| 105 |
+
env.close()
|
| 106 |
+
|
| 107 |
+
# ---- Table ----
|
| 108 |
+
header = f"{'Metric':<22} {'Naive':>12} {'RL (MLP)':>12}"
|
| 109 |
+
if llm_stats is not None:
|
| 110 |
+
header += f" {'LLM Researcher':>14}"
|
| 111 |
+
sep = "-" * len(header)
|
| 112 |
+
print()
|
| 113 |
+
print(sep)
|
| 114 |
+
print(" Agent comparison (same eval seeds)")
|
| 115 |
+
print(sep)
|
| 116 |
+
print(header)
|
| 117 |
+
print(sep)
|
| 118 |
+
|
| 119 |
+
def row(label: str, n_val: str, r_val: str, l_val: str | None = None) -> None:
|
| 120 |
+
line = f"{label:<22} {n_val:>12} {r_val:>12}"
|
| 121 |
+
if l_val is not None:
|
| 122 |
+
line += f" {l_val:>14}"
|
| 123 |
+
print(line)
|
| 124 |
+
|
| 125 |
+
row("Success rate", f"{naive_stats['success_rate']:.1%}", f"{rl_stats['success_rate']:.1%}",
|
| 126 |
+
f"{llm_stats['success_rate']:.1%}" if llm_stats else None)
|
| 127 |
+
row("Experiments to success", f"{naive_stats['experiments_to_success']:.1f}", f"{rl_stats['experiments_to_success']:.1f}",
|
| 128 |
+
f"{llm_stats['experiments_to_success']:.1f}" if llm_stats else None)
|
| 129 |
+
row("Cost/episode", f"${naive_stats['avg_cost']:.1f}", f"${rl_stats['avg_cost']:.1f}",
|
| 130 |
+
f"${llm_stats['avg_cost']:.1f}" if llm_stats else None)
|
| 131 |
+
row("Avg reward", f"{naive_stats['avg_reward']:.1f}", f"{rl_stats['avg_reward']:.1f}",
|
| 132 |
+
f"{llm_stats['avg_reward']:.1f}" if llm_stats else None)
|
| 133 |
+
row("Avg steps", f"{naive_stats['avg_steps']:.1f}", f"{rl_stats['avg_steps']:.1f}",
|
| 134 |
+
f"{llm_stats['avg_steps']:.1f}" if llm_stats else None)
|
| 135 |
+
print(sep)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|
scripts/demo_hackathon.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Hackathon live demo — start API + remind steps.
|
| 3 |
+
# Run from repo root: ./scripts/demo_hackathon.sh
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
cd "$(dirname "$0")/.."
|
| 7 |
+
|
| 8 |
+
echo "=== SimLab Hackathon Demo ==="
|
| 9 |
+
echo ""
|
| 10 |
+
echo "1. Start the API (leave this running):"
|
| 11 |
+
echo " uvicorn server.app:app --host 0.0.0.0 --port 8000"
|
| 12 |
+
echo ""
|
| 13 |
+
echo "2. In another terminal, start the UI:"
|
| 14 |
+
echo " cd v0ap && pnpm dev"
|
| 15 |
+
echo ""
|
| 16 |
+
echo "3. Open http://localhost:3000"
|
| 17 |
+
echo " - Training: /training (set 500 episodes, Start Training, show chart + comparison)"
|
| 18 |
+
echo " - Lab run: /workflows/pcr-amplification (Run Naive → Run with AI Agent)"
|
| 19 |
+
echo ""
|
| 20 |
+
read -p "Start API in this terminal now? [y/N] " -n 1 -r
|
| 21 |
+
echo
|
| 22 |
+
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
| 23 |
+
exec uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 24 |
+
fi
|
scripts/demo_research_agent.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run 1–2 episodes of the research LLM agent with verbose terminal output."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 11 |
+
|
| 12 |
+
from lab_env.env import LabEnv
|
| 13 |
+
from agents.research_llm_agent import ResearchLLMAgent
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main() -> None:
|
| 17 |
+
parser = argparse.ArgumentParser(description="Demo: Research LLM agent (verbose)")
|
| 18 |
+
parser.add_argument("--episodes", type=int, default=2, help="Number of episodes to run")
|
| 19 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 20 |
+
parser.add_argument("--max-trials", type=int, default=5, help="Max trials per episode")
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
env = LabEnv()
|
| 24 |
+
agent = ResearchLLMAgent(max_trials=args.max_trials)
|
| 25 |
+
|
| 26 |
+
print("=" * 60)
|
| 27 |
+
print(" Research LLM Agent — Self-Improving Lab Scientist Demo")
|
| 28 |
+
print("=" * 60)
|
| 29 |
+
|
| 30 |
+
for ep in range(1, args.episodes + 1):
|
| 31 |
+
print(f"\n--- Episode {ep}/{args.episodes} (seed={args.seed + ep}) ---")
|
| 32 |
+
callback: list[dict] = []
|
| 33 |
+
result = agent.run_episode(env, seed=args.seed + ep, verbose=True, episode_callback=callback)
|
| 34 |
+
for step in callback:
|
| 35 |
+
print(f" Trial {step['trial']}: hypothesis {step['hypothesis']} -> ran {step['params_used']} -> {step['result']}")
|
| 36 |
+
print(f" Outcome: {'SUCCESS' if result['success'] else 'partial' if result['partial'] else 'fail'}")
|
| 37 |
+
print(f" Reward: {result['reward']:.1f} Cost: ${result['cost']:.1f} Steps: {result['steps']}")
|
| 38 |
+
print(f" Knowledge: temp_range={agent.knowledge['temp_range']}, cycle_range={agent.knowledge['cycle_range']}")
|
| 39 |
+
|
| 40 |
+
env.close()
|
| 41 |
+
print("\nDone.")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
scripts/run_naive_baseline.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run the naive baseline agent on LabEnv and report aggregate metrics."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 11 |
+
|
| 12 |
+
from lab_env.env import LabEnv, INITIAL_BUDGET
|
| 13 |
+
from agents.naive_agent import NaiveAgent
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def run_episode(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
|
| 17 |
+
obs, info = env.reset(seed=seed)
|
| 18 |
+
agent.reset()
|
| 19 |
+
|
| 20 |
+
total_reward = 0.0
|
| 21 |
+
steps = 0
|
| 22 |
+
|
| 23 |
+
while True:
|
| 24 |
+
action = agent.select_action(obs)
|
| 25 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 26 |
+
total_reward += reward
|
| 27 |
+
steps += 1
|
| 28 |
+
if terminated or truncated:
|
| 29 |
+
break
|
| 30 |
+
|
| 31 |
+
return {
|
| 32 |
+
"reward": total_reward,
|
| 33 |
+
"success": info["best_result"] == "success",
|
| 34 |
+
"partial": info["best_result"] == "partial",
|
| 35 |
+
"minutes": info["elapsed_minutes"],
|
| 36 |
+
"cost": INITIAL_BUDGET - info["remaining_budget"],
|
| 37 |
+
"steps": steps,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main() -> None:
|
| 42 |
+
parser = argparse.ArgumentParser(description="Naive baseline evaluation")
|
| 43 |
+
parser.add_argument("--episodes", type=int, default=200)
|
| 44 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
env = LabEnv()
|
| 48 |
+
agent = NaiveAgent(num_trials=3, seed=args.seed)
|
| 49 |
+
|
| 50 |
+
results = [run_episode(env, agent, seed=args.seed + i) for i in range(args.episodes)]
|
| 51 |
+
env.close()
|
| 52 |
+
|
| 53 |
+
rewards = [r["reward"] for r in results]
|
| 54 |
+
successes = sum(r["success"] for r in results)
|
| 55 |
+
partials = sum(r["partial"] for r in results)
|
| 56 |
+
minutes = [r["minutes"] for r in results]
|
| 57 |
+
costs = [r["cost"] for r in results]
|
| 58 |
+
steps = [r["steps"] for r in results]
|
| 59 |
+
n = len(results)
|
| 60 |
+
|
| 61 |
+
print("=" * 50)
|
| 62 |
+
print(" Naive Baseline Results")
|
| 63 |
+
print("=" * 50)
|
| 64 |
+
print(f" Episodes: {n}")
|
| 65 |
+
print(f" Avg reward: {sum(rewards) / n:8.2f}")
|
| 66 |
+
print(f" Success rate: {successes / n:8.2%}")
|
| 67 |
+
print(f" Partial rate: {partials / n:8.2%}")
|
| 68 |
+
print(f" Avg time (min): {sum(minutes) / n:8.1f}")
|
| 69 |
+
print(f" Avg cost ($): {sum(costs) / n:8.1f}")
|
| 70 |
+
print(f" Avg steps: {sum(steps) / n:8.1f}")
|
| 71 |
+
print("=" * 50)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
scripts/run_research_generate_agent.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Run the Research & Generate agent: research → generate any protocol → run → learn from feedback.
|
| 4 |
+
|
| 5 |
+
Uses env.run_assay_with_protocol() so the agent can try arbitrary parameter values
|
| 6 |
+
(not limited to presets). Feedback from each run is passed into the next trial.
|
| 7 |
+
Works with any spec that has evaluate_custom_protocol (PCR, ELISA, etc.).
|
| 8 |
+
|
| 9 |
+
Requires: pip install openai, OPENAI_API_KEY set.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 19 |
+
|
| 20 |
+
from lab_env.env import LabEnv
|
| 21 |
+
from lab_env.spec import get_spec_for_workflow
|
| 22 |
+
from agents.research_generate_agent import ResearchGenerateAgent
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main() -> None:
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description="Run Research & Generate agent (research → generate protocol → run → learn)"
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--episodes",
|
| 31 |
+
type=int,
|
| 32 |
+
default=5,
|
| 33 |
+
help="Number of episodes to run",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--workflow",
|
| 37 |
+
type=str,
|
| 38 |
+
default="pcr-amplification",
|
| 39 |
+
choices=["pcr-amplification", "elisa-readout"],
|
| 40 |
+
help="Experiment type (uses spec with custom protocol support)",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--max-trials",
|
| 44 |
+
type=int,
|
| 45 |
+
default=6,
|
| 46 |
+
help="Max protocol attempts per episode",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--seed",
|
| 50 |
+
type=int,
|
| 51 |
+
default=42,
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--verbose",
|
| 55 |
+
action="store_true",
|
| 56 |
+
help="Print each trial's protocol and result",
|
| 57 |
+
)
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
|
| 60 |
+
spec = get_spec_for_workflow(args.workflow)
|
| 61 |
+
env = LabEnv(spec=spec)
|
| 62 |
+
agent = ResearchGenerateAgent(max_trials=args.max_trials)
|
| 63 |
+
|
| 64 |
+
print(f"Research & Generate agent — workflow={args.workflow}, episodes={args.episodes}")
|
| 65 |
+
print("(Research → generate protocol → run in lab → learn from feedback)\n")
|
| 66 |
+
|
| 67 |
+
results = []
|
| 68 |
+
for ep in range(args.episodes):
|
| 69 |
+
seed = args.seed + ep * 1000
|
| 70 |
+
if args.verbose:
|
| 71 |
+
print(f"--- Episode {ep + 1} (seed={seed}) ---")
|
| 72 |
+
out = agent.run_episode(env, seed=seed, verbose=args.verbose)
|
| 73 |
+
results.append(out)
|
| 74 |
+
if not args.verbose:
|
| 75 |
+
print(
|
| 76 |
+
f" Episode {ep + 1}: reward={out['reward']:.1f}, "
|
| 77 |
+
f"success={out['success']}, protocols_tried={out['num_protocols_tried']}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
env.close()
|
| 81 |
+
|
| 82 |
+
n = len(results)
|
| 83 |
+
print("\n--- Summary ---")
|
| 84 |
+
print(f" Success rate: {sum(r['success'] for r in results) / n:.1%}")
|
| 85 |
+
print(f" Partial rate: {sum(r['partial'] for r in results) / n:.1%}")
|
| 86 |
+
print(f" Avg reward: {sum(r['reward'] for r in results) / n:.1f}")
|
| 87 |
+
print(f" Avg protocols: {sum(r['num_protocols_tried'] for r in results) / n:.1f} per episode")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|
scripts/train_and_eval_agent.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Train a REINFORCE agent on LabEnv and compare against the naive baseline."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 11 |
+
|
| 12 |
+
from lab_env.env import LabEnv, INITIAL_BUDGET
|
| 13 |
+
from agents.naive_agent import NaiveAgent
|
| 14 |
+
from agents.rl_agent import ReinforceAgent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ------------------------------------------------------------------
|
| 18 |
+
# Naive episode runner
|
| 19 |
+
# ------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
|
| 22 |
+
obs, info = env.reset(seed=seed)
|
| 23 |
+
agent.reset()
|
| 24 |
+
|
| 25 |
+
total_reward = 0.0
|
| 26 |
+
steps = 0
|
| 27 |
+
|
| 28 |
+
while True:
|
| 29 |
+
action = agent.select_action(obs)
|
| 30 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 31 |
+
total_reward += reward
|
| 32 |
+
steps += 1
|
| 33 |
+
if terminated or truncated:
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
return {
|
| 37 |
+
"reward": total_reward,
|
| 38 |
+
"success": info["best_result"] == "success",
|
| 39 |
+
"partial": info["best_result"] == "partial",
|
| 40 |
+
"minutes": info["elapsed_minutes"],
|
| 41 |
+
"cost": INITIAL_BUDGET - info["remaining_budget"],
|
| 42 |
+
"steps": steps,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
# Aggregation
|
| 48 |
+
# ------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def aggregate(results: list[dict]) -> dict:
|
| 51 |
+
n = len(results)
|
| 52 |
+
return {
|
| 53 |
+
"n": n,
|
| 54 |
+
"avg_reward": sum(r["reward"] for r in results) / n,
|
| 55 |
+
"success_rate": sum(r["success"] for r in results) / n,
|
| 56 |
+
"partial_rate": sum(r["partial"] for r in results) / n,
|
| 57 |
+
"avg_minutes": sum(r["minutes"] for r in results) / n,
|
| 58 |
+
"avg_cost": sum(r["cost"] for r in results) / n,
|
| 59 |
+
"avg_steps": sum(r["steps"] for r in results) / n,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ------------------------------------------------------------------
|
| 64 |
+
# Main
|
| 65 |
+
# ------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
def main() -> None:
|
| 68 |
+
parser = argparse.ArgumentParser(description="Train & evaluate REINFORCE agent")
|
| 69 |
+
parser.add_argument("--train-episodes", type=int, default=2000)
|
| 70 |
+
parser.add_argument("--eval-episodes", type=int, default=100)
|
| 71 |
+
parser.add_argument("--log-interval", type=int, default=100)
|
| 72 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 73 |
+
parser.add_argument("--lr", type=float, default=3e-3)
|
| 74 |
+
parser.add_argument("--gamma", type=float, default=0.99)
|
| 75 |
+
parser.add_argument("--max-trials", type=int, default=4)
|
| 76 |
+
args = parser.parse_args()
|
| 77 |
+
|
| 78 |
+
env = LabEnv()
|
| 79 |
+
rl_agent = ReinforceAgent(lr=args.lr, gamma=args.gamma, max_trials=args.max_trials)
|
| 80 |
+
|
| 81 |
+
# ---- Training ----
|
| 82 |
+
print("=" * 60)
|
| 83 |
+
print(" Training REINFORCE agent")
|
| 84 |
+
print("=" * 60)
|
| 85 |
+
|
| 86 |
+
window: list[float] = []
|
| 87 |
+
successes_window: list[bool] = []
|
| 88 |
+
for ep in range(1, args.train_episodes + 1):
|
| 89 |
+
result = rl_agent.run_episode(env, seed=args.seed + ep, train=True)
|
| 90 |
+
window.append(result["reward"])
|
| 91 |
+
successes_window.append(result["success"])
|
| 92 |
+
|
| 93 |
+
if ep % args.log_interval == 0:
|
| 94 |
+
avg = sum(window) / len(window)
|
| 95 |
+
sr = sum(successes_window) / len(successes_window)
|
| 96 |
+
print(
|
| 97 |
+
f" Episode {ep:5d} | avg reward (last {args.log_interval}): "
|
| 98 |
+
f"{avg:7.1f} | success rate: {sr:.0%}"
|
| 99 |
+
)
|
| 100 |
+
window.clear()
|
| 101 |
+
successes_window.clear()
|
| 102 |
+
|
| 103 |
+
# ---- Evaluation ----
|
| 104 |
+
print()
|
| 105 |
+
print("=" * 60)
|
| 106 |
+
print(" Evaluating on fixed seed range")
|
| 107 |
+
print("=" * 60)
|
| 108 |
+
|
| 109 |
+
eval_seed_base = 999_999
|
| 110 |
+
|
| 111 |
+
rl_results = [
|
| 112 |
+
rl_agent.run_episode(env, seed=eval_seed_base + i, train=False)
|
| 113 |
+
for i in range(args.eval_episodes)
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
naive_agent = NaiveAgent(num_trials=3, seed=0)
|
| 117 |
+
naive_results = [
|
| 118 |
+
run_episode_naive(env, naive_agent, seed=eval_seed_base + i)
|
| 119 |
+
for i in range(args.eval_episodes)
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
env.close()
|
| 123 |
+
|
| 124 |
+
rl_stats = aggregate(rl_results)
|
| 125 |
+
naive_stats = aggregate(naive_results)
|
| 126 |
+
|
| 127 |
+
header = f"{'Metric':<20s} {'REINFORCE':>12s} {'Naive':>12s}"
|
| 128 |
+
sep = "-" * len(header)
|
| 129 |
+
rows = [
|
| 130 |
+
("Avg reward", f"{rl_stats['avg_reward']:.1f}", f"{naive_stats['avg_reward']:.1f}"),
|
| 131 |
+
("Success rate", f"{rl_stats['success_rate']:.1%}", f"{naive_stats['success_rate']:.1%}"),
|
| 132 |
+
("Partial rate", f"{rl_stats['partial_rate']:.1%}", f"{naive_stats['partial_rate']:.1%}"),
|
| 133 |
+
("Avg time", f"{rl_stats['avg_minutes']:.1f}m", f"{naive_stats['avg_minutes']:.1f}m"),
|
| 134 |
+
("Avg cost", f"${rl_stats['avg_cost']:.1f}", f"${naive_stats['avg_cost']:.1f}"),
|
| 135 |
+
("Avg steps", f"{rl_stats['avg_steps']:.1f}", f"{naive_stats['avg_steps']:.1f}"),
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
print()
|
| 139 |
+
print(header)
|
| 140 |
+
print(sep)
|
| 141 |
+
for label, rl_val, naive_val in rows:
|
| 142 |
+
print(f"{label:<20s} {rl_val:>12s} {naive_val:>12s}")
|
| 143 |
+
print(sep)
|
| 144 |
+
print()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
scripts/train_per_protocol.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train a separate REINFORCE agent for each protocol set (e.g. PCR, ELISA).
|
| 4 |
+
|
| 5 |
+
Each protocol has its own presets and outcome model. Training one agent per
|
| 6 |
+
protocol gives you a policy tailored to that protocol's action/observation
|
| 7 |
+
space. Checkpoints are saved under checkpoints/<workflow_id>.pt.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout
|
| 11 |
+
python scripts/train_per_protocol.py --workflows pcr-amplification --train-episodes 1000
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 21 |
+
|
| 22 |
+
from lab_env.env import LabEnv
|
| 23 |
+
from lab_env.spec import get_spec_for_workflow
|
| 24 |
+
from agents.rl_agent import ReinforceAgent
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> None:
|
| 28 |
+
parser = argparse.ArgumentParser(
|
| 29 |
+
description="Train one RL agent per protocol set (different presets / specs)"
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--workflows",
|
| 33 |
+
nargs="+",
|
| 34 |
+
default=["pcr-amplification", "elisa-readout"],
|
| 35 |
+
help="Workflow IDs to train (each gets its own agent and checkpoint)",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--train-episodes", type=int, default=1500)
|
| 38 |
+
parser.add_argument("--eval-episodes", type=int, default=50)
|
| 39 |
+
parser.add_argument("--lr", type=float, default=3e-3)
|
| 40 |
+
parser.add_argument("--max-trials", type=int, default=4)
|
| 41 |
+
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints")
|
| 42 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
for workflow_id in args.workflows:
|
| 48 |
+
spec = get_spec_for_workflow(workflow_id)
|
| 49 |
+
env = LabEnv(spec=spec)
|
| 50 |
+
agent = ReinforceAgent(
|
| 51 |
+
lr=args.lr,
|
| 52 |
+
max_trials=args.max_trials,
|
| 53 |
+
spec=spec,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
print(f"\n{'='*60}")
|
| 57 |
+
print(f" Training for protocol: {workflow_id} (presets={spec.num_presets}, obs_dim={spec.obs_dim})")
|
| 58 |
+
print("=" * 60)
|
| 59 |
+
|
| 60 |
+
for ep in range(1, args.train_episodes + 1):
|
| 61 |
+
result = agent.run_episode(env, seed=args.seed + ep, train=True)
|
| 62 |
+
if ep % 200 == 0 or ep == args.train_episodes:
|
| 63 |
+
print(f" Episode {ep:5d} | reward: {result['reward']:7.1f} | success: {result['success']}")
|
| 64 |
+
|
| 65 |
+
checkpoint_path = Path(args.checkpoint_dir) / f"{workflow_id}.pt"
|
| 66 |
+
agent.save(str(checkpoint_path))
|
| 67 |
+
print(f" Saved checkpoint: {checkpoint_path}")
|
| 68 |
+
|
| 69 |
+
# Quick eval
|
| 70 |
+
successes = 0
|
| 71 |
+
for i in range(args.eval_episodes):
|
| 72 |
+
r = agent.run_episode(env, seed=999_000 + i, train=False)
|
| 73 |
+
successes += r["success"]
|
| 74 |
+
print(f" Eval success rate: {successes / args.eval_episodes:.0%}")
|
| 75 |
+
|
| 76 |
+
env.close()
|
| 77 |
+
|
| 78 |
+
print("\nDone. Use each checkpoint with LabEnv(spec=<same_spec>) and ReinforceAgent(spec=spec).load(path).")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
scripts/visualize.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Train, evaluate, and visualize REINFORCE vs Naive agent on LabEnv.
|
| 3 |
+
|
| 4 |
+
Produces a 2x2 figure:
|
| 5 |
+
Top-left: Training reward curve (smoothed)
|
| 6 |
+
Top-right: Training success-rate curve (smoothed)
|
| 7 |
+
Bottom-left: Final comparison bar chart (reward, success%, partial%)
|
| 8 |
+
Bottom-right: Single-episode trace showing the RL agent's actions
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import matplotlib.ticker as mticker
|
| 22 |
+
|
| 23 |
+
from lab_env.env import (
|
| 24 |
+
LabEnv,
|
| 25 |
+
INITIAL_BUDGET,
|
| 26 |
+
ACTION_SETUP_START,
|
| 27 |
+
ACTION_SETUP_END,
|
| 28 |
+
ACTION_RUN_ASSAY,
|
| 29 |
+
ACTION_ORDER_TIPS,
|
| 30 |
+
ACTION_ORDER_BUFFER,
|
| 31 |
+
ACTION_ORDER_POLYMERASE,
|
| 32 |
+
ACTION_WAIT,
|
| 33 |
+
ACTION_FINISH,
|
| 34 |
+
PRESETS,
|
| 35 |
+
)
|
| 36 |
+
from agents.naive_agent import NaiveAgent
|
| 37 |
+
from agents.rl_agent import ReinforceAgent
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def smooth(values: list[float], window: int = 50) -> np.ndarray:
|
| 41 |
+
if len(values) < window:
|
| 42 |
+
return np.array(values)
|
| 43 |
+
kernel = np.ones(window) / window
|
| 44 |
+
return np.convolve(values, kernel, mode="valid")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
|
| 48 |
+
obs, info = env.reset(seed=seed)
|
| 49 |
+
agent.reset()
|
| 50 |
+
total_reward = 0.0
|
| 51 |
+
steps = 0
|
| 52 |
+
while True:
|
| 53 |
+
action = agent.select_action(obs)
|
| 54 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 55 |
+
total_reward += reward
|
| 56 |
+
steps += 1
|
| 57 |
+
if terminated or truncated:
|
| 58 |
+
break
|
| 59 |
+
return {
|
| 60 |
+
"reward": total_reward,
|
| 61 |
+
"success": info["best_result"] == "success",
|
| 62 |
+
"partial": info["best_result"] == "partial",
|
| 63 |
+
"minutes": info["elapsed_minutes"],
|
| 64 |
+
"cost": INITIAL_BUDGET - info["remaining_budget"],
|
| 65 |
+
"steps": steps,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def trace_rl_episode(env: LabEnv, agent: ReinforceAgent, seed: int) -> list[dict]:
|
| 70 |
+
"""Run one episode and return a step-by-step trace for visualization."""
|
| 71 |
+
obs, info = env.reset(seed=seed)
|
| 72 |
+
agent.reset()
|
| 73 |
+
trace: list[dict] = []
|
| 74 |
+
|
| 75 |
+
for trial in range(agent.max_trials):
|
| 76 |
+
if agent._inventory_low(obs):
|
| 77 |
+
for act in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
|
| 78 |
+
obs, rew, done, trunc, info = env.step(act)
|
| 79 |
+
trace.append({"action": "order", "label": "Order", "result": "", "reward": rew, "minutes": info["elapsed_minutes"]})
|
| 80 |
+
if done or trunc:
|
| 81 |
+
return trace
|
| 82 |
+
|
| 83 |
+
preset = agent._select_preset(obs, deterministic=True)
|
| 84 |
+
p = PRESETS[preset]
|
| 85 |
+
label = f"Setup {p['temp']}C/{p['cycles']}cy/{p['ratio'][:4]}"
|
| 86 |
+
|
| 87 |
+
obs, rew, done, trunc, info = env.step(ACTION_SETUP_START + preset)
|
| 88 |
+
trace.append({"action": "setup", "label": label, "result": "", "reward": rew, "minutes": info["elapsed_minutes"]})
|
| 89 |
+
if done or trunc:
|
| 90 |
+
return trace
|
| 91 |
+
|
| 92 |
+
obs, rew, done, trunc, info = env.step(ACTION_RUN_ASSAY)
|
| 93 |
+
trace.append({"action": "run", "label": "Run assay", "result": info["last_result"], "reward": rew, "minutes": info["elapsed_minutes"]})
|
| 94 |
+
if done or trunc:
|
| 95 |
+
return trace
|
| 96 |
+
|
| 97 |
+
if info.get("best_result") == "success":
|
| 98 |
+
obs, rew, _, _, info = env.step(ACTION_FINISH)
|
| 99 |
+
trace.append({"action": "finish", "label": "Finish", "result": "success", "reward": rew, "minutes": info["elapsed_minutes"]})
|
| 100 |
+
return trace
|
| 101 |
+
|
| 102 |
+
if not (done or trunc):
|
| 103 |
+
obs, rew, _, _, info = env.step(ACTION_FINISH)
|
| 104 |
+
trace.append({"action": "finish", "label": "Finish", "result": info["best_result"], "reward": rew, "minutes": info["elapsed_minutes"]})
|
| 105 |
+
|
| 106 |
+
return trace
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main() -> None:
|
| 110 |
+
parser = argparse.ArgumentParser(description="Visualize training & evaluation")
|
| 111 |
+
parser.add_argument("--train-episodes", type=int, default=2000)
|
| 112 |
+
parser.add_argument("--eval-episodes", type=int, default=200)
|
| 113 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 114 |
+
parser.add_argument("--save", type=str, default="", help="Save figure to path instead of showing")
|
| 115 |
+
args = parser.parse_args()
|
| 116 |
+
|
| 117 |
+
env = LabEnv()
|
| 118 |
+
rl_agent = ReinforceAgent(max_trials=4)
|
| 119 |
+
|
| 120 |
+
# ---- Training with metric collection ----
|
| 121 |
+
print(f"Training REINFORCE for {args.train_episodes} episodes...")
|
| 122 |
+
train_rewards: list[float] = []
|
| 123 |
+
train_successes: list[float] = []
|
| 124 |
+
|
| 125 |
+
for ep in range(1, args.train_episodes + 1):
|
| 126 |
+
result = rl_agent.run_episode(env, seed=args.seed + ep, train=True)
|
| 127 |
+
train_rewards.append(result["reward"])
|
| 128 |
+
train_successes.append(float(result["success"]))
|
| 129 |
+
if ep % 500 == 0:
|
| 130 |
+
print(f" ...episode {ep}/{args.train_episodes}")
|
| 131 |
+
|
| 132 |
+
# ---- Evaluation ----
|
| 133 |
+
print(f"Evaluating both agents for {args.eval_episodes} episodes...")
|
| 134 |
+
eval_seed = 999_999
|
| 135 |
+
naive_agent = NaiveAgent(num_trials=3, seed=0)
|
| 136 |
+
|
| 137 |
+
rl_eval = [rl_agent.run_episode(env, seed=eval_seed + i, train=False) for i in range(args.eval_episodes)]
|
| 138 |
+
naive_eval = [run_episode_naive(env, naive_agent, seed=eval_seed + i) for i in range(args.eval_episodes)]
|
| 139 |
+
|
| 140 |
+
# ---- Episode trace ----
|
| 141 |
+
trace = trace_rl_episode(env, rl_agent, seed=12345)
|
| 142 |
+
|
| 143 |
+
env.close()
|
| 144 |
+
|
| 145 |
+
# ---- Aggregate ----
|
| 146 |
+
def agg(results):
|
| 147 |
+
n = len(results)
|
| 148 |
+
return {
|
| 149 |
+
"reward": sum(r["reward"] for r in results) / n,
|
| 150 |
+
"success": sum(r["success"] for r in results) / n,
|
| 151 |
+
"partial": sum(r["partial"] for r in results) / n,
|
| 152 |
+
"minutes": sum(r["minutes"] for r in results) / n,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
rl_stats = agg(rl_eval)
|
| 156 |
+
naive_stats = agg(naive_eval)
|
| 157 |
+
|
| 158 |
+
# ==================================================================
|
| 159 |
+
# Plot
|
| 160 |
+
# ==================================================================
|
| 161 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 162 |
+
fig.suptitle("SimLab — Lab Automation RL Environment", fontsize=16, fontweight="bold")
|
| 163 |
+
|
| 164 |
+
# -- Top-left: reward curve --
|
| 165 |
+
ax = axes[0, 0]
|
| 166 |
+
smoothed_r = smooth(train_rewards, window=50)
|
| 167 |
+
ax.plot(range(len(smoothed_r)), smoothed_r, color="#2563eb", linewidth=1.5)
|
| 168 |
+
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
|
| 169 |
+
ax.set_title("Training Reward (smoothed, window=50)")
|
| 170 |
+
ax.set_xlabel("Episode")
|
| 171 |
+
ax.set_ylabel("Total Episode Reward")
|
| 172 |
+
ax.grid(True, alpha=0.3)
|
| 173 |
+
|
| 174 |
+
# -- Top-right: success rate curve --
|
| 175 |
+
ax = axes[0, 1]
|
| 176 |
+
smoothed_s = smooth(train_successes, window=100) * 100
|
| 177 |
+
ax.plot(range(len(smoothed_s)), smoothed_s, color="#16a34a", linewidth=1.5)
|
| 178 |
+
ax.set_title("Training Success Rate (smoothed, window=100)")
|
| 179 |
+
ax.set_xlabel("Episode")
|
| 180 |
+
ax.set_ylabel("Success %")
|
| 181 |
+
ax.yaxis.set_major_formatter(mticker.PercentFormatter())
|
| 182 |
+
ax.set_ylim(0, 100)
|
| 183 |
+
ax.grid(True, alpha=0.3)
|
| 184 |
+
|
| 185 |
+
# -- Bottom-left: comparison bars --
|
| 186 |
+
ax = axes[1, 0]
|
| 187 |
+
metrics = ["Avg Reward", "Success %", "Partial %", "Avg Time (min)"]
|
| 188 |
+
rl_vals = [rl_stats["reward"], rl_stats["success"] * 100, rl_stats["partial"] * 100, rl_stats["minutes"]]
|
| 189 |
+
naive_vals = [naive_stats["reward"], naive_stats["success"] * 100, naive_stats["partial"] * 100, naive_stats["minutes"]]
|
| 190 |
+
|
| 191 |
+
x = np.arange(len(metrics))
|
| 192 |
+
w = 0.35
|
| 193 |
+
bars_rl = ax.bar(x - w / 2, rl_vals, w, label="REINFORCE", color="#2563eb", edgecolor="white")
|
| 194 |
+
bars_naive = ax.bar(x + w / 2, naive_vals, w, label="Naive", color="#f97316", edgecolor="white")
|
| 195 |
+
ax.set_xticks(x)
|
| 196 |
+
ax.set_xticklabels(metrics, fontsize=9)
|
| 197 |
+
ax.set_title("Evaluation Comparison")
|
| 198 |
+
ax.legend()
|
| 199 |
+
ax.grid(True, alpha=0.3, axis="y")
|
| 200 |
+
|
| 201 |
+
for bar_group in (bars_rl, bars_naive):
|
| 202 |
+
for bar in bar_group:
|
| 203 |
+
h = bar.get_height()
|
| 204 |
+
ax.annotate(f"{h:.1f}", xy=(bar.get_x() + bar.get_width() / 2, h),
|
| 205 |
+
xytext=(0, 4), textcoords="offset points",
|
| 206 |
+
ha="center", va="bottom", fontsize=8)
|
| 207 |
+
|
| 208 |
+
# -- Bottom-right: episode trace --
|
| 209 |
+
ax = axes[1, 1]
|
| 210 |
+
if trace:
|
| 211 |
+
y_labels = []
|
| 212 |
+
colors = []
|
| 213 |
+
for i, step in enumerate(trace):
|
| 214 |
+
y_labels.append(step["label"])
|
| 215 |
+
if step["result"] == "success":
|
| 216 |
+
colors.append("#16a34a")
|
| 217 |
+
elif step["result"] == "partial":
|
| 218 |
+
colors.append("#eab308")
|
| 219 |
+
elif step["result"] == "fail":
|
| 220 |
+
colors.append("#dc2626")
|
| 221 |
+
else:
|
| 222 |
+
colors.append("#6b7280")
|
| 223 |
+
|
| 224 |
+
y_pos = np.arange(len(trace))
|
| 225 |
+
minutes = [s["minutes"] for s in trace]
|
| 226 |
+
ax.barh(y_pos, minutes, color=colors, edgecolor="white", height=0.6)
|
| 227 |
+
ax.set_yticks(y_pos)
|
| 228 |
+
ax.set_yticklabels(y_labels, fontsize=8)
|
| 229 |
+
ax.invert_yaxis()
|
| 230 |
+
ax.set_xlabel("Elapsed Minutes")
|
| 231 |
+
ax.set_title("Single Episode Trace (RL Agent)")
|
| 232 |
+
|
| 233 |
+
for i, step in enumerate(trace):
|
| 234 |
+
if step["result"] in ("success", "partial", "fail"):
|
| 235 |
+
ax.annotate(step["result"], xy=(minutes[i], i),
|
| 236 |
+
xytext=(5, 0), textcoords="offset points",
|
| 237 |
+
va="center", fontsize=8, fontweight="bold",
|
| 238 |
+
color=colors[i])
|
| 239 |
+
else:
|
| 240 |
+
ax.text(0.5, 0.5, "No trace data", ha="center", va="center", transform=ax.transAxes)
|
| 241 |
+
ax.set_title("Single Episode Trace (RL Agent)")
|
| 242 |
+
|
| 243 |
+
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| 244 |
+
|
| 245 |
+
if args.save:
|
| 246 |
+
fig.savefig(args.save, dpi=150, bbox_inches="tight")
|
| 247 |
+
print(f"Saved to {args.save}")
|
| 248 |
+
else:
|
| 249 |
+
plt.show()
|
| 250 |
+
|
| 251 |
+
# Print summary
|
| 252 |
+
print()
|
| 253 |
+
print(f" REINFORCE: reward={rl_stats['reward']:.1f} success={rl_stats['success']:.1%} time={rl_stats['minutes']:.0f}m")
|
| 254 |
+
print(f" Naive: reward={naive_stats['reward']:.1f} success={naive_stats['success']:.1%} time={naive_stats['minutes']:.0f}m")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
main()
|
server/app.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server bridging the LabEnv Python backend to the Next.js frontend.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
POST /api/training/start — train the agent (SSE stream)
|
| 6 |
+
POST /api/run/ai — run one AI-agent episode
|
| 7 |
+
POST /api/run/naive — run one naive-agent episode
|
| 8 |
+
POST /api/env/reset — reset environment
|
| 9 |
+
POST /api/env/step — take one step
|
| 10 |
+
GET /api/stats — dashboard aggregate stats
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import sys
|
| 17 |
+
import time
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
from fastapi import FastAPI, Request
|
| 22 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 23 |
+
from fastapi.responses import StreamingResponse
|
| 24 |
+
from pydantic import BaseModel
|
| 25 |
+
|
| 26 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 27 |
+
|
| 28 |
+
from lab_env.env import (
|
| 29 |
+
LabEnv,
|
| 30 |
+
INITIAL_BUDGET,
|
| 31 |
+
ACTION_SETUP_START,
|
| 32 |
+
ACTION_RUN_ASSAY,
|
| 33 |
+
ACTION_ORDER_TIPS,
|
| 34 |
+
ACTION_ORDER_BUFFER,
|
| 35 |
+
ACTION_ORDER_POLYMERASE,
|
| 36 |
+
ACTION_WAIT,
|
| 37 |
+
ACTION_FINISH,
|
| 38 |
+
)
|
| 39 |
+
from lab_env.spec import pcr_experiment_spec, get_spec_for_workflow
|
| 40 |
+
from agents.naive_agent import NaiveAgent
|
| 41 |
+
from agents.rl_agent import ReinforceAgent
|
| 42 |
+
|
| 43 |
+
# Per-workflow envs (created on first use). RL agent is shared and trained on PCR.
|
| 44 |
+
_envs: dict[str, LabEnv] = {}
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
from agents.research_llm_agent import ResearchLLMAgent
|
| 48 |
+
HAS_RESEARCH_AGENT = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
ResearchLLMAgent = None
|
| 51 |
+
HAS_RESEARCH_AGENT = False
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
from agents.research_generate_agent import ResearchGenerateAgent
|
| 55 |
+
HAS_RESEARCH_GENERATE_AGENT = True
|
| 56 |
+
except ImportError:
|
| 57 |
+
ResearchGenerateAgent = None
|
| 58 |
+
HAS_RESEARCH_GENERATE_AGENT = False
|
| 59 |
+
|
| 60 |
+
app = FastAPI(title="SimLab API")
|
| 61 |
+
app.add_middleware(
|
| 62 |
+
CORSMiddleware,
|
| 63 |
+
allow_origins=["*"],
|
| 64 |
+
allow_methods=["*"],
|
| 65 |
+
allow_headers=["*"],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
rl_agent: ReinforceAgent | None = None
|
| 69 |
+
_trained_agents: dict[str, ReinforceAgent] = {} # workflow_id -> agent (for UI per-protocol training)
|
| 70 |
+
run_history: list[dict] = []
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _get_env(workflow_id: str) -> LabEnv:
|
| 74 |
+
"""Get or create LabEnv for this workflow. Uses spec from get_spec_for_workflow(workflow_id)."""
|
| 75 |
+
if workflow_id not in _envs:
|
| 76 |
+
spec = get_spec_for_workflow(workflow_id)
|
| 77 |
+
_envs[workflow_id] = LabEnv(spec=spec)
|
| 78 |
+
return _envs[workflow_id]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ──────────────────────────────────────────────
|
| 82 |
+
# Request / response models
|
| 83 |
+
# ──────────────────────────────────────────────
|
| 84 |
+
|
| 85 |
+
class TrainRequest(BaseModel):
|
| 86 |
+
episodes: int = 2000
|
| 87 |
+
lr: float = 3e-3
|
| 88 |
+
max_trials: int = 4
|
| 89 |
+
eval_episodes: int = 100
|
| 90 |
+
workflow_id: str = "pcr-amplification"
|
| 91 |
+
|
| 92 |
+
class StepRequest(BaseModel):
|
| 93 |
+
action: int
|
| 94 |
+
workflow_id: str = "pcr-amplification"
|
| 95 |
+
|
| 96 |
+
class RunRequest(BaseModel):
|
| 97 |
+
seed: int = 42
|
| 98 |
+
workflow_id: str = "pcr-amplification"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ──────────────────────────────────────────────
|
| 102 |
+
# Helpers
|
| 103 |
+
# ──────────────────────────────────────────────
|
| 104 |
+
|
| 105 |
+
def _env_state_dict(env: LabEnv) -> dict[str, Any]:
|
| 106 |
+
info = env._info()
|
| 107 |
+
return {
|
| 108 |
+
"step_index": info["step_index"],
|
| 109 |
+
"elapsed_minutes": info["elapsed_minutes"],
|
| 110 |
+
"remaining_budget": info["remaining_budget"],
|
| 111 |
+
"inventory": info["inventory"],
|
| 112 |
+
"last_result": info["last_result"],
|
| 113 |
+
"best_result": info["best_result"],
|
| 114 |
+
"max_time": 240,
|
| 115 |
+
"max_budget": 500,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _trace_episode(env: LabEnv, agent: ReinforceAgent, seed: int) -> dict:
|
| 120 |
+
"""Run an AI episode and produce a step-by-step timeline."""
|
| 121 |
+
presets = env.spec.presets
|
| 122 |
+
obs, info = env.reset(seed=seed)
|
| 123 |
+
agent.reset()
|
| 124 |
+
timeline: list[dict] = []
|
| 125 |
+
presets_tried: dict[int, str] = {}
|
| 126 |
+
|
| 127 |
+
for trial in range(agent.max_trials):
|
| 128 |
+
if agent._inventory_low(obs):
|
| 129 |
+
for act in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
|
| 130 |
+
obs, rew, done, trunc, info = env.step(act)
|
| 131 |
+
timeline.append({
|
| 132 |
+
"title": "Order Reagents",
|
| 133 |
+
"description": _order_label(act),
|
| 134 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 135 |
+
"status": "action",
|
| 136 |
+
"icon": "order",
|
| 137 |
+
})
|
| 138 |
+
if done or trunc:
|
| 139 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 140 |
+
|
| 141 |
+
preset = agent._select_preset(obs, deterministic=True)
|
| 142 |
+
p = presets[preset]
|
| 143 |
+
label = _preset_label(p)
|
| 144 |
+
|
| 145 |
+
obs, rew, done, trunc, info = env.step(ACTION_SETUP_START + preset)
|
| 146 |
+
timeline.append({
|
| 147 |
+
"title": "Setup",
|
| 148 |
+
"description": label,
|
| 149 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 150 |
+
"status": "pending",
|
| 151 |
+
"icon": "setup",
|
| 152 |
+
})
|
| 153 |
+
if done or trunc:
|
| 154 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 155 |
+
|
| 156 |
+
obs, rew, done, trunc, info = env.step(ACTION_RUN_ASSAY)
|
| 157 |
+
result = info["last_result"]
|
| 158 |
+
presets_tried[preset] = result
|
| 159 |
+
timeline.append({
|
| 160 |
+
"title": "Run Assay",
|
| 161 |
+
"description": _result_description(result),
|
| 162 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 163 |
+
"status": result,
|
| 164 |
+
"icon": "run",
|
| 165 |
+
})
|
| 166 |
+
if done or trunc:
|
| 167 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 168 |
+
|
| 169 |
+
if info.get("best_result") == "success":
|
| 170 |
+
obs, rew, _, _, info = env.step(ACTION_FINISH)
|
| 171 |
+
timeline.append({
|
| 172 |
+
"title": "Finish",
|
| 173 |
+
"description": "Experiment complete — success!",
|
| 174 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 175 |
+
"status": "success",
|
| 176 |
+
"icon": "finish",
|
| 177 |
+
})
|
| 178 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 179 |
+
|
| 180 |
+
obs, rew, _, _, info = env.step(ACTION_FINISH)
|
| 181 |
+
timeline.append({
|
| 182 |
+
"title": "Finish",
|
| 183 |
+
"description": f"Experiment complete — best: {info['best_result']}",
|
| 184 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 185 |
+
"status": info["best_result"] if info["best_result"] in ("success", "partial") else "fail",
|
| 186 |
+
"icon": "finish",
|
| 187 |
+
})
|
| 188 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _trace_naive_episode(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
|
| 192 |
+
presets = env.spec.presets
|
| 193 |
+
num_presets = len(presets)
|
| 194 |
+
obs, info = env.reset(seed=seed)
|
| 195 |
+
agent.reset()
|
| 196 |
+
timeline: list[dict] = []
|
| 197 |
+
presets_tried: dict[int, str] = {}
|
| 198 |
+
total_reward = 0.0
|
| 199 |
+
|
| 200 |
+
while True:
|
| 201 |
+
action = agent.select_action(obs)
|
| 202 |
+
obs, reward, done, trunc, info = env.step(action)
|
| 203 |
+
total_reward += reward
|
| 204 |
+
|
| 205 |
+
if ACTION_SETUP_START <= action < ACTION_SETUP_START + num_presets:
|
| 206 |
+
p = presets[action - ACTION_SETUP_START]
|
| 207 |
+
timeline.append({
|
| 208 |
+
"title": "Setup",
|
| 209 |
+
"description": _preset_label(p),
|
| 210 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 211 |
+
"status": "pending",
|
| 212 |
+
"icon": "setup",
|
| 213 |
+
})
|
| 214 |
+
elif action == ACTION_RUN_ASSAY:
|
| 215 |
+
result = info["last_result"]
|
| 216 |
+
timeline.append({
|
| 217 |
+
"title": "Run Assay",
|
| 218 |
+
"description": _result_description(result),
|
| 219 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 220 |
+
"status": result,
|
| 221 |
+
"icon": "run",
|
| 222 |
+
})
|
| 223 |
+
elif action in (ACTION_ORDER_TIPS, ACTION_ORDER_BUFFER, ACTION_ORDER_POLYMERASE):
|
| 224 |
+
timeline.append({
|
| 225 |
+
"title": "Order Reagents",
|
| 226 |
+
"description": _order_label(action),
|
| 227 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 228 |
+
"status": "action",
|
| 229 |
+
"icon": "order",
|
| 230 |
+
})
|
| 231 |
+
elif action == ACTION_FINISH:
|
| 232 |
+
timeline.append({
|
| 233 |
+
"title": "Finish",
|
| 234 |
+
"description": f"Experiment complete — best: {info['best_result']}",
|
| 235 |
+
"time": f"{info['elapsed_minutes']:.0f} min",
|
| 236 |
+
"status": info["best_result"] if info["best_result"] in ("success", "partial") else "fail",
|
| 237 |
+
"icon": "finish",
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
if done or trunc:
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _build_run_result(env: LabEnv, info: dict, timeline: list[dict], presets_tried: dict[int, str]) -> dict:
|
| 247 |
+
presets = env.spec.presets
|
| 248 |
+
spec = env.spec
|
| 249 |
+
preset_statuses = []
|
| 250 |
+
for i, p in enumerate(presets):
|
| 251 |
+
row: dict[str, Any] = {
|
| 252 |
+
"id": str(i),
|
| 253 |
+
"status": presets_tried.get(i, "untried"),
|
| 254 |
+
"label": _preset_label(p),
|
| 255 |
+
}
|
| 256 |
+
if "temp" in p:
|
| 257 |
+
row["temp"] = p["temp"]
|
| 258 |
+
row["cycles"] = p["cycles"]
|
| 259 |
+
row["ratio"] = p["ratio"]
|
| 260 |
+
if "coating_hr" in p:
|
| 261 |
+
row["coating_hr"] = p["coating_hr"]
|
| 262 |
+
row["block"] = p.get("block", "")
|
| 263 |
+
preset_statuses.append(row)
|
| 264 |
+
return {
|
| 265 |
+
"state": {
|
| 266 |
+
"elapsed_minutes": info["elapsed_minutes"],
|
| 267 |
+
"remaining_budget": info["remaining_budget"],
|
| 268 |
+
"inventory": info["inventory"],
|
| 269 |
+
"best_result": info["best_result"],
|
| 270 |
+
"max_time": getattr(spec, "max_minutes", 240),
|
| 271 |
+
"max_budget": getattr(spec, "initial_budget", 500),
|
| 272 |
+
},
|
| 273 |
+
"timeline": timeline,
|
| 274 |
+
"presets": preset_statuses,
|
| 275 |
+
"reward": float(INITIAL_BUDGET - info["remaining_budget"]),
|
| 276 |
+
"best_result": info["best_result"],
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _result_description(result: str) -> str:
|
| 281 |
+
return {"success": "Success!", "partial": "Partial — low yield", "fail": "Failed — no amplification"}.get(result, result)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _order_label(action: int) -> str:
|
| 285 |
+
return {ACTION_ORDER_TIPS: "+5 tips", ACTION_ORDER_BUFFER: "+5 buffer", ACTION_ORDER_POLYMERASE: "+3 polymerase"}.get(action, "reagents")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _preset_label(preset: dict) -> str:
|
| 289 |
+
"""Human-readable preset description for timeline/UI (PCR or ELISA)."""
|
| 290 |
+
if "coating_hr" in preset:
|
| 291 |
+
return f"{preset['coating_hr']}hr coat / {preset['temp']}°C / {preset.get('block', '')}"
|
| 292 |
+
return f"{preset.get('temp', '?')}°C / {preset.get('cycles', '?')} cyc / {preset.get('ratio', '?')}"
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _trace_research_episode(env: LabEnv, seed: int, max_trials: int = 5) -> dict:
|
| 296 |
+
"""Run Research LLM agent episode and build timeline (Research → Hypothesis → Experiment → Learn). PCR only."""
|
| 297 |
+
presets = env.spec.presets
|
| 298 |
+
if not HAS_RESEARCH_AGENT:
|
| 299 |
+
return _build_run_result(env, env._info(), [{"title": "Research agent unavailable", "description": "Install openai and set OPENAI_API_KEY", "time": "0 min", "status": "fail", "icon": "run"}], {})
|
| 300 |
+
if env.spec.name != "pcr":
|
| 301 |
+
return _build_run_result(env, env._info(), [{"title": "Research agent", "description": "Research agent is only supported for PCR workflow.", "time": "0 min", "status": "fail", "icon": "run"}], {})
|
| 302 |
+
agent = ResearchLLMAgent(max_trials=max_trials)
|
| 303 |
+
callback: list[dict] = []
|
| 304 |
+
result = agent.run_episode(env, seed=seed, episode_callback=callback)
|
| 305 |
+
info = env._info()
|
| 306 |
+
timeline: list[dict] = []
|
| 307 |
+
presets_tried: dict[int, str] = {}
|
| 308 |
+
|
| 309 |
+
for step in callback:
|
| 310 |
+
research = (step.get("research") or "")[:200]
|
| 311 |
+
if len(step.get("research") or "") > 200:
|
| 312 |
+
research += "..."
|
| 313 |
+
timeline.append({
|
| 314 |
+
"title": "Research",
|
| 315 |
+
"description": research or "Literature search for PCR protocol",
|
| 316 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 317 |
+
"status": "action",
|
| 318 |
+
"icon": "research",
|
| 319 |
+
})
|
| 320 |
+
hyp = step.get("hypothesis") or {}
|
| 321 |
+
timeline.append({
|
| 322 |
+
"title": "Hypothesis",
|
| 323 |
+
"description": f"temp={hyp.get('temp', '?')}°C, cycles={hyp.get('cycles', '?')}, ratio={hyp.get('ratio', '?')}",
|
| 324 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 325 |
+
"status": "pending",
|
| 326 |
+
"icon": "hypothesis",
|
| 327 |
+
})
|
| 328 |
+
params = step.get("params_used") or {}
|
| 329 |
+
res = step.get("result", "fail")
|
| 330 |
+
timeline.append({
|
| 331 |
+
"title": "Run Assay",
|
| 332 |
+
"description": _result_description(res),
|
| 333 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 334 |
+
"status": res,
|
| 335 |
+
"icon": "run",
|
| 336 |
+
})
|
| 337 |
+
for i, p in enumerate(presets):
|
| 338 |
+
if p.get("temp") == params.get("temp") and p.get("cycles") == params.get("cycles") and p.get("ratio") == params.get("ratio"):
|
| 339 |
+
presets_tried[i] = res
|
| 340 |
+
break
|
| 341 |
+
timeline.append({
|
| 342 |
+
"title": "Learn",
|
| 343 |
+
"description": f"temp_range={agent.knowledge.get('temp_range', [])}, cycle_range={agent.knowledge.get('cycle_range', [])}",
|
| 344 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 345 |
+
"status": "action",
|
| 346 |
+
"icon": "learn",
|
| 347 |
+
})
|
| 348 |
+
|
| 349 |
+
return _build_run_result(env, info, timeline, presets_tried)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _protocol_dict_label(protocol: dict) -> str:
|
| 353 |
+
"""Human-readable label for a protocol dict (PCR or ELISA)."""
|
| 354 |
+
if "coating_hr" in protocol:
|
| 355 |
+
return f"{protocol.get('coating_hr', '?')}hr / {protocol.get('temp', '?')}°C / {protocol.get('block', '?')}"
|
| 356 |
+
return f"{protocol.get('temp', '?')}°C / {protocol.get('cycles', '?')} cyc / {protocol.get('ratio', '?')}"
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _trace_research_generate_episode(env: LabEnv, seed: int, max_trials: int = 6) -> dict:
|
| 360 |
+
"""Run Research & Generate agent (research → generate any protocol → run → learn). Works for PCR, ELISA, etc."""
|
| 361 |
+
if not HAS_RESEARCH_GENERATE_AGENT:
|
| 362 |
+
return _build_run_result(
|
| 363 |
+
env, env._info(),
|
| 364 |
+
[{"title": "Research & Generate agent unavailable", "description": "Install openai and set OPENAI_API_KEY", "time": "0 min", "status": "fail", "icon": "run"}],
|
| 365 |
+
{},
|
| 366 |
+
)
|
| 367 |
+
if env.spec.evaluate_custom_protocol is None:
|
| 368 |
+
return _build_run_result(
|
| 369 |
+
env, env._info(),
|
| 370 |
+
[{"title": "Research & Generate", "description": "This workflow does not support custom protocols.", "time": "0 min", "status": "fail", "icon": "run"}],
|
| 371 |
+
{},
|
| 372 |
+
)
|
| 373 |
+
agent = ResearchGenerateAgent(max_trials=max_trials)
|
| 374 |
+
agent.run_episode(env, seed=seed, verbose=False)
|
| 375 |
+
info = env._info()
|
| 376 |
+
timeline: list[dict] = []
|
| 377 |
+
preset_statuses: list[dict[str, Any]] = []
|
| 378 |
+
for i, entry in enumerate(agent.feedback_history):
|
| 379 |
+
protocol = entry.get("protocol", {})
|
| 380 |
+
result = entry.get("result", "fail")
|
| 381 |
+
label = _protocol_dict_label(protocol)
|
| 382 |
+
timeline.append({
|
| 383 |
+
"title": "Research & Generate",
|
| 384 |
+
"description": f"Generated: {label}",
|
| 385 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 386 |
+
"status": "pending",
|
| 387 |
+
"icon": "research",
|
| 388 |
+
})
|
| 389 |
+
timeline.append({
|
| 390 |
+
"title": "Run Assay",
|
| 391 |
+
"description": _result_description(result),
|
| 392 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 393 |
+
"status": result,
|
| 394 |
+
"icon": "run",
|
| 395 |
+
})
|
| 396 |
+
row: dict[str, Any] = {"id": str(i), "status": result, "label": label}
|
| 397 |
+
if "temp" in protocol:
|
| 398 |
+
row["temp"] = protocol.get("temp")
|
| 399 |
+
row["cycles"] = protocol.get("cycles")
|
| 400 |
+
row["ratio"] = protocol.get("ratio", "")
|
| 401 |
+
if "coating_hr" in protocol:
|
| 402 |
+
row["coating_hr"] = protocol.get("coating_hr")
|
| 403 |
+
row["block"] = protocol.get("block", "")
|
| 404 |
+
preset_statuses.append(row)
|
| 405 |
+
timeline.append({
|
| 406 |
+
"title": "Finish",
|
| 407 |
+
"description": f"Best result: {info.get('best_result', 'none')}",
|
| 408 |
+
"time": f"{info.get('elapsed_minutes', 0):.0f} min",
|
| 409 |
+
"status": info["best_result"] if info["best_result"] in ("success", "partial") else "fail",
|
| 410 |
+
"icon": "finish",
|
| 411 |
+
})
|
| 412 |
+
return {
|
| 413 |
+
"state": {
|
| 414 |
+
"elapsed_minutes": info["elapsed_minutes"],
|
| 415 |
+
"remaining_budget": info["remaining_budget"],
|
| 416 |
+
"inventory": info["inventory"],
|
| 417 |
+
"best_result": info["best_result"],
|
| 418 |
+
"max_time": getattr(env.spec, "max_minutes", 240),
|
| 419 |
+
"max_budget": getattr(env.spec, "initial_budget", 500),
|
| 420 |
+
},
|
| 421 |
+
"timeline": timeline,
|
| 422 |
+
"presets": preset_statuses,
|
| 423 |
+
"reward": float(INITIAL_BUDGET - info["remaining_budget"]),
|
| 424 |
+
"best_result": info["best_result"],
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ──────────────────────────────────────────────
|
| 429 |
+
# Training endpoint (SSE stream)
|
| 430 |
+
# ──────────────────────────────────────────────
|
| 431 |
+
|
| 432 |
+
@app.post("/api/training/start")
|
| 433 |
+
async def training_start(req: TrainRequest):
|
| 434 |
+
global rl_agent, _trained_agents
|
| 435 |
+
|
| 436 |
+
def generate():
|
| 437 |
+
global rl_agent, _trained_agents
|
| 438 |
+
spec = get_spec_for_workflow(req.workflow_id)
|
| 439 |
+
agent = ReinforceAgent(lr=req.lr, max_trials=req.max_trials, spec=spec)
|
| 440 |
+
train_env = LabEnv(spec=spec)
|
| 441 |
+
|
| 442 |
+
window_rewards: list[float] = []
|
| 443 |
+
window_successes: list[float] = []
|
| 444 |
+
chart_data: list[dict] = []
|
| 445 |
+
log_interval = max(req.episodes // 40, 10)
|
| 446 |
+
|
| 447 |
+
for ep in range(1, req.episodes + 1):
|
| 448 |
+
result = agent.run_episode(train_env, seed=42 + ep, train=True)
|
| 449 |
+
window_rewards.append(result["reward"])
|
| 450 |
+
window_successes.append(float(result["success"]))
|
| 451 |
+
|
| 452 |
+
if ep % log_interval == 0 or ep == req.episodes:
|
| 453 |
+
avg_reward = sum(window_rewards) / len(window_rewards)
|
| 454 |
+
avg_success = sum(window_successes) / len(window_successes) * 100
|
| 455 |
+
chart_data.append({
|
| 456 |
+
"episode": ep,
|
| 457 |
+
"reward": round(avg_reward, 2),
|
| 458 |
+
"successRate": round(avg_success, 1),
|
| 459 |
+
})
|
| 460 |
+
progress = round(ep / req.episodes * 100)
|
| 461 |
+
event = {
|
| 462 |
+
"type": "progress",
|
| 463 |
+
"episode": ep,
|
| 464 |
+
"total": req.episodes,
|
| 465 |
+
"progress": progress,
|
| 466 |
+
"reward": round(avg_reward, 2),
|
| 467 |
+
"successRate": round(avg_success, 1),
|
| 468 |
+
"chartData": chart_data,
|
| 469 |
+
}
|
| 470 |
+
yield f"data: {json.dumps(event)}\n\n"
|
| 471 |
+
window_rewards.clear()
|
| 472 |
+
window_successes.clear()
|
| 473 |
+
|
| 474 |
+
rl_agent = agent
|
| 475 |
+
_trained_agents[req.workflow_id] = agent
|
| 476 |
+
|
| 477 |
+
eval_seed = 999_999
|
| 478 |
+
rl_results = [agent.run_episode(train_env, seed=eval_seed + i, train=False) for i in range(req.eval_episodes)]
|
| 479 |
+
naive = NaiveAgent(num_trials=3, seed=0)
|
| 480 |
+
naive_results = []
|
| 481 |
+
for i in range(req.eval_episodes):
|
| 482 |
+
obs, info = train_env.reset(seed=eval_seed + i)
|
| 483 |
+
naive.reset()
|
| 484 |
+
total_r = 0.0
|
| 485 |
+
while True:
|
| 486 |
+
a = naive.select_action(obs)
|
| 487 |
+
obs, r, d, t, info = train_env.step(a)
|
| 488 |
+
total_r += r
|
| 489 |
+
if d or t:
|
| 490 |
+
break
|
| 491 |
+
naive_results.append({"reward": total_r, "success": info["best_result"] == "success",
|
| 492 |
+
"partial": info["best_result"] == "partial",
|
| 493 |
+
"minutes": info["elapsed_minutes"],
|
| 494 |
+
"cost": 500.0 - info["remaining_budget"]})
|
| 495 |
+
|
| 496 |
+
train_env.close()
|
| 497 |
+
n_rl = len(rl_results)
|
| 498 |
+
n_nv = len(naive_results)
|
| 499 |
+
|
| 500 |
+
def agg(res, n):
|
| 501 |
+
return {
|
| 502 |
+
"reward": round(sum(r["reward"] for r in res) / n, 1),
|
| 503 |
+
"success": round(sum(r["success"] for r in res) / n * 100, 1),
|
| 504 |
+
"partial": round(sum(r["partial"] for r in res) / n * 100, 1),
|
| 505 |
+
"minutes": round(sum(r["minutes"] for r in res) / n, 0),
|
| 506 |
+
"cost": round(sum(r["cost"] for r in res) / n, 1),
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
rl_s = agg(rl_results, n_rl)
|
| 510 |
+
nv_s = agg(naive_results, n_nv)
|
| 511 |
+
|
| 512 |
+
def imp(rl_v, nv_v):
|
| 513 |
+
if nv_v == 0:
|
| 514 |
+
return None
|
| 515 |
+
return round((rl_v - nv_v) / abs(nv_v) * 100)
|
| 516 |
+
|
| 517 |
+
comparison = [
|
| 518 |
+
{"metric": "Avg Reward", "reinforce": rl_s["reward"], "baseline": nv_s["reward"], "improvement": imp(rl_s["reward"], nv_s["reward"]), "unit": ""},
|
| 519 |
+
{"metric": "Success Rate", "reinforce": rl_s["success"], "baseline": nv_s["success"], "improvement": imp(rl_s["success"], nv_s["success"]), "unit": "%"},
|
| 520 |
+
{"metric": "Partial Rate", "reinforce": rl_s["partial"], "baseline": nv_s["partial"], "improvement": imp(rl_s["partial"], nv_s["partial"]), "unit": "%"},
|
| 521 |
+
{"metric": "Avg Time", "reinforce": rl_s["minutes"], "baseline": nv_s["minutes"], "improvement": imp(nv_s["minutes"], rl_s["minutes"]), "unit": "min"},
|
| 522 |
+
{"metric": "Avg Cost", "reinforce": rl_s["cost"], "baseline": nv_s["cost"], "improvement": imp(nv_s["cost"], rl_s["cost"]), "unit": "$"},
|
| 523 |
+
]
|
| 524 |
+
|
| 525 |
+
final_event = {
|
| 526 |
+
"type": "done",
|
| 527 |
+
"chartData": chart_data,
|
| 528 |
+
"comparison": comparison,
|
| 529 |
+
}
|
| 530 |
+
yield f"data: {json.dumps(final_event)}\n\n"
|
| 531 |
+
|
| 532 |
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# ──────────────────────────────────────────────
|
| 536 |
+
# Run endpoints
|
| 537 |
+
# ──────────────────────────────────────────────
|
| 538 |
+
|
| 539 |
+
@app.post("/api/run/ai")
|
| 540 |
+
async def run_ai(req: RunRequest):
|
| 541 |
+
global rl_agent, _trained_agents
|
| 542 |
+
env = _get_env(req.workflow_id)
|
| 543 |
+
agent = _trained_agents.get(req.workflow_id) or rl_agent
|
| 544 |
+
if agent is None:
|
| 545 |
+
spec = get_spec_for_workflow(req.workflow_id)
|
| 546 |
+
agent = ReinforceAgent(max_trials=4, spec=spec)
|
| 547 |
+
rl_agent = agent
|
| 548 |
+
_trained_agents[req.workflow_id] = agent
|
| 549 |
+
return _trace_episode(env, agent, seed=req.seed)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
@app.post("/api/run/naive")
|
| 553 |
+
async def run_naive(req: RunRequest):
|
| 554 |
+
env = _get_env(req.workflow_id)
|
| 555 |
+
agent = NaiveAgent(num_trials=3, seed=req.seed)
|
| 556 |
+
return _trace_naive_episode(env, agent, seed=req.seed)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@app.post("/api/run/research")
|
| 560 |
+
async def run_research(req: RunRequest):
|
| 561 |
+
"""Run Research LLM agent (research → hypothesize → experiment → learn). PCR workflow only."""
|
| 562 |
+
env = _get_env(req.workflow_id)
|
| 563 |
+
return _trace_research_episode(env, seed=req.seed, max_trials=5)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
@app.post("/api/run/research-generate")
|
| 567 |
+
async def run_research_generate(req: RunRequest):
|
| 568 |
+
"""Run Research & Generate agent (research → generate any protocol → run → learn). PCR, ELISA, any spec with evaluate_custom_protocol."""
|
| 569 |
+
env = _get_env(req.workflow_id)
|
| 570 |
+
return _trace_research_generate_episode(env, seed=req.seed, max_trials=6)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
# ──────────────────────────────────────────────
|
| 574 |
+
# Step-by-step endpoint
|
| 575 |
+
# ──────────────────────────────────────────────
|
| 576 |
+
|
| 577 |
+
@app.post("/api/env/reset")
|
| 578 |
+
async def env_reset(req: RunRequest):
|
| 579 |
+
env = _get_env(req.workflow_id)
|
| 580 |
+
obs, info = env.reset(seed=req.seed)
|
| 581 |
+
return _env_state_dict(env)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
@app.post("/api/env/step")
|
| 585 |
+
async def env_step(req: StepRequest):
|
| 586 |
+
env = _get_env(req.workflow_id)
|
| 587 |
+
obs, reward, terminated, truncated, info = env.step(req.action)
|
| 588 |
+
return {
|
| 589 |
+
**_env_state_dict(env),
|
| 590 |
+
"reward": float(reward),
|
| 591 |
+
"terminated": terminated,
|
| 592 |
+
"truncated": truncated,
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
# ──────────────────────────────────────────────
|
| 597 |
+
# Stats endpoint
|
| 598 |
+
# ──────────────────────────────────────────────
|
| 599 |
+
|
| 600 |
+
@app.get("/api/stats")
|
| 601 |
+
async def get_stats():
|
| 602 |
+
n_runs = len(run_history)
|
| 603 |
+
if n_runs == 0:
|
| 604 |
+
return {
|
| 605 |
+
"active_workflows": 1,
|
| 606 |
+
"total_experiments": 0,
|
| 607 |
+
"success_rate": "—",
|
| 608 |
+
"budget_spent": "$0",
|
| 609 |
+
}
|
| 610 |
+
successes = sum(1 for r in run_history if r.get("best_result") == "success")
|
| 611 |
+
return {
|
| 612 |
+
"active_workflows": 1,
|
| 613 |
+
"total_experiments": n_runs,
|
| 614 |
+
"success_rate": f"{successes / n_runs:.0%}",
|
| 615 |
+
"budget_spent": f"${sum(r.get('cost', 0) for r in run_history):.0f}",
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
if __name__ == "__main__":
|
| 620 |
+
import uvicorn
|
| 621 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
v0ap/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v0 runtime files
|
| 2 |
+
__v0_runtime_loader.js
|
| 3 |
+
__v0_devtools.tsx
|
| 4 |
+
__v0_jsx-dev-runtime.ts
|
| 5 |
+
|
| 6 |
+
# Common ignores
|
| 7 |
+
node_modules/
|
| 8 |
+
.next/
|
| 9 |
+
.env*.local
|
| 10 |
+
.DS_Store
|
v0ap/app/docs/page.tsx
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { SidebarTrigger } from "@/components/ui/sidebar"
|
| 2 |
+
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
| 3 |
+
import { Badge } from "@/components/ui/badge"
|
| 4 |
+
import { BookOpen, Code, FlaskConical, GraduationCap, Lightbulb, Zap } from "lucide-react"
|
| 5 |
+
|
| 6 |
+
const docs = [
|
| 7 |
+
{
|
| 8 |
+
title: "Getting Started",
|
| 9 |
+
description: "Learn the basics of SimLab and run your first experiment",
|
| 10 |
+
icon: Lightbulb,
|
| 11 |
+
badge: "Beginner",
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
title: "Workflow Reference",
|
| 15 |
+
description: "Complete documentation for all available workflows",
|
| 16 |
+
icon: FlaskConical,
|
| 17 |
+
badge: "Reference",
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
title: "RL Agent Architecture",
|
| 21 |
+
description: "Deep dive into the REINFORCE algorithm implementation",
|
| 22 |
+
icon: GraduationCap,
|
| 23 |
+
badge: "Advanced",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
title: "API Documentation",
|
| 27 |
+
description: "REST API endpoints for programmatic access",
|
| 28 |
+
icon: Code,
|
| 29 |
+
badge: "Developer",
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
title: "Best Practices",
|
| 33 |
+
description: "Tips for optimizing experiment success rates",
|
| 34 |
+
icon: Zap,
|
| 35 |
+
badge: "Guide",
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
title: "OpenEnv Integration",
|
| 39 |
+
description: "Connect SimLab with the OpenEnv ecosystem",
|
| 40 |
+
icon: BookOpen,
|
| 41 |
+
badge: "Integration",
|
| 42 |
+
},
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
export default function DocsPage() {
|
| 46 |
+
return (
|
| 47 |
+
<div className="flex flex-col min-h-screen">
|
| 48 |
+
<header className="sticky top-0 z-10 flex h-14 items-center gap-4 border-b bg-background/95 backdrop-blur supports-[backdrop-filter]:bg-background/60 px-6">
|
| 49 |
+
<SidebarTrigger className="-ml-2" />
|
| 50 |
+
<div className="flex-1">
|
| 51 |
+
<h1 className="text-lg font-semibold">Documentation</h1>
|
| 52 |
+
</div>
|
| 53 |
+
</header>
|
| 54 |
+
<div className="flex-1 p-6">
|
| 55 |
+
<div className="mb-6">
|
| 56 |
+
<h2 className="text-2xl font-bold tracking-tight">Documentation</h2>
|
| 57 |
+
<p className="text-muted-foreground">
|
| 58 |
+
Learn how to use SimLab effectively
|
| 59 |
+
</p>
|
| 60 |
+
</div>
|
| 61 |
+
<div className="grid gap-4 sm:grid-cols-2 lg:grid-cols-3">
|
| 62 |
+
{docs.map((doc) => (
|
| 63 |
+
<Card
|
| 64 |
+
key={doc.title}
|
| 65 |
+
className="border-border/50 hover:border-primary/50 transition-colors cursor-pointer group"
|
| 66 |
+
>
|
| 67 |
+
<CardHeader className="pb-3">
|
| 68 |
+
<div className="flex items-center justify-between">
|
| 69 |
+
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-primary/10 text-primary group-hover:bg-primary/20 transition-colors">
|
| 70 |
+
<doc.icon className="h-5 w-5" />
|
| 71 |
+
</div>
|
| 72 |
+
<Badge variant="secondary" className="text-xs">
|
| 73 |
+
{doc.badge}
|
| 74 |
+
</Badge>
|
| 75 |
+
</div>
|
| 76 |
+
<CardTitle className="text-base mt-3">{doc.title}</CardTitle>
|
| 77 |
+
<CardDescription className="text-sm">
|
| 78 |
+
{doc.description}
|
| 79 |
+
</CardDescription>
|
| 80 |
+
</CardHeader>
|
| 81 |
+
</Card>
|
| 82 |
+
))}
|
| 83 |
+
</div>
|
| 84 |
+
</div>
|
| 85 |
+
</div>
|
| 86 |
+
)
|
| 87 |
+
}
|
v0ap/app/globals.css
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@import 'tailwindcss';
|
| 2 |
+
@import 'tw-animate-css';
|
| 3 |
+
|
| 4 |
+
@custom-variant dark (&:is(.dark *));
|
| 5 |
+
|
| 6 |
+
:root {
|
| 7 |
+
--background: oklch(0.13 0.02 260);
|
| 8 |
+
--foreground: oklch(0.95 0.01 260);
|
| 9 |
+
--card: oklch(0.16 0.02 260);
|
| 10 |
+
--card-foreground: oklch(0.95 0.01 260);
|
| 11 |
+
--popover: oklch(0.16 0.02 260);
|
| 12 |
+
--popover-foreground: oklch(0.95 0.01 260);
|
| 13 |
+
--primary: oklch(0.65 0.18 230);
|
| 14 |
+
--primary-foreground: oklch(0.98 0 0);
|
| 15 |
+
--secondary: oklch(0.22 0.02 260);
|
| 16 |
+
--secondary-foreground: oklch(0.9 0.01 260);
|
| 17 |
+
--muted: oklch(0.2 0.02 260);
|
| 18 |
+
--muted-foreground: oklch(0.6 0.02 260);
|
| 19 |
+
--accent: oklch(0.22 0.02 260);
|
| 20 |
+
--accent-foreground: oklch(0.95 0.01 260);
|
| 21 |
+
--destructive: oklch(0.55 0.22 25);
|
| 22 |
+
--destructive-foreground: oklch(0.98 0 0);
|
| 23 |
+
--border: oklch(0.25 0.02 260);
|
| 24 |
+
--input: oklch(0.22 0.02 260);
|
| 25 |
+
--ring: oklch(0.65 0.18 230);
|
| 26 |
+
--success: oklch(0.7 0.19 145);
|
| 27 |
+
--success-foreground: oklch(0.15 0.05 145);
|
| 28 |
+
--warning: oklch(0.75 0.18 85);
|
| 29 |
+
--warning-foreground: oklch(0.2 0.05 85);
|
| 30 |
+
--chart-1: oklch(0.65 0.18 230);
|
| 31 |
+
--chart-2: oklch(0.7 0.19 145);
|
| 32 |
+
--chart-3: oklch(0.75 0.18 85);
|
| 33 |
+
--chart-4: oklch(0.55 0.22 25);
|
| 34 |
+
--chart-5: oklch(0.6 0.15 280);
|
| 35 |
+
--radius: 0.625rem;
|
| 36 |
+
--sidebar: oklch(0.11 0.02 260);
|
| 37 |
+
--sidebar-foreground: oklch(0.9 0.01 260);
|
| 38 |
+
--sidebar-primary: oklch(0.65 0.18 230);
|
| 39 |
+
--sidebar-primary-foreground: oklch(0.98 0 0);
|
| 40 |
+
--sidebar-accent: oklch(0.18 0.02 260);
|
| 41 |
+
--sidebar-accent-foreground: oklch(0.95 0.01 260);
|
| 42 |
+
--sidebar-border: oklch(0.22 0.02 260);
|
| 43 |
+
--sidebar-ring: oklch(0.65 0.18 230);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.dark {
|
| 47 |
+
--background: oklch(0.13 0.02 260);
|
| 48 |
+
--foreground: oklch(0.95 0.01 260);
|
| 49 |
+
--card: oklch(0.16 0.02 260);
|
| 50 |
+
--card-foreground: oklch(0.95 0.01 260);
|
| 51 |
+
--popover: oklch(0.16 0.02 260);
|
| 52 |
+
--popover-foreground: oklch(0.95 0.01 260);
|
| 53 |
+
--primary: oklch(0.65 0.18 230);
|
| 54 |
+
--primary-foreground: oklch(0.98 0 0);
|
| 55 |
+
--secondary: oklch(0.22 0.02 260);
|
| 56 |
+
--secondary-foreground: oklch(0.9 0.01 260);
|
| 57 |
+
--muted: oklch(0.2 0.02 260);
|
| 58 |
+
--muted-foreground: oklch(0.6 0.02 260);
|
| 59 |
+
--accent: oklch(0.22 0.02 260);
|
| 60 |
+
--accent-foreground: oklch(0.95 0.01 260);
|
| 61 |
+
--destructive: oklch(0.55 0.22 25);
|
| 62 |
+
--destructive-foreground: oklch(0.98 0 0);
|
| 63 |
+
--border: oklch(0.25 0.02 260);
|
| 64 |
+
--input: oklch(0.22 0.02 260);
|
| 65 |
+
--ring: oklch(0.65 0.18 230);
|
| 66 |
+
--success: oklch(0.7 0.19 145);
|
| 67 |
+
--success-foreground: oklch(0.15 0.05 145);
|
| 68 |
+
--warning: oklch(0.75 0.18 85);
|
| 69 |
+
--warning-foreground: oklch(0.2 0.05 85);
|
| 70 |
+
--chart-1: oklch(0.65 0.18 230);
|
| 71 |
+
--chart-2: oklch(0.7 0.19 145);
|
| 72 |
+
--chart-3: oklch(0.75 0.18 85);
|
| 73 |
+
--chart-4: oklch(0.55 0.22 25);
|
| 74 |
+
--chart-5: oklch(0.6 0.15 280);
|
| 75 |
+
--sidebar: oklch(0.11 0.02 260);
|
| 76 |
+
--sidebar-foreground: oklch(0.9 0.01 260);
|
| 77 |
+
--sidebar-primary: oklch(0.65 0.18 230);
|
| 78 |
+
--sidebar-primary-foreground: oklch(0.98 0 0);
|
| 79 |
+
--sidebar-accent: oklch(0.18 0.02 260);
|
| 80 |
+
--sidebar-accent-foreground: oklch(0.95 0.01 260);
|
| 81 |
+
--sidebar-border: oklch(0.22 0.02 260);
|
| 82 |
+
--sidebar-ring: oklch(0.65 0.18 230);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
@theme inline {
|
| 86 |
+
--font-sans: 'Geist', 'Geist Fallback';
|
| 87 |
+
--font-mono: 'Geist Mono', 'Geist Mono Fallback';
|
| 88 |
+
--color-background: var(--background);
|
| 89 |
+
--color-foreground: var(--foreground);
|
| 90 |
+
--color-card: var(--card);
|
| 91 |
+
--color-card-foreground: var(--card-foreground);
|
| 92 |
+
--color-popover: var(--popover);
|
| 93 |
+
--color-popover-foreground: var(--popover-foreground);
|
| 94 |
+
--color-primary: var(--primary);
|
| 95 |
+
--color-primary-foreground: var(--primary-foreground);
|
| 96 |
+
--color-secondary: var(--secondary);
|
| 97 |
+
--color-secondary-foreground: var(--secondary-foreground);
|
| 98 |
+
--color-muted: var(--muted);
|
| 99 |
+
--color-muted-foreground: var(--muted-foreground);
|
| 100 |
+
--color-accent: var(--accent);
|
| 101 |
+
--color-accent-foreground: var(--accent-foreground);
|
| 102 |
+
--color-destructive: var(--destructive);
|
| 103 |
+
--color-destructive-foreground: var(--destructive-foreground);
|
| 104 |
+
--color-border: var(--border);
|
| 105 |
+
--color-input: var(--input);
|
| 106 |
+
--color-ring: var(--ring);
|
| 107 |
+
--color-success: var(--success);
|
| 108 |
+
--color-success-foreground: var(--success-foreground);
|
| 109 |
+
--color-warning: var(--warning);
|
| 110 |
+
--color-warning-foreground: var(--warning-foreground);
|
| 111 |
+
--color-chart-1: var(--chart-1);
|
| 112 |
+
--color-chart-2: var(--chart-2);
|
| 113 |
+
--color-chart-3: var(--chart-3);
|
| 114 |
+
--color-chart-4: var(--chart-4);
|
| 115 |
+
--color-chart-5: var(--chart-5);
|
| 116 |
+
--radius-sm: calc(var(--radius) - 4px);
|
| 117 |
+
--radius-md: calc(var(--radius) - 2px);
|
| 118 |
+
--radius-lg: var(--radius);
|
| 119 |
+
--radius-xl: calc(var(--radius) + 4px);
|
| 120 |
+
--color-sidebar: var(--sidebar);
|
| 121 |
+
--color-sidebar-foreground: var(--sidebar-foreground);
|
| 122 |
+
--color-sidebar-primary: var(--sidebar-primary);
|
| 123 |
+
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
|
| 124 |
+
--color-sidebar-accent: var(--sidebar-accent);
|
| 125 |
+
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
|
| 126 |
+
--color-sidebar-border: var(--sidebar-border);
|
| 127 |
+
--color-sidebar-ring: var(--sidebar-ring);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
@layer base {
|
| 131 |
+
* {
|
| 132 |
+
@apply border-border outline-ring/50;
|
| 133 |
+
}
|
| 134 |
+
body {
|
| 135 |
+
@apply bg-background text-foreground;
|
| 136 |
+
}
|
| 137 |
+
}
|
v0ap/app/layout.tsx
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { Metadata, Viewport } from 'next'
|
| 2 |
+
import { Geist, Geist_Mono } from 'next/font/google'
|
| 3 |
+
import { Analytics } from '@vercel/analytics/next'
|
| 4 |
+
import './globals.css'
|
| 5 |
+
|
| 6 |
+
import { ThemeProvider } from '@/components/theme-provider'
|
| 7 |
+
import { SidebarProvider } from '@/components/ui/sidebar'
|
| 8 |
+
import { AppSidebar } from '@/components/app-sidebar'
|
| 9 |
+
|
| 10 |
+
const _geist = Geist({ subsets: ["latin"] });
|
| 11 |
+
const _geistMono = Geist_Mono({ subsets: ["latin"] });
|
| 12 |
+
|
| 13 |
+
export const metadata: Metadata = {
|
| 14 |
+
title: 'SimLab - Lab Automation RL Environment',
|
| 15 |
+
description: 'AI-powered lab automation environment for optimizing wet-lab experiment workflows',
|
| 16 |
+
generator: 'v0.app',
|
| 17 |
+
icons: {
|
| 18 |
+
icon: [
|
| 19 |
+
{
|
| 20 |
+
url: '/icon-light-32x32.png',
|
| 21 |
+
media: '(prefers-color-scheme: light)',
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
url: '/icon-dark-32x32.png',
|
| 25 |
+
media: '(prefers-color-scheme: dark)',
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
url: '/icon.svg',
|
| 29 |
+
type: 'image/svg+xml',
|
| 30 |
+
},
|
| 31 |
+
],
|
| 32 |
+
apple: '/apple-icon.png',
|
| 33 |
+
},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
export const viewport: Viewport = {
|
| 37 |
+
themeColor: [
|
| 38 |
+
{ media: '(prefers-color-scheme: light)', color: '#ffffff' },
|
| 39 |
+
{ media: '(prefers-color-scheme: dark)', color: '#0f172a' },
|
| 40 |
+
],
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
export default function RootLayout({
|
| 44 |
+
children,
|
| 45 |
+
}: Readonly<{
|
| 46 |
+
children: React.ReactNode
|
| 47 |
+
}>) {
|
| 48 |
+
return (
|
| 49 |
+
<html lang="en" suppressHydrationWarning>
|
| 50 |
+
<body className="font-sans antialiased">
|
| 51 |
+
<ThemeProvider
|
| 52 |
+
attribute="class"
|
| 53 |
+
defaultTheme="dark"
|
| 54 |
+
enableSystem
|
| 55 |
+
disableTransitionOnChange
|
| 56 |
+
>
|
| 57 |
+
<SidebarProvider>
|
| 58 |
+
<AppSidebar />
|
| 59 |
+
<main className="flex-1 overflow-auto">
|
| 60 |
+
{children}
|
| 61 |
+
</main>
|
| 62 |
+
</SidebarProvider>
|
| 63 |
+
</ThemeProvider>
|
| 64 |
+
<Analytics />
|
| 65 |
+
</body>
|
| 66 |
+
</html>
|
| 67 |
+
)
|
| 68 |
+
}
|
v0ap/app/page.tsx
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { StatsCards } from "@/components/dashboard/stats-cards"
|
| 2 |
+
import { PerformanceChart } from "@/components/dashboard/performance-chart"
|
| 3 |
+
import { RecentExperiments } from "@/components/dashboard/recent-experiments"
|
| 4 |
+
import { SidebarTrigger } from "@/components/ui/sidebar"
|
| 5 |
+
|
| 6 |
+
export default function DashboardPage() {
|
| 7 |
+
return (
|
| 8 |
+
<div className="flex flex-col min-h-screen">
|
| 9 |
+
<header className="sticky top-0 z-10 flex h-14 items-center gap-4 border-b bg-background/95 backdrop-blur supports-[backdrop-filter]:bg-background/60 px-6">
|
| 10 |
+
<SidebarTrigger className="-ml-2" />
|
| 11 |
+
<div className="flex-1">
|
| 12 |
+
<h1 className="text-lg font-semibold">Dashboard</h1>
|
| 13 |
+
</div>
|
| 14 |
+
</header>
|
| 15 |
+
<div className="flex-1 p-6 space-y-6">
|
| 16 |
+
<StatsCards />
|
| 17 |
+
<div className="grid gap-6 lg:grid-cols-2">
|
| 18 |
+
<PerformanceChart />
|
| 19 |
+
<RecentExperiments />
|
| 20 |
+
</div>
|
| 21 |
+
</div>
|
| 22 |
+
</div>
|
| 23 |
+
)
|
| 24 |
+
}
|
v0ap/app/training/page.tsx
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import { useState, useCallback } from "react"
|
| 4 |
+
import { TrainingControls } from "@/components/training/training-controls"
|
| 5 |
+
import { TrainingChart } from "@/components/training/training-chart"
|
| 6 |
+
import { ComparisonTable } from "@/components/training/comparison-table"
|
| 7 |
+
import { SidebarTrigger } from "@/components/ui/sidebar"
|
| 8 |
+
|
| 9 |
+
interface ChartPoint {
|
| 10 |
+
episode: number
|
| 11 |
+
reward: number
|
| 12 |
+
successRate: number
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
interface ComparisonRow {
|
| 16 |
+
metric: string
|
| 17 |
+
reinforce: number
|
| 18 |
+
baseline: number
|
| 19 |
+
improvement: number | null
|
| 20 |
+
unit?: string
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
export default function TrainingPage() {
|
| 24 |
+
const [isTraining, setIsTraining] = useState(false)
|
| 25 |
+
const [progress, setProgress] = useState(0)
|
| 26 |
+
const [currentEpisode, setCurrentEpisode] = useState(0)
|
| 27 |
+
const [totalEpisodes, setTotalEpisodes] = useState(0)
|
| 28 |
+
const [chartData, setChartData] = useState<ChartPoint[]>([])
|
| 29 |
+
const [comparison, setComparison] = useState<ComparisonRow[]>([])
|
| 30 |
+
const [isDone, setIsDone] = useState(false)
|
| 31 |
+
|
| 32 |
+
const startTraining = useCallback(
|
| 33 |
+
async (episodes: number, lr: number, maxTrials: number, workflowId: string) => {
|
| 34 |
+
setIsTraining(true)
|
| 35 |
+
setIsDone(false)
|
| 36 |
+
setProgress(0)
|
| 37 |
+
setChartData([])
|
| 38 |
+
setComparison([])
|
| 39 |
+
setTotalEpisodes(episodes)
|
| 40 |
+
|
| 41 |
+
const res = await fetch("/api/training/start", {
|
| 42 |
+
method: "POST",
|
| 43 |
+
headers: { "Content-Type": "application/json" },
|
| 44 |
+
body: JSON.stringify({ episodes, lr, max_trials: maxTrials, workflow_id: workflowId }),
|
| 45 |
+
})
|
| 46 |
+
|
| 47 |
+
const reader = res.body?.getReader()
|
| 48 |
+
const decoder = new TextDecoder()
|
| 49 |
+
let buffer = ""
|
| 50 |
+
|
| 51 |
+
if (!reader) return
|
| 52 |
+
|
| 53 |
+
while (true) {
|
| 54 |
+
const { done, value } = await reader.read()
|
| 55 |
+
if (done) break
|
| 56 |
+
|
| 57 |
+
buffer += decoder.decode(value, { stream: true })
|
| 58 |
+
const lines = buffer.split("\n")
|
| 59 |
+
buffer = lines.pop() || ""
|
| 60 |
+
|
| 61 |
+
for (const line of lines) {
|
| 62 |
+
if (!line.startsWith("data: ")) continue
|
| 63 |
+
try {
|
| 64 |
+
const event = JSON.parse(line.slice(6))
|
| 65 |
+
if (event.type === "progress") {
|
| 66 |
+
setProgress(event.progress)
|
| 67 |
+
setCurrentEpisode(event.episode)
|
| 68 |
+
setChartData(event.chartData)
|
| 69 |
+
} else if (event.type === "done") {
|
| 70 |
+
setChartData(event.chartData)
|
| 71 |
+
setComparison(event.comparison)
|
| 72 |
+
setIsDone(true)
|
| 73 |
+
}
|
| 74 |
+
} catch {}
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
setIsTraining(false)
|
| 79 |
+
},
|
| 80 |
+
[]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return (
|
| 84 |
+
<div className="flex flex-col min-h-screen">
|
| 85 |
+
<header className="sticky top-0 z-10 flex h-14 items-center gap-4 border-b bg-background/95 backdrop-blur supports-[backdrop-filter]:bg-background/60 px-6">
|
| 86 |
+
<SidebarTrigger className="-ml-2" />
|
| 87 |
+
<div className="flex-1">
|
| 88 |
+
<h1 className="text-lg font-semibold">Training</h1>
|
| 89 |
+
</div>
|
| 90 |
+
</header>
|
| 91 |
+
<div className="flex-1 p-6 space-y-6">
|
| 92 |
+
<div className="mb-6">
|
| 93 |
+
<h2 className="text-2xl font-bold tracking-tight">Agent Training</h2>
|
| 94 |
+
<p className="text-muted-foreground">
|
| 95 |
+
Train the RL agent to optimize experiment workflows
|
| 96 |
+
</p>
|
| 97 |
+
</div>
|
| 98 |
+
<div className="grid gap-6 lg:grid-cols-[320px_1fr]">
|
| 99 |
+
<TrainingControls
|
| 100 |
+
isTraining={isTraining}
|
| 101 |
+
progress={progress}
|
| 102 |
+
currentEpisode={currentEpisode}
|
| 103 |
+
totalEpisodes={totalEpisodes}
|
| 104 |
+
onStartTraining={startTraining}
|
| 105 |
+
/>
|
| 106 |
+
<div className="space-y-6">
|
| 107 |
+
<TrainingChart data={chartData} />
|
| 108 |
+
{isDone && <ComparisonTable data={comparison} />}
|
| 109 |
+
</div>
|
| 110 |
+
</div>
|
| 111 |
+
</div>
|
| 112 |
+
</div>
|
| 113 |
+
)
|
| 114 |
+
}
|
v0ap/app/workflows/[id]/page.tsx
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import { useState, use } from "react"
|
| 4 |
+
import { EnvironmentState } from "@/components/workflow-run/environment-state"
|
| 5 |
+
import { ExperimentTimeline } from "@/components/workflow-run/experiment-timeline"
|
| 6 |
+
import { ProtocolSelection } from "@/components/workflow-run/protocol-selection"
|
| 7 |
+
import { SidebarTrigger } from "@/components/ui/sidebar"
|
| 8 |
+
|
| 9 |
+
const workflowNames: Record<string, string> = {
|
| 10 |
+
"pcr-amplification": "PCR Amplification",
|
| 11 |
+
"elisa-readout": "ELISA Readout",
|
| 12 |
+
"dna-extraction": "DNA Extraction",
|
| 13 |
+
"rna-sequencing-prep": "RNA Sequencing Prep",
|
| 14 |
+
"gel-electrophoresis": "Gel Electrophoresis",
|
| 15 |
+
"cell-culture-passage": "Cell Culture Passage",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
interface EnvState {
|
| 19 |
+
elapsed_minutes: number
|
| 20 |
+
remaining_budget: number
|
| 21 |
+
inventory: Record<string, number>
|
| 22 |
+
best_result: string
|
| 23 |
+
max_time: number
|
| 24 |
+
max_budget: number
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
interface TimelineEntry {
|
| 28 |
+
title: string
|
| 29 |
+
description: string
|
| 30 |
+
time: string
|
| 31 |
+
status: string
|
| 32 |
+
icon: string
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
interface PresetInfo {
|
| 36 |
+
id: string
|
| 37 |
+
temp: number
|
| 38 |
+
cycles: number
|
| 39 |
+
ratio: string
|
| 40 |
+
status: string
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
export default function WorkflowRunPage({
|
| 44 |
+
params,
|
| 45 |
+
}: {
|
| 46 |
+
params: Promise<{ id: string }>
|
| 47 |
+
}) {
|
| 48 |
+
const { id } = use(params)
|
| 49 |
+
const workflowName = workflowNames[id] || "Unknown Workflow"
|
| 50 |
+
|
| 51 |
+
const [envState, setEnvState] = useState<EnvState | null>(null)
|
| 52 |
+
const [timeline, setTimeline] = useState<TimelineEntry[]>([])
|
| 53 |
+
const [presets, setPresets] = useState<PresetInfo[]>([])
|
| 54 |
+
const [isRunning, setIsRunning] = useState(false)
|
| 55 |
+
|
| 56 |
+
const runAI = async () => {
|
| 57 |
+
setIsRunning(true)
|
| 58 |
+
setTimeline([])
|
| 59 |
+
setPresets([])
|
| 60 |
+
setEnvState(null)
|
| 61 |
+
|
| 62 |
+
const seed = Math.floor(Math.random() * 100000)
|
| 63 |
+
const res = await fetch("/api/run/ai", {
|
| 64 |
+
method: "POST",
|
| 65 |
+
headers: { "Content-Type": "application/json" },
|
| 66 |
+
body: JSON.stringify({ seed, workflow_id: id }),
|
| 67 |
+
})
|
| 68 |
+
const data = await res.json()
|
| 69 |
+
|
| 70 |
+
setEnvState(data.state)
|
| 71 |
+
setPresets(data.presets)
|
| 72 |
+
|
| 73 |
+
for (let i = 0; i < data.timeline.length; i++) {
|
| 74 |
+
await new Promise((r) => setTimeout(r, 400))
|
| 75 |
+
setTimeline((prev) => [...prev, data.timeline[i]])
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
setIsRunning(false)
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
const runNaive = async () => {
|
| 82 |
+
setIsRunning(true)
|
| 83 |
+
setTimeline([])
|
| 84 |
+
setPresets([])
|
| 85 |
+
setEnvState(null)
|
| 86 |
+
|
| 87 |
+
const seed = Math.floor(Math.random() * 100000)
|
| 88 |
+
const res = await fetch("/api/run/naive", {
|
| 89 |
+
method: "POST",
|
| 90 |
+
headers: { "Content-Type": "application/json" },
|
| 91 |
+
body: JSON.stringify({ seed, workflow_id: id }),
|
| 92 |
+
})
|
| 93 |
+
const data = await res.json()
|
| 94 |
+
|
| 95 |
+
setEnvState(data.state)
|
| 96 |
+
setPresets(data.presets)
|
| 97 |
+
|
| 98 |
+
for (let i = 0; i < data.timeline.length; i++) {
|
| 99 |
+
await new Promise((r) => setTimeout(r, 400))
|
| 100 |
+
setTimeline((prev) => [...prev, data.timeline[i]])
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
setIsRunning(false)
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
const runResearch = async () => {
|
| 107 |
+
setIsRunning(true)
|
| 108 |
+
setTimeline([])
|
| 109 |
+
setPresets([])
|
| 110 |
+
setEnvState(null)
|
| 111 |
+
|
| 112 |
+
const seed = Math.floor(Math.random() * 100000)
|
| 113 |
+
const res = await fetch("/api/run/research", {
|
| 114 |
+
method: "POST",
|
| 115 |
+
headers: { "Content-Type": "application/json" },
|
| 116 |
+
body: JSON.stringify({ seed, workflow_id: id }),
|
| 117 |
+
})
|
| 118 |
+
const data = await res.json()
|
| 119 |
+
|
| 120 |
+
setEnvState(data.state)
|
| 121 |
+
setPresets(data.presets)
|
| 122 |
+
|
| 123 |
+
for (let i = 0; i < data.timeline.length; i++) {
|
| 124 |
+
await new Promise((r) => setTimeout(r, 400))
|
| 125 |
+
setTimeline((prev) => [...prev, data.timeline[i]])
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
setIsRunning(false)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
const runResearchGenerate = async () => {
|
| 132 |
+
setIsRunning(true)
|
| 133 |
+
setTimeline([])
|
| 134 |
+
setPresets([])
|
| 135 |
+
setEnvState(null)
|
| 136 |
+
|
| 137 |
+
const seed = Math.floor(Math.random() * 100000)
|
| 138 |
+
const res = await fetch("/api/run/research-generate", {
|
| 139 |
+
method: "POST",
|
| 140 |
+
headers: { "Content-Type": "application/json" },
|
| 141 |
+
body: JSON.stringify({ seed, workflow_id: id }),
|
| 142 |
+
})
|
| 143 |
+
const data = await res.json()
|
| 144 |
+
|
| 145 |
+
setEnvState(data.state)
|
| 146 |
+
setPresets(data.presets)
|
| 147 |
+
|
| 148 |
+
for (let i = 0; i < data.timeline.length; i++) {
|
| 149 |
+
await new Promise((r) => setTimeout(r, 400))
|
| 150 |
+
setTimeline((prev) => [...prev, data.timeline[i]])
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
setIsRunning(false)
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
return (
|
| 157 |
+
<div className="flex flex-col min-h-screen">
|
| 158 |
+
<header className="sticky top-0 z-10 flex h-14 items-center gap-4 border-b bg-background/95 backdrop-blur supports-[backdrop-filter]:bg-background/60 px-6">
|
| 159 |
+
<SidebarTrigger className="-ml-2" />
|
| 160 |
+
<div className="flex-1">
|
| 161 |
+
<h1 className="text-lg font-semibold">{workflowName}</h1>
|
| 162 |
+
</div>
|
| 163 |
+
</header>
|
| 164 |
+
<div className="flex-1 p-6">
|
| 165 |
+
<div className="grid gap-6 lg:grid-cols-[280px_1fr_320px] h-[calc(100vh-8rem)]">
|
| 166 |
+
<EnvironmentState data={envState} />
|
| 167 |
+
<ExperimentTimeline entries={timeline} />
|
| 168 |
+
<ProtocolSelection
|
| 169 |
+
presets={presets}
|
| 170 |
+
isRunning={isRunning}
|
| 171 |
+
onRunAI={runAI}
|
| 172 |
+
onRunNaive={runNaive}
|
| 173 |
+
onRunResearch={runResearch}
|
| 174 |
+
onRunResearchGenerate={runResearchGenerate}
|
| 175 |
+
/>
|
| 176 |
+
</div>
|
| 177 |
+
</div>
|
| 178 |
+
</div>
|
| 179 |
+
)
|
| 180 |
+
}
|
v0ap/app/workflows/page.tsx
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { WorkflowGrid } from "@/components/workflows/workflow-grid"
|
| 2 |
+
import { SidebarTrigger } from "@/components/ui/sidebar"
|
| 3 |
+
|
| 4 |
+
export default function WorkflowsPage() {
|
| 5 |
+
return (
|
| 6 |
+
<div className="flex flex-col min-h-screen">
|
| 7 |
+
<header className="sticky top-0 z-10 flex h-14 items-center gap-4 border-b bg-background/95 backdrop-blur supports-[backdrop-filter]:bg-background/60 px-6">
|
| 8 |
+
<SidebarTrigger className="-ml-2" />
|
| 9 |
+
<div className="flex-1">
|
| 10 |
+
<h1 className="text-lg font-semibold">Workflows</h1>
|
| 11 |
+
</div>
|
| 12 |
+
</header>
|
| 13 |
+
<div className="flex-1 p-6">
|
| 14 |
+
<div className="mb-6">
|
| 15 |
+
<h2 className="text-2xl font-bold tracking-tight">Available Workflows</h2>
|
| 16 |
+
<p className="text-muted-foreground">
|
| 17 |
+
Select a workflow to configure and run with the AI agent
|
| 18 |
+
</p>
|
| 19 |
+
</div>
|
| 20 |
+
<WorkflowGrid />
|
| 21 |
+
</div>
|
| 22 |
+
</div>
|
| 23 |
+
)
|
| 24 |
+
}
|
v0ap/components.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "https://ui.shadcn.com/schema.json",
|
| 3 |
+
"style": "new-york",
|
| 4 |
+
"rsc": true,
|
| 5 |
+
"tsx": true,
|
| 6 |
+
"tailwind": {
|
| 7 |
+
"config": "",
|
| 8 |
+
"css": "app/globals.css",
|
| 9 |
+
"baseColor": "neutral",
|
| 10 |
+
"cssVariables": true,
|
| 11 |
+
"prefix": ""
|
| 12 |
+
},
|
| 13 |
+
"aliases": {
|
| 14 |
+
"components": "@/components",
|
| 15 |
+
"utils": "@/lib/utils",
|
| 16 |
+
"ui": "@/components/ui",
|
| 17 |
+
"lib": "@/lib",
|
| 18 |
+
"hooks": "@/hooks"
|
| 19 |
+
},
|
| 20 |
+
"iconLibrary": "lucide"
|
| 21 |
+
}
|
v0ap/components/app-sidebar.tsx
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import Link from "next/link"
|
| 4 |
+
import { usePathname } from "next/navigation"
|
| 5 |
+
import {
|
| 6 |
+
LayoutDashboard,
|
| 7 |
+
FlaskConical,
|
| 8 |
+
GraduationCap,
|
| 9 |
+
FileText,
|
| 10 |
+
Moon,
|
| 11 |
+
Sun,
|
| 12 |
+
Beaker,
|
| 13 |
+
} from "lucide-react"
|
| 14 |
+
|
| 15 |
+
import {
|
| 16 |
+
Sidebar,
|
| 17 |
+
SidebarContent,
|
| 18 |
+
SidebarFooter,
|
| 19 |
+
SidebarGroup,
|
| 20 |
+
SidebarGroupContent,
|
| 21 |
+
SidebarHeader,
|
| 22 |
+
SidebarMenu,
|
| 23 |
+
SidebarMenuButton,
|
| 24 |
+
SidebarMenuItem,
|
| 25 |
+
} from "@/components/ui/sidebar"
|
| 26 |
+
import { Button } from "@/components/ui/button"
|
| 27 |
+
import { useTheme } from "@/components/theme-provider"
|
| 28 |
+
|
| 29 |
+
const navItems = [
|
| 30 |
+
{ title: "Dashboard", href: "/", icon: LayoutDashboard },
|
| 31 |
+
{ title: "Workflows", href: "/workflows", icon: FlaskConical },
|
| 32 |
+
{ title: "Training", href: "/training", icon: GraduationCap },
|
| 33 |
+
{ title: "Docs", href: "/docs", icon: FileText },
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
export function AppSidebar() {
|
| 37 |
+
const pathname = usePathname()
|
| 38 |
+
const { theme, setTheme } = useTheme()
|
| 39 |
+
|
| 40 |
+
return (
|
| 41 |
+
<Sidebar collapsible="icon">
|
| 42 |
+
<SidebarHeader className="p-4">
|
| 43 |
+
<Link href="/" className="flex items-center gap-3 group-data-[collapsible=icon]:justify-center">
|
| 44 |
+
<div className="flex h-9 w-9 items-center justify-center rounded-lg bg-primary/10 text-primary">
|
| 45 |
+
<Beaker className="h-5 w-5" />
|
| 46 |
+
</div>
|
| 47 |
+
<span className="text-lg font-semibold tracking-tight group-data-[collapsible=icon]:hidden">
|
| 48 |
+
SimLab
|
| 49 |
+
</span>
|
| 50 |
+
</Link>
|
| 51 |
+
</SidebarHeader>
|
| 52 |
+
<SidebarContent>
|
| 53 |
+
<SidebarGroup>
|
| 54 |
+
<SidebarGroupContent>
|
| 55 |
+
<SidebarMenu>
|
| 56 |
+
{navItems.map((item) => {
|
| 57 |
+
const isActive = pathname === item.href ||
|
| 58 |
+
(item.href !== "/" && pathname.startsWith(item.href))
|
| 59 |
+
return (
|
| 60 |
+
<SidebarMenuItem key={item.title}>
|
| 61 |
+
<SidebarMenuButton asChild isActive={isActive} tooltip={item.title}>
|
| 62 |
+
<Link href={item.href}>
|
| 63 |
+
<item.icon className="h-4 w-4" />
|
| 64 |
+
<span>{item.title}</span>
|
| 65 |
+
</Link>
|
| 66 |
+
</SidebarMenuButton>
|
| 67 |
+
</SidebarMenuItem>
|
| 68 |
+
)
|
| 69 |
+
})}
|
| 70 |
+
</SidebarMenu>
|
| 71 |
+
</SidebarGroupContent>
|
| 72 |
+
</SidebarGroup>
|
| 73 |
+
</SidebarContent>
|
| 74 |
+
<SidebarFooter className="p-4">
|
| 75 |
+
<Button
|
| 76 |
+
variant="ghost"
|
| 77 |
+
size="icon"
|
| 78 |
+
onClick={() => setTheme(theme === "dark" ? "light" : "dark")}
|
| 79 |
+
className="h-8 w-8 group-data-[collapsible=icon]:mx-auto"
|
| 80 |
+
>
|
| 81 |
+
<Sun className="h-4 w-4 rotate-0 scale-100 transition-all dark:-rotate-90 dark:scale-0" />
|
| 82 |
+
<Moon className="absolute h-4 w-4 rotate-90 scale-0 transition-all dark:rotate-0 dark:scale-100" />
|
| 83 |
+
<span className="sr-only">Toggle theme</span>
|
| 84 |
+
</Button>
|
| 85 |
+
<div className="mt-2 flex items-center gap-2 text-xs text-muted-foreground group-data-[collapsible=icon]:hidden">
|
| 86 |
+
<span className="text-[10px] px-1.5 py-0.5 rounded bg-muted">Powered by OpenEnv</span>
|
| 87 |
+
</div>
|
| 88 |
+
</SidebarFooter>
|
| 89 |
+
</Sidebar>
|
| 90 |
+
)
|
| 91 |
+
}
|
v0ap/components/dashboard/performance-chart.tsx
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import {
|
| 4 |
+
Area,
|
| 5 |
+
AreaChart,
|
| 6 |
+
CartesianGrid,
|
| 7 |
+
ResponsiveContainer,
|
| 8 |
+
Tooltip,
|
| 9 |
+
XAxis,
|
| 10 |
+
YAxis,
|
| 11 |
+
} from "recharts"
|
| 12 |
+
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
| 13 |
+
import { ChartContainer, ChartTooltipContent } from "@/components/ui/chart"
|
| 14 |
+
|
| 15 |
+
const performanceData = [
|
| 16 |
+
{ episode: 0, successRate: 23 },
|
| 17 |
+
{ episode: 50, successRate: 31 },
|
| 18 |
+
{ episode: 100, successRate: 42 },
|
| 19 |
+
{ episode: 150, successRate: 48 },
|
| 20 |
+
{ episode: 200, successRate: 56 },
|
| 21 |
+
{ episode: 250, successRate: 61 },
|
| 22 |
+
{ episode: 300, successRate: 67 },
|
| 23 |
+
{ episode: 350, successRate: 72 },
|
| 24 |
+
{ episode: 400, successRate: 75 },
|
| 25 |
+
{ episode: 450, successRate: 79 },
|
| 26 |
+
{ episode: 500, successRate: 82 },
|
| 27 |
+
{ episode: 550, successRate: 84 },
|
| 28 |
+
{ episode: 600, successRate: 85 },
|
| 29 |
+
{ episode: 650, successRate: 87 },
|
| 30 |
+
{ episode: 700, successRate: 87.3 },
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
const chartConfig = {
|
| 34 |
+
successRate: {
|
| 35 |
+
label: "Success Rate",
|
| 36 |
+
color: "var(--color-success)",
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
export function PerformanceChart() {
|
| 41 |
+
return (
|
| 42 |
+
<Card className="border-border/50">
|
| 43 |
+
<CardHeader>
|
| 44 |
+
<CardTitle>Agent Performance Over Time</CardTitle>
|
| 45 |
+
<CardDescription>
|
| 46 |
+
Success rate trending upward as the RL agent learns
|
| 47 |
+
</CardDescription>
|
| 48 |
+
</CardHeader>
|
| 49 |
+
<CardContent>
|
| 50 |
+
<ChartContainer config={chartConfig} className="h-[300px] w-full">
|
| 51 |
+
<ResponsiveContainer width="100%" height="100%">
|
| 52 |
+
<AreaChart data={performanceData}>
|
| 53 |
+
<defs>
|
| 54 |
+
<linearGradient id="successGradient" x1="0" y1="0" x2="0" y2="1">
|
| 55 |
+
<stop offset="0%" stopColor="var(--color-success)" stopOpacity={0.3} />
|
| 56 |
+
<stop offset="100%" stopColor="var(--color-success)" stopOpacity={0.05} />
|
| 57 |
+
</linearGradient>
|
| 58 |
+
</defs>
|
| 59 |
+
<CartesianGrid strokeDasharray="3 3" className="stroke-border/50" />
|
| 60 |
+
<XAxis
|
| 61 |
+
dataKey="episode"
|
| 62 |
+
tickLine={false}
|
| 63 |
+
axisLine={false}
|
| 64 |
+
className="text-xs fill-muted-foreground"
|
| 65 |
+
tickFormatter={(value) => `Ep ${value}`}
|
| 66 |
+
/>
|
| 67 |
+
<YAxis
|
| 68 |
+
tickLine={false}
|
| 69 |
+
axisLine={false}
|
| 70 |
+
className="text-xs fill-muted-foreground"
|
| 71 |
+
tickFormatter={(value) => `${value}%`}
|
| 72 |
+
domain={[0, 100]}
|
| 73 |
+
/>
|
| 74 |
+
<Tooltip content={<ChartTooltipContent />} />
|
| 75 |
+
<Area
|
| 76 |
+
type="monotone"
|
| 77 |
+
dataKey="successRate"
|
| 78 |
+
stroke="var(--color-success)"
|
| 79 |
+
strokeWidth={2}
|
| 80 |
+
fill="url(#successGradient)"
|
| 81 |
+
/>
|
| 82 |
+
</AreaChart>
|
| 83 |
+
</ResponsiveContainer>
|
| 84 |
+
</ChartContainer>
|
| 85 |
+
</CardContent>
|
| 86 |
+
</Card>
|
| 87 |
+
)
|
| 88 |
+
}
|
v0ap/components/dashboard/recent-experiments.tsx
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
| 4 |
+
import { Badge } from "@/components/ui/badge"
|
| 5 |
+
import {
|
| 6 |
+
Table,
|
| 7 |
+
TableBody,
|
| 8 |
+
TableCell,
|
| 9 |
+
TableHead,
|
| 10 |
+
TableHeader,
|
| 11 |
+
TableRow,
|
| 12 |
+
} from "@/components/ui/table"
|
| 13 |
+
|
| 14 |
+
type ExperimentResult = "success" | "partial" | "fail"
|
| 15 |
+
|
| 16 |
+
interface Experiment {
|
| 17 |
+
id: string
|
| 18 |
+
workflow: string
|
| 19 |
+
preset: string
|
| 20 |
+
result: ExperimentResult
|
| 21 |
+
time: string
|
| 22 |
+
cost: string
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
const recentExperiments: Experiment[] = [
|
| 26 |
+
{ id: "1", workflow: "PCR Amplification", preset: "65°C / 30 cycles", result: "success", time: "42 min", cost: "$23.50" },
|
| 27 |
+
{ id: "2", workflow: "ELISA Readout", preset: "37°C / standard", result: "success", time: "85 min", cost: "$45.00" },
|
| 28 |
+
{ id: "3", workflow: "DNA Extraction", preset: "conservative", result: "partial", time: "38 min", cost: "$18.25" },
|
| 29 |
+
{ id: "4", workflow: "RNA Sequencing Prep", preset: "high-yield", result: "fail", time: "120 min", cost: "$89.00" },
|
| 30 |
+
{ id: "5", workflow: "Gel Electrophoresis", preset: "1% agarose", result: "success", time: "55 min", cost: "$12.00" },
|
| 31 |
+
{ id: "6", workflow: "Cell Culture Passage", preset: "70% confluence", result: "success", time: "25 min", cost: "$8.50" },
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
const resultStyles: Record<ExperimentResult, string> = {
|
| 35 |
+
success: "bg-success/10 text-success border-success/20",
|
| 36 |
+
partial: "bg-warning/10 text-warning border-warning/20",
|
| 37 |
+
fail: "bg-destructive/10 text-destructive border-destructive/20",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
export function RecentExperiments() {
|
| 41 |
+
return (
|
| 42 |
+
<Card className="border-border/50">
|
| 43 |
+
<CardHeader>
|
| 44 |
+
<CardTitle>Recent Experiments</CardTitle>
|
| 45 |
+
<CardDescription>
|
| 46 |
+
Latest experiment runs and their results
|
| 47 |
+
</CardDescription>
|
| 48 |
+
</CardHeader>
|
| 49 |
+
<CardContent>
|
| 50 |
+
<Table>
|
| 51 |
+
<TableHeader>
|
| 52 |
+
<TableRow className="border-border/50 hover:bg-transparent">
|
| 53 |
+
<TableHead className="text-muted-foreground">Workflow</TableHead>
|
| 54 |
+
<TableHead className="text-muted-foreground">Preset</TableHead>
|
| 55 |
+
<TableHead className="text-muted-foreground">Result</TableHead>
|
| 56 |
+
<TableHead className="text-muted-foreground text-right">Time</TableHead>
|
| 57 |
+
<TableHead className="text-muted-foreground text-right">Cost</TableHead>
|
| 58 |
+
</TableRow>
|
| 59 |
+
</TableHeader>
|
| 60 |
+
<TableBody>
|
| 61 |
+
{recentExperiments.map((exp) => (
|
| 62 |
+
<TableRow key={exp.id} className="border-border/50">
|
| 63 |
+
<TableCell className="font-medium">{exp.workflow}</TableCell>
|
| 64 |
+
<TableCell className="font-mono text-sm text-muted-foreground">{exp.preset}</TableCell>
|
| 65 |
+
<TableCell>
|
| 66 |
+
<Badge variant="outline" className={resultStyles[exp.result]}>
|
| 67 |
+
{exp.result}
|
| 68 |
+
</Badge>
|
| 69 |
+
</TableCell>
|
| 70 |
+
<TableCell className="text-right font-mono text-sm">{exp.time}</TableCell>
|
| 71 |
+
<TableCell className="text-right font-mono text-sm">{exp.cost}</TableCell>
|
| 72 |
+
</TableRow>
|
| 73 |
+
))}
|
| 74 |
+
</TableBody>
|
| 75 |
+
</Table>
|
| 76 |
+
</CardContent>
|
| 77 |
+
</Card>
|
| 78 |
+
)
|
| 79 |
+
}
|
v0ap/components/dashboard/stats-cards.tsx
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import { Activity, FlaskConical, TrendingUp, DollarSign } from "lucide-react"
|
| 4 |
+
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"
|
| 5 |
+
|
| 6 |
+
const stats = [
|
| 7 |
+
{
|
| 8 |
+
title: "Active Workflows",
|
| 9 |
+
value: "4",
|
| 10 |
+
description: "Currently running",
|
| 11 |
+
icon: Activity,
|
| 12 |
+
trend: "+2 from last hour",
|
| 13 |
+
color: "text-primary",
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
title: "Total Experiments",
|
| 17 |
+
value: "1,247",
|
| 18 |
+
description: "All time",
|
| 19 |
+
icon: FlaskConical,
|
| 20 |
+
trend: "+89 this week",
|
| 21 |
+
color: "text-chart-2",
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
title: "Success Rate",
|
| 25 |
+
value: "87.3%",
|
| 26 |
+
description: "Overall performance",
|
| 27 |
+
icon: TrendingUp,
|
| 28 |
+
trend: "+4.2% from baseline",
|
| 29 |
+
color: "text-success",
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
title: "Budget Spent",
|
| 33 |
+
value: "$12,450",
|
| 34 |
+
description: "This month",
|
| 35 |
+
icon: DollarSign,
|
| 36 |
+
trend: "$3,200 remaining",
|
| 37 |
+
color: "text-warning",
|
| 38 |
+
},
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
export function StatsCards() {
|
| 42 |
+
return (
|
| 43 |
+
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-4">
|
| 44 |
+
{stats.map((stat) => (
|
| 45 |
+
<Card key={stat.title} className="border-border/50">
|
| 46 |
+
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
| 47 |
+
<CardTitle className="text-sm font-medium text-muted-foreground">
|
| 48 |
+
{stat.title}
|
| 49 |
+
</CardTitle>
|
| 50 |
+
<stat.icon className={`h-4 w-4 ${stat.color}`} />
|
| 51 |
+
</CardHeader>
|
| 52 |
+
<CardContent>
|
| 53 |
+
<div className="text-2xl font-bold font-mono">{stat.value}</div>
|
| 54 |
+
<p className="text-xs text-muted-foreground mt-1">
|
| 55 |
+
{stat.trend}
|
| 56 |
+
</p>
|
| 57 |
+
</CardContent>
|
| 58 |
+
</Card>
|
| 59 |
+
))}
|
| 60 |
+
</div>
|
| 61 |
+
)
|
| 62 |
+
}
|
v0ap/components/theme-provider.tsx
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
|
| 3 |
+
import * as React from 'react'
|
| 4 |
+
import {
|
| 5 |
+
ThemeProvider as NextThemesProvider,
|
| 6 |
+
useTheme as useNextTheme,
|
| 7 |
+
type ThemeProviderProps,
|
| 8 |
+
} from 'next-themes'
|
| 9 |
+
|
| 10 |
+
export function ThemeProvider({ children, ...props }: ThemeProviderProps) {
|
| 11 |
+
return <NextThemesProvider {...props}>{children}</NextThemesProvider>
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export const useTheme = useNextTheme
|
v0ap/components/training/comparison-table.tsx
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
| 4 |
+
import { Badge } from "@/components/ui/badge"
|
| 5 |
+
import {
|
| 6 |
+
Table,
|
| 7 |
+
TableBody,
|
| 8 |
+
TableCell,
|
| 9 |
+
TableHead,
|
| 10 |
+
TableHeader,
|
| 11 |
+
TableRow,
|
| 12 |
+
} from "@/components/ui/table"
|
| 13 |
+
import { ArrowUp, ArrowDown, Minus } from "lucide-react"
|
| 14 |
+
|
| 15 |
+
interface ComparisonRow {
|
| 16 |
+
metric: string
|
| 17 |
+
reinforce: number
|
| 18 |
+
baseline: number
|
| 19 |
+
improvement: number | null
|
| 20 |
+
unit?: string
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
interface ComparisonTableProps {
|
| 24 |
+
data: ComparisonRow[]
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
function ImprovementBadge({ value }: { value: number | null }) {
|
| 28 |
+
if (value === null) {
|
| 29 |
+
return (
|
| 30 |
+
<Badge variant="outline" className="bg-muted/50 text-muted-foreground border-muted">
|
| 31 |
+
<Minus className="h-3 w-3 mr-1" />
|
| 32 |
+
N/A
|
| 33 |
+
</Badge>
|
| 34 |
+
)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
const isPositive = value > 0
|
| 38 |
+
const absValue = Math.abs(value)
|
| 39 |
+
|
| 40 |
+
return (
|
| 41 |
+
<Badge
|
| 42 |
+
variant="outline"
|
| 43 |
+
className={
|
| 44 |
+
isPositive
|
| 45 |
+
? "bg-success/10 text-success border-success/20"
|
| 46 |
+
: "bg-destructive/10 text-destructive border-destructive/20"
|
| 47 |
+
}
|
| 48 |
+
>
|
| 49 |
+
{isPositive ? (
|
| 50 |
+
<ArrowUp className="h-3 w-3 mr-1" />
|
| 51 |
+
) : (
|
| 52 |
+
<ArrowDown className="h-3 w-3 mr-1" />
|
| 53 |
+
)}
|
| 54 |
+
{absValue}%
|
| 55 |
+
</Badge>
|
| 56 |
+
)
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
export function ComparisonTable({ data }: ComparisonTableProps) {
|
| 60 |
+
if (data.length === 0) return null
|
| 61 |
+
|
| 62 |
+
return (
|
| 63 |
+
<Card className="border-border/50 animate-in fade-in slide-in-from-bottom-4 duration-500">
|
| 64 |
+
<CardHeader>
|
| 65 |
+
<CardTitle>Agent Comparison</CardTitle>
|
| 66 |
+
<CardDescription>
|
| 67 |
+
REINFORCE Agent vs Naive Baseline — evaluated on 100 episodes
|
| 68 |
+
</CardDescription>
|
| 69 |
+
</CardHeader>
|
| 70 |
+
<CardContent>
|
| 71 |
+
<Table>
|
| 72 |
+
<TableHeader>
|
| 73 |
+
<TableRow className="border-border/50 hover:bg-transparent">
|
| 74 |
+
<TableHead className="text-muted-foreground">Metric</TableHead>
|
| 75 |
+
<TableHead className="text-muted-foreground text-right">
|
| 76 |
+
<div className="flex flex-col items-end">
|
| 77 |
+
<span>REINFORCE</span>
|
| 78 |
+
<Badge variant="secondary" className="text-[10px] mt-1">Agent</Badge>
|
| 79 |
+
</div>
|
| 80 |
+
</TableHead>
|
| 81 |
+
<TableHead className="text-muted-foreground text-right">
|
| 82 |
+
<div className="flex flex-col items-end">
|
| 83 |
+
<span>Naive</span>
|
| 84 |
+
<Badge variant="outline" className="text-[10px] mt-1">Baseline</Badge>
|
| 85 |
+
</div>
|
| 86 |
+
</TableHead>
|
| 87 |
+
<TableHead className="text-muted-foreground text-right">Improvement</TableHead>
|
| 88 |
+
</TableRow>
|
| 89 |
+
</TableHeader>
|
| 90 |
+
<TableBody>
|
| 91 |
+
{data.map((row) => (
|
| 92 |
+
<TableRow key={row.metric} className="border-border/50">
|
| 93 |
+
<TableCell className="font-medium">{row.metric}</TableCell>
|
| 94 |
+
<TableCell className="text-right font-mono">
|
| 95 |
+
{row.unit === "$" && "$"}
|
| 96 |
+
{row.reinforce}
|
| 97 |
+
{row.unit === "%" && "%"}
|
| 98 |
+
{row.unit === "min" && " min"}
|
| 99 |
+
</TableCell>
|
| 100 |
+
<TableCell className="text-right font-mono text-muted-foreground">
|
| 101 |
+
{row.unit === "$" && "$"}
|
| 102 |
+
{row.baseline}
|
| 103 |
+
{row.unit === "%" && "%"}
|
| 104 |
+
{row.unit === "min" && " min"}
|
| 105 |
+
</TableCell>
|
| 106 |
+
<TableCell className="text-right">
|
| 107 |
+
<ImprovementBadge value={row.improvement} />
|
| 108 |
+
</TableCell>
|
| 109 |
+
</TableRow>
|
| 110 |
+
))}
|
| 111 |
+
</TableBody>
|
| 112 |
+
</Table>
|
| 113 |
+
</CardContent>
|
| 114 |
+
</Card>
|
| 115 |
+
)
|
| 116 |
+
}
|
v0ap/components/training/training-chart.tsx
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import {
|
| 4 |
+
Line,
|
| 5 |
+
LineChart,
|
| 6 |
+
CartesianGrid,
|
| 7 |
+
ResponsiveContainer,
|
| 8 |
+
Tooltip,
|
| 9 |
+
XAxis,
|
| 10 |
+
YAxis,
|
| 11 |
+
Legend,
|
| 12 |
+
} from "recharts"
|
| 13 |
+
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
| 14 |
+
import { ChartContainer, ChartTooltipContent } from "@/components/ui/chart"
|
| 15 |
+
|
| 16 |
+
interface ChartPoint {
|
| 17 |
+
episode: number
|
| 18 |
+
reward: number
|
| 19 |
+
successRate: number
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
interface TrainingChartProps {
|
| 23 |
+
data: ChartPoint[]
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
const chartConfig = {
|
| 27 |
+
reward: {
|
| 28 |
+
label: "Reward",
|
| 29 |
+
color: "var(--color-primary)",
|
| 30 |
+
},
|
| 31 |
+
successRate: {
|
| 32 |
+
label: "Success Rate",
|
| 33 |
+
color: "var(--color-success)",
|
| 34 |
+
},
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
export function TrainingChart({ data }: TrainingChartProps) {
|
| 38 |
+
const hasData = data.length > 0
|
| 39 |
+
|
| 40 |
+
const rewardDomain: [number, number] = hasData
|
| 41 |
+
? [
|
| 42 |
+
Math.min(...data.map((d) => d.reward)) - 5,
|
| 43 |
+
Math.max(...data.map((d) => d.reward)) + 5,
|
| 44 |
+
]
|
| 45 |
+
: [-3, 4]
|
| 46 |
+
|
| 47 |
+
return (
|
| 48 |
+
<Card className="border-border/50">
|
| 49 |
+
<CardHeader>
|
| 50 |
+
<CardTitle>Training Progress</CardTitle>
|
| 51 |
+
<CardDescription>
|
| 52 |
+
{hasData
|
| 53 |
+
? `Live reward curve and success rate — ${data[data.length - 1].episode} episodes`
|
| 54 |
+
: "Start training to see live metrics"}
|
| 55 |
+
</CardDescription>
|
| 56 |
+
</CardHeader>
|
| 57 |
+
<CardContent>
|
| 58 |
+
<ChartContainer config={chartConfig} className="h-[350px] w-full">
|
| 59 |
+
{hasData ? (
|
| 60 |
+
<ResponsiveContainer width="100%" height="100%">
|
| 61 |
+
<LineChart data={data}>
|
| 62 |
+
<CartesianGrid strokeDasharray="3 3" className="stroke-border/50" />
|
| 63 |
+
<XAxis
|
| 64 |
+
dataKey="episode"
|
| 65 |
+
tickLine={false}
|
| 66 |
+
axisLine={false}
|
| 67 |
+
className="text-xs fill-muted-foreground"
|
| 68 |
+
/>
|
| 69 |
+
<YAxis
|
| 70 |
+
yAxisId="left"
|
| 71 |
+
tickLine={false}
|
| 72 |
+
axisLine={false}
|
| 73 |
+
className="text-xs fill-muted-foreground"
|
| 74 |
+
domain={rewardDomain}
|
| 75 |
+
tickFormatter={(value) => value.toFixed(0)}
|
| 76 |
+
/>
|
| 77 |
+
<YAxis
|
| 78 |
+
yAxisId="right"
|
| 79 |
+
orientation="right"
|
| 80 |
+
tickLine={false}
|
| 81 |
+
axisLine={false}
|
| 82 |
+
className="text-xs fill-muted-foreground"
|
| 83 |
+
domain={[0, 100]}
|
| 84 |
+
tickFormatter={(value) => `${value}%`}
|
| 85 |
+
/>
|
| 86 |
+
<Tooltip content={<ChartTooltipContent />} />
|
| 87 |
+
<Legend />
|
| 88 |
+
<Line
|
| 89 |
+
yAxisId="left"
|
| 90 |
+
type="monotone"
|
| 91 |
+
dataKey="reward"
|
| 92 |
+
stroke="var(--color-primary)"
|
| 93 |
+
strokeWidth={2}
|
| 94 |
+
dot={false}
|
| 95 |
+
name="Reward"
|
| 96 |
+
isAnimationActive={false}
|
| 97 |
+
/>
|
| 98 |
+
<Line
|
| 99 |
+
yAxisId="right"
|
| 100 |
+
type="monotone"
|
| 101 |
+
dataKey="successRate"
|
| 102 |
+
stroke="var(--color-success)"
|
| 103 |
+
strokeWidth={2}
|
| 104 |
+
dot={false}
|
| 105 |
+
name="Success Rate (%)"
|
| 106 |
+
isAnimationActive={false}
|
| 107 |
+
/>
|
| 108 |
+
</LineChart>
|
| 109 |
+
</ResponsiveContainer>
|
| 110 |
+
) : (
|
| 111 |
+
<div className="h-full flex items-center justify-center text-muted-foreground">
|
| 112 |
+
Configure parameters and click Start Training
|
| 113 |
+
</div>
|
| 114 |
+
)}
|
| 115 |
+
</ChartContainer>
|
| 116 |
+
</CardContent>
|
| 117 |
+
</Card>
|
| 118 |
+
)
|
| 119 |
+
}
|
v0ap/components/training/training-controls.tsx
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"use client"
|
| 2 |
+
|
| 3 |
+
import { useState } from "react"
|
| 4 |
+
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"
|
| 5 |
+
import { Button } from "@/components/ui/button"
|
| 6 |
+
import { Slider } from "@/components/ui/slider"
|
| 7 |
+
import { Progress } from "@/components/ui/progress"
|
| 8 |
+
import { Play, Square } from "lucide-react"
|
| 9 |
+
import { Field, FieldLabel } from "@/components/ui/field"
|
| 10 |
+
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"
|
| 11 |
+
|
| 12 |
+
const WORKFLOWS = [
|
| 13 |
+
{ id: "pcr-amplification", label: "PCR Amplification" },
|
| 14 |
+
{ id: "elisa-readout", label: "ELISA Readout" },
|
| 15 |
+
] as const
|
| 16 |
+
|
| 17 |
+
interface TrainingControlsProps {
|
| 18 |
+
isTraining: boolean
|
| 19 |
+
progress: number
|
| 20 |
+
currentEpisode: number
|
| 21 |
+
totalEpisodes: number
|
| 22 |
+
onStartTraining: (episodes: number, lr: number, maxTrials: number, workflowId: string) => void
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
export function TrainingControls({
|
| 26 |
+
isTraining,
|
| 27 |
+
progress,
|
| 28 |
+
currentEpisode,
|
| 29 |
+
totalEpisodes,
|
| 30 |
+
onStartTraining,
|
| 31 |
+
}: TrainingControlsProps) {
|
| 32 |
+
const [episodes, setEpisodes] = useState([1000])
|
| 33 |
+
const [learningRate, setLearningRate] = useState([0.003])
|
| 34 |
+
const [maxTrials, setMaxTrials] = useState([4])
|
| 35 |
+
const [workflowId, setWorkflowId] = useState<string>("pcr-amplification")
|
| 36 |
+
|
| 37 |
+
return (
|
| 38 |
+
<Card className="border-border/50 h-fit">
|
| 39 |
+
<CardHeader className="pb-4">
|
| 40 |
+
<CardTitle className="text-base">Training Configuration</CardTitle>
|
| 41 |
+
</CardHeader>
|
| 42 |
+
<CardContent className="space-y-6">
|
| 43 |
+
<Field>
|
| 44 |
+
<FieldLabel>Protocol</FieldLabel>
|
| 45 |
+
<Select value={workflowId} onValueChange={setWorkflowId} disabled={isTraining}>
|
| 46 |
+
<SelectTrigger className="mt-2">
|
| 47 |
+
<SelectValue placeholder="Select protocol" />
|
| 48 |
+
</SelectTrigger>
|
| 49 |
+
<SelectContent>
|
| 50 |
+
{WORKFLOWS.map((w) => (
|
| 51 |
+
<SelectItem key={w.id} value={w.id}>
|
| 52 |
+
{w.label}
|
| 53 |
+
</SelectItem>
|
| 54 |
+
))}
|
| 55 |
+
</SelectContent>
|
| 56 |
+
</Select>
|
| 57 |
+
</Field>
|
| 58 |
+
<Field>
|
| 59 |
+
<div className="flex items-center justify-between">
|
| 60 |
+
<FieldLabel>Number of Episodes</FieldLabel>
|
| 61 |
+
<span className="text-sm font-mono text-muted-foreground">{episodes[0]}</span>
|
| 62 |
+
</div>
|
| 63 |
+
<Slider
|
| 64 |
+
value={episodes}
|
| 65 |
+
onValueChange={setEpisodes}
|
| 66 |
+
min={100}
|
| 67 |
+
max={5000}
|
| 68 |
+
step={100}
|
| 69 |
+
className="mt-2"
|
| 70 |
+
disabled={isTraining}
|
| 71 |
+
/>
|
| 72 |
+
</Field>
|
| 73 |
+
|
| 74 |
+
<Field>
|
| 75 |
+
<div className="flex items-center justify-between">
|
| 76 |
+
<FieldLabel>Learning Rate</FieldLabel>
|
| 77 |
+
<span className="text-sm font-mono text-muted-foreground">{learningRate[0].toFixed(4)}</span>
|
| 78 |
+
</div>
|
| 79 |
+
<Slider
|
| 80 |
+
value={learningRate}
|
| 81 |
+
onValueChange={setLearningRate}
|
| 82 |
+
min={0.0001}
|
| 83 |
+
max={0.01}
|
| 84 |
+
step={0.0001}
|
| 85 |
+
className="mt-2"
|
| 86 |
+
disabled={isTraining}
|
| 87 |
+
/>
|
| 88 |
+
</Field>
|
| 89 |
+
|
| 90 |
+
<Field>
|
| 91 |
+
<div className="flex items-center justify-between">
|
| 92 |
+
<FieldLabel>Max Trials per Episode</FieldLabel>
|
| 93 |
+
<span className="text-sm font-mono text-muted-foreground">{maxTrials[0]}</span>
|
| 94 |
+
</div>
|
| 95 |
+
<Slider
|
| 96 |
+
value={maxTrials}
|
| 97 |
+
onValueChange={setMaxTrials}
|
| 98 |
+
min={1}
|
| 99 |
+
max={8}
|
| 100 |
+
step={1}
|
| 101 |
+
className="mt-2"
|
| 102 |
+
disabled={isTraining}
|
| 103 |
+
/>
|
| 104 |
+
</Field>
|
| 105 |
+
|
| 106 |
+
{isTraining && (
|
| 107 |
+
<div className="space-y-2 pt-4 border-t border-border/50">
|
| 108 |
+
<div className="flex items-center justify-between text-sm">
|
| 109 |
+
<span className="text-muted-foreground">Training Progress</span>
|
| 110 |
+
<span className="font-mono">{progress}%</span>
|
| 111 |
+
</div>
|
| 112 |
+
<Progress value={progress} className="h-2" />
|
| 113 |
+
<p className="text-xs text-muted-foreground">
|
| 114 |
+
Episode {currentEpisode} of {totalEpisodes}
|
| 115 |
+
</p>
|
| 116 |
+
</div>
|
| 117 |
+
)}
|
| 118 |
+
|
| 119 |
+
<div className="space-y-2 pt-4 border-t border-border/50">
|
| 120 |
+
<Button
|
| 121 |
+
className="w-full"
|
| 122 |
+
onClick={() => onStartTraining(episodes[0], learningRate[0], maxTrials[0], workflowId)}
|
| 123 |
+
disabled={isTraining}
|
| 124 |
+
>
|
| 125 |
+
{isTraining ? (
|
| 126 |
+
<>
|
| 127 |
+
<Square className="h-4 w-4 mr-2 animate-pulse" />
|
| 128 |
+
Training...
|
| 129 |
+
</>
|
| 130 |
+
) : (
|
| 131 |
+
<>
|
| 132 |
+
<Play className="h-4 w-4 mr-2" />
|
| 133 |
+
Start Training
|
| 134 |
+
</>
|
| 135 |
+
)}
|
| 136 |
+
</Button>
|
| 137 |
+
</div>
|
| 138 |
+
</CardContent>
|
| 139 |
+
</Card>
|
| 140 |
+
)
|
| 141 |
+
}
|
v0ap/components/ui/accordion.tsx
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
|
| 3 |
+
import * as React from 'react'
|
| 4 |
+
import * as AccordionPrimitive from '@radix-ui/react-accordion'
|
| 5 |
+
import { ChevronDownIcon } from 'lucide-react'
|
| 6 |
+
|
| 7 |
+
import { cn } from '@/lib/utils'
|
| 8 |
+
|
| 9 |
+
function Accordion({
|
| 10 |
+
...props
|
| 11 |
+
}: React.ComponentProps<typeof AccordionPrimitive.Root>) {
|
| 12 |
+
return <AccordionPrimitive.Root data-slot="accordion" {...props} />
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
function AccordionItem({
|
| 16 |
+
className,
|
| 17 |
+
...props
|
| 18 |
+
}: React.ComponentProps<typeof AccordionPrimitive.Item>) {
|
| 19 |
+
return (
|
| 20 |
+
<AccordionPrimitive.Item
|
| 21 |
+
data-slot="accordion-item"
|
| 22 |
+
className={cn('border-b last:border-b-0', className)}
|
| 23 |
+
{...props}
|
| 24 |
+
/>
|
| 25 |
+
)
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
function AccordionTrigger({
|
| 29 |
+
className,
|
| 30 |
+
children,
|
| 31 |
+
...props
|
| 32 |
+
}: React.ComponentProps<typeof AccordionPrimitive.Trigger>) {
|
| 33 |
+
return (
|
| 34 |
+
<AccordionPrimitive.Header className="flex">
|
| 35 |
+
<AccordionPrimitive.Trigger
|
| 36 |
+
data-slot="accordion-trigger"
|
| 37 |
+
className={cn(
|
| 38 |
+
'focus-visible:border-ring focus-visible:ring-ring/50 flex flex-1 items-start justify-between gap-4 rounded-md py-4 text-left text-sm font-medium transition-all outline-none hover:underline focus-visible:ring-[3px] disabled:pointer-events-none disabled:opacity-50 [&[data-state=open]>svg]:rotate-180',
|
| 39 |
+
className,
|
| 40 |
+
)}
|
| 41 |
+
{...props}
|
| 42 |
+
>
|
| 43 |
+
{children}
|
| 44 |
+
<ChevronDownIcon className="text-muted-foreground pointer-events-none size-4 shrink-0 translate-y-0.5 transition-transform duration-200" />
|
| 45 |
+
</AccordionPrimitive.Trigger>
|
| 46 |
+
</AccordionPrimitive.Header>
|
| 47 |
+
)
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
function AccordionContent({
|
| 51 |
+
className,
|
| 52 |
+
children,
|
| 53 |
+
...props
|
| 54 |
+
}: React.ComponentProps<typeof AccordionPrimitive.Content>) {
|
| 55 |
+
return (
|
| 56 |
+
<AccordionPrimitive.Content
|
| 57 |
+
data-slot="accordion-content"
|
| 58 |
+
className="data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down overflow-hidden text-sm"
|
| 59 |
+
{...props}
|
| 60 |
+
>
|
| 61 |
+
<div className={cn('pt-0 pb-4', className)}>{children}</div>
|
| 62 |
+
</AccordionPrimitive.Content>
|
| 63 |
+
)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
export { Accordion, AccordionItem, AccordionTrigger, AccordionContent }
|
v0ap/components/ui/alert-dialog.tsx
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
|
| 3 |
+
import * as React from 'react'
|
| 4 |
+
import * as AlertDialogPrimitive from '@radix-ui/react-alert-dialog'
|
| 5 |
+
|
| 6 |
+
import { cn } from '@/lib/utils'
|
| 7 |
+
import { buttonVariants } from '@/components/ui/button'
|
| 8 |
+
|
| 9 |
+
function AlertDialog({
|
| 10 |
+
...props
|
| 11 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Root>) {
|
| 12 |
+
return <AlertDialogPrimitive.Root data-slot="alert-dialog" {...props} />
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
function AlertDialogTrigger({
|
| 16 |
+
...props
|
| 17 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Trigger>) {
|
| 18 |
+
return (
|
| 19 |
+
<AlertDialogPrimitive.Trigger data-slot="alert-dialog-trigger" {...props} />
|
| 20 |
+
)
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
function AlertDialogPortal({
|
| 24 |
+
...props
|
| 25 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Portal>) {
|
| 26 |
+
return (
|
| 27 |
+
<AlertDialogPrimitive.Portal data-slot="alert-dialog-portal" {...props} />
|
| 28 |
+
)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
function AlertDialogOverlay({
|
| 32 |
+
className,
|
| 33 |
+
...props
|
| 34 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Overlay>) {
|
| 35 |
+
return (
|
| 36 |
+
<AlertDialogPrimitive.Overlay
|
| 37 |
+
data-slot="alert-dialog-overlay"
|
| 38 |
+
className={cn(
|
| 39 |
+
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 fixed inset-0 z-50 bg-black/50',
|
| 40 |
+
className,
|
| 41 |
+
)}
|
| 42 |
+
{...props}
|
| 43 |
+
/>
|
| 44 |
+
)
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
function AlertDialogContent({
|
| 48 |
+
className,
|
| 49 |
+
...props
|
| 50 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Content>) {
|
| 51 |
+
return (
|
| 52 |
+
<AlertDialogPortal>
|
| 53 |
+
<AlertDialogOverlay />
|
| 54 |
+
<AlertDialogPrimitive.Content
|
| 55 |
+
data-slot="alert-dialog-content"
|
| 56 |
+
className={cn(
|
| 57 |
+
'bg-background data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border p-6 shadow-lg duration-200 sm:max-w-lg',
|
| 58 |
+
className,
|
| 59 |
+
)}
|
| 60 |
+
{...props}
|
| 61 |
+
/>
|
| 62 |
+
</AlertDialogPortal>
|
| 63 |
+
)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
function AlertDialogHeader({
|
| 67 |
+
className,
|
| 68 |
+
...props
|
| 69 |
+
}: React.ComponentProps<'div'>) {
|
| 70 |
+
return (
|
| 71 |
+
<div
|
| 72 |
+
data-slot="alert-dialog-header"
|
| 73 |
+
className={cn('flex flex-col gap-2 text-center sm:text-left', className)}
|
| 74 |
+
{...props}
|
| 75 |
+
/>
|
| 76 |
+
)
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
function AlertDialogFooter({
|
| 80 |
+
className,
|
| 81 |
+
...props
|
| 82 |
+
}: React.ComponentProps<'div'>) {
|
| 83 |
+
return (
|
| 84 |
+
<div
|
| 85 |
+
data-slot="alert-dialog-footer"
|
| 86 |
+
className={cn(
|
| 87 |
+
'flex flex-col-reverse gap-2 sm:flex-row sm:justify-end',
|
| 88 |
+
className,
|
| 89 |
+
)}
|
| 90 |
+
{...props}
|
| 91 |
+
/>
|
| 92 |
+
)
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
function AlertDialogTitle({
|
| 96 |
+
className,
|
| 97 |
+
...props
|
| 98 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Title>) {
|
| 99 |
+
return (
|
| 100 |
+
<AlertDialogPrimitive.Title
|
| 101 |
+
data-slot="alert-dialog-title"
|
| 102 |
+
className={cn('text-lg font-semibold', className)}
|
| 103 |
+
{...props}
|
| 104 |
+
/>
|
| 105 |
+
)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
function AlertDialogDescription({
|
| 109 |
+
className,
|
| 110 |
+
...props
|
| 111 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Description>) {
|
| 112 |
+
return (
|
| 113 |
+
<AlertDialogPrimitive.Description
|
| 114 |
+
data-slot="alert-dialog-description"
|
| 115 |
+
className={cn('text-muted-foreground text-sm', className)}
|
| 116 |
+
{...props}
|
| 117 |
+
/>
|
| 118 |
+
)
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
function AlertDialogAction({
|
| 122 |
+
className,
|
| 123 |
+
...props
|
| 124 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Action>) {
|
| 125 |
+
return (
|
| 126 |
+
<AlertDialogPrimitive.Action
|
| 127 |
+
className={cn(buttonVariants(), className)}
|
| 128 |
+
{...props}
|
| 129 |
+
/>
|
| 130 |
+
)
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
function AlertDialogCancel({
|
| 134 |
+
className,
|
| 135 |
+
...props
|
| 136 |
+
}: React.ComponentProps<typeof AlertDialogPrimitive.Cancel>) {
|
| 137 |
+
return (
|
| 138 |
+
<AlertDialogPrimitive.Cancel
|
| 139 |
+
className={cn(buttonVariants({ variant: 'outline' }), className)}
|
| 140 |
+
{...props}
|
| 141 |
+
/>
|
| 142 |
+
)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
export {
|
| 146 |
+
AlertDialog,
|
| 147 |
+
AlertDialogPortal,
|
| 148 |
+
AlertDialogOverlay,
|
| 149 |
+
AlertDialogTrigger,
|
| 150 |
+
AlertDialogContent,
|
| 151 |
+
AlertDialogHeader,
|
| 152 |
+
AlertDialogFooter,
|
| 153 |
+
AlertDialogTitle,
|
| 154 |
+
AlertDialogDescription,
|
| 155 |
+
AlertDialogAction,
|
| 156 |
+
AlertDialogCancel,
|
| 157 |
+
}
|
v0ap/components/ui/alert.tsx
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import * as React from 'react'
|
| 2 |
+
import { cva, type VariantProps } from 'class-variance-authority'
|
| 3 |
+
|
| 4 |
+
import { cn } from '@/lib/utils'
|
| 5 |
+
|
| 6 |
+
const alertVariants = cva(
|
| 7 |
+
'relative w-full rounded-lg border px-4 py-3 text-sm grid has-[>svg]:grid-cols-[calc(var(--spacing)*4)_1fr] grid-cols-[0_1fr] has-[>svg]:gap-x-3 gap-y-0.5 items-start [&>svg]:size-4 [&>svg]:translate-y-0.5 [&>svg]:text-current',
|
| 8 |
+
{
|
| 9 |
+
variants: {
|
| 10 |
+
variant: {
|
| 11 |
+
default: 'bg-card text-card-foreground',
|
| 12 |
+
destructive:
|
| 13 |
+
'text-destructive bg-card [&>svg]:text-current *:data-[slot=alert-description]:text-destructive/90',
|
| 14 |
+
},
|
| 15 |
+
},
|
| 16 |
+
defaultVariants: {
|
| 17 |
+
variant: 'default',
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
function Alert({
|
| 23 |
+
className,
|
| 24 |
+
variant,
|
| 25 |
+
...props
|
| 26 |
+
}: React.ComponentProps<'div'> & VariantProps<typeof alertVariants>) {
|
| 27 |
+
return (
|
| 28 |
+
<div
|
| 29 |
+
data-slot="alert"
|
| 30 |
+
role="alert"
|
| 31 |
+
className={cn(alertVariants({ variant }), className)}
|
| 32 |
+
{...props}
|
| 33 |
+
/>
|
| 34 |
+
)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
function AlertTitle({ className, ...props }: React.ComponentProps<'div'>) {
|
| 38 |
+
return (
|
| 39 |
+
<div
|
| 40 |
+
data-slot="alert-title"
|
| 41 |
+
className={cn(
|
| 42 |
+
'col-start-2 line-clamp-1 min-h-4 font-medium tracking-tight',
|
| 43 |
+
className,
|
| 44 |
+
)}
|
| 45 |
+
{...props}
|
| 46 |
+
/>
|
| 47 |
+
)
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
function AlertDescription({
|
| 51 |
+
className,
|
| 52 |
+
...props
|
| 53 |
+
}: React.ComponentProps<'div'>) {
|
| 54 |
+
return (
|
| 55 |
+
<div
|
| 56 |
+
data-slot="alert-description"
|
| 57 |
+
className={cn(
|
| 58 |
+
'text-muted-foreground col-start-2 grid justify-items-start gap-1 text-sm [&_p]:leading-relaxed',
|
| 59 |
+
className,
|
| 60 |
+
)}
|
| 61 |
+
{...props}
|
| 62 |
+
/>
|
| 63 |
+
)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
export { Alert, AlertTitle, AlertDescription }
|
v0ap/components/ui/aspect-ratio.tsx
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
|
| 3 |
+
import * as AspectRatioPrimitive from '@radix-ui/react-aspect-ratio'
|
| 4 |
+
|
| 5 |
+
function AspectRatio({
|
| 6 |
+
...props
|
| 7 |
+
}: React.ComponentProps<typeof AspectRatioPrimitive.Root>) {
|
| 8 |
+
return <AspectRatioPrimitive.Root data-slot="aspect-ratio" {...props} />
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
export { AspectRatio }
|
v0ap/components/ui/avatar.tsx
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client'
|
| 2 |
+
|
| 3 |
+
import * as React from 'react'
|
| 4 |
+
import * as AvatarPrimitive from '@radix-ui/react-avatar'
|
| 5 |
+
|
| 6 |
+
import { cn } from '@/lib/utils'
|
| 7 |
+
|
| 8 |
+
function Avatar({
|
| 9 |
+
className,
|
| 10 |
+
...props
|
| 11 |
+
}: React.ComponentProps<typeof AvatarPrimitive.Root>) {
|
| 12 |
+
return (
|
| 13 |
+
<AvatarPrimitive.Root
|
| 14 |
+
data-slot="avatar"
|
| 15 |
+
className={cn(
|
| 16 |
+
'relative flex size-8 shrink-0 overflow-hidden rounded-full',
|
| 17 |
+
className,
|
| 18 |
+
)}
|
| 19 |
+
{...props}
|
| 20 |
+
/>
|
| 21 |
+
)
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
function AvatarImage({
|
| 25 |
+
className,
|
| 26 |
+
...props
|
| 27 |
+
}: React.ComponentProps<typeof AvatarPrimitive.Image>) {
|
| 28 |
+
return (
|
| 29 |
+
<AvatarPrimitive.Image
|
| 30 |
+
data-slot="avatar-image"
|
| 31 |
+
className={cn('aspect-square size-full', className)}
|
| 32 |
+
{...props}
|
| 33 |
+
/>
|
| 34 |
+
)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
function AvatarFallback({
|
| 38 |
+
className,
|
| 39 |
+
...props
|
| 40 |
+
}: React.ComponentProps<typeof AvatarPrimitive.Fallback>) {
|
| 41 |
+
return (
|
| 42 |
+
<AvatarPrimitive.Fallback
|
| 43 |
+
data-slot="avatar-fallback"
|
| 44 |
+
className={cn(
|
| 45 |
+
'bg-muted flex size-full items-center justify-center rounded-full',
|
| 46 |
+
className,
|
| 47 |
+
)}
|
| 48 |
+
{...props}
|
| 49 |
+
/>
|
| 50 |
+
)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
export { Avatar, AvatarImage, AvatarFallback }
|
v0ap/components/ui/badge.tsx
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import * as React from 'react'
|
| 2 |
+
import { Slot } from '@radix-ui/react-slot'
|
| 3 |
+
import { cva, type VariantProps } from 'class-variance-authority'
|
| 4 |
+
|
| 5 |
+
import { cn } from '@/lib/utils'
|
| 6 |
+
|
| 7 |
+
const badgeVariants = cva(
|
| 8 |
+
'inline-flex items-center justify-center rounded-md border px-2 py-0.5 text-xs font-medium w-fit whitespace-nowrap shrink-0 [&>svg]:size-3 gap-1 [&>svg]:pointer-events-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive transition-[color,box-shadow] overflow-hidden',
|
| 9 |
+
{
|
| 10 |
+
variants: {
|
| 11 |
+
variant: {
|
| 12 |
+
default:
|
| 13 |
+
'border-transparent bg-primary text-primary-foreground [a&]:hover:bg-primary/90',
|
| 14 |
+
secondary:
|
| 15 |
+
'border-transparent bg-secondary text-secondary-foreground [a&]:hover:bg-secondary/90',
|
| 16 |
+
destructive:
|
| 17 |
+
'border-transparent bg-destructive text-white [a&]:hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60',
|
| 18 |
+
outline:
|
| 19 |
+
'text-foreground [a&]:hover:bg-accent [a&]:hover:text-accent-foreground',
|
| 20 |
+
},
|
| 21 |
+
},
|
| 22 |
+
defaultVariants: {
|
| 23 |
+
variant: 'default',
|
| 24 |
+
},
|
| 25 |
+
},
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
function Badge({
|
| 29 |
+
className,
|
| 30 |
+
variant,
|
| 31 |
+
asChild = false,
|
| 32 |
+
...props
|
| 33 |
+
}: React.ComponentProps<'span'> &
|
| 34 |
+
VariantProps<typeof badgeVariants> & { asChild?: boolean }) {
|
| 35 |
+
const Comp = asChild ? Slot : 'span'
|
| 36 |
+
|
| 37 |
+
return (
|
| 38 |
+
<Comp
|
| 39 |
+
data-slot="badge"
|
| 40 |
+
className={cn(badgeVariants({ variant }), className)}
|
| 41 |
+
{...props}
|
| 42 |
+
/>
|
| 43 |
+
)
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
export { Badge, badgeVariants }
|
v0ap/components/ui/breadcrumb.tsx
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import * as React from 'react'
|
| 2 |
+
import { Slot } from '@radix-ui/react-slot'
|
| 3 |
+
import { ChevronRight, MoreHorizontal } from 'lucide-react'
|
| 4 |
+
|
| 5 |
+
import { cn } from '@/lib/utils'
|
| 6 |
+
|
| 7 |
+
function Breadcrumb({ ...props }: React.ComponentProps<'nav'>) {
|
| 8 |
+
return <nav aria-label="breadcrumb" data-slot="breadcrumb" {...props} />
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
function BreadcrumbList({ className, ...props }: React.ComponentProps<'ol'>) {
|
| 12 |
+
return (
|
| 13 |
+
<ol
|
| 14 |
+
data-slot="breadcrumb-list"
|
| 15 |
+
className={cn(
|
| 16 |
+
'text-muted-foreground flex flex-wrap items-center gap-1.5 text-sm break-words sm:gap-2.5',
|
| 17 |
+
className,
|
| 18 |
+
)}
|
| 19 |
+
{...props}
|
| 20 |
+
/>
|
| 21 |
+
)
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
function BreadcrumbItem({ className, ...props }: React.ComponentProps<'li'>) {
|
| 25 |
+
return (
|
| 26 |
+
<li
|
| 27 |
+
data-slot="breadcrumb-item"
|
| 28 |
+
className={cn('inline-flex items-center gap-1.5', className)}
|
| 29 |
+
{...props}
|
| 30 |
+
/>
|
| 31 |
+
)
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
function BreadcrumbLink({
|
| 35 |
+
asChild,
|
| 36 |
+
className,
|
| 37 |
+
...props
|
| 38 |
+
}: React.ComponentProps<'a'> & {
|
| 39 |
+
asChild?: boolean
|
| 40 |
+
}) {
|
| 41 |
+
const Comp = asChild ? Slot : 'a'
|
| 42 |
+
|
| 43 |
+
return (
|
| 44 |
+
<Comp
|
| 45 |
+
data-slot="breadcrumb-link"
|
| 46 |
+
className={cn('hover:text-foreground transition-colors', className)}
|
| 47 |
+
{...props}
|
| 48 |
+
/>
|
| 49 |
+
)
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
function BreadcrumbPage({ className, ...props }: React.ComponentProps<'span'>) {
|
| 53 |
+
return (
|
| 54 |
+
<span
|
| 55 |
+
data-slot="breadcrumb-page"
|
| 56 |
+
role="link"
|
| 57 |
+
aria-disabled="true"
|
| 58 |
+
aria-current="page"
|
| 59 |
+
className={cn('text-foreground font-normal', className)}
|
| 60 |
+
{...props}
|
| 61 |
+
/>
|
| 62 |
+
)
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
function BreadcrumbSeparator({
|
| 66 |
+
children,
|
| 67 |
+
className,
|
| 68 |
+
...props
|
| 69 |
+
}: React.ComponentProps<'li'>) {
|
| 70 |
+
return (
|
| 71 |
+
<li
|
| 72 |
+
data-slot="breadcrumb-separator"
|
| 73 |
+
role="presentation"
|
| 74 |
+
aria-hidden="true"
|
| 75 |
+
className={cn('[&>svg]:size-3.5', className)}
|
| 76 |
+
{...props}
|
| 77 |
+
>
|
| 78 |
+
{children ?? <ChevronRight />}
|
| 79 |
+
</li>
|
| 80 |
+
)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
function BreadcrumbEllipsis({
|
| 84 |
+
className,
|
| 85 |
+
...props
|
| 86 |
+
}: React.ComponentProps<'span'>) {
|
| 87 |
+
return (
|
| 88 |
+
<span
|
| 89 |
+
data-slot="breadcrumb-ellipsis"
|
| 90 |
+
role="presentation"
|
| 91 |
+
aria-hidden="true"
|
| 92 |
+
className={cn('flex size-9 items-center justify-center', className)}
|
| 93 |
+
{...props}
|
| 94 |
+
>
|
| 95 |
+
<MoreHorizontal className="size-4" />
|
| 96 |
+
<span className="sr-only">More</span>
|
| 97 |
+
</span>
|
| 98 |
+
)
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
export {
|
| 102 |
+
Breadcrumb,
|
| 103 |
+
BreadcrumbList,
|
| 104 |
+
BreadcrumbItem,
|
| 105 |
+
BreadcrumbLink,
|
| 106 |
+
BreadcrumbPage,
|
| 107 |
+
BreadcrumbSeparator,
|
| 108 |
+
BreadcrumbEllipsis,
|
| 109 |
+
}
|
v0ap/components/ui/button-group.tsx
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Slot } from '@radix-ui/react-slot'
|
| 2 |
+
import { cva, type VariantProps } from 'class-variance-authority'
|
| 3 |
+
|
| 4 |
+
import { cn } from '@/lib/utils'
|
| 5 |
+
import { Separator } from '@/components/ui/separator'
|
| 6 |
+
|
| 7 |
+
const buttonGroupVariants = cva(
|
| 8 |
+
"flex w-fit items-stretch [&>*]:focus-visible:z-10 [&>*]:focus-visible:relative [&>[data-slot=select-trigger]:not([class*='w-'])]:w-fit [&>input]:flex-1 has-[select[aria-hidden=true]:last-child]:[&>[data-slot=select-trigger]:last-of-type]:rounded-r-md has-[>[data-slot=button-group]]:gap-2",
|
| 9 |
+
{
|
| 10 |
+
variants: {
|
| 11 |
+
orientation: {
|
| 12 |
+
horizontal:
|
| 13 |
+
'[&>*:not(:first-child)]:rounded-l-none [&>*:not(:first-child)]:border-l-0 [&>*:not(:last-child)]:rounded-r-none',
|
| 14 |
+
vertical:
|
| 15 |
+
'flex-col [&>*:not(:first-child)]:rounded-t-none [&>*:not(:first-child)]:border-t-0 [&>*:not(:last-child)]:rounded-b-none',
|
| 16 |
+
},
|
| 17 |
+
},
|
| 18 |
+
defaultVariants: {
|
| 19 |
+
orientation: 'horizontal',
|
| 20 |
+
},
|
| 21 |
+
},
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
function ButtonGroup({
|
| 25 |
+
className,
|
| 26 |
+
orientation,
|
| 27 |
+
...props
|
| 28 |
+
}: React.ComponentProps<'div'> & VariantProps<typeof buttonGroupVariants>) {
|
| 29 |
+
return (
|
| 30 |
+
<div
|
| 31 |
+
role="group"
|
| 32 |
+
data-slot="button-group"
|
| 33 |
+
data-orientation={orientation}
|
| 34 |
+
className={cn(buttonGroupVariants({ orientation }), className)}
|
| 35 |
+
{...props}
|
| 36 |
+
/>
|
| 37 |
+
)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
function ButtonGroupText({
|
| 41 |
+
className,
|
| 42 |
+
asChild = false,
|
| 43 |
+
...props
|
| 44 |
+
}: React.ComponentProps<'div'> & {
|
| 45 |
+
asChild?: boolean
|
| 46 |
+
}) {
|
| 47 |
+
const Comp = asChild ? Slot : 'div'
|
| 48 |
+
|
| 49 |
+
return (
|
| 50 |
+
<Comp
|
| 51 |
+
className={cn(
|
| 52 |
+
"bg-muted flex items-center gap-2 rounded-md border px-4 text-sm font-medium shadow-xs [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4",
|
| 53 |
+
className,
|
| 54 |
+
)}
|
| 55 |
+
{...props}
|
| 56 |
+
/>
|
| 57 |
+
)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
function ButtonGroupSeparator({
|
| 61 |
+
className,
|
| 62 |
+
orientation = 'vertical',
|
| 63 |
+
...props
|
| 64 |
+
}: React.ComponentProps<typeof Separator>) {
|
| 65 |
+
return (
|
| 66 |
+
<Separator
|
| 67 |
+
data-slot="button-group-separator"
|
| 68 |
+
orientation={orientation}
|
| 69 |
+
className={cn(
|
| 70 |
+
'bg-input relative !m-0 self-stretch data-[orientation=vertical]:h-auto',
|
| 71 |
+
className,
|
| 72 |
+
)}
|
| 73 |
+
{...props}
|
| 74 |
+
/>
|
| 75 |
+
)
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
export {
|
| 79 |
+
ButtonGroup,
|
| 80 |
+
ButtonGroupSeparator,
|
| 81 |
+
ButtonGroupText,
|
| 82 |
+
buttonGroupVariants,
|
| 83 |
+
}
|