Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Upload folder using huggingface_hub
Browse files- README.md +277 -71
- __init__.py +27 -0
- _dashboard_state.json +207 -0
- dashboard.html +522 -0
- dashboard.py +129 -0
- demo.html +1639 -0
- eval_compare.py +174 -0
- models.py +1408 -19
- my_env/README.md +255 -0
- my_env/__init__.py +16 -0
- my_env/client.py +99 -0
- my_env/models.py +28 -0
- my_env/openenv.yaml +7 -0
- my_env/pyproject.toml +45 -0
- my_env/server/Dockerfile +80 -0
- my_env/server/__init__.py +11 -0
- my_env/server/app.py +81 -0
- my_env/server/my_env_environment.py +101 -0
- my_env/server/requirements.txt +6 -0
- pyproject.toml +10 -4
- run_agent.py +978 -292
- server/app.py +21 -7
- server/biology/__init__.py +0 -0
- server/biology/gene_index.py +225 -0
- server/hackathon_environment.py +35 -16
- server/requirements.txt +7 -4
- server/rewards/reward.py +207 -20
- server/rules/engine.py +130 -10
- server/simulator/latent_state.py +10 -0
- server/simulator/noise.py +5 -1
- server/simulator/output_generator.py +213 -27
- server/simulator/transition.py +58 -9
- server/tasks/bio_palette.py +692 -0
- server/tasks/generator.py +15 -2
- server/tasks/procedural_generator.py +501 -0
- server/tasks/scenarios.py +4 -2
- tests/test_environment.py +11 -0
- tests/test_rewards.py +62 -0
- tests/test_rules.py +30 -0
- tests/test_run_agent.py +36 -0
- tests/test_training_script.py +123 -0
- training/__init__.py +0 -2
- training/evaluation.py +3 -2
- training/literature_benchmark.py +6 -19
- training/rollout_collection.py +219 -0
- training_script.py +1250 -0
- uv.lock +0 -0
README.md
CHANGED
|
@@ -21,7 +21,7 @@ The environment is designed as a partially observable Markov decision process (P
|
|
| 21 |
- visible task metadata, resource usage, step history, and intermediate outputs
|
| 22 |
- dense step-wise reward plus terminal reward for conclusion quality
|
| 23 |
|
| 24 |
-
##
|
| 25 |
|
| 26 |
At a high level, each episode looks like this:
|
| 27 |
|
|
@@ -68,11 +68,15 @@ This separation is what makes the environment a POMDP rather than a fully observ
|
|
| 68 |
|
| 69 |
Defines the contracts that all other modules use:
|
| 70 |
|
| 71 |
-
- `
|
| 72 |
-
- `
|
| 73 |
-
- `
|
| 74 |
-
- `
|
| 75 |
-
- `
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
The action vocabulary is intentionally broad enough to mix wet-lab, computational, and meta-planning actions.
|
| 78 |
|
|
@@ -80,27 +84,30 @@ The action vocabulary is intentionally broad enough to mix wet-lab, computationa
|
|
| 80 |
|
| 81 |
This is where episodes come from.
|
| 82 |
|
| 83 |
-
- `scenarios.py` defines a
|
| 84 |
-
- `generator.py` turns a scenario into a `(TaskSpec, FullLatentState)` pair
|
| 85 |
-
- optional domain randomization perturbs budget, time, noise, batch effects, cell proportions, and effect sizes
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
### `server/simulator/`
|
| 95 |
|
| 96 |
This is the simulator itself.
|
| 97 |
|
| 98 |
-
- `latent_state.py` defines hidden
|
| 99 |
-
- `noise.py`
|
| 100 |
-
- `output_generator.py` turns an action plus hidden state into a realistic `IntermediateOutput`
|
| 101 |
-
- `transition.py` applies action costs, updates progress flags, propagates
|
| 102 |
|
| 103 |
-
The output generator does not simply echo the action. It conditions outputs on the hidden state, then injects realistic noise
|
| 104 |
|
| 105 |
### `server/rules/engine.py`
|
| 106 |
|
|
@@ -109,32 +116,41 @@ The rule engine enforces scientific and procedural constraints before each actio
|
|
| 109 |
- hard violations block the action entirely
|
| 110 |
- soft violations allow the action, but reduce output quality and add reward penalties
|
| 111 |
|
| 112 |
-
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
### `server/rewards/reward.py`
|
| 119 |
|
| 120 |
Rewards are decomposed rather than being a single opaque number.
|
| 121 |
|
| 122 |
-
Per-step reward
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
Terminal reward adds:
|
| 133 |
|
| 134 |
-
|
| 135 |
-
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
|
| 139 |
This makes the environment easier to debug, benchmark, and train against.
|
| 140 |
|
|
@@ -145,7 +161,7 @@ This is the orchestration layer that ties everything together.
|
|
| 145 |
On `reset()` it:
|
| 146 |
|
| 147 |
- seeds the noise model
|
| 148 |
-
- generates a task and latent state
|
| 149 |
- clears history, outputs, discoveries, conclusions, and cumulative reward
|
| 150 |
|
| 151 |
On `step()` it:
|
|
@@ -171,12 +187,34 @@ Here is the concrete order of operations for `env.step(action)`:
|
|
| 171 |
5. Otherwise deduct budget and time based on `ACTION_COSTS`.
|
| 172 |
6. Update latent progress flags like `samples_collected`, `qc_performed`, or `de_performed`.
|
| 173 |
7. Generate a structured simulated output for the chosen action.
|
| 174 |
-
8. If there were soft violations, degrade output quality and attach warnings.
|
| 175 |
9. Propagate artifacts back into latent state, such as discovered DE genes or cluster names.
|
| 176 |
10. Compute decomposed reward from state transition plus output quality.
|
| 177 |
11. If the episode is ending, compute terminal reward from completeness and conclusion calibration.
|
| 178 |
12. Return an observation that exposes the visible summary but not the hidden truth.
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
## Typical successful pipeline
|
| 181 |
|
| 182 |
Most scenarios reward a sensible experiment order similar to:
|
|
@@ -193,12 +231,47 @@ Most scenarios reward a sensible experiment order similar to:
|
|
| 193 |
`regulatory_network_inference`, `marker_selection`, `validate_marker`
|
| 194 |
9. `synthesize_conclusion`
|
| 195 |
|
| 196 |
-
The exact best sequence depends on the scenario
|
| 197 |
|
| 198 |
- trajectory scenarios benefit from `trajectory_analysis` and regulatory inference
|
| 199 |
- biomarker scenarios benefit from DE, marker selection, and validation
|
| 200 |
- perturbation scenarios benefit from pathway-level interpretation
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
## Interfaces you can use
|
| 203 |
|
| 204 |
### 1. In-process environment
|
|
@@ -224,6 +297,10 @@ print(obs.latest_output.summary if obs.latest_output else "No output yet")
|
|
| 224 |
print(obs.reward)
|
| 225 |
```
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
### 2. OpenEnv client/server mode
|
| 228 |
|
| 229 |
Use the FastAPI app when you want to serve the environment over HTTP and WebSocket:
|
|
@@ -233,6 +310,16 @@ uv sync --extra dev
|
|
| 233 |
uv run uvicorn server.app:app --reload
|
| 234 |
```
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
Then connect with the client:
|
| 237 |
|
| 238 |
```python
|
|
@@ -247,40 +334,133 @@ with BioExperimentEnv(base_url="http://localhost:8000") as env:
|
|
| 247 |
|
| 248 |
The environment class supports concurrent sessions, but the bundled server is currently configured with `max_concurrent_envs=1` in `server/app.py`.
|
| 249 |
|
| 250 |
-
### 3.
|
| 251 |
|
| 252 |
-
|
| 253 |
|
| 254 |
-
```
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
env = BioExperimentGymEnv()
|
| 258 |
-
obs, info = env.reset()
|
| 259 |
-
obs, reward, terminated, truncated, info = env.step({
|
| 260 |
-
"action_type": 0,
|
| 261 |
-
"confidence": 0.7,
|
| 262 |
-
})
|
| 263 |
```
|
| 264 |
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
- `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
|
|
|
|
|
|
|
| 273 |
- `run_agent.py` runs a local language model planner against the environment
|
| 274 |
- `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
|
| 275 |
- `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
|
| 276 |
|
| 277 |
-
##
|
| 278 |
|
| 279 |
-
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
## Why this is useful
|
| 286 |
|
|
@@ -299,31 +479,51 @@ That makes it suitable for:
|
|
| 299 |
- literature-grounded evaluation
|
| 300 |
- comparing structured policies against LLM-driven planners
|
| 301 |
|
| 302 |
-
##
|
| 303 |
|
| 304 |
```text
|
| 305 |
.
|
| 306 |
-
├── client.py # OpenEnv client
|
| 307 |
├── models.py # Shared action / observation / task schemas
|
|
|
|
|
|
|
|
|
|
| 308 |
├── server/
|
| 309 |
-
│ ├── app.py # FastAPI/OpenEnv server
|
|
|
|
| 310 |
│ ├── hackathon_environment.py # Main environment orchestration
|
| 311 |
-
│ ├──
|
| 312 |
-
│ ├──
|
| 313 |
-
│
|
| 314 |
-
│
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
├── training/
|
| 316 |
-
│ ├── evaluation.py #
|
| 317 |
-
│ ├── gym_wrapper.py # Gymnasium wrapper
|
| 318 |
│ ├── literature_benchmark.py # Paper-backed benchmark flow
|
| 319 |
-
│
|
| 320 |
-
└──
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
```
|
| 322 |
|
| 323 |
## Quick sanity check
|
| 324 |
|
| 325 |
-
The current implementation was sanity-checked with:
|
| 326 |
-
|
| 327 |
```bash
|
| 328 |
uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
|
| 329 |
```
|
|
@@ -335,3 +535,9 @@ Those tests verify:
|
|
| 335 |
- conclusion termination
|
| 336 |
- literature-backed scenario selection
|
| 337 |
- benchmark matching for curated expected findings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
- visible task metadata, resource usage, step history, and intermediate outputs
|
| 22 |
- dense step-wise reward plus terminal reward for conclusion quality
|
| 23 |
|
| 24 |
+
## How it works
|
| 25 |
|
| 26 |
At a high level, each episode looks like this:
|
| 27 |
|
|
|
|
| 68 |
|
| 69 |
Defines the contracts that all other modules use:
|
| 70 |
|
| 71 |
+
- `ActionType`: 21 discrete experiment steps, grouped into three frozensets — `WET_LAB_ACTIONS` (8), `COMPUTATIONAL_ACTIONS` (10), and `META_ACTIONS` (3)
|
| 72 |
+
- `SubagentType`: 9 sub-agent delegate roles (e.g. `wet_lab_planner`, `computational_analyst`, `causal_reasoning_agent`)
|
| 73 |
+
- `ExperimentAction`: one structured step chosen by the agent; fields include `action_type`, `method`, `parameters`, `justification`, `confidence` (clamped to `[0, 1]`), `invoked_subagent`, `tool_call_spec`, `input_targets`
|
| 74 |
+
- `ExperimentObservation`: what the agent can see after each step; includes `task`, `pipeline_history`, `resource_usage`, `latest_output`, `all_outputs`, `discovered_markers`, `candidate_mechanisms`, `conclusions`, `rule_violations`, `step_reward_breakdown`
|
| 75 |
+
- `TaskSpec`: the problem statement, organism, tissue, conditions, budget, time limit, assays, tools, paper references, and expected findings
|
| 76 |
+
- `IntermediateOutput`: the simulated artifact returned by a step; carries `output_type`, `success`, `quality_score`, `summary`, `data`, `uncertainty`, `warnings`, `artifacts_available`
|
| 77 |
+
- `ConclusionClaim`: structured claims used for final synthesis; carries `claim`, `evidence_steps`, `confidence`, `claim_type`, `supporting_data`
|
| 78 |
+
- `PipelineStepRecord`: compact observable record of one past step stored in history
|
| 79 |
+
- `ResourceUsage`: budget and time tracking visible to the agent
|
| 80 |
|
| 81 |
The action vocabulary is intentionally broad enough to mix wet-lab, computational, and meta-planning actions.
|
| 82 |
|
|
|
|
| 84 |
|
| 85 |
This is where episodes come from.
|
| 86 |
|
| 87 |
+
- `scenarios.py` defines a curated library of four biological scenarios as `Scenario` dataclass objects, each bundling a `TaskSpec`, a `LatentBiologicalState`, a `TechnicalState`, hidden failure conditions, and tags
|
| 88 |
+
- `generator.py` turns a scenario into a `(TaskSpec, FullLatentState)` pair via `TaskGenerator.generate()`; optional domain randomisation perturbs budget (±30%), time (±20%), technical noise, batch effects, cell proportions, and effect sizes
|
|
|
|
| 89 |
|
| 90 |
+
The four scenarios are:
|
| 91 |
|
| 92 |
+
| Name | Difficulty | Tissue | Problem | Budget | Time |
|
| 93 |
+
|---|---|---|---|---|---|
|
| 94 |
+
| `cardiac_disease_de` | easy | heart | Differential expression between healthy and dilated cardiomyopathy cardiomyocytes | $80 K | 120 days |
|
| 95 |
+
| `hematopoiesis_trajectory` | medium | bone marrow | Infer HSC → mature lineage trajectory with three branches | $100 K | 150 days |
|
| 96 |
+
| `perturbation_immune` | hard | synovial fluid | JAK inhibitor effect on T-cell states in rheumatoid arthritis | $120 K | 180 days |
|
| 97 |
+
| `biomarker_validation_lung` | medium | lung | Validate SPP1 as biomarker for pro-fibrotic macrophages in IPF | $90 K | 150 days |
|
| 98 |
+
|
| 99 |
+
Each scenario carries paper references with DOIs, true DE genes with log2FC values, true pathway activities, true regulatory networks, and ground-truth causal mechanisms used for terminal reward calibration.
|
| 100 |
|
| 101 |
### `server/simulator/`
|
| 102 |
|
| 103 |
This is the simulator itself.
|
| 104 |
|
| 105 |
+
- `latent_state.py` defines `FullLatentState`, the root aggregate of all hidden state. Key sub-structures are `LatentBiologicalState` (true DE genes, pathways, gene programs, trajectory, regulatory network, markers, causal mechanisms), `TechnicalState` (dropout, doublets, ambient RNA, sample quality), `ExperimentProgress` (18 boolean milestone flags plus counts), and `ResourceState` (internal budget and time tracking with exhaustion properties)
|
| 106 |
+
- `noise.py` centralises stochasticity in `NoiseModel`. All randomness flows through a single seeded `numpy.Generator`. Methods include `add_expression_noise`, `sample_effect_sizes`, `sample_p_values`, `generate_false_positives`, `generate_false_negatives`, `quality_degradation`, `sample_qc_metric`, `sample_cluster_count`, `shuffle_ranking`, and `coin_flip`
|
| 107 |
+
- `output_generator.py` turns an action plus hidden state into a realistic `IntermediateOutput`. Every action type has a dedicated handler conditioned on the latent state; noise is then injected — dropout in expression data, false positives and false negatives in DE and marker results, over/under-clustering, and pathway contamination
|
| 108 |
+
- `transition.py` applies action costs from `ACTION_COSTS`, updates progress flags, calls the output generator, degrades quality on soft violations, propagates discovered DE genes and cluster names back into latent state, and decides whether the episode is done
|
| 109 |
|
| 110 |
+
The output generator does not simply echo the action. It conditions outputs on the hidden state, then injects realistic noise.
|
| 111 |
|
| 112 |
### `server/rules/engine.py`
|
| 113 |
|
|
|
|
| 116 |
- hard violations block the action entirely
|
| 117 |
- soft violations allow the action, but reduce output quality and add reward penalties
|
| 118 |
|
| 119 |
+
The four rule families are:
|
| 120 |
|
| 121 |
+
1. **Prerequisites (HARD)** — each computational step requires the appropriate upstream milestone flag. For example: `normalize_data` requires `data_filtered`, `differential_expression` requires `data_normalized`, `validate_marker` requires `markers_discovered`
|
| 122 |
+
2. **Resource constraints (HARD/SOFT)** — budget or time exhausted is a hard block; action cost exceeding remaining budget (when budget > 0) is a soft warning
|
| 123 |
+
3. **Redundancy (SOFT)** — repeating an already-completed step such as `run_qc` or `normalize_data`
|
| 124 |
+
4. **Causal validity (SOFT)** — synthesizing conclusions without prior DE or clustering; making causal claims without validation evidence; pathway enrichment before DE
|
| 125 |
|
| 126 |
### `server/rewards/reward.py`
|
| 127 |
|
| 128 |
Rewards are decomposed rather than being a single opaque number.
|
| 129 |
|
| 130 |
+
Per-step reward formula:
|
| 131 |
|
| 132 |
+
```
|
| 133 |
+
R_t = r_validity + r_ordering + r_info_gain + r_efficiency + r_novelty + r_penalty + γ[φ(s_{t+1}) − φ(s_t)]
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
| Component | Weight | Description |
|
| 137 |
+
|---|---|---|
|
| 138 |
+
| `validity` | 0.3 | `1.0` if output succeeded, `−1.0` if hard violation |
|
| 139 |
+
| `ordering` | 0.2 | `1.0` if natural next step, `0.3` otherwise |
|
| 140 |
+
| `info_gain` | 0.4 | `quality_score × (1 − uncertainty)` |
|
| 141 |
+
| `efficiency` | 0.3 | `max(0, 1 − 5 × budget_fraction_used)` |
|
| 142 |
+
| `novelty` | +0.1 | Bonus when no soft violations |
|
| 143 |
+
| `penalty` | −0.15/violation | Per soft violation |
|
| 144 |
+
| `shaping` | γ = 0.99 | Potential-based over 12 progress milestones |
|
| 145 |
|
| 146 |
Terminal reward adds:
|
| 147 |
|
| 148 |
+
| Component | Weight | Description |
|
| 149 |
+
|---|---|---|
|
| 150 |
+
| Pipeline completeness | 3.0 | Fraction of 7 core milestones completed |
|
| 151 |
+
| Calibration | 4.0 | How well conclusions match hidden markers and mechanisms |
|
| 152 |
+
| Budget + time efficiency | 1.0 | Average fraction of budget and time remaining |
|
| 153 |
+
| Overconfidence penalty | −0.5/claim | For high-confidence claims (`> 0.8`) that are wrong |
|
| 154 |
|
| 155 |
This makes the environment easier to debug, benchmark, and train against.
|
| 156 |
|
|
|
|
| 161 |
On `reset()` it:
|
| 162 |
|
| 163 |
- seeds the noise model
|
| 164 |
+
- generates a task and latent state via `TaskGenerator`
|
| 165 |
- clears history, outputs, discoveries, conclusions, and cumulative reward
|
| 166 |
|
| 167 |
On `step()` it:
|
|
|
|
| 187 |
5. Otherwise deduct budget and time based on `ACTION_COSTS`.
|
| 188 |
6. Update latent progress flags like `samples_collected`, `qc_performed`, or `de_performed`.
|
| 189 |
7. Generate a structured simulated output for the chosen action.
|
| 190 |
+
8. If there were soft violations, degrade output quality (×0.5) and attach warnings.
|
| 191 |
9. Propagate artifacts back into latent state, such as discovered DE genes or cluster names.
|
| 192 |
10. Compute decomposed reward from state transition plus output quality.
|
| 193 |
11. If the episode is ending, compute terminal reward from completeness and conclusion calibration.
|
| 194 |
12. Return an observation that exposes the visible summary but not the hidden truth.
|
| 195 |
|
| 196 |
+
## Action costs
|
| 197 |
+
|
| 198 |
+
Each action deducts from the episode's budget and time. Computational steps also accrue compute hours.
|
| 199 |
+
|
| 200 |
+
| Action | Budget | Time (days) |
|
| 201 |
+
|---|---|---|
|
| 202 |
+
| `sequence_cells` | $15,000 | 5 |
|
| 203 |
+
| `prepare_library` | $8,000 | 3 |
|
| 204 |
+
| `collect_sample` | $5,000 | 7 |
|
| 205 |
+
| `validate_marker` | $5,000 | 14 |
|
| 206 |
+
| `culture_cells` | $3,000 | 14 |
|
| 207 |
+
| `perturb_gene` | $2,000 | 3 |
|
| 208 |
+
| `perturb_compound` | $1,000 | 2 |
|
| 209 |
+
| `select_cohort` | $500 | 1 |
|
| 210 |
+
| `run_qc` | $100 | 0.5 |
|
| 211 |
+
| `integrate_batches` | $300 | 1 |
|
| 212 |
+
| `regulatory_network_inference` | $200 | 1 |
|
| 213 |
+
| `cluster_cells` | $150 | 0.5 |
|
| 214 |
+
| `differential_expression`, `trajectory_analysis`, `pathway_enrichment` | $100–200 | 0.5–1 |
|
| 215 |
+
| `filter_data`, `normalize_data`, `marker_selection` | $50–100 | 0.25–0.5 |
|
| 216 |
+
| `synthesize_conclusion`, `design_followup_experiment`, `request_subagent_review` | $0 | 0.25–0.5 |
|
| 217 |
+
|
| 218 |
## Typical successful pipeline
|
| 219 |
|
| 220 |
Most scenarios reward a sensible experiment order similar to:
|
|
|
|
| 231 |
`regulatory_network_inference`, `marker_selection`, `validate_marker`
|
| 232 |
9. `synthesize_conclusion`
|
| 233 |
|
| 234 |
+
The exact best sequence depends on the scenario:
|
| 235 |
|
| 236 |
- trajectory scenarios benefit from `trajectory_analysis` and regulatory inference
|
| 237 |
- biomarker scenarios benefit from DE, marker selection, and validation
|
| 238 |
- perturbation scenarios benefit from pathway-level interpretation
|
| 239 |
|
| 240 |
+
## Episode termination
|
| 241 |
+
|
| 242 |
+
An episode ends when one of the following happens:
|
| 243 |
+
|
| 244 |
+
- the agent chooses `synthesize_conclusion`
|
| 245 |
+
- resources are exhausted
|
| 246 |
+
- the environment reaches `MAX_STEPS` which is currently `30`
|
| 247 |
+
|
| 248 |
+
## Installation
|
| 249 |
+
|
| 250 |
+
Dependencies are managed with `uv`. The package requires Python ≥ 3.10.
|
| 251 |
+
|
| 252 |
+
```bash
|
| 253 |
+
# Core environment only
|
| 254 |
+
uv sync
|
| 255 |
+
|
| 256 |
+
# With dev/test tools
|
| 257 |
+
uv sync --extra dev
|
| 258 |
+
|
| 259 |
+
# With training dependencies (TRL, transformers, torch)
|
| 260 |
+
uv sync --extra train
|
| 261 |
+
|
| 262 |
+
# With bioinformatics extras (scanpy, biopython, gseapy)
|
| 263 |
+
uv sync --extra bio
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
Key dependency groups from `pyproject.toml`:
|
| 267 |
+
|
| 268 |
+
| Group | Key packages |
|
| 269 |
+
|---|---|
|
| 270 |
+
| core | `openenv-core[core]>=0.2.0`, `numpy`, `scipy`, `pydantic>=2.0` |
|
| 271 |
+
| train | `trl>=0.29`, `transformers>=5.3`, `accelerate`, `datasets`, `torch`, `matplotlib` |
|
| 272 |
+
| bio | `scanpy`, `biopython`, `gseapy` |
|
| 273 |
+
| dev | `pytest`, `pytest-cov` |
|
| 274 |
+
|
| 275 |
## Interfaces you can use
|
| 276 |
|
| 277 |
### 1. In-process environment
|
|
|
|
| 297 |
print(obs.reward)
|
| 298 |
```
|
| 299 |
|
| 300 |
+
The constructor accepts:
|
| 301 |
+
- `scenario_name: Optional[str]` — pin to a specific scenario; `None` picks randomly each episode
|
| 302 |
+
- `domain_randomise: bool = True` — perturbs scenario parameters for generalization
|
| 303 |
+
|
| 304 |
### 2. OpenEnv client/server mode
|
| 305 |
|
| 306 |
Use the FastAPI app when you want to serve the environment over HTTP and WebSocket:
|
|
|
|
| 310 |
uv run uvicorn server.app:app --reload
|
| 311 |
```
|
| 312 |
|
| 313 |
+
The server exposes five endpoints:
|
| 314 |
+
|
| 315 |
+
| Method | Path | Description |
|
| 316 |
+
|---|---|---|
|
| 317 |
+
| `POST` | `/reset` | Start a new episode |
|
| 318 |
+
| `POST` | `/step` | Execute one action |
|
| 319 |
+
| `GET` | `/state` | Current environment state |
|
| 320 |
+
| `GET` | `/schema` | Action/observation JSON schemas |
|
| 321 |
+
| `WS` | `/ws` | WebSocket for persistent sessions |
|
| 322 |
+
|
| 323 |
Then connect with the client:
|
| 324 |
|
| 325 |
```python
|
|
|
|
| 334 |
|
| 335 |
The environment class supports concurrent sessions, but the bundled server is currently configured with `max_concurrent_envs=1` in `server/app.py`.
|
| 336 |
|
| 337 |
+
### 3. Running a local agent
|
| 338 |
|
| 339 |
+
`run_agent.py` runs a single interactive episode using a local Hugging Face model:
|
| 340 |
|
| 341 |
+
```bash
|
| 342 |
+
uv run python run_agent.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
```
|
| 344 |
|
| 345 |
+
Configuration is via environment variables:
|
| 346 |
+
|
| 347 |
+
| Variable | Default | Description |
|
| 348 |
+
|---|---|---|
|
| 349 |
+
| `RUN_AGENT_USE_PIPELINE` | `0` | Use HF `pipeline()` path instead of direct generate |
|
| 350 |
+
| `RUN_AGENT_MAX_EPISODE_STEPS` | `12` | Maximum number of planning steps |
|
| 351 |
+
|
| 352 |
+
The local model defaults to `Qwen/Qwen3.5-0.8B` with sampling parameters `temperature=0.7`, `top_p=0.8`, `top_k=20`, `repetition_penalty=1.3`. The episode runs up to `MAX_EPISODE_STEPS = 12` steps. When action parsing fails, the script falls back to an observation-aware action that respects prerequisites.
|
| 353 |
+
|
| 354 |
+
PowerShell note: older PowerShell versions do not support `&&`. Run commands from the target directory directly, or use `;` as the command separator.
|
| 355 |
+
|
| 356 |
+
Windows runtime warnings:
|
| 357 |
+
- If you see HuggingFace symlink-cache warnings, functionality is unaffected; optionally set `HF_HUB_DISABLE_SYMLINKS_WARNING=1`.
|
| 358 |
+
- If you see flash attention / causal-conv fallback warnings, execution continues with a slower PyTorch path.
|
| 359 |
|
| 360 |
+
### 4. GRPO training
|
| 361 |
+
|
| 362 |
+
`training_script.py` follows the TRL GRPO pattern and uses OpenEnv rewards to score generated action JSON against this environment.
|
| 363 |
+
|
| 364 |
+
```bash
|
| 365 |
+
uv sync --extra train
|
| 366 |
+
uv run python training_script.py --dry-run
|
| 367 |
+
uv run python training_script.py --model-id Qwen/Qwen3.5-0.8B
|
| 368 |
+
```
|
| 369 |
|
| 370 |
+
Key arguments:
|
| 371 |
+
|
| 372 |
+
| Argument | Default | Description |
|
| 373 |
+
|---|---|---|
|
| 374 |
+
| `--model-id` | `Qwen/Qwen2.5-7B-Instruct` | Base model to fine-tune |
|
| 375 |
+
| `--output-dir` | `training/grpo-output` | Save directory |
|
| 376 |
+
| `--dataset-episodes` | `8` | Rollout episodes for prompt dataset |
|
| 377 |
+
| `--rollout-steps` | `6` | Steps per episode during collection |
|
| 378 |
+
| `--collection-policy` | `heuristic` | `random` or `heuristic` |
|
| 379 |
+
| `--reward-backend` | `local` | `local` (in-process) or `remote` (live server) |
|
| 380 |
+
| `--base-url` | `http://localhost:8000` | Server URL for remote backend |
|
| 381 |
+
| `--scenario-name` | all | Repeatable; restricts which scenarios are used |
|
| 382 |
+
| `--domain-randomise` | off | Enable domain randomisation |
|
| 383 |
+
| `--num-generations` | `4` | GRPO generations per prompt |
|
| 384 |
+
| `--max-completion-length` | `220` | Max tokens for model completions |
|
| 385 |
+
| `--max-prompt-length` | `768` | Max tokens for prompts |
|
| 386 |
+
| `--learning-rate` | `5e-6` | AdamW learning rate |
|
| 387 |
+
| `--dry-run` | off | Build data and test reward without training |
|
| 388 |
+
|
| 389 |
+
By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
|
| 390 |
+
|
| 391 |
+
After training, the script saves plots to the output directory:
|
| 392 |
+
|
| 393 |
+
- `training_loss.png`
|
| 394 |
+
- `training_reward.png`
|
| 395 |
+
- `training_metric.png`
|
| 396 |
+
- `training_dashboard.png`
|
| 397 |
+
- `training_plot_manifest.json`
|
| 398 |
+
|
| 399 |
+
Use `--plot-metric-key <logged_key>` to force a specific extra metric on the third chart; otherwise the script auto-selects a useful logged metric such as KL or gradient norm.
|
| 400 |
+
|
| 401 |
+
### 5. Rollout collection
|
| 402 |
+
|
| 403 |
+
`training/rollout_collection.py` collects direct environment rollouts into trajectory files:
|
| 404 |
+
|
| 405 |
+
```bash
|
| 406 |
+
uv run python -m training.rollout_collection
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
This runs N episodes with a `random` or `heuristic` policy, saves JSON trajectories, and prints evaluation metrics.
|
| 410 |
+
|
| 411 |
+
### 6. Benchmark and scripted agents
|
| 412 |
|
| 413 |
- `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
|
| 414 |
+
- `training/rollout_collection.py` collects direct environment rollouts into trajectory files
|
| 415 |
+
- `training_script.py` trains a GRPO policy with OpenEnv reward calls
|
| 416 |
- `run_agent.py` runs a local language model planner against the environment
|
| 417 |
- `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
|
| 418 |
- `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
|
| 419 |
|
| 420 |
+
## Training utilities
|
| 421 |
|
| 422 |
+
### `training/trajectory.py`
|
| 423 |
|
| 424 |
+
Provides `TrajectoryStep`, `Trajectory`, and `TrajectoryDataset` for episode serialization.
|
| 425 |
+
|
| 426 |
+
- `TrajectoryStep` stores `action`, `observation`, `reward`, `done`, `reward_breakdown`, and an optional `latent_snapshot`
|
| 427 |
+
- `Trajectory` accumulates steps with `add_step()`, computes `total_reward`, and exposes `save(path)` / `load(path)`
|
| 428 |
+
- `TrajectoryDataset` wraps a list of trajectories with `filter_successful()`, `save_dir()`, `load_dir()`, and `summary()` (n, success_rate, mean_reward, mean_length, max/min reward)
|
| 429 |
+
|
| 430 |
+
### `training/evaluation.py`
|
| 431 |
+
|
| 432 |
+
`EvaluationSuite` is a stateless class with four families of `@staticmethod` methods:
|
| 433 |
+
|
| 434 |
+
| Family | Method | Metrics |
|
| 435 |
+
|---|---|---|
|
| 436 |
+
| Online RL | `online_metrics(trajectories)` | `mean_return`, `median_return`, `std_return`, `mean_episode_length`, `success_rate` |
|
| 437 |
+
| Offline benchmark | `benchmark_metrics(dataset)` | `pipeline_validity_rate`, `ordering_score`, `action_diversity`, `mean_conclusion_confidence` |
|
| 438 |
+
| Expert review | `expert_review_metrics(...)` | Placeholder; averages provided scores |
|
| 439 |
+
| Simulator fidelity | `simulator_fidelity_metrics(sim, real)` | `reward_distribution_gap` |
|
| 440 |
+
|
| 441 |
+
### `training/literature_benchmark.py`
|
| 442 |
+
|
| 443 |
+
`run_paper_benchmark(problem_statement, scenario_name, domain_randomise)` runs a paper-aligned action pipeline and scores against `expected_findings` using keyword matching. Returns a `PaperBenchmarkResult` with `match_ratio`.
|
| 444 |
+
|
| 445 |
+
## Docker deployment
|
| 446 |
+
|
| 447 |
+
The server ships with a `server/Dockerfile`. It uses a multi-stage build based on `openenv-base`, installs dependencies via `uv`, and starts `uvicorn server.app:app` on port 8000.
|
| 448 |
+
|
| 449 |
+
```bash
|
| 450 |
+
docker build -f server/Dockerfile -t bio-experiment-env .
|
| 451 |
+
docker run -p 8000:8000 bio-experiment-env
|
| 452 |
+
```
|
| 453 |
+
|
| 454 |
+
The `openenv.yaml` file configures the deployment for the OpenEnv platform:
|
| 455 |
+
|
| 456 |
+
```yaml
|
| 457 |
+
spec_version: 1
|
| 458 |
+
name: hackathon
|
| 459 |
+
type: space
|
| 460 |
+
runtime: fastapi
|
| 461 |
+
app: server.app:app
|
| 462 |
+
port: 8000
|
| 463 |
+
```
|
| 464 |
|
| 465 |
## Why this is useful
|
| 466 |
|
|
|
|
| 479 |
- literature-grounded evaluation
|
| 480 |
- comparing structured policies against LLM-driven planners
|
| 481 |
|
| 482 |
+
## Project map
|
| 483 |
|
| 484 |
```text
|
| 485 |
.
|
| 486 |
+
├── client.py # OpenEnv HTTP/WebSocket client
|
| 487 |
├── models.py # Shared action / observation / task schemas
|
| 488 |
+
├── openenv.yaml # OpenEnv platform deployment config
|
| 489 |
+
├── pyproject.toml # Package metadata and dependency groups
|
| 490 |
+
├── run_agent.py # Single-episode interactive agent runner
|
| 491 |
├── server/
|
| 492 |
+
│ ├── app.py # FastAPI/OpenEnv server entry point
|
| 493 |
+
│ ├── Dockerfile # Multi-stage Docker build
|
| 494 |
│ ├── hackathon_environment.py # Main environment orchestration
|
| 495 |
+
│ ├── requirements.txt # Minimal server dependencies
|
| 496 |
+
│ ├── rewards/
|
| 497 |
+
│ │ └── reward.py # Decomposed reward model
|
| 498 |
+
│ ├── rules/
|
| 499 |
+
│ │ └── engine.py # Biological constraint checking
|
| 500 |
+
│ ├── simulator/
|
| 501 |
+
│ │ ├── latent_state.py # Hidden biological, technical, progress, resource state
|
| 502 |
+
│ │ ├── noise.py # Seeded stochastic noise model
|
| 503 |
+
│ │ ├── output_generator.py # Per-action simulated output generation
|
| 504 |
+
│ │ └── transition.py # State transition engine and ACTION_COSTS table
|
| 505 |
+
│ ├── subagents/ # Placeholder for future sub-agent integration
|
| 506 |
+
│ └── tasks/
|
| 507 |
+
│ ├── generator.py # TaskGenerator with domain randomisation
|
| 508 |
+
│ └── scenarios.py # SCENARIO_LIBRARY with 4 curated scenarios
|
| 509 |
├── training/
|
| 510 |
+
│ ├── evaluation.py # EvaluationSuite metrics
|
|
|
|
| 511 |
│ ├── literature_benchmark.py # Paper-backed benchmark flow
|
| 512 |
+
│ ├── rollout_collection.py # Direct rollout collection helper
|
| 513 |
+
│ └── trajectory.py # Trajectory serialization and dataset utilities
|
| 514 |
+
├── training_script.py # TRL GRPO training entry point
|
| 515 |
+
└── tests/
|
| 516 |
+
├── test_environment.py
|
| 517 |
+
├── test_literature_benchmark.py
|
| 518 |
+
├── test_models.py
|
| 519 |
+
├── test_rewards.py
|
| 520 |
+
├── test_rules.py
|
| 521 |
+
├── test_simulator.py
|
| 522 |
+
└── test_training_script.py
|
| 523 |
```
|
| 524 |
|
| 525 |
## Quick sanity check
|
| 526 |
|
|
|
|
|
|
|
| 527 |
```bash
|
| 528 |
uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
|
| 529 |
```
|
|
|
|
| 535 |
- conclusion termination
|
| 536 |
- literature-backed scenario selection
|
| 537 |
- benchmark matching for curated expected findings
|
| 538 |
+
|
| 539 |
+
Run the full suite with coverage:
|
| 540 |
+
|
| 541 |
+
```bash
|
| 542 |
+
uv run pytest tests/ --cov -q
|
| 543 |
+
```
|
__init__.py
CHANGED
|
@@ -1,48 +1,75 @@
|
|
| 1 |
try: # pragma: no cover - package import path
|
| 2 |
from .client import BioExperimentEnv
|
| 3 |
from .models import (
|
|
|
|
| 4 |
ActionType,
|
|
|
|
| 5 |
ConclusionClaim,
|
| 6 |
ExpectedFinding,
|
| 7 |
ExperimentAction,
|
| 8 |
ExperimentObservation,
|
| 9 |
IntermediateOutput,
|
|
|
|
|
|
|
| 10 |
OutputType,
|
| 11 |
PaperReference,
|
| 12 |
PipelineStepRecord,
|
| 13 |
ResourceUsage,
|
| 14 |
SubagentType,
|
|
|
|
| 15 |
TaskSpec,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
except ImportError: # pragma: no cover - direct module import path
|
| 18 |
from client import BioExperimentEnv
|
| 19 |
from models import (
|
|
|
|
| 20 |
ActionType,
|
|
|
|
| 21 |
ConclusionClaim,
|
| 22 |
ExpectedFinding,
|
| 23 |
ExperimentAction,
|
| 24 |
ExperimentObservation,
|
| 25 |
IntermediateOutput,
|
|
|
|
|
|
|
| 26 |
OutputType,
|
| 27 |
PaperReference,
|
| 28 |
PipelineStepRecord,
|
| 29 |
ResourceUsage,
|
| 30 |
SubagentType,
|
|
|
|
| 31 |
TaskSpec,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
__all__ = [
|
|
|
|
| 35 |
"ActionType",
|
|
|
|
| 36 |
"BioExperimentEnv",
|
| 37 |
"ConclusionClaim",
|
| 38 |
"ExpectedFinding",
|
| 39 |
"ExperimentAction",
|
| 40 |
"ExperimentObservation",
|
| 41 |
"IntermediateOutput",
|
|
|
|
|
|
|
| 42 |
"OutputType",
|
| 43 |
"PaperReference",
|
| 44 |
"PipelineStepRecord",
|
| 45 |
"ResourceUsage",
|
| 46 |
"SubagentType",
|
|
|
|
| 47 |
"TaskSpec",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
]
|
|
|
|
| 1 |
try: # pragma: no cover - package import path
|
| 2 |
from .client import BioExperimentEnv
|
| 3 |
from .models import (
|
| 4 |
+
ASSAY_REGISTRY,
|
| 5 |
ActionType,
|
| 6 |
+
AssaySpec,
|
| 7 |
ConclusionClaim,
|
| 8 |
ExpectedFinding,
|
| 9 |
ExperimentAction,
|
| 10 |
ExperimentObservation,
|
| 11 |
IntermediateOutput,
|
| 12 |
+
MODALITY_REGISTRY,
|
| 13 |
+
ModalitySpec,
|
| 14 |
OutputType,
|
| 15 |
PaperReference,
|
| 16 |
PipelineStepRecord,
|
| 17 |
ResourceUsage,
|
| 18 |
SubagentType,
|
| 19 |
+
TOOL_REGISTRY,
|
| 20 |
TaskSpec,
|
| 21 |
+
ToolSpec,
|
| 22 |
+
assays_for_modality,
|
| 23 |
+
tools_by_category,
|
| 24 |
+
tools_for_modality,
|
| 25 |
)
|
| 26 |
except ImportError: # pragma: no cover - direct module import path
|
| 27 |
from client import BioExperimentEnv
|
| 28 |
from models import (
|
| 29 |
+
ASSAY_REGISTRY,
|
| 30 |
ActionType,
|
| 31 |
+
AssaySpec,
|
| 32 |
ConclusionClaim,
|
| 33 |
ExpectedFinding,
|
| 34 |
ExperimentAction,
|
| 35 |
ExperimentObservation,
|
| 36 |
IntermediateOutput,
|
| 37 |
+
MODALITY_REGISTRY,
|
| 38 |
+
ModalitySpec,
|
| 39 |
OutputType,
|
| 40 |
PaperReference,
|
| 41 |
PipelineStepRecord,
|
| 42 |
ResourceUsage,
|
| 43 |
SubagentType,
|
| 44 |
+
TOOL_REGISTRY,
|
| 45 |
TaskSpec,
|
| 46 |
+
ToolSpec,
|
| 47 |
+
assays_for_modality,
|
| 48 |
+
tools_by_category,
|
| 49 |
+
tools_for_modality,
|
| 50 |
)
|
| 51 |
|
| 52 |
__all__ = [
|
| 53 |
+
"ASSAY_REGISTRY",
|
| 54 |
"ActionType",
|
| 55 |
+
"AssaySpec",
|
| 56 |
"BioExperimentEnv",
|
| 57 |
"ConclusionClaim",
|
| 58 |
"ExpectedFinding",
|
| 59 |
"ExperimentAction",
|
| 60 |
"ExperimentObservation",
|
| 61 |
"IntermediateOutput",
|
| 62 |
+
"MODALITY_REGISTRY",
|
| 63 |
+
"ModalitySpec",
|
| 64 |
"OutputType",
|
| 65 |
"PaperReference",
|
| 66 |
"PipelineStepRecord",
|
| 67 |
"ResourceUsage",
|
| 68 |
"SubagentType",
|
| 69 |
+
"TOOL_REGISTRY",
|
| 70 |
"TaskSpec",
|
| 71 |
+
"ToolSpec",
|
| 72 |
+
"assays_for_modality",
|
| 73 |
+
"tools_by_category",
|
| 74 |
+
"tools_for_modality",
|
| 75 |
]
|
_dashboard_state.json
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": 1772961529.009918,
|
| 3 |
+
"step": 4,
|
| 4 |
+
"episode_done": false,
|
| 5 |
+
"cumulative_reward": 1.8781281693998408,
|
| 6 |
+
"gen_time_s": 15.02,
|
| 7 |
+
"model_response_raw": "{\n\"action_type\": \"prepare_library\",\n\"method\": \"\",\n\"parameters\": {\n\"cutoff_value\": \"\"\n},\n\"Justification\": \"The error 'Cannot sequence without library preparation' indicates that while sample collection was completed successfully via dry-polymerization technology ('collectible'), mature cDNA libraries have yet to be constructed from those frozen tissues. Library prep requires reverse transcription followed by amplification/cleanup using polymerases/ligation enzymes. Since all other required conditions met but materials remain unprocessed, preparing high-throughput Illumina libraries first allows",
|
| 8 |
+
"model_thinking": "",
|
| 9 |
+
"thinking_enabled": false,
|
| 10 |
+
"task": {
|
| 11 |
+
"problem_statement": "Determine the effect of anti-TNF biologic therapy on cell states in synovium tissue affected by rheumatoid arthritis.",
|
| 12 |
+
"organism": "human",
|
| 13 |
+
"tissue": "synovium",
|
| 14 |
+
"modality": "scRNA-seq",
|
| 15 |
+
"conditions": [
|
| 16 |
+
"untreated_rheumatoid_arthritis",
|
| 17 |
+
"anti_TNF_treated"
|
| 18 |
+
],
|
| 19 |
+
"budget_limit": 80340.57829173621,
|
| 20 |
+
"time_limit_days": 140.67464277202689
|
| 21 |
+
},
|
| 22 |
+
"resources": {
|
| 23 |
+
"budget_used": 13500.0,
|
| 24 |
+
"budget_remaining": 66840.58,
|
| 25 |
+
"time_used_days": 11.0,
|
| 26 |
+
"time_remaining_days": 129.7,
|
| 27 |
+
"samples_consumed": 0,
|
| 28 |
+
"compute_hours_used": 0.0
|
| 29 |
+
},
|
| 30 |
+
"pipeline_history": [
|
| 31 |
+
{
|
| 32 |
+
"step_index": 1,
|
| 33 |
+
"action_type": "collect_sample",
|
| 34 |
+
"method": null,
|
| 35 |
+
"output_summary": "Collected 6 samples (quality=0.54)",
|
| 36 |
+
"success": true,
|
| 37 |
+
"quality_score": 0.541,
|
| 38 |
+
"resource_cost": 5000.0,
|
| 39 |
+
"time_cost_days": 7.0
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"step_index": 2,
|
| 43 |
+
"action_type": "select_cohort",
|
| 44 |
+
"method": null,
|
| 45 |
+
"output_summary": "Selected cohort of 4 samples with criteria {}",
|
| 46 |
+
"success": true,
|
| 47 |
+
"quality_score": 1.0,
|
| 48 |
+
"resource_cost": 500.0,
|
| 49 |
+
"time_cost_days": 1.0
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"step_index": 3,
|
| 53 |
+
"action_type": "sequence_cells",
|
| 54 |
+
"method": null,
|
| 55 |
+
"output_summary": "Action blocked: Cannot sequence without library preparation",
|
| 56 |
+
"success": false,
|
| 57 |
+
"quality_score": 1.0,
|
| 58 |
+
"resource_cost": 15000.0,
|
| 59 |
+
"time_cost_days": 5.0
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"step_index": 4,
|
| 63 |
+
"action_type": "prepare_library",
|
| 64 |
+
"method": null,
|
| 65 |
+
"output_summary": "Library prepared (complexity=0.72)",
|
| 66 |
+
"success": true,
|
| 67 |
+
"quality_score": 0.718,
|
| 68 |
+
"resource_cost": 8000.0,
|
| 69 |
+
"time_cost_days": 3.0
|
| 70 |
+
}
|
| 71 |
+
],
|
| 72 |
+
"current_action": {
|
| 73 |
+
"action_type": "prepare_library",
|
| 74 |
+
"method": null,
|
| 75 |
+
"parameters": {
|
| 76 |
+
"cutoff_value": ""
|
| 77 |
+
},
|
| 78 |
+
"justification": "The error 'Cannot sequence without library preparation' indicates that while sample collection was completed successfully via dry-polymerization technology ('collectible'), mature cDNA libraries have yet to be constructed from those frozen tissues. Library prep requires reverse transcription followed by amplification/cleanup using polymerases/ligation enzymes. Since all other required conditions met but materials remain unprocessed, preparing high-throughput Illumina libraries first allows conversion of input DNA/RNA directly into sequencable formats.",
|
| 79 |
+
"confidence": 0.95
|
| 80 |
+
},
|
| 81 |
+
"latest_output": {
|
| 82 |
+
"summary": "Library prepared (complexity=0.72)",
|
| 83 |
+
"success": true,
|
| 84 |
+
"quality_score": 0.718,
|
| 85 |
+
"uncertainty": 0.0,
|
| 86 |
+
"warnings": [],
|
| 87 |
+
"data_preview": "{\"library_complexity\": 0.7177962559040082, \"method\": \"10x_chromium\"}"
|
| 88 |
+
},
|
| 89 |
+
"discovered_markers": [],
|
| 90 |
+
"candidate_mechanisms": [],
|
| 91 |
+
"rule_violations": [],
|
| 92 |
+
"uncertainty_summary": {
|
| 93 |
+
"avg_uncertainty": 0.0,
|
| 94 |
+
"avg_quality": 0.815
|
| 95 |
+
},
|
| 96 |
+
"reward_breakdown": {},
|
| 97 |
+
"latent": {
|
| 98 |
+
"cell_populations": [
|
| 99 |
+
{
|
| 100 |
+
"name": "fibroblast",
|
| 101 |
+
"proportion": 0.445,
|
| 102 |
+
"marker_genes": [
|
| 103 |
+
"COL1A1",
|
| 104 |
+
"FAP",
|
| 105 |
+
"THY1"
|
| 106 |
+
],
|
| 107 |
+
"state": "activated"
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"name": "CD4_T_cell",
|
| 111 |
+
"proportion": 0.179,
|
| 112 |
+
"marker_genes": [
|
| 113 |
+
"CD3D",
|
| 114 |
+
"CD4",
|
| 115 |
+
"IL7R"
|
| 116 |
+
],
|
| 117 |
+
"state": "quiescent"
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"name": "CD8_T_cell",
|
| 121 |
+
"proportion": 0.139,
|
| 122 |
+
"marker_genes": [
|
| 123 |
+
"CD3D",
|
| 124 |
+
"CD8A",
|
| 125 |
+
"GZMB"
|
| 126 |
+
],
|
| 127 |
+
"state": "activated"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "B_cell",
|
| 131 |
+
"proportion": 0.142,
|
| 132 |
+
"marker_genes": [
|
| 133 |
+
"CD19",
|
| 134 |
+
"MS4A1",
|
| 135 |
+
"CD79A"
|
| 136 |
+
],
|
| 137 |
+
"state": "quiescent"
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"name": "endothelial",
|
| 141 |
+
"proportion": 0.096,
|
| 142 |
+
"marker_genes": [
|
| 143 |
+
"PECAM1",
|
| 144 |
+
"VWF"
|
| 145 |
+
],
|
| 146 |
+
"state": "quiescent"
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"true_markers": [
|
| 150 |
+
"TNF",
|
| 151 |
+
"IL6",
|
| 152 |
+
"MMP3",
|
| 153 |
+
"CXCL13"
|
| 154 |
+
],
|
| 155 |
+
"causal_mechanisms": [
|
| 156 |
+
"TNF/NF-kB-driven synovial inflammation",
|
| 157 |
+
"Th17-mediated cartilage destruction via MMPs"
|
| 158 |
+
],
|
| 159 |
+
"true_pathways": {
|
| 160 |
+
"JAK_STAT_signalling": 0.785,
|
| 161 |
+
"TNF_signalling": 0.723,
|
| 162 |
+
"Th17_differentiation": 0.633,
|
| 163 |
+
"NF_kB_signalling": 0.826,
|
| 164 |
+
"matrix_metalloproteinase_activity": 0.847
|
| 165 |
+
},
|
| 166 |
+
"true_de_genes_count": 9,
|
| 167 |
+
"true_regulatory_network_size": 16,
|
| 168 |
+
"confounders": {},
|
| 169 |
+
"n_true_cells": 15873,
|
| 170 |
+
"technical": {
|
| 171 |
+
"ambient_rna_fraction": 0.037873267501661645,
|
| 172 |
+
"doublet_rate": 0.03797665930677535,
|
| 173 |
+
"dropout_rate": 0.14738025069803395,
|
| 174 |
+
"sample_quality": 0.9068064354870293,
|
| 175 |
+
"library_complexity": 0.8,
|
| 176 |
+
"capture_efficiency": 0.6
|
| 177 |
+
},
|
| 178 |
+
"progress": {
|
| 179 |
+
"samples_collected": true,
|
| 180 |
+
"cohort_selected": true,
|
| 181 |
+
"cells_cultured": false,
|
| 182 |
+
"library_prepared": true,
|
| 183 |
+
"perturbation_applied": false,
|
| 184 |
+
"cells_sequenced": false,
|
| 185 |
+
"qc_performed": false,
|
| 186 |
+
"data_filtered": false,
|
| 187 |
+
"data_normalized": false,
|
| 188 |
+
"batches_integrated": false,
|
| 189 |
+
"cells_clustered": false,
|
| 190 |
+
"de_performed": false,
|
| 191 |
+
"trajectories_inferred": false,
|
| 192 |
+
"pathways_analyzed": false,
|
| 193 |
+
"networks_inferred": false,
|
| 194 |
+
"markers_discovered": false,
|
| 195 |
+
"markers_validated": false,
|
| 196 |
+
"followup_designed": false,
|
| 197 |
+
"subagent_review_requested": false,
|
| 198 |
+
"conclusion_reached": false,
|
| 199 |
+
"n_cells_sequenced": null,
|
| 200 |
+
"n_cells_after_filter": null,
|
| 201 |
+
"n_clusters_found": null,
|
| 202 |
+
"n_de_genes_found": null,
|
| 203 |
+
"n_markers_found": null
|
| 204 |
+
},
|
| 205 |
+
"hidden_failure_conditions": []
|
| 206 |
+
}
|
| 207 |
+
}
|
dashboard.html
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 6 |
+
<title>Bio-Experiment Agent Dashboard</title>
|
| 7 |
+
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
| 8 |
+
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=DM+Sans:wght@400;500;700&display=swap" rel="stylesheet" />
|
| 9 |
+
<style>
|
| 10 |
+
:root {
|
| 11 |
+
--bg: #0c0e14;
|
| 12 |
+
--surface: #151822;
|
| 13 |
+
--surface2: #1c2030;
|
| 14 |
+
--border: #2a2f42;
|
| 15 |
+
--text: #e2e4ea;
|
| 16 |
+
--text-dim: #8b90a5;
|
| 17 |
+
--accent: #5ce0d8;
|
| 18 |
+
--accent2: #7c6cf0;
|
| 19 |
+
--green: #4ade80;
|
| 20 |
+
--red: #f87171;
|
| 21 |
+
--amber: #fbbf24;
|
| 22 |
+
--blue: #60a5fa;
|
| 23 |
+
--pink: #f472b6;
|
| 24 |
+
}
|
| 25 |
+
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 26 |
+
body { background: var(--bg); color: var(--text); font-family: 'DM Sans', system-ui, sans-serif; line-height: 1.5; min-height: 100vh; }
|
| 27 |
+
.mono { font-family: 'JetBrains Mono', monospace; }
|
| 28 |
+
|
| 29 |
+
.header { display: flex; align-items: center; justify-content: space-between; padding: 14px 28px; border-bottom: 1px solid var(--border); background: var(--surface); }
|
| 30 |
+
.header h1 { font-size: 18px; font-weight: 700; letter-spacing: -.3px; }
|
| 31 |
+
.header h1 span { color: var(--accent); }
|
| 32 |
+
.header-right { display: flex; align-items: center; gap: 10px; }
|
| 33 |
+
.status-pill { font-size: 12px; padding: 4px 14px; border-radius: 20px; font-weight: 600; text-transform: uppercase; letter-spacing: .5px; }
|
| 34 |
+
.status-pill.live { background: rgba(76,222,128,.15); color: var(--green); }
|
| 35 |
+
.status-pill.done { background: rgba(248,113,113,.15); color: var(--red); }
|
| 36 |
+
.status-pill.waiting { background: rgba(139,144,165,.15); color: var(--text-dim); }
|
| 37 |
+
|
| 38 |
+
.btn { padding: 6px 16px; border-radius: 8px; border: 1px solid var(--border); background: var(--surface2); color: var(--text); font-size: 12px; font-weight: 600; cursor: pointer; transition: all .15s; }
|
| 39 |
+
.btn:hover { border-color: var(--accent); color: var(--accent); }
|
| 40 |
+
.btn.primary { background: rgba(92,224,216,.12); border-color: var(--accent); color: var(--accent); }
|
| 41 |
+
.btn.primary:hover { background: rgba(92,224,216,.25); }
|
| 42 |
+
.btn.danger { border-color: var(--red); color: var(--red); }
|
| 43 |
+
.btn.danger:hover { background: rgba(248,113,113,.12); }
|
| 44 |
+
|
| 45 |
+
.grid { display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 16px; padding: 20px 28px; max-width: 1600px; }
|
| 46 |
+
@media (max-width: 1100px) { .grid { grid-template-columns: 1fr 1fr; } }
|
| 47 |
+
@media (max-width: 700px) { .grid { grid-template-columns: 1fr; } }
|
| 48 |
+
|
| 49 |
+
.card { background: var(--surface); border: 1px solid var(--border); border-radius: 12px; padding: 18px 20px; overflow: hidden; }
|
| 50 |
+
.card h2 { font-size: 11px; font-weight: 600; text-transform: uppercase; letter-spacing: 1px; color: var(--text-dim); margin-bottom: 12px; }
|
| 51 |
+
.card.span2 { grid-column: span 2; }
|
| 52 |
+
.card.span3 { grid-column: span 3; }
|
| 53 |
+
@media (max-width: 700px) { .card.span2, .card.span3 { grid-column: span 1; } }
|
| 54 |
+
|
| 55 |
+
.gauge-row { display: flex; gap: 14px; flex-wrap: wrap; }
|
| 56 |
+
.gauge { flex: 1; min-width: 130px; background: var(--surface2); border-radius: 10px; padding: 14px; }
|
| 57 |
+
.gauge-label { font-size: 11px; color: var(--text-dim); margin-bottom: 6px; text-transform: uppercase; letter-spacing: .5px; }
|
| 58 |
+
.gauge-value { font-size: 22px; font-weight: 700; }
|
| 59 |
+
.gauge-bar { height: 5px; border-radius: 3px; background: var(--border); margin-top: 8px; overflow: hidden; }
|
| 60 |
+
.gauge-bar-fill { height: 100%; border-radius: 3px; transition: width .6s ease; }
|
| 61 |
+
|
| 62 |
+
.timeline { position: relative; padding-left: 20px; }
|
| 63 |
+
.timeline::before { content: ''; position: absolute; left: 6px; top: 0; bottom: 0; width: 2px; background: var(--border); }
|
| 64 |
+
.timeline-item { position: relative; margin-bottom: 14px; padding-left: 18px; }
|
| 65 |
+
.timeline-item::before { content: ''; position: absolute; left: -18px; top: 6px; width: 10px; height: 10px; border-radius: 50%; border: 2px solid var(--accent); background: var(--bg); }
|
| 66 |
+
.timeline-item.fail::before { border-color: var(--red); }
|
| 67 |
+
.tl-action { font-weight: 600; font-size: 14px; }
|
| 68 |
+
.tl-meta { font-size: 12px; color: var(--text-dim); margin-top: 2px; }
|
| 69 |
+
|
| 70 |
+
.mini-table { width: 100%; font-size: 13px; border-collapse: collapse; }
|
| 71 |
+
.mini-table td { padding: 5px 8px; border-bottom: 1px solid var(--border); vertical-align: top; }
|
| 72 |
+
.mini-table td:first-child { color: var(--text-dim); white-space: nowrap; width: 40%; }
|
| 73 |
+
|
| 74 |
+
.tag-list { display: flex; flex-wrap: wrap; gap: 6px; }
|
| 75 |
+
.tag { font-size: 12px; padding: 3px 10px; border-radius: 6px; background: var(--surface2); border: 1px solid var(--border); font-family: 'JetBrains Mono', monospace; }
|
| 76 |
+
.tag.green { border-color: rgba(76,222,128,.3); color: var(--green); }
|
| 77 |
+
.tag.pink { border-color: rgba(244,114,182,.3); color: var(--pink); }
|
| 78 |
+
.tag.amber { border-color: rgba(251,191,36,.3); color: var(--amber); }
|
| 79 |
+
.tag.red { border-color: rgba(248,113,113,.3); color: var(--red); }
|
| 80 |
+
.tag.match { background: rgba(76,222,128,.15); }
|
| 81 |
+
.tag.miss { background: rgba(248,113,113,.08); }
|
| 82 |
+
|
| 83 |
+
.code-block { background: var(--surface2); border: 1px solid var(--border); border-radius: 8px; padding: 12px 14px; font-family: 'JetBrains Mono', monospace; font-size: 12px; white-space: pre-wrap; word-break: break-all; max-height: 220px; overflow-y: auto; color: var(--text-dim); line-height: 1.6; }
|
| 84 |
+
|
| 85 |
+
.progress-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 6px; }
|
| 86 |
+
.progress-item { display: flex; align-items: center; gap: 6px; font-size: 12px; }
|
| 87 |
+
.dot { width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0; background: var(--border); }
|
| 88 |
+
.dot.done { background: var(--green); }
|
| 89 |
+
|
| 90 |
+
.pop-bar-container { margin-bottom: 10px; }
|
| 91 |
+
.pop-bar-label { font-size: 12px; margin-bottom: 3px; display: flex; justify-content: space-between; }
|
| 92 |
+
.pop-bar { height: 14px; border-radius: 4px; background: var(--surface2); overflow: hidden; }
|
| 93 |
+
.pop-bar-fill { height: 100%; border-radius: 4px; }
|
| 94 |
+
|
| 95 |
+
#reward-chart { width: 100%; height: 120px; }
|
| 96 |
+
::-webkit-scrollbar { width: 6px; }
|
| 97 |
+
::-webkit-scrollbar-track { background: transparent; }
|
| 98 |
+
::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
|
| 99 |
+
|
| 100 |
+
.conclusion-card { background: var(--surface2); border: 1px solid var(--border); border-radius: 10px; padding: 14px 16px; margin-bottom: 12px; }
|
| 101 |
+
.conclusion-card .cc-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px; }
|
| 102 |
+
.cc-type { font-size: 11px; padding: 2px 10px; border-radius: 4px; font-weight: 600; text-transform: uppercase; letter-spacing: .5px; }
|
| 103 |
+
.cc-type.causal { background: rgba(244,114,182,.15); color: var(--pink); }
|
| 104 |
+
.cc-type.correlative { background: rgba(96,165,250,.15); color: var(--blue); }
|
| 105 |
+
.cc-type.descriptive { background: rgba(139,144,165,.15); color: var(--text-dim); }
|
| 106 |
+
.cc-conf { font-family: 'JetBrains Mono', monospace; font-size: 13px; font-weight: 600; }
|
| 107 |
+
.cc-claim { font-size: 14px; margin-bottom: 8px; line-height: 1.5; }
|
| 108 |
+
.cc-section-label { font-size: 10px; color: var(--text-dim); text-transform: uppercase; letter-spacing: .5px; margin-bottom: 3px; margin-top: 8px; }
|
| 109 |
+
|
| 110 |
+
/* ── control panel ────────────────────────────── */
|
| 111 |
+
.control-panel { background: var(--surface); border: 1px solid var(--border); border-radius: 12px; margin: 20px 28px 0; padding: 18px 20px; }
|
| 112 |
+
.control-panel summary { cursor: pointer; font-size: 13px; font-weight: 600; color: var(--accent); }
|
| 113 |
+
.control-panel[open] summary { margin-bottom: 14px; }
|
| 114 |
+
.form-row { display: flex; gap: 12px; margin-bottom: 10px; flex-wrap: wrap; align-items: end; }
|
| 115 |
+
.form-field { display: flex; flex-direction: column; gap: 4px; }
|
| 116 |
+
.form-field label { font-size: 11px; color: var(--text-dim); text-transform: uppercase; letter-spacing: .5px; }
|
| 117 |
+
.form-field input, .form-field textarea, .form-field select {
|
| 118 |
+
background: var(--surface2); border: 1px solid var(--border); border-radius: 6px;
|
| 119 |
+
color: var(--text); padding: 7px 10px; font-size: 13px; font-family: inherit; outline: none;
|
| 120 |
+
}
|
| 121 |
+
.form-field input:focus, .form-field textarea:focus, .form-field select:focus { border-color: var(--accent); }
|
| 122 |
+
.form-field textarea { min-height: 60px; resize: vertical; }
|
| 123 |
+
|
| 124 |
+
/* ── final report ─────────────────────────────── */
|
| 125 |
+
.report-overlay { display: none; position: fixed; inset: 0; z-index: 100; background: rgba(12,14,20,.85); backdrop-filter: blur(6px); overflow-y: auto; padding: 40px 20px; }
|
| 126 |
+
.report-overlay.visible { display: flex; justify-content: center; align-items: flex-start; }
|
| 127 |
+
.report-card { background: var(--surface); border: 1px solid var(--border); border-radius: 16px; padding: 32px 36px; max-width: 900px; width: 100%; }
|
| 128 |
+
.report-card h2 { font-size: 22px; font-weight: 700; margin-bottom: 4px; color: var(--text); text-transform: none; letter-spacing: normal; }
|
| 129 |
+
.report-card .subtitle { font-size: 13px; color: var(--text-dim); margin-bottom: 20px; }
|
| 130 |
+
.report-section { margin-bottom: 20px; }
|
| 131 |
+
.report-section h3 { font-size: 12px; color: var(--accent); text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px; }
|
| 132 |
+
.comparison-row { display: flex; gap: 20px; margin-bottom: 16px; }
|
| 133 |
+
.comparison-col { flex: 1; }
|
| 134 |
+
.comparison-col h4 { font-size: 11px; color: var(--text-dim); text-transform: uppercase; margin-bottom: 6px; }
|
| 135 |
+
|
| 136 |
+
.pulse { animation: pulse 1.5s ease-in-out infinite; }
|
| 137 |
+
@keyframes pulse { 0%,100% { opacity: 1; } 50% { opacity: .5; } }
|
| 138 |
+
</style>
|
| 139 |
+
</head>
|
| 140 |
+
<body>
|
| 141 |
+
|
| 142 |
+
<div class="header">
|
| 143 |
+
<h1><span>BioExp</span> Agent Dashboard</h1>
|
| 144 |
+
<div class="header-right">
|
| 145 |
+
<span id="thinking-badge" class="mono" style="font-size:11px;color:var(--accent2);display:none">REASONING ON</span>
|
| 146 |
+
<span id="step-label" class="mono" style="font-size:13px;color:var(--text-dim)">Step 0</span>
|
| 147 |
+
<span id="status-pill" class="status-pill waiting">Waiting</span>
|
| 148 |
+
<button class="btn primary" onclick="doRestart()">Restart</button>
|
| 149 |
+
<button class="btn" onclick="showReport()">Report</button>
|
| 150 |
+
</div>
|
| 151 |
+
</div>
|
| 152 |
+
|
| 153 |
+
<!-- Control Panel (collapsible) -->
|
| 154 |
+
<details class="control-panel" id="control-panel">
|
| 155 |
+
<summary>New Task / Custom Ground Truth</summary>
|
| 156 |
+
<div class="form-row">
|
| 157 |
+
<div class="form-field" style="flex:2">
|
| 158 |
+
<label>Scenario (leave blank for random)</label>
|
| 159 |
+
<select id="f-scenario"><option value="">— random —</option></select>
|
| 160 |
+
</div>
|
| 161 |
+
<div class="form-field" style="flex:1">
|
| 162 |
+
<label>True Markers (comma-separated)</label>
|
| 163 |
+
<input id="f-markers" placeholder="e.g. MYH7, TNNT2, ACTA1" />
|
| 164 |
+
</div>
|
| 165 |
+
<div class="form-field" style="flex:1">
|
| 166 |
+
<label>Causal Mechanisms (comma-separated)</label>
|
| 167 |
+
<input id="f-mechanisms" placeholder="e.g. sarcomere dysfunction" />
|
| 168 |
+
</div>
|
| 169 |
+
</div>
|
| 170 |
+
<div class="form-row">
|
| 171 |
+
<div class="form-field" style="flex:2">
|
| 172 |
+
<label>True Pathways (name:score, comma-sep)</label>
|
| 173 |
+
<input id="f-pathways" placeholder="e.g. Wnt_signaling:0.8, MAPK:0.6" />
|
| 174 |
+
</div>
|
| 175 |
+
<div class="form-field">
|
| 176 |
+
<button class="btn primary" onclick="doCustomRun()">Run with Ground Truth</button>
|
| 177 |
+
</div>
|
| 178 |
+
</div>
|
| 179 |
+
</details>
|
| 180 |
+
|
| 181 |
+
<div class="grid">
|
| 182 |
+
<div class="card span2" id="card-task">
|
| 183 |
+
<h2>Task</h2>
|
| 184 |
+
<div id="task-statement" style="font-size:15px;font-weight:500;margin-bottom:8px;">—</div>
|
| 185 |
+
<div id="task-meta" style="font-size:13px;color:var(--text-dim)"></div>
|
| 186 |
+
</div>
|
| 187 |
+
|
| 188 |
+
<div class="card">
|
| 189 |
+
<h2>Reward</h2>
|
| 190 |
+
<div id="reward-value" class="mono" style="font-size:32px;font-weight:700;margin-bottom:6px;">0.000</div>
|
| 191 |
+
<canvas id="reward-chart"></canvas>
|
| 192 |
+
</div>
|
| 193 |
+
|
| 194 |
+
<div class="card span3"><h2>Resources</h2><div class="gauge-row" id="gauges"></div></div>
|
| 195 |
+
|
| 196 |
+
<div class="card span2" style="max-height:460px;overflow-y:auto">
|
| 197 |
+
<h2>Pipeline History <span style="color:var(--accent);font-size:10px">OBSERVABLE</span></h2>
|
| 198 |
+
<div class="timeline" id="timeline"></div>
|
| 199 |
+
</div>
|
| 200 |
+
|
| 201 |
+
<div class="card">
|
| 202 |
+
<h2>Current Action</h2>
|
| 203 |
+
<table class="mini-table" id="action-table"><tbody></tbody></table>
|
| 204 |
+
<h2 style="margin-top:14px" id="thinking-header" style="display:none">Model Reasoning</h2>
|
| 205 |
+
<div class="code-block" id="model-thinking" style="display:none;border-color:rgba(124,108,240,.2);max-height:140px;margin-bottom:10px">—</div>
|
| 206 |
+
<h2 style="margin-top:10px">Model Raw Output</h2>
|
| 207 |
+
<div class="code-block" id="model-response">—</div>
|
| 208 |
+
</div>
|
| 209 |
+
|
| 210 |
+
<div class="card">
|
| 211 |
+
<h2>Discovered Markers <span style="color:var(--accent);font-size:10px">OBSERVABLE</span></h2>
|
| 212 |
+
<div class="tag-list" id="markers-list"><span class="tag" style="color:var(--text-dim)">none yet</span></div>
|
| 213 |
+
<h2 style="margin-top:14px">Candidate Mechanisms</h2>
|
| 214 |
+
<div class="tag-list" id="mechanisms-list"><span class="tag" style="color:var(--text-dim)">none yet</span></div>
|
| 215 |
+
</div>
|
| 216 |
+
|
| 217 |
+
<div class="card">
|
| 218 |
+
<h2>Rule Violations</h2>
|
| 219 |
+
<div id="violations" style="font-size:13px;color:var(--text-dim)">None</div>
|
| 220 |
+
<h2 style="margin-top:14px">Uncertainty Summary</h2>
|
| 221 |
+
<table class="mini-table" id="uncertainty-table"><tbody></tbody></table>
|
| 222 |
+
<h2 style="margin-top:14px">Reward Breakdown</h2>
|
| 223 |
+
<table class="mini-table" id="reward-breakdown-table"><tbody></tbody></table>
|
| 224 |
+
</div>
|
| 225 |
+
|
| 226 |
+
<div class="card">
|
| 227 |
+
<h2>Latest Output</h2>
|
| 228 |
+
<table class="mini-table" id="output-table"><tbody></tbody></table>
|
| 229 |
+
<div class="code-block" id="output-data" style="margin-top:10px;max-height:140px">—</div>
|
| 230 |
+
</div>
|
| 231 |
+
|
| 232 |
+
<div class="card span3" id="card-conclusions" style="display:none;border-color:rgba(76,222,128,.25)">
|
| 233 |
+
<h2 style="color:var(--green)">Synthesized Conclusions</h2>
|
| 234 |
+
<div id="conclusions-list"></div>
|
| 235 |
+
</div>
|
| 236 |
+
|
| 237 |
+
<!-- Ground Truth Comparison (shown when episode done + has conclusions) -->
|
| 238 |
+
<div class="card span3" id="card-gt-comparison" style="display:none;border-color:rgba(251,191,36,.25)">
|
| 239 |
+
<h2 style="color:var(--amber)">Ground Truth Comparison</h2>
|
| 240 |
+
<div class="comparison-row">
|
| 241 |
+
<div class="comparison-col">
|
| 242 |
+
<h4>Agent's Markers</h4>
|
| 243 |
+
<div class="tag-list" id="gt-agent-markers"></div>
|
| 244 |
+
</div>
|
| 245 |
+
<div class="comparison-col">
|
| 246 |
+
<h4>True Markers</h4>
|
| 247 |
+
<div class="tag-list" id="gt-true-markers"></div>
|
| 248 |
+
</div>
|
| 249 |
+
</div>
|
| 250 |
+
<div class="comparison-row">
|
| 251 |
+
<div class="comparison-col">
|
| 252 |
+
<h4>Agent's Mechanisms</h4>
|
| 253 |
+
<div class="tag-list" id="gt-agent-mechs"></div>
|
| 254 |
+
</div>
|
| 255 |
+
<div class="comparison-col">
|
| 256 |
+
<h4>True Mechanisms</h4>
|
| 257 |
+
<div class="tag-list" id="gt-true-mechs"></div>
|
| 258 |
+
</div>
|
| 259 |
+
</div>
|
| 260 |
+
<div id="gt-score" style="margin-top:8px;font-size:14px;font-weight:600"></div>
|
| 261 |
+
</div>
|
| 262 |
+
|
| 263 |
+
<div class="card" style="border-color:rgba(124,108,240,.25)">
|
| 264 |
+
<h2 style="color:var(--accent2)">Cell Populations <span style="font-size:10px">HIDDEN</span></h2>
|
| 265 |
+
<div id="populations"></div>
|
| 266 |
+
</div>
|
| 267 |
+
<div class="card" style="border-color:rgba(124,108,240,.25)">
|
| 268 |
+
<h2 style="color:var(--accent2)">Ground Truth <span style="font-size:10px">HIDDEN</span></h2>
|
| 269 |
+
<div style="margin-bottom:8px"><span style="font-size:11px;color:var(--text-dim);text-transform:uppercase">True Markers</span><div class="tag-list" id="true-markers" style="margin-top:4px"></div></div>
|
| 270 |
+
<div style="margin-bottom:8px"><span style="font-size:11px;color:var(--text-dim);text-transform:uppercase">Causal Mechanisms</span><div class="tag-list" id="true-mechanisms" style="margin-top:4px"></div></div>
|
| 271 |
+
<div><span style="font-size:11px;color:var(--text-dim);text-transform:uppercase">Top Pathways</span><table class="mini-table" id="pathways-table" style="margin-top:4px"><tbody></tbody></table></div>
|
| 272 |
+
</div>
|
| 273 |
+
<div class="card" style="border-color:rgba(124,108,240,.25)">
|
| 274 |
+
<h2 style="color:var(--accent2)">Technical State <span style="font-size:10px">HIDDEN</span></h2>
|
| 275 |
+
<table class="mini-table" id="technical-table"><tbody></tbody></table>
|
| 276 |
+
<h2 style="margin-top:14px;color:var(--accent2)">Failure Conditions <span style="font-size:10px">HIDDEN</span></h2>
|
| 277 |
+
<div class="tag-list" id="failure-conditions"></div>
|
| 278 |
+
</div>
|
| 279 |
+
<div class="card span3" style="border-color:rgba(124,108,240,.25)">
|
| 280 |
+
<h2 style="color:var(--accent2)">Experiment Progress <span style="font-size:10px">HIDDEN</span></h2>
|
| 281 |
+
<div class="progress-grid" id="progress-grid"></div>
|
| 282 |
+
</div>
|
| 283 |
+
</div>
|
| 284 |
+
|
| 285 |
+
<!-- Final Report Overlay -->
|
| 286 |
+
<div class="report-overlay" id="report-overlay" onclick="if(event.target===this)hideReport()">
|
| 287 |
+
<div class="report-card" id="report-content"></div>
|
| 288 |
+
</div>
|
| 289 |
+
|
| 290 |
+
<script>
|
| 291 |
+
const POLL_MS = 1200;
|
| 292 |
+
const POP_COLORS = ['#5ce0d8','#7c6cf0','#f472b6','#60a5fa','#fbbf24','#4ade80','#f87171','#c084fc','#fb923c','#38bdf8'];
|
| 293 |
+
let rewardHistory = [];
|
| 294 |
+
let lastTimestamp = 0;
|
| 295 |
+
let latestState = null;
|
| 296 |
+
|
| 297 |
+
function $(id) { return document.getElementById(id); }
|
| 298 |
+
function setHTML(id, html) { $(id).innerHTML = html; }
|
| 299 |
+
function tagsHTML(arr, cls) {
|
| 300 |
+
if (!arr || !arr.length) return '<span class="tag" style="color:var(--text-dim)">—</span>';
|
| 301 |
+
return arr.map(t => `<span class="tag ${cls||''}">${esc(t)}</span>`).join('');
|
| 302 |
+
}
|
| 303 |
+
function esc(s) { if (s == null) return '—'; const d = document.createElement('div'); d.textContent = String(s); return d.innerHTML; }
|
| 304 |
+
function pct(used, total) { if (!total) return 0; return Math.min(100, Math.max(0, (used / total) * 100)); }
|
| 305 |
+
function gaugeColor(p) { return p < 50 ? 'var(--green)' : p < 80 ? 'var(--amber)' : 'var(--red)'; }
|
| 306 |
+
function fmt(n) { if (n == null) return '0'; return Number(n).toLocaleString('en-US', { maximumFractionDigits: 0 }); }
|
| 307 |
+
function gauge(label, value, pctVal, inv) {
|
| 308 |
+
let bar = '';
|
| 309 |
+
if (pctVal != null) { const c = inv ? gaugeColor(100-pctVal) : gaugeColor(pctVal); bar = `<div class="gauge-bar"><div class="gauge-bar-fill" style="width:${pctVal.toFixed(1)}%;background:${c}"></div></div>`; }
|
| 310 |
+
return `<div class="gauge"><div class="gauge-label">${label}</div><div class="gauge-value mono">${value}</div>${bar}</div>`;
|
| 311 |
+
}
|
| 312 |
+
function miniRows(obj) { return Object.entries(obj).map(([k,v]) => `<tr><td>${esc(k)}</td><td>${esc(v)}</td></tr>`).join(''); }
|
| 313 |
+
|
| 314 |
+
function drawRewardChart(canvas, data) {
|
| 315 |
+
const ctx = canvas.getContext('2d'); const W = canvas.width = canvas.offsetWidth * 2; const H = canvas.height = canvas.offsetHeight * 2;
|
| 316 |
+
ctx.clearRect(0, 0, W, H); if (data.length < 2) return;
|
| 317 |
+
const vals = data.map(d => d.v); const minV = Math.min(0, ...vals); const maxV = Math.max(0.1, ...vals); const range = maxV - minV || 1; const pad = 8;
|
| 318 |
+
ctx.strokeStyle = 'rgba(92,224,216,.4)'; ctx.lineWidth = 2; ctx.beginPath();
|
| 319 |
+
const yZ = H - pad - ((0 - minV) / range) * (H - 2*pad); ctx.moveTo(pad, yZ); ctx.lineTo(W-pad, yZ); ctx.stroke();
|
| 320 |
+
ctx.strokeStyle = '#5ce0d8'; ctx.lineWidth = 3; ctx.beginPath();
|
| 321 |
+
data.forEach((d,i) => { const x = pad+(i/(data.length-1))*(W-2*pad); const y = H-pad-((d.v-minV)/range)*(H-2*pad); i===0?ctx.moveTo(x,y):ctx.lineTo(x,y); }); ctx.stroke();
|
| 322 |
+
data.forEach((d,i) => { const x = pad+(i/(data.length-1))*(W-2*pad); const y = H-pad-((d.v-minV)/range)*(H-2*pad); ctx.fillStyle = d.v>=0?'#4ade80':'#f87171'; ctx.beginPath(); ctx.arc(x,y,5,0,Math.PI*2); ctx.fill(); });
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
function comparedTags(agentArr, trueArr, cls) {
|
| 326 |
+
if (!agentArr || !agentArr.length) return '<span class="tag" style="color:var(--text-dim)">—</span>';
|
| 327 |
+
const trueSet = new Set((trueArr||[]).map(t => t.toUpperCase()));
|
| 328 |
+
return agentArr.map(t => {
|
| 329 |
+
const hit = trueSet.has(t.toUpperCase());
|
| 330 |
+
return `<span class="tag ${cls} ${hit?'match':'miss'}">${esc(t)} ${hit?'✓':'✗'}</span>`;
|
| 331 |
+
}).join('');
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// ── API actions ──
|
| 335 |
+
async function doRestart() {
|
| 336 |
+
rewardHistory = []; lastTimestamp = 0;
|
| 337 |
+
await fetch('/api/restart', { method: 'POST' });
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
async function doCustomRun() {
|
| 341 |
+
const scenario = $('f-scenario').value || undefined;
|
| 342 |
+
const markers = $('f-markers').value.split(',').map(s=>s.trim()).filter(Boolean);
|
| 343 |
+
const mechs = $('f-mechanisms').value.split(',').map(s=>s.trim()).filter(Boolean);
|
| 344 |
+
const pwRaw = $('f-pathways').value.split(',').map(s=>s.trim()).filter(Boolean);
|
| 345 |
+
const pathways = {};
|
| 346 |
+
pwRaw.forEach(p => { const [k,v] = p.split(':'); if (k && v) pathways[k.trim()] = parseFloat(v); });
|
| 347 |
+
const gt = {};
|
| 348 |
+
if (markers.length) gt.true_markers = markers;
|
| 349 |
+
if (mechs.length) gt.causal_mechanisms = mechs;
|
| 350 |
+
if (Object.keys(pathways).length) gt.true_pathways = pathways;
|
| 351 |
+
rewardHistory = []; lastTimestamp = 0;
|
| 352 |
+
await fetch('/api/run', { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify({ scenario_name: scenario, ground_truth: Object.keys(gt).length ? gt : undefined }) });
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
function showReport() {
|
| 356 |
+
const s = latestState; if (!s) return;
|
| 357 |
+
const rc = $('report-content');
|
| 358 |
+
const t = s.task || {};
|
| 359 |
+
const lat = s.latent || {};
|
| 360 |
+
const conc = s.conclusions || [];
|
| 361 |
+
const trueM = lat.true_markers || [];
|
| 362 |
+
const trueMech = lat.causal_mechanisms || [];
|
| 363 |
+
const agentM = s.discovered_markers || [];
|
| 364 |
+
const markerHits = agentM.filter(m => trueM.some(t => t.toUpperCase() === m.toUpperCase()));
|
| 365 |
+
const r = s.resources || {};
|
| 366 |
+
|
| 367 |
+
let html = `<h2>Experiment Report</h2>
|
| 368 |
+
<div class="subtitle">${esc(t.problem_statement)}</div>
|
| 369 |
+
<div class="report-section"><h3>Summary</h3>
|
| 370 |
+
<table class="mini-table"><tbody>
|
| 371 |
+
<tr><td>Status</td><td>${s.episode_done ? 'Completed' : 'In Progress'}</td></tr>
|
| 372 |
+
<tr><td>Steps</td><td>${s.step}</td></tr>
|
| 373 |
+
<tr><td>Cumulative Reward</td><td style="color:${(s.cumulative_reward||0)>=0?'var(--green)':'var(--red)'}">${((s.cumulative_reward||0)>=0?'+':'')}${(s.cumulative_reward||0).toFixed(3)}</td></tr>
|
| 374 |
+
<tr><td>Budget Used</td><td>$${fmt(r.budget_used)} / $${fmt((r.budget_used||0)+(r.budget_remaining||0))}</td></tr>
|
| 375 |
+
<tr><td>Time Used</td><td>${(r.time_used_days||0).toFixed(0)}d / ${((r.time_used_days||0)+(r.time_remaining_days||0)).toFixed(0)}d</td></tr>
|
| 376 |
+
<tr><td>Markers Found</td><td>${agentM.length} (${markerHits.length} match ground truth)</td></tr>
|
| 377 |
+
</tbody></table>
|
| 378 |
+
</div>`;
|
| 379 |
+
|
| 380 |
+
if (conc.length) {
|
| 381 |
+
html += `<div class="report-section"><h3>Conclusions</h3>`;
|
| 382 |
+
conc.forEach(c => {
|
| 383 |
+
html += `<div class="conclusion-card"><div class="cc-header"><span class="cc-type ${(c.claim_type||'').toLowerCase()}">${esc(c.claim_type)}</span><span class="cc-conf" style="color:${c.confidence>=.7?'var(--green)':c.confidence>=.4?'var(--amber)':'var(--red)'}">${((c.confidence||0)*100).toFixed(0)}%</span></div>`;
|
| 384 |
+
if (c.claim) html += `<div class="cc-claim">${esc(c.claim)}</div>`;
|
| 385 |
+
if (c.top_markers?.length) html += `<div class="cc-section-label">Top Markers</div><div class="tag-list">${c.top_markers.map(m=>`<span class="tag green">${esc(m)}</span>`).join('')}</div>`;
|
| 386 |
+
if (c.causal_mechanisms?.length) html += `<div class="cc-section-label">Causal Mechanisms</div><div class="tag-list">${c.causal_mechanisms.map(m=>`<span class="tag pink">${esc(m)}</span>`).join('')}</div>`;
|
| 387 |
+
if (c.predicted_pathways && Object.keys(c.predicted_pathways).length) html += `<div class="cc-section-label">Predicted Pathways</div><table class="mini-table"><tbody>${Object.entries(c.predicted_pathways).map(([k,v])=>`<tr><td>${esc(k)}</td><td>${Number(v).toFixed(3)}</td></tr>`).join('')}</tbody></table>`;
|
| 388 |
+
html += `</div>`;
|
| 389 |
+
});
|
| 390 |
+
html += `</div>`;
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
html += `<div class="report-section"><h3>Ground Truth Comparison</h3>
|
| 394 |
+
<div class="comparison-row"><div class="comparison-col"><h4>Agent's Markers</h4><div class="tag-list">${comparedTags(agentM, trueM, 'green')}</div></div>
|
| 395 |
+
<div class="comparison-col"><h4>True Markers</h4><div class="tag-list">${tagsHTML(trueM,'green')}</div></div></div>
|
| 396 |
+
<div class="comparison-row"><div class="comparison-col"><h4>Agent's Mechanisms</h4><div class="tag-list">${comparedTags(s.candidate_mechanisms, trueMech, 'pink')}</div></div>
|
| 397 |
+
<div class="comparison-col"><h4>True Mechanisms</h4><div class="tag-list">${tagsHTML(trueMech,'pink')}</div></div></div>
|
| 398 |
+
</div>`;
|
| 399 |
+
|
| 400 |
+
const hist = s.pipeline_history || [];
|
| 401 |
+
if (hist.length) {
|
| 402 |
+
html += `<div class="report-section"><h3>Pipeline Steps</h3><table class="mini-table"><tbody>`;
|
| 403 |
+
hist.forEach(h => { html += `<tr><td>${h.success?'✓':'✗'} ${esc(h.action_type)}</td><td>${esc(h.output_summary)} · q=${h.quality_score}</td></tr>`; });
|
| 404 |
+
html += `</tbody></table></div>`;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
html += `<div style="margin-top:20px;text-align:right"><button class="btn" onclick="hideReport()">Close</button> <button class="btn primary" onclick="doRestart();hideReport()">New Run</button></div>`;
|
| 408 |
+
rc.innerHTML = html;
|
| 409 |
+
$('report-overlay').classList.add('visible');
|
| 410 |
+
}
|
| 411 |
+
function hideReport() { $('report-overlay').classList.remove('visible'); }
|
| 412 |
+
|
| 413 |
+
function renderState(s) {
|
| 414 |
+
latestState = s;
|
| 415 |
+
if (s.error) { $('status-pill').className='status-pill waiting'; $('status-pill').textContent='Waiting'; $('task-statement').textContent=s.error; return; }
|
| 416 |
+
|
| 417 |
+
const pill = $('status-pill');
|
| 418 |
+
if (s.episode_done) { pill.className='status-pill done'; pill.textContent='Done'; } else { pill.className='status-pill live'; pill.textContent='Live'; }
|
| 419 |
+
$('step-label').textContent = `Step ${s.step}`;
|
| 420 |
+
|
| 421 |
+
if (s.thinking_enabled) { $('thinking-badge').style.display = ''; } else { $('thinking-badge').style.display = 'none'; }
|
| 422 |
+
|
| 423 |
+
const t = s.task || {};
|
| 424 |
+
$('task-statement').textContent = t.problem_statement || '—';
|
| 425 |
+
$('task-meta').innerHTML = [t.organism, t.tissue, t.modality, t.conditions ? t.conditions.join(' vs ') : null].filter(Boolean).map(v => `<span class="tag">${esc(v)}</span>`).join(' ');
|
| 426 |
+
|
| 427 |
+
const cum = s.cumulative_reward || 0;
|
| 428 |
+
$('reward-value').textContent = (cum >= 0 ? '+' : '') + cum.toFixed(3);
|
| 429 |
+
$('reward-value').style.color = cum >= 0 ? 'var(--green)' : 'var(--red)';
|
| 430 |
+
if (s.timestamp !== lastTimestamp && s.step > 0) { rewardHistory.push({ step: s.step, v: cum }); lastTimestamp = s.timestamp; }
|
| 431 |
+
drawRewardChart($('reward-chart'), rewardHistory);
|
| 432 |
+
|
| 433 |
+
const r = s.resources || {};
|
| 434 |
+
const bT = (r.budget_used||0)+(r.budget_remaining||0), tT = (r.time_used_days||0)+(r.time_remaining_days||0);
|
| 435 |
+
const bP = pct(r.budget_used, bT), tP = pct(r.time_used_days, tT);
|
| 436 |
+
$('gauges').innerHTML = [gauge('Budget Used',`$${fmt(r.budget_used)}`,bP), gauge('Budget Left',`$${fmt(r.budget_remaining)}`,100-bP,true), gauge('Time Used',`${(r.time_used_days||0).toFixed(0)}d`,tP), gauge('Time Left',`${(r.time_remaining_days||0).toFixed(0)}d`,100-tP,true), gauge('Samples',String(r.samples_consumed||0),null), gauge('Compute',`${(r.compute_hours_used||0).toFixed(1)}h`,null)].join('');
|
| 437 |
+
|
| 438 |
+
const hist = s.pipeline_history || [];
|
| 439 |
+
$('timeline').innerHTML = hist.length ? hist.map(h => `<div class="timeline-item ${!h.success?'fail':''}"><div class="tl-action">${esc(h.action_type)}${h.method?` <span style="color:var(--text-dim);font-weight:400;font-size:12px">${esc(h.method)}</span>`:''}</div><div class="tl-meta">${h.success?'✓':'✗'} ${esc(h.output_summary)} · q=${h.quality_score} · $${fmt(h.resource_cost)} · ${h.time_cost_days}d</div></div>`).join('') : '<div style="color:var(--text-dim);font-size:13px">No steps yet</div>';
|
| 440 |
+
|
| 441 |
+
const a = s.current_action;
|
| 442 |
+
if (a) { $('action-table').querySelector('tbody').innerHTML = miniRows({'Type':a.action_type,'Method':a.method||'—','Confidence':a.confidence?.toFixed(2),'Justification':a.justification||'—','Fallback?':s.used_fallback?'YES':'no'}); }
|
| 443 |
+
|
| 444 |
+
if (s.model_thinking) { $('model-thinking').style.display=''; $('model-thinking').textContent = s.model_thinking; } else { $('model-thinking').style.display='none'; }
|
| 445 |
+
$('model-response').textContent = s.model_response_raw || '—';
|
| 446 |
+
|
| 447 |
+
setHTML('markers-list', tagsHTML(s.discovered_markers, 'green'));
|
| 448 |
+
setHTML('mechanisms-list', tagsHTML(s.candidate_mechanisms, 'pink'));
|
| 449 |
+
|
| 450 |
+
const v = s.rule_violations || [];
|
| 451 |
+
$('violations').innerHTML = v.length ? v.map(x=>`<div class="tag red" style="margin-bottom:4px">${esc(x)}</div>`).join('') : '<span style="color:var(--text-dim)">None</span>';
|
| 452 |
+
$('uncertainty-table').querySelector('tbody').innerHTML = miniRows(s.uncertainty_summary || {});
|
| 453 |
+
const rb = s.reward_breakdown || {};
|
| 454 |
+
$('reward-breakdown-table').querySelector('tbody').innerHTML = miniRows(Object.fromEntries(Object.entries(rb).map(([k,v])=>[k,(v>=0?'+':'')+v.toFixed(4)])));
|
| 455 |
+
|
| 456 |
+
const lo = s.latest_output;
|
| 457 |
+
if (lo) { $('output-table').querySelector('tbody').innerHTML = miniRows({'Summary':lo.summary,'Success':lo.success?'✓':'✗','Quality':lo.quality_score,'Uncertainty':lo.uncertainty,'Warnings':(lo.warnings||[]).join('; ')||'—'}); $('output-data').textContent = lo.data_preview||'—'; }
|
| 458 |
+
|
| 459 |
+
const conc = s.conclusions || [];
|
| 460 |
+
if (conc.length) {
|
| 461 |
+
$('card-conclusions').style.display = '';
|
| 462 |
+
$('conclusions-list').innerHTML = conc.map(c => {
|
| 463 |
+
const confColor = c.confidence>=.7?'var(--green)':c.confidence>=.4?'var(--amber)':'var(--red)';
|
| 464 |
+
let h = `<div class="conclusion-card"><div class="cc-header"><span class="cc-type ${(c.claim_type||'').toLowerCase()}">${esc(c.claim_type||'unknown')}</span><span class="cc-conf" style="color:${confColor}">${((c.confidence||0)*100).toFixed(0)}%</span></div>`;
|
| 465 |
+
if (c.claim) h += `<div class="cc-claim">${esc(c.claim)}</div>`;
|
| 466 |
+
if (c.top_markers?.length) h += `<div class="cc-section-label">Top Markers</div><div class="tag-list">${c.top_markers.map(m=>`<span class="tag green">${esc(m)}</span>`).join('')}</div>`;
|
| 467 |
+
if (c.causal_mechanisms?.length) h += `<div class="cc-section-label">Causal Mechanisms</div><div class="tag-list">${c.causal_mechanisms.map(m=>`<span class="tag pink">${esc(m)}</span>`).join('')}</div>`;
|
| 468 |
+
if (c.predicted_pathways && Object.keys(c.predicted_pathways).length) h += `<div class="cc-section-label">Predicted Pathways</div><table class="mini-table"><tbody>${Object.entries(c.predicted_pathways).map(([k,v])=>`<tr><td>${esc(k)}</td><td>${Number(v).toFixed(3)}</td></tr>`).join('')}</tbody></table>`;
|
| 469 |
+
return h + '</div>';
|
| 470 |
+
}).join('');
|
| 471 |
+
} else { $('card-conclusions').style.display = 'none'; }
|
| 472 |
+
|
| 473 |
+
// Ground truth comparison (visible when done or has conclusions)
|
| 474 |
+
const lat = s.latent;
|
| 475 |
+
if ((s.episode_done || conc.length) && lat) {
|
| 476 |
+
$('card-gt-comparison').style.display = '';
|
| 477 |
+
setHTML('gt-agent-markers', comparedTags(s.discovered_markers, lat.true_markers, 'green'));
|
| 478 |
+
setHTML('gt-true-markers', tagsHTML(lat.true_markers, 'green'));
|
| 479 |
+
setHTML('gt-agent-mechs', comparedTags(s.candidate_mechanisms, lat.causal_mechanisms, 'pink'));
|
| 480 |
+
setHTML('gt-true-mechs', tagsHTML(lat.causal_mechanisms, 'pink'));
|
| 481 |
+
const hits = (s.discovered_markers||[]).filter(m => (lat.true_markers||[]).some(t => t.toUpperCase()===m.toUpperCase()));
|
| 482 |
+
$('gt-score').innerHTML = `Marker accuracy: <span style="color:var(--accent)">${hits.length}</span> / ${(lat.true_markers||[]).length} true markers recovered`;
|
| 483 |
+
} else { $('card-gt-comparison').style.display = 'none'; }
|
| 484 |
+
|
| 485 |
+
if (!lat) return;
|
| 486 |
+
const pops = lat.cell_populations || [];
|
| 487 |
+
$('populations').innerHTML = pops.map((p,i) => { const c = POP_COLORS[i%POP_COLORS.length]; const w = (p.proportion*100).toFixed(1); return `<div class="pop-bar-container"><div class="pop-bar-label"><span>${esc(p.name)} <span style="color:var(--text-dim);font-size:11px">${p.state}</span></span><span class="mono" style="font-size:12px">${w}%</span></div><div class="pop-bar"><div class="pop-bar-fill" style="width:${w}%;background:${c}"></div></div><div class="tag-list" style="margin-top:3px">${p.marker_genes.map(g=>`<span class="tag" style="font-size:11px">${esc(g)}</span>`).join('')}</div></div>`; }).join('') || '<span style="color:var(--text-dim)">—</span>';
|
| 488 |
+
|
| 489 |
+
setHTML('true-markers', tagsHTML(lat.true_markers, 'green'));
|
| 490 |
+
setHTML('true-mechanisms', tagsHTML(lat.causal_mechanisms, 'pink'));
|
| 491 |
+
const pw = lat.true_pathways || {};
|
| 492 |
+
$('pathways-table').querySelector('tbody').innerHTML = miniRows(Object.fromEntries(Object.entries(pw).slice(0,10).map(([k,v])=>[k,v.toFixed(3)])));
|
| 493 |
+
$('technical-table').querySelector('tbody').innerHTML = miniRows(lat.technical || {});
|
| 494 |
+
setHTML('failure-conditions', tagsHTML(lat.hidden_failure_conditions, 'red'));
|
| 495 |
+
const prog = lat.progress || {};
|
| 496 |
+
const bK = Object.entries(prog).filter(([,v])=>typeof v==='boolean'), nK = Object.entries(prog).filter(([,v])=>typeof v!=='boolean');
|
| 497 |
+
$('progress-grid').innerHTML = bK.map(([k,v])=>`<div class="progress-item"><div class="dot ${v?'done':''}"></div>${k.replace(/_/g,' ')}</div>`).join('') + nK.map(([k,v])=>`<div class="progress-item" style="color:var(--accent)"><span class="mono" style="font-size:11px;margin-right:4px">${v??'—'}</span>${k.replace(/_/g,' ')}</div>`).join('');
|
| 498 |
+
|
| 499 |
+
if (s.episode_done && !reportShownForTimestamp && s.timestamp) { reportShownForTimestamp = s.timestamp; setTimeout(showReport, 800); }
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
let reportShownForTimestamp = null;
|
| 503 |
+
|
| 504 |
+
async function loadScenarios() {
|
| 505 |
+
try {
|
| 506 |
+
const res = await fetch('/api/scenarios');
|
| 507 |
+
const data = await res.json();
|
| 508 |
+
const sel = $('f-scenario');
|
| 509 |
+
(data.scenarios || []).forEach(n => { const o = document.createElement('option'); o.value = n; o.textContent = n; sel.appendChild(o); });
|
| 510 |
+
} catch(e) {}
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
async function poll() {
|
| 514 |
+
try { const res = await fetch('/api/state',{cache:'no-store'}); const data = await res.json(); renderState(data); } catch(e) {}
|
| 515 |
+
setTimeout(poll, POLL_MS);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
loadScenarios();
|
| 519 |
+
poll();
|
| 520 |
+
</script>
|
| 521 |
+
</body>
|
| 522 |
+
</html>
|
dashboard.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Lightweight dashboard server for the bio-experiment agent.
|
| 2 |
+
|
| 3 |
+
No external dependencies — uses only the Python standard library.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python dashboard.py # serves on http://localhost:8050
|
| 7 |
+
python dashboard.py --port 9000
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
ROOT = Path(__file__).parent
|
| 18 |
+
STATE_FILE = ROOT / "_dashboard_state.json"
|
| 19 |
+
CMD_FILE = ROOT / "_dashboard_cmd.json"
|
| 20 |
+
DASHBOARD_HTML = ROOT / "dashboard.html"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DashboardHandler(SimpleHTTPRequestHandler):
|
| 24 |
+
def do_GET(self):
|
| 25 |
+
if self.path == "/" or self.path == "/index.html":
|
| 26 |
+
self._serve_file(DASHBOARD_HTML, "text/html")
|
| 27 |
+
elif self.path == "/api/state":
|
| 28 |
+
self._serve_state()
|
| 29 |
+
elif self.path == "/api/scenarios":
|
| 30 |
+
self._serve_scenarios()
|
| 31 |
+
else:
|
| 32 |
+
self.send_error(404)
|
| 33 |
+
|
| 34 |
+
def do_POST(self):
|
| 35 |
+
if self.path == "/api/restart":
|
| 36 |
+
self._handle_command({"action": "restart"})
|
| 37 |
+
elif self.path == "/api/run":
|
| 38 |
+
body = self._read_body()
|
| 39 |
+
if body is None:
|
| 40 |
+
return
|
| 41 |
+
body["action"] = "restart"
|
| 42 |
+
self._handle_command(body)
|
| 43 |
+
else:
|
| 44 |
+
self.send_error(404)
|
| 45 |
+
|
| 46 |
+
def do_OPTIONS(self):
|
| 47 |
+
self.send_response(204)
|
| 48 |
+
self.send_header("Access-Control-Allow-Origin", "*")
|
| 49 |
+
self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
| 50 |
+
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
| 51 |
+
self.end_headers()
|
| 52 |
+
|
| 53 |
+
def _read_body(self):
|
| 54 |
+
length = int(self.headers.get("Content-Length", 0))
|
| 55 |
+
if length == 0:
|
| 56 |
+
return {}
|
| 57 |
+
raw = self.rfile.read(length)
|
| 58 |
+
try:
|
| 59 |
+
return json.loads(raw)
|
| 60 |
+
except json.JSONDecodeError:
|
| 61 |
+
self._json_response(400, {"error": "Invalid JSON"})
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
def _handle_command(self, cmd: dict):
|
| 65 |
+
CMD_FILE.write_text(json.dumps(cmd), encoding="utf-8")
|
| 66 |
+
self._json_response(200, {"ok": True, "command": cmd.get("action")})
|
| 67 |
+
|
| 68 |
+
def _serve_state(self):
|
| 69 |
+
self.send_response(200)
|
| 70 |
+
self.send_header("Content-Type", "application/json")
|
| 71 |
+
self.send_header("Access-Control-Allow-Origin", "*")
|
| 72 |
+
self.send_header("Cache-Control", "no-cache")
|
| 73 |
+
self.end_headers()
|
| 74 |
+
try:
|
| 75 |
+
data = STATE_FILE.read_bytes()
|
| 76 |
+
except FileNotFoundError:
|
| 77 |
+
data = b'{"error": "No state file yet. Run run_agent.py to start an episode."}'
|
| 78 |
+
self.wfile.write(data)
|
| 79 |
+
|
| 80 |
+
def _serve_scenarios(self):
|
| 81 |
+
try:
|
| 82 |
+
from server.tasks.scenarios import SCENARIO_LIBRARY
|
| 83 |
+
names = [s.name for s in SCENARIO_LIBRARY]
|
| 84 |
+
except Exception:
|
| 85 |
+
names = []
|
| 86 |
+
self._json_response(200, {"scenarios": names})
|
| 87 |
+
|
| 88 |
+
def _serve_file(self, path: Path, content_type: str):
|
| 89 |
+
try:
|
| 90 |
+
body = path.read_bytes()
|
| 91 |
+
except FileNotFoundError:
|
| 92 |
+
self.send_error(404, f"{path.name} not found")
|
| 93 |
+
return
|
| 94 |
+
self.send_response(200)
|
| 95 |
+
self.send_header("Content-Type", content_type)
|
| 96 |
+
self.send_header("Content-Length", str(len(body)))
|
| 97 |
+
self.end_headers()
|
| 98 |
+
self.wfile.write(body)
|
| 99 |
+
|
| 100 |
+
def _json_response(self, code: int, obj: dict):
|
| 101 |
+
body = json.dumps(obj).encode()
|
| 102 |
+
self.send_response(code)
|
| 103 |
+
self.send_header("Content-Type", "application/json")
|
| 104 |
+
self.send_header("Access-Control-Allow-Origin", "*")
|
| 105 |
+
self.send_header("Content-Length", str(len(body)))
|
| 106 |
+
self.end_headers()
|
| 107 |
+
self.wfile.write(body)
|
| 108 |
+
|
| 109 |
+
def log_message(self, format, *args):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
parser = argparse.ArgumentParser(description="Bio-experiment dashboard server")
|
| 115 |
+
parser.add_argument("--port", type=int, default=8050)
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
server = HTTPServer(("0.0.0.0", args.port), DashboardHandler)
|
| 119 |
+
print(f"Dashboard running at http://localhost:{args.port}")
|
| 120 |
+
print("Waiting for agent state from run_agent.py ...")
|
| 121 |
+
try:
|
| 122 |
+
server.serve_forever()
|
| 123 |
+
except KeyboardInterrupt:
|
| 124 |
+
print("\nShutting down.")
|
| 125 |
+
server.server_close()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
demo.html
ADDED
|
@@ -0,0 +1,1639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>BioEnv</title>
|
| 7 |
+
<style>
|
| 8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap');
|
| 9 |
+
|
| 10 |
+
:root {
|
| 11 |
+
--bg: #07090d;
|
| 12 |
+
--bg-surface: #0c0f16;
|
| 13 |
+
--bg-raised: #111827;
|
| 14 |
+
--bg-hover: #1a2235;
|
| 15 |
+
--border: #1e293b;
|
| 16 |
+
--border-active: #334155;
|
| 17 |
+
--text: #e2e8f0;
|
| 18 |
+
--text-dim: #94a3b8;
|
| 19 |
+
--text-muted: #475569;
|
| 20 |
+
--accent: #38bdf8;
|
| 21 |
+
--accent-dim: rgba(56,189,248,0.12);
|
| 22 |
+
--green: #34d399;
|
| 23 |
+
--green-dim: rgba(52,211,153,0.10);
|
| 24 |
+
--amber: #fbbf24;
|
| 25 |
+
--amber-dim: rgba(251,191,36,0.10);
|
| 26 |
+
--red: #f87171;
|
| 27 |
+
--red-dim: rgba(248,113,113,0.10);
|
| 28 |
+
--cyan: #22d3ee;
|
| 29 |
+
--cyan-dim: rgba(34,211,238,0.10);
|
| 30 |
+
--pink: #f472b6;
|
| 31 |
+
--purple: #a78bfa;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 35 |
+
html, body { height: 100%; overflow: hidden; }
|
| 36 |
+
|
| 37 |
+
body {
|
| 38 |
+
font-family: 'Inter', -apple-system, sans-serif;
|
| 39 |
+
background: var(--bg);
|
| 40 |
+
color: var(--text);
|
| 41 |
+
display: flex;
|
| 42 |
+
flex-direction: column;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/* ---- Top Bar ---- */
|
| 46 |
+
.topbar {
|
| 47 |
+
height: 48px;
|
| 48 |
+
min-height: 48px;
|
| 49 |
+
background: var(--bg-surface);
|
| 50 |
+
border-bottom: 1px solid var(--border);
|
| 51 |
+
display: flex;
|
| 52 |
+
align-items: center;
|
| 53 |
+
padding: 0 20px;
|
| 54 |
+
gap: 16px;
|
| 55 |
+
z-index: 10;
|
| 56 |
+
}
|
| 57 |
+
.topbar-logo {
|
| 58 |
+
font-size: 15px;
|
| 59 |
+
font-weight: 800;
|
| 60 |
+
letter-spacing: -0.5px;
|
| 61 |
+
background: linear-gradient(135deg, #38bdf8, #22d3ee);
|
| 62 |
+
-webkit-background-clip: text;
|
| 63 |
+
-webkit-text-fill-color: transparent;
|
| 64 |
+
}
|
| 65 |
+
.topbar-sep { width: 1px; height: 20px; background: var(--border); }
|
| 66 |
+
.topbar-env {
|
| 67 |
+
font-size: 12px;
|
| 68 |
+
color: var(--text-dim);
|
| 69 |
+
font-family: 'JetBrains Mono', monospace;
|
| 70 |
+
}
|
| 71 |
+
.topbar-status {
|
| 72 |
+
display: flex;
|
| 73 |
+
align-items: center;
|
| 74 |
+
gap: 6px;
|
| 75 |
+
margin-left: auto;
|
| 76 |
+
font-size: 12px;
|
| 77 |
+
color: var(--text-dim);
|
| 78 |
+
}
|
| 79 |
+
.status-dot {
|
| 80 |
+
width: 7px; height: 7px;
|
| 81 |
+
border-radius: 50%;
|
| 82 |
+
background: var(--text-muted);
|
| 83 |
+
}
|
| 84 |
+
.status-dot.live {
|
| 85 |
+
background: var(--green);
|
| 86 |
+
box-shadow: 0 0 8px var(--green);
|
| 87 |
+
animation: pulse 2s infinite;
|
| 88 |
+
}
|
| 89 |
+
@keyframes pulse {
|
| 90 |
+
0%, 100% { opacity: 1; }
|
| 91 |
+
50% { opacity: 0.5; }
|
| 92 |
+
}
|
| 93 |
+
.topbar-btn {
|
| 94 |
+
font-size: 12px;
|
| 95 |
+
font-weight: 600;
|
| 96 |
+
padding: 6px 14px;
|
| 97 |
+
border-radius: 6px;
|
| 98 |
+
border: none;
|
| 99 |
+
cursor: pointer;
|
| 100 |
+
transition: all 0.15s;
|
| 101 |
+
font-family: inherit;
|
| 102 |
+
}
|
| 103 |
+
.btn-primary { background: var(--accent); color: #07090d; font-weight: 700; }
|
| 104 |
+
.btn-primary:hover { background: #7dd3fc; }
|
| 105 |
+
.btn-primary:disabled { opacity: 0.4; cursor: not-allowed; }
|
| 106 |
+
.btn-ghost {
|
| 107 |
+
background: transparent;
|
| 108 |
+
color: var(--text-dim);
|
| 109 |
+
border: 1px solid var(--border);
|
| 110 |
+
}
|
| 111 |
+
.btn-ghost:hover { background: var(--bg-hover); color: var(--text); }
|
| 112 |
+
|
| 113 |
+
/* ---- Main Layout ---- */
|
| 114 |
+
.main {
|
| 115 |
+
flex: 1;
|
| 116 |
+
display: grid;
|
| 117 |
+
grid-template-columns: 260px 1fr 340px;
|
| 118 |
+
overflow: hidden;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/* ---- Left Sidebar ---- */
|
| 122 |
+
.sidebar {
|
| 123 |
+
background: var(--bg-surface);
|
| 124 |
+
border-right: 1px solid var(--border);
|
| 125 |
+
display: flex;
|
| 126 |
+
flex-direction: column;
|
| 127 |
+
overflow-y: auto;
|
| 128 |
+
}
|
| 129 |
+
.sidebar-section {
|
| 130 |
+
padding: 16px;
|
| 131 |
+
border-bottom: 1px solid var(--border);
|
| 132 |
+
}
|
| 133 |
+
.sidebar-heading {
|
| 134 |
+
font-size: 10px;
|
| 135 |
+
font-weight: 600;
|
| 136 |
+
text-transform: uppercase;
|
| 137 |
+
letter-spacing: 1.5px;
|
| 138 |
+
color: var(--text-muted);
|
| 139 |
+
margin-bottom: 10px;
|
| 140 |
+
}
|
| 141 |
+
.scenario-list { display: flex; flex-direction: column; gap: 4px; }
|
| 142 |
+
.scenario-opt {
|
| 143 |
+
display: flex;
|
| 144 |
+
align-items: center;
|
| 145 |
+
gap: 10px;
|
| 146 |
+
padding: 8px 10px;
|
| 147 |
+
border-radius: 6px;
|
| 148 |
+
cursor: pointer;
|
| 149 |
+
transition: all 0.15s;
|
| 150 |
+
border: 1px solid transparent;
|
| 151 |
+
}
|
| 152 |
+
.scenario-opt:hover { background: var(--bg-hover); }
|
| 153 |
+
.scenario-opt.active {
|
| 154 |
+
background: var(--accent-dim);
|
| 155 |
+
border-color: rgba(56,189,248,0.2);
|
| 156 |
+
}
|
| 157 |
+
.scenario-opt .sc-dot { width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0; }
|
| 158 |
+
.scenario-opt .sc-name {
|
| 159 |
+
font-size: 12px; font-weight: 500; flex: 1;
|
| 160 |
+
white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
|
| 161 |
+
}
|
| 162 |
+
.scenario-opt .sc-diff {
|
| 163 |
+
font-size: 10px; font-weight: 600;
|
| 164 |
+
text-transform: uppercase; letter-spacing: 0.5px;
|
| 165 |
+
}
|
| 166 |
+
.gauge { margin-bottom: 14px; }
|
| 167 |
+
.gauge:last-child { margin-bottom: 0; }
|
| 168 |
+
.gauge-header {
|
| 169 |
+
display: flex; justify-content: space-between;
|
| 170 |
+
align-items: baseline; margin-bottom: 6px;
|
| 171 |
+
}
|
| 172 |
+
.gauge-label { font-size: 12px; color: var(--text-dim); font-weight: 500; }
|
| 173 |
+
.gauge-value {
|
| 174 |
+
font-size: 12px; font-weight: 600;
|
| 175 |
+
font-family: 'JetBrains Mono', monospace;
|
| 176 |
+
}
|
| 177 |
+
.gauge-track {
|
| 178 |
+
height: 4px; background: var(--bg-hover);
|
| 179 |
+
border-radius: 4px; overflow: hidden;
|
| 180 |
+
}
|
| 181 |
+
.gauge-fill {
|
| 182 |
+
height: 100%; border-radius: 4px;
|
| 183 |
+
transition: width 0.8s cubic-bezier(0.4,0,0.2,1);
|
| 184 |
+
}
|
| 185 |
+
.pipeline-steps { display: flex; flex-direction: column; gap: 2px; }
|
| 186 |
+
.pipe-step {
|
| 187 |
+
display: flex; align-items: center; gap: 8px;
|
| 188 |
+
padding: 5px 8px; border-radius: 4px;
|
| 189 |
+
font-size: 11px; font-family: 'JetBrains Mono', monospace;
|
| 190 |
+
color: var(--text-muted);
|
| 191 |
+
opacity: 0; transform: translateX(-8px);
|
| 192 |
+
transition: all 0.3s ease;
|
| 193 |
+
}
|
| 194 |
+
.pipe-step.visible { opacity: 1; transform: translateX(0); }
|
| 195 |
+
.pipe-step.active { color: var(--text); background: var(--accent-dim); }
|
| 196 |
+
.pipe-step.done { color: var(--text-dim); }
|
| 197 |
+
.pipe-step .step-icon {
|
| 198 |
+
width: 16px; height: 16px; border-radius: 50%;
|
| 199 |
+
border: 1.5px solid var(--text-muted);
|
| 200 |
+
display: flex; align-items: center; justify-content: center;
|
| 201 |
+
font-size: 8px; flex-shrink: 0; transition: all 0.3s;
|
| 202 |
+
}
|
| 203 |
+
.pipe-step.done .step-icon {
|
| 204 |
+
background: var(--green-dim); border-color: var(--green); color: var(--green);
|
| 205 |
+
}
|
| 206 |
+
.pipe-step.active .step-icon {
|
| 207 |
+
border-color: var(--accent); background: var(--accent-dim);
|
| 208 |
+
color: var(--accent); animation: pulse 1.5s infinite;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/* ---- Center: Lab + Terminal ---- */
|
| 212 |
+
.center {
|
| 213 |
+
display: flex;
|
| 214 |
+
flex-direction: column;
|
| 215 |
+
overflow: hidden;
|
| 216 |
+
background: var(--bg);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/* Lab canvas */
|
| 220 |
+
.lab-panel {
|
| 221 |
+
height: 300px;
|
| 222 |
+
min-height: 300px;
|
| 223 |
+
background: var(--bg-surface);
|
| 224 |
+
border-bottom: 1px solid var(--border);
|
| 225 |
+
position: relative;
|
| 226 |
+
overflow: hidden;
|
| 227 |
+
}
|
| 228 |
+
.lab-panel canvas {
|
| 229 |
+
display: block;
|
| 230 |
+
width: 100%;
|
| 231 |
+
height: 100%;
|
| 232 |
+
}
|
| 233 |
+
.lab-label {
|
| 234 |
+
position: absolute;
|
| 235 |
+
top: 8px;
|
| 236 |
+
left: 12px;
|
| 237 |
+
font-size: 10px;
|
| 238 |
+
font-weight: 600;
|
| 239 |
+
text-transform: uppercase;
|
| 240 |
+
letter-spacing: 1.5px;
|
| 241 |
+
color: var(--text-muted);
|
| 242 |
+
z-index: 2;
|
| 243 |
+
pointer-events: none;
|
| 244 |
+
}
|
| 245 |
+
.lab-action-label {
|
| 246 |
+
position: absolute;
|
| 247 |
+
bottom: 10px;
|
| 248 |
+
left: 50%;
|
| 249 |
+
transform: translateX(-50%);
|
| 250 |
+
font-size: 11px;
|
| 251 |
+
font-family: 'JetBrains Mono', monospace;
|
| 252 |
+
color: var(--text-dim);
|
| 253 |
+
background: rgba(12,15,22,0.85);
|
| 254 |
+
padding: 4px 14px;
|
| 255 |
+
border-radius: 100px;
|
| 256 |
+
border: 1px solid var(--border);
|
| 257 |
+
z-index: 2;
|
| 258 |
+
pointer-events: none;
|
| 259 |
+
opacity: 0;
|
| 260 |
+
transition: opacity 0.3s;
|
| 261 |
+
}
|
| 262 |
+
.lab-action-label.visible { opacity: 1; }
|
| 263 |
+
|
| 264 |
+
.center-header {
|
| 265 |
+
height: 36px;
|
| 266 |
+
min-height: 36px;
|
| 267 |
+
display: flex;
|
| 268 |
+
align-items: center;
|
| 269 |
+
padding: 0 16px;
|
| 270 |
+
background: var(--bg-surface);
|
| 271 |
+
border-bottom: 1px solid var(--border);
|
| 272 |
+
gap: 8px;
|
| 273 |
+
}
|
| 274 |
+
.tab {
|
| 275 |
+
font-size: 11px; font-weight: 500;
|
| 276 |
+
padding: 4px 12px; border-radius: 4px;
|
| 277 |
+
color: var(--text-dim); cursor: pointer;
|
| 278 |
+
transition: all 0.15s;
|
| 279 |
+
}
|
| 280 |
+
.tab.active { color: var(--text); background: var(--bg-hover); }
|
| 281 |
+
.tab:hover { color: var(--text); }
|
| 282 |
+
|
| 283 |
+
.terminal {
|
| 284 |
+
flex: 1;
|
| 285 |
+
overflow-y: auto;
|
| 286 |
+
padding: 16px 20px;
|
| 287 |
+
font-family: 'JetBrains Mono', monospace;
|
| 288 |
+
font-size: 12.5px;
|
| 289 |
+
line-height: 1.9;
|
| 290 |
+
scrollbar-width: thin;
|
| 291 |
+
scrollbar-color: var(--border) transparent;
|
| 292 |
+
}
|
| 293 |
+
.terminal::-webkit-scrollbar { width: 6px; }
|
| 294 |
+
.terminal::-webkit-scrollbar-track { background: transparent; }
|
| 295 |
+
.terminal::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
|
| 296 |
+
|
| 297 |
+
.t-line {
|
| 298 |
+
white-space: pre-wrap;
|
| 299 |
+
opacity: 0;
|
| 300 |
+
animation: lineIn 0.25s ease forwards;
|
| 301 |
+
}
|
| 302 |
+
@keyframes lineIn {
|
| 303 |
+
from { opacity: 0; transform: translateY(4px); }
|
| 304 |
+
to { opacity: 1; transform: translateY(0); }
|
| 305 |
+
}
|
| 306 |
+
.t-prompt { color: var(--green); }
|
| 307 |
+
.t-cmd { color: var(--text); }
|
| 308 |
+
.t-dim { color: var(--text-muted); }
|
| 309 |
+
.t-label { color: var(--accent); }
|
| 310 |
+
.t-str { color: var(--amber); }
|
| 311 |
+
.t-kw { color: var(--pink); }
|
| 312 |
+
.t-fn { color: var(--cyan); }
|
| 313 |
+
.t-num { color: var(--purple); }
|
| 314 |
+
.t-ok { color: var(--green); }
|
| 315 |
+
.t-warn { color: var(--amber); }
|
| 316 |
+
.t-err { color: var(--red); }
|
| 317 |
+
.t-sub { color: var(--text-dim); }
|
| 318 |
+
|
| 319 |
+
/* ---- Right Panel ---- */
|
| 320 |
+
.right {
|
| 321 |
+
background: var(--bg-surface);
|
| 322 |
+
border-left: 1px solid var(--border);
|
| 323 |
+
display: flex;
|
| 324 |
+
flex-direction: column;
|
| 325 |
+
overflow-y: auto;
|
| 326 |
+
scrollbar-width: thin;
|
| 327 |
+
scrollbar-color: var(--border) transparent;
|
| 328 |
+
}
|
| 329 |
+
.panel-section {
|
| 330 |
+
padding: 16px;
|
| 331 |
+
border-bottom: 1px solid var(--border);
|
| 332 |
+
}
|
| 333 |
+
.panel-heading {
|
| 334 |
+
font-size: 10px; font-weight: 600;
|
| 335 |
+
text-transform: uppercase; letter-spacing: 1.5px;
|
| 336 |
+
color: var(--text-muted); margin-bottom: 12px;
|
| 337 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 338 |
+
}
|
| 339 |
+
.reward-row {
|
| 340 |
+
display: flex; align-items: center; gap: 10px; margin-bottom: 8px;
|
| 341 |
+
}
|
| 342 |
+
.reward-row:last-child { margin-bottom: 0; }
|
| 343 |
+
.rw-label {
|
| 344 |
+
font-size: 11px; font-weight: 500; width: 80px;
|
| 345 |
+
color: var(--text-dim); text-align: right;
|
| 346 |
+
}
|
| 347 |
+
.rw-track {
|
| 348 |
+
flex: 1; height: 18px;
|
| 349 |
+
background: rgba(255,255,255,0.03);
|
| 350 |
+
border-radius: 4px; overflow: hidden; position: relative;
|
| 351 |
+
}
|
| 352 |
+
.rw-fill {
|
| 353 |
+
height: 100%; border-radius: 4px; width: 0%;
|
| 354 |
+
transition: width 0.6s cubic-bezier(0.4,0,0.2,1);
|
| 355 |
+
display: flex; align-items: center; justify-content: flex-end;
|
| 356 |
+
padding-right: 6px; font-size: 10px; font-weight: 600;
|
| 357 |
+
font-family: 'JetBrains Mono', monospace;
|
| 358 |
+
color: rgba(255,255,255,0.85); min-width: fit-content;
|
| 359 |
+
}
|
| 360 |
+
.rw-fill.validity { background: linear-gradient(90deg, rgba(52,211,153,0.5), rgba(52,211,153,0.85)); }
|
| 361 |
+
.rw-fill.ordering { background: linear-gradient(90deg, rgba(34,211,238,0.5), rgba(34,211,238,0.85)); }
|
| 362 |
+
.rw-fill.info_gain { background: linear-gradient(90deg, rgba(56,189,248,0.5), rgba(56,189,248,0.85)); }
|
| 363 |
+
.rw-fill.efficiency { background: linear-gradient(90deg, rgba(251,191,36,0.5), rgba(251,191,36,0.85)); }
|
| 364 |
+
.rw-fill.novelty { background: linear-gradient(90deg, rgba(167,139,250,0.5), rgba(167,139,250,0.85)); }
|
| 365 |
+
.rw-fill.penalty { background: linear-gradient(90deg, rgba(248,113,113,0.5), rgba(248,113,113,0.85)); }
|
| 366 |
+
.cumulative-row {
|
| 367 |
+
display: flex; align-items: baseline; justify-content: space-between;
|
| 368 |
+
margin-top: 12px; padding-top: 12px; border-top: 1px solid var(--border);
|
| 369 |
+
}
|
| 370 |
+
.cum-label { font-size: 11px; color: var(--text-dim); }
|
| 371 |
+
.cum-value {
|
| 372 |
+
font-size: 20px; font-weight: 700;
|
| 373 |
+
font-family: 'JetBrains Mono', monospace; color: var(--green);
|
| 374 |
+
}
|
| 375 |
+
.discovery-list { display: flex; flex-direction: column; gap: 6px; }
|
| 376 |
+
.discovery {
|
| 377 |
+
display: flex; align-items: flex-start; gap: 8px;
|
| 378 |
+
padding: 8px 10px; background: var(--bg-raised);
|
| 379 |
+
border-radius: 6px; border: 1px solid var(--border);
|
| 380 |
+
opacity: 0; transform: scale(0.95); transition: all 0.3s ease;
|
| 381 |
+
}
|
| 382 |
+
.discovery.visible { opacity: 1; transform: scale(1); }
|
| 383 |
+
.disc-icon {
|
| 384 |
+
width: 20px; height: 20px; border-radius: 4px;
|
| 385 |
+
display: flex; align-items: center; justify-content: center;
|
| 386 |
+
font-size: 10px; flex-shrink: 0; margin-top: 1px;
|
| 387 |
+
}
|
| 388 |
+
.disc-body { flex: 1; }
|
| 389 |
+
.disc-title { font-size: 11px; font-weight: 600; }
|
| 390 |
+
.disc-detail {
|
| 391 |
+
font-size: 10px; color: var(--text-dim); margin-top: 2px;
|
| 392 |
+
font-family: 'JetBrains Mono', monospace;
|
| 393 |
+
}
|
| 394 |
+
.empty-state {
|
| 395 |
+
font-size: 11px; color: var(--text-muted);
|
| 396 |
+
font-style: italic; padding: 8px 0;
|
| 397 |
+
}
|
| 398 |
+
.step-reward-mini {
|
| 399 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 400 |
+
padding: 6px 10px; background: var(--bg-raised);
|
| 401 |
+
border-radius: 6px; margin-bottom: 4px;
|
| 402 |
+
font-size: 11px; font-family: 'JetBrains Mono', monospace;
|
| 403 |
+
opacity: 0; transition: all 0.3s;
|
| 404 |
+
}
|
| 405 |
+
.step-reward-mini.visible { opacity: 1; }
|
| 406 |
+
.step-reward-mini .srm-name { color: var(--text-dim); }
|
| 407 |
+
.step-reward-mini .srm-val { font-weight: 600; }
|
| 408 |
+
.step-reward-mini .srm-val.pos { color: var(--green); }
|
| 409 |
+
.step-reward-mini .srm-val.neg { color: var(--red); }
|
| 410 |
+
</style>
|
| 411 |
+
</head>
|
| 412 |
+
<body>
|
| 413 |
+
|
| 414 |
+
<!-- Top Bar -->
|
| 415 |
+
<div class="topbar">
|
| 416 |
+
<div class="topbar-logo">BioEnv</div>
|
| 417 |
+
<div class="topbar-sep"></div>
|
| 418 |
+
<div class="topbar-env">biomarker_validation_lung</div>
|
| 419 |
+
<div class="topbar-status">
|
| 420 |
+
<div class="status-dot" id="statusDot"></div>
|
| 421 |
+
<span id="statusText">Ready</span>
|
| 422 |
+
</div>
|
| 423 |
+
<button class="topbar-btn btn-ghost" id="resetBtn" onclick="resetDemo()">Reset</button>
|
| 424 |
+
<button class="topbar-btn btn-primary" id="runBtn" onclick="startDemo()">Run Episode</button>
|
| 425 |
+
</div>
|
| 426 |
+
|
| 427 |
+
<div class="main">
|
| 428 |
+
<!-- Left Sidebar -->
|
| 429 |
+
<div class="sidebar">
|
| 430 |
+
<div class="sidebar-section">
|
| 431 |
+
<div class="sidebar-heading">Scenario</div>
|
| 432 |
+
<div class="scenario-list">
|
| 433 |
+
<div class="scenario-opt" onclick="selectScenario(this)">
|
| 434 |
+
<div class="sc-dot" style="background: var(--green);"></div>
|
| 435 |
+
<span class="sc-name">Cardiac Disease DE</span>
|
| 436 |
+
<span class="sc-diff" style="color: var(--green);">Easy</span>
|
| 437 |
+
</div>
|
| 438 |
+
<div class="scenario-opt" onclick="selectScenario(this)">
|
| 439 |
+
<div class="sc-dot" style="background: var(--amber);"></div>
|
| 440 |
+
<span class="sc-name">Hematopoiesis Trajectory</span>
|
| 441 |
+
<span class="sc-diff" style="color: var(--amber);">Med</span>
|
| 442 |
+
</div>
|
| 443 |
+
<div class="scenario-opt" onclick="selectScenario(this)">
|
| 444 |
+
<div class="sc-dot" style="background: var(--amber);"></div>
|
| 445 |
+
<span class="sc-name">Perturbation Immune</span>
|
| 446 |
+
<span class="sc-diff" style="color: var(--amber);">Med</span>
|
| 447 |
+
</div>
|
| 448 |
+
<div class="scenario-opt active" onclick="selectScenario(this)">
|
| 449 |
+
<div class="sc-dot" style="background: var(--red);"></div>
|
| 450 |
+
<span class="sc-name">Biomarker Validation (Lung)</span>
|
| 451 |
+
<span class="sc-diff" style="color: var(--red);">Hard</span>
|
| 452 |
+
</div>
|
| 453 |
+
</div>
|
| 454 |
+
</div>
|
| 455 |
+
<div class="sidebar-section">
|
| 456 |
+
<div class="sidebar-heading">Environment State</div>
|
| 457 |
+
<div class="gauge">
|
| 458 |
+
<div class="gauge-header">
|
| 459 |
+
<span class="gauge-label">Budget</span>
|
| 460 |
+
<span class="gauge-value" id="budgetVal">$100,000</span>
|
| 461 |
+
</div>
|
| 462 |
+
<div class="gauge-track"><div class="gauge-fill" id="budgetFill" style="width:100%;background:var(--green);"></div></div>
|
| 463 |
+
</div>
|
| 464 |
+
<div class="gauge">
|
| 465 |
+
<div class="gauge-header">
|
| 466 |
+
<span class="gauge-label">Time</span>
|
| 467 |
+
<span class="gauge-value" id="timeVal">180 / 180 days</span>
|
| 468 |
+
</div>
|
| 469 |
+
<div class="gauge-track"><div class="gauge-fill" id="timeFill" style="width:100%;background:var(--cyan);"></div></div>
|
| 470 |
+
</div>
|
| 471 |
+
<div class="gauge">
|
| 472 |
+
<div class="gauge-header">
|
| 473 |
+
<span class="gauge-label">Steps</span>
|
| 474 |
+
<span class="gauge-value" id="stepVal">0 / 30</span>
|
| 475 |
+
</div>
|
| 476 |
+
<div class="gauge-track"><div class="gauge-fill" id="stepFill" style="width:0%;background:var(--accent);"></div></div>
|
| 477 |
+
</div>
|
| 478 |
+
</div>
|
| 479 |
+
<div class="sidebar-section" style="flex:1;overflow-y:auto;">
|
| 480 |
+
<div class="sidebar-heading">Pipeline</div>
|
| 481 |
+
<div class="pipeline-steps" id="pipelineSteps"></div>
|
| 482 |
+
</div>
|
| 483 |
+
</div>
|
| 484 |
+
|
| 485 |
+
<!-- Center: Lab + Terminal -->
|
| 486 |
+
<div class="center">
|
| 487 |
+
<div class="lab-panel">
|
| 488 |
+
<div class="lab-label">Virtual Lab</div>
|
| 489 |
+
<div class="lab-action-label" id="labActionLabel"></div>
|
| 490 |
+
<canvas id="labCanvas"></canvas>
|
| 491 |
+
</div>
|
| 492 |
+
<div class="center-header">
|
| 493 |
+
<div class="tab active">Agent Log</div>
|
| 494 |
+
<div class="tab">Raw JSON</div>
|
| 495 |
+
</div>
|
| 496 |
+
<div class="terminal" id="terminal"></div>
|
| 497 |
+
</div>
|
| 498 |
+
|
| 499 |
+
<!-- Right Panel -->
|
| 500 |
+
<div class="right">
|
| 501 |
+
<div class="panel-section">
|
| 502 |
+
<div class="panel-heading">
|
| 503 |
+
Step Reward
|
| 504 |
+
<span id="stepRewardLabel" style="font-family:'JetBrains Mono',monospace;font-size:11px;color:var(--text-dim);">--</span>
|
| 505 |
+
</div>
|
| 506 |
+
<div id="rewardBars">
|
| 507 |
+
<div class="reward-row"><span class="rw-label">Validity</span><div class="rw-track"><div class="rw-fill validity" id="rw-validity"></div></div></div>
|
| 508 |
+
<div class="reward-row"><span class="rw-label">Ordering</span><div class="rw-track"><div class="rw-fill ordering" id="rw-ordering"></div></div></div>
|
| 509 |
+
<div class="reward-row"><span class="rw-label">Info Gain</span><div class="rw-track"><div class="rw-fill info_gain" id="rw-info_gain"></div></div></div>
|
| 510 |
+
<div class="reward-row"><span class="rw-label">Efficiency</span><div class="rw-track"><div class="rw-fill efficiency" id="rw-efficiency"></div></div></div>
|
| 511 |
+
<div class="reward-row"><span class="rw-label">Novelty</span><div class="rw-track"><div class="rw-fill novelty" id="rw-novelty"></div></div></div>
|
| 512 |
+
<div class="reward-row"><span class="rw-label">Penalty</span><div class="rw-track"><div class="rw-fill penalty" id="rw-penalty"></div></div></div>
|
| 513 |
+
</div>
|
| 514 |
+
<div class="cumulative-row">
|
| 515 |
+
<span class="cum-label">Cumulative Reward</span>
|
| 516 |
+
<span class="cum-value" id="cumReward">0.00</span>
|
| 517 |
+
</div>
|
| 518 |
+
</div>
|
| 519 |
+
<div class="panel-section">
|
| 520 |
+
<div class="panel-heading">Reward History</div>
|
| 521 |
+
<div id="rewardHistory"><div class="empty-state">No steps yet</div></div>
|
| 522 |
+
</div>
|
| 523 |
+
<div class="panel-section">
|
| 524 |
+
<div class="panel-heading">Discoveries</div>
|
| 525 |
+
<div class="discovery-list" id="discoveries"><div class="empty-state">No discoveries yet</div></div>
|
| 526 |
+
</div>
|
| 527 |
+
<div class="panel-section">
|
| 528 |
+
<div class="panel-heading">Violations</div>
|
| 529 |
+
<div id="violations"><div class="empty-state">No violations</div></div>
|
| 530 |
+
</div>
|
| 531 |
+
</div>
|
| 532 |
+
</div>
|
| 533 |
+
|
| 534 |
+
<script>
|
| 535 |
+
// =====================================================
|
| 536 |
+
// VIRTUAL LAB - Canvas rendering
|
| 537 |
+
// =====================================================
|
| 538 |
+
const labCanvas = document.getElementById('labCanvas');
|
| 539 |
+
const ctx = labCanvas.getContext('2d');
|
| 540 |
+
let labW, labH, dpr;
|
| 541 |
+
|
| 542 |
+
function resizeLab() {
|
| 543 |
+
const rect = labCanvas.parentElement.getBoundingClientRect();
|
| 544 |
+
dpr = window.devicePixelRatio || 1;
|
| 545 |
+
labW = rect.width;
|
| 546 |
+
labH = rect.height;
|
| 547 |
+
labCanvas.width = labW * dpr;
|
| 548 |
+
labCanvas.height = labH * dpr;
|
| 549 |
+
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
|
| 550 |
+
}
|
| 551 |
+
resizeLab();
|
| 552 |
+
window.addEventListener('resize', () => { resizeLab(); });
|
| 553 |
+
|
| 554 |
+
// Lab stations (positions as fractions of canvas, converted in draw)
|
| 555 |
+
const STATIONS = {
|
| 556 |
+
idle: { fx: 0.06, fy: 0.55, label: 'ENTRANCE', icon: 'door', color: '#475569' },
|
| 557 |
+
sample: { fx: 0.20, fy: 0.35, label: 'SAMPLE BENCH', icon: 'bench', color: '#34d399' },
|
| 558 |
+
cohort: { fx: 0.20, fy: 0.75, label: 'COHORT SELECT', icon: 'people', color: '#34d399' },
|
| 559 |
+
prep: { fx: 0.38, fy: 0.35, label: 'LIBRARY PREP', icon: 'flask', color: '#2dd4bf' },
|
| 560 |
+
sequencer: { fx: 0.38, fy: 0.75, label: 'SEQUENCER', icon: 'machine', color: '#22d3ee' },
|
| 561 |
+
computer: { fx: 0.62, fy: 0.50, label: 'COMPUTE', icon: 'screen', color: '#38bdf8' },
|
| 562 |
+
whiteboard: { fx: 0.84, fy: 0.45, label: 'SYNTHESIS', icon: 'board', color: '#a78bfa' },
|
| 563 |
+
};
|
| 564 |
+
|
| 565 |
+
// Map actions to stations
|
| 566 |
+
const ACTION_STATION = {
|
| 567 |
+
collect_sample: 'sample',
|
| 568 |
+
select_cohort: 'cohort',
|
| 569 |
+
prepare_library: 'prep',
|
| 570 |
+
sequence_cells: 'sequencer',
|
| 571 |
+
run_qc: 'computer',
|
| 572 |
+
normalize_data: 'computer',
|
| 573 |
+
cluster_cells: 'computer',
|
| 574 |
+
differential_expression: 'computer',
|
| 575 |
+
pathway_enrichment: 'computer',
|
| 576 |
+
marker_selection: 'computer',
|
| 577 |
+
validate_marker: 'computer',
|
| 578 |
+
synthesize_conclusion: 'whiteboard',
|
| 579 |
+
};
|
| 580 |
+
|
| 581 |
+
// Agent state
|
| 582 |
+
let agent = { x: 0, y: 0, targetX: 0, targetY: 0, station: 'idle', working: false };
|
| 583 |
+
let agentTrail = [];
|
| 584 |
+
let workingTick = 0;
|
| 585 |
+
let terminalLines = []; // fake terminal on computer screen
|
| 586 |
+
let activeStationKey = null;
|
| 587 |
+
let particlesLab = [];
|
| 588 |
+
|
| 589 |
+
function stationPos(key) {
|
| 590 |
+
const s = STATIONS[key];
|
| 591 |
+
return { x: s.fx * labW, y: s.fy * labH };
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
function initAgent() {
|
| 595 |
+
const p = stationPos('idle');
|
| 596 |
+
agent.x = p.x; agent.y = p.y;
|
| 597 |
+
agent.targetX = p.x; agent.targetY = p.y;
|
| 598 |
+
agent.station = 'idle';
|
| 599 |
+
agent.working = false;
|
| 600 |
+
agent.facing = 1;
|
| 601 |
+
agentTrail = [];
|
| 602 |
+
terminalLines = [];
|
| 603 |
+
activeStationKey = null;
|
| 604 |
+
particlesLab = [];
|
| 605 |
+
}
|
| 606 |
+
initAgent();
|
| 607 |
+
|
| 608 |
+
function moveAgentTo(stationKey) {
|
| 609 |
+
const p = stationPos(stationKey);
|
| 610 |
+
agent.targetX = p.x;
|
| 611 |
+
agent.targetY = p.y;
|
| 612 |
+
agent.station = stationKey;
|
| 613 |
+
agent.working = false;
|
| 614 |
+
activeStationKey = stationKey;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
function setAgentWorking(actionName) {
|
| 618 |
+
agent.working = true;
|
| 619 |
+
workingTick = 0;
|
| 620 |
+
// If at computer, set up terminal lines
|
| 621 |
+
if (agent.station === 'computer') {
|
| 622 |
+
terminalLines = [];
|
| 623 |
+
typeComputerLines(actionName);
|
| 624 |
+
}
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
const COMP_COMMANDS = {
|
| 628 |
+
run_qc: ['$ scanpy.pp.filter_cells()', ' filtering 11847 cells...', ' 10234 passed QC', ' doublet rate: 3.2%'],
|
| 629 |
+
normalize_data: ['$ scran.normalize(adata)', ' computing size factors...', ' log1p transform', ' HVGs: 3000 selected'],
|
| 630 |
+
cluster_cells: ['$ sc.tl.leiden(adata, 0.8)', ' building kNN graph...', ' optimizing modularity', ' 14 clusters found'],
|
| 631 |
+
differential_expression: ['$ DESeq2.run(IPF, Ctrl)', ' fitting GLM...', ' 1847 DE genes', ' SPP1 log2FC=3.42 ***'],
|
| 632 |
+
pathway_enrichment: ['$ gseapy.enrich(de_genes)', ' KEGG + Reactome...', ' ECM-receptor p=4.2e-12', ' TGF-beta p=1.8e-09'],
|
| 633 |
+
marker_selection: ['$ rank_markers(candidates)', ' SPP1 AUROC: 0.94', ' MMP7 AUROC: 0.87', ' COL1A1 AUROC: 0.81'],
|
| 634 |
+
validate_marker: ['$ cross_validate("SPP1")', ' fold 1: 0.93', ' fold 2: 0.89', ' mean AUROC: 0.91 OK'],
|
| 635 |
+
};
|
| 636 |
+
|
| 637 |
+
async function typeComputerLines(actionName) {
|
| 638 |
+
const lines = COMP_COMMANDS[actionName] || ['$ processing...', ' computing...', ' done'];
|
| 639 |
+
for (let i = 0; i < lines.length; i++) {
|
| 640 |
+
await wait(250);
|
| 641 |
+
terminalLines.push(lines[i]);
|
| 642 |
+
if (terminalLines.length > 5) terminalLines.shift();
|
| 643 |
+
}
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
// Particles burst
|
| 647 |
+
function spawnParticles(x, y, color, count = 8) {
|
| 648 |
+
for (let i = 0; i < count; i++) {
|
| 649 |
+
const angle = (Math.PI * 2 / count) * i + Math.random() * 0.5;
|
| 650 |
+
particlesLab.push({
|
| 651 |
+
x, y,
|
| 652 |
+
vx: Math.cos(angle) * (1.5 + Math.random() * 2),
|
| 653 |
+
vy: Math.sin(angle) * (1.5 + Math.random() * 2),
|
| 654 |
+
life: 1,
|
| 655 |
+
color,
|
| 656 |
+
size: 2 + Math.random() * 2,
|
| 657 |
+
});
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
// ---- Draw loop ----
|
| 662 |
+
let frameCount = 0;
|
| 663 |
+
const FLOOR_COLOR = '#0f1520';
|
| 664 |
+
const WALL_COLOR = '#1a2332';
|
| 665 |
+
const FLOOR_TILE_A = '#0d1219';
|
| 666 |
+
const FLOOR_TILE_B = '#10161f';
|
| 667 |
+
|
| 668 |
+
function drawLab() {
|
| 669 |
+
frameCount++;
|
| 670 |
+
ctx.clearRect(0, 0, labW, labH);
|
| 671 |
+
|
| 672 |
+
// Floor - checkerboard tiles
|
| 673 |
+
const tileSize = 24;
|
| 674 |
+
for (let ty = 0; ty < labH; ty += tileSize) {
|
| 675 |
+
for (let tx = 0; tx < labW; tx += tileSize) {
|
| 676 |
+
const checker = ((Math.floor(tx / tileSize) + Math.floor(ty / tileSize)) % 2 === 0);
|
| 677 |
+
ctx.fillStyle = checker ? FLOOR_TILE_A : FLOOR_TILE_B;
|
| 678 |
+
ctx.fillRect(tx, ty, tileSize, tileSize);
|
| 679 |
+
}
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
// Walls - top and bottom border
|
| 683 |
+
ctx.fillStyle = WALL_COLOR;
|
| 684 |
+
ctx.fillRect(0, 0, labW, 18);
|
| 685 |
+
ctx.fillRect(0, labH - 8, labW, 8);
|
| 686 |
+
ctx.strokeStyle = '#253040';
|
| 687 |
+
ctx.lineWidth = 1;
|
| 688 |
+
ctx.beginPath(); ctx.moveTo(0, 18); ctx.lineTo(labW, 18); ctx.stroke();
|
| 689 |
+
|
| 690 |
+
// Draw equipment at each station (behind the person)
|
| 691 |
+
for (const [key, s] of Object.entries(STATIONS)) {
|
| 692 |
+
const pos = stationPos(key);
|
| 693 |
+
const isActive = key === activeStationKey;
|
| 694 |
+
drawEquipment(key, pos.x, pos.y, s.color, isActive);
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
// Draw walking path (subtle floor markings)
|
| 698 |
+
ctx.strokeStyle = 'rgba(56,189,248,0.06)';
|
| 699 |
+
ctx.lineWidth = 16;
|
| 700 |
+
ctx.lineCap = 'round';
|
| 701 |
+
ctx.lineJoin = 'round';
|
| 702 |
+
const pathOrder = ['idle','sample','prep','computer','whiteboard'];
|
| 703 |
+
ctx.beginPath();
|
| 704 |
+
const p0 = stationPos(pathOrder[0]);
|
| 705 |
+
ctx.moveTo(p0.x, p0.y + 10);
|
| 706 |
+
for (let i = 1; i < pathOrder.length; i++) {
|
| 707 |
+
const p = stationPos(pathOrder[i]);
|
| 708 |
+
ctx.lineTo(p.x, p.y + 10);
|
| 709 |
+
}
|
| 710 |
+
ctx.stroke();
|
| 711 |
+
// Lower path
|
| 712 |
+
ctx.beginPath();
|
| 713 |
+
const pl0 = stationPos('idle');
|
| 714 |
+
ctx.moveTo(pl0.x, pl0.y + 10);
|
| 715 |
+
const pl1 = stationPos('cohort');
|
| 716 |
+
ctx.lineTo(pl1.x, pl1.y + 10);
|
| 717 |
+
const pl2 = stationPos('sequencer');
|
| 718 |
+
ctx.lineTo(pl2.x, pl2.y + 10);
|
| 719 |
+
const pl3 = stationPos('computer');
|
| 720 |
+
ctx.lineTo(pl3.x, pl3.y + 10);
|
| 721 |
+
ctx.stroke();
|
| 722 |
+
ctx.lineCap = 'butt';
|
| 723 |
+
|
| 724 |
+
// Floating terminal popup at computer
|
| 725 |
+
if (agent.station === 'computer' && agent.working && terminalLines.length > 0) {
|
| 726 |
+
const cp = stationPos('computer');
|
| 727 |
+
const sx = cp.x + 55, sy = cp.y - 65;
|
| 728 |
+
const sw = 170, sh = 95;
|
| 729 |
+
|
| 730 |
+
// Shadow
|
| 731 |
+
ctx.fillStyle = 'rgba(0,0,0,0.4)';
|
| 732 |
+
roundRect(ctx, sx + 3, sy + 3, sw, sh, 6);
|
| 733 |
+
ctx.fill();
|
| 734 |
+
|
| 735 |
+
ctx.fillStyle = 'rgba(7,9,13,0.97)';
|
| 736 |
+
ctx.strokeStyle = 'rgba(56,189,248,0.3)';
|
| 737 |
+
ctx.lineWidth = 1;
|
| 738 |
+
roundRect(ctx, sx, sy, sw, sh, 6);
|
| 739 |
+
ctx.fill(); ctx.stroke();
|
| 740 |
+
|
| 741 |
+
// Title bar
|
| 742 |
+
ctx.fillStyle = 'rgba(30,41,59,0.5)';
|
| 743 |
+
ctx.fillRect(sx + 1, sy + 1, sw - 2, 14);
|
| 744 |
+
ctx.fillStyle = '#475569';
|
| 745 |
+
ctx.font = '500 7px Inter, sans-serif';
|
| 746 |
+
ctx.textAlign = 'left';
|
| 747 |
+
ctx.fillText('terminal', sx + 6, sy + 10);
|
| 748 |
+
// dots
|
| 749 |
+
ctx.fillStyle = '#f87171'; ctx.beginPath(); ctx.arc(sx + sw - 28, sy + 7, 3, 0, Math.PI*2); ctx.fill();
|
| 750 |
+
ctx.fillStyle = '#fbbf24'; ctx.beginPath(); ctx.arc(sx + sw - 18, sy + 7, 3, 0, Math.PI*2); ctx.fill();
|
| 751 |
+
ctx.fillStyle = '#34d399'; ctx.beginPath(); ctx.arc(sx + sw - 8, sy + 7, 3, 0, Math.PI*2); ctx.fill();
|
| 752 |
+
|
| 753 |
+
ctx.font = '500 9px JetBrains Mono, monospace';
|
| 754 |
+
const startY = sy + 28;
|
| 755 |
+
for (let i = 0; i < terminalLines.length; i++) {
|
| 756 |
+
const line = terminalLines[i];
|
| 757 |
+
ctx.fillStyle = line.startsWith('$') ? '#34d399' : line.includes('***') || line.includes('OK') ? '#34d399' : '#94a3b8';
|
| 758 |
+
ctx.fillText(terminalLines[i].substring(0, 24), sx + 8, startY + i * 14);
|
| 759 |
+
}
|
| 760 |
+
if (frameCount % 60 < 30) {
|
| 761 |
+
ctx.fillStyle = '#34d399';
|
| 762 |
+
ctx.fillRect(sx + 8, startY + terminalLines.length * 14 - 8, 6, 11);
|
| 763 |
+
}
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
// Whiteboard popup
|
| 767 |
+
if (agent.station === 'whiteboard' && agent.working) {
|
| 768 |
+
const wp = stationPos('whiteboard');
|
| 769 |
+
const bx = wp.x - 60, by = wp.y - 75;
|
| 770 |
+
const bw = 120, bh = 72;
|
| 771 |
+
ctx.fillStyle = 'rgba(0,0,0,0.3)';
|
| 772 |
+
roundRect(ctx, bx + 3, by + 3, bw, bh, 6);
|
| 773 |
+
ctx.fill();
|
| 774 |
+
ctx.fillStyle = 'rgba(17,24,39,0.95)';
|
| 775 |
+
ctx.strokeStyle = 'rgba(167,139,250,0.3)';
|
| 776 |
+
ctx.lineWidth = 1;
|
| 777 |
+
roundRect(ctx, bx, by, bw, bh, 6);
|
| 778 |
+
ctx.fill(); ctx.stroke();
|
| 779 |
+
ctx.font = '600 8px JetBrains Mono, monospace';
|
| 780 |
+
ctx.textAlign = 'left';
|
| 781 |
+
ctx.fillStyle = '#a78bfa';
|
| 782 |
+
ctx.fillText('CONCLUSION', bx + 8, by + 14);
|
| 783 |
+
ctx.font = '400 7.5px JetBrains Mono, monospace';
|
| 784 |
+
const synthLines = ['SPP1 validated', 'AUROC = 0.91', 'Confidence: 0.85', 'Match: 4/5'];
|
| 785 |
+
for (let i = 0; i < synthLines.length; i++) {
|
| 786 |
+
ctx.fillStyle = i === 0 ? '#34d399' : '#94a3b8';
|
| 787 |
+
ctx.fillText(synthLines[i], bx + 8, by + 28 + i * 12);
|
| 788 |
+
}
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
// Activity text above active station
|
| 792 |
+
if (agent.working && activeStationKey && activeStationKey !== 'idle') {
|
| 793 |
+
const sp = stationPos(activeStationKey);
|
| 794 |
+
const actTexts = {
|
| 795 |
+
sample: 'collecting tissue...', cohort: 'selecting cohort...',
|
| 796 |
+
prep: 'preparing library...', sequencer: 'sequencing...',
|
| 797 |
+
computer: 'computing...', whiteboard: 'synthesizing...',
|
| 798 |
+
};
|
| 799 |
+
ctx.fillStyle = STATIONS[activeStationKey].color;
|
| 800 |
+
ctx.font = '500 9px JetBrains Mono, monospace';
|
| 801 |
+
ctx.textAlign = 'center';
|
| 802 |
+
ctx.globalAlpha = 0.5 + 0.3 * Math.sin(frameCount * 0.06);
|
| 803 |
+
const yOff = ['sample','prep'].includes(activeStationKey) ? -55 : -50;
|
| 804 |
+
ctx.fillText(actTexts[activeStationKey] || 'working...', sp.x, sp.y + yOff);
|
| 805 |
+
ctx.globalAlpha = 1;
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
// Move agent smoothly
|
| 809 |
+
const dx = agent.targetX - agent.x;
|
| 810 |
+
const dy = agent.targetY - agent.y;
|
| 811 |
+
const dist = Math.sqrt(dx * dx + dy * dy);
|
| 812 |
+
const isWalking = dist > 2;
|
| 813 |
+
if (isWalking) {
|
| 814 |
+
const speed = 0.05;
|
| 815 |
+
agent.x += dx * speed;
|
| 816 |
+
agent.y += dy * speed;
|
| 817 |
+
agent.facing = dx > 0 ? 1 : dx < -0.5 ? -1 : agent.facing;
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
// Draw person
|
| 821 |
+
drawPerson(agent.x, agent.y, isWalking, agent.working, agent.facing || 1);
|
| 822 |
+
|
| 823 |
+
// Particles
|
| 824 |
+
for (let i = particlesLab.length - 1; i >= 0; i--) {
|
| 825 |
+
const p = particlesLab[i];
|
| 826 |
+
p.x += p.vx; p.y += p.vy;
|
| 827 |
+
p.vx *= 0.95; p.vy *= 0.95;
|
| 828 |
+
p.life -= 0.02;
|
| 829 |
+
if (p.life <= 0) { particlesLab.splice(i, 1); continue; }
|
| 830 |
+
ctx.globalAlpha = p.life * 0.6;
|
| 831 |
+
ctx.fillStyle = p.color;
|
| 832 |
+
ctx.beginPath();
|
| 833 |
+
ctx.arc(p.x, p.y, p.size * p.life, 0, Math.PI * 2);
|
| 834 |
+
ctx.fill();
|
| 835 |
+
}
|
| 836 |
+
ctx.globalAlpha = 1;
|
| 837 |
+
|
| 838 |
+
// Station labels
|
| 839 |
+
for (const [key, s] of Object.entries(STATIONS)) {
|
| 840 |
+
if (key === 'idle') continue;
|
| 841 |
+
const pos = stationPos(key);
|
| 842 |
+
const isActive = key === activeStationKey;
|
| 843 |
+
ctx.fillStyle = isActive ? s.color : '#334155';
|
| 844 |
+
ctx.font = `600 ${isActive ? 9 : 8}px Inter, sans-serif`;
|
| 845 |
+
ctx.textAlign = 'center';
|
| 846 |
+
const ly = key === 'cohort' || key === 'sequencer' ? pos.y + 45 : pos.y + 42;
|
| 847 |
+
ctx.fillText(s.label, pos.x, ly);
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
requestAnimationFrame(drawLab);
|
| 851 |
+
}
|
| 852 |
+
|
| 853 |
+
// ---- Draw person (lab coat researcher) ----
|
| 854 |
+
function drawPerson(x, y, walking, working, facing) {
|
| 855 |
+
const f = facing;
|
| 856 |
+
const t = frameCount;
|
| 857 |
+
// Walking cycle
|
| 858 |
+
const walkCycle = walking ? Math.sin(t * 0.15) : 0;
|
| 859 |
+
const bobY = walking ? Math.abs(Math.sin(t * 0.15)) * 2 : 0;
|
| 860 |
+
// Working arm animation
|
| 861 |
+
const workArm = working ? Math.sin(t * 0.08) * 0.3 : 0;
|
| 862 |
+
|
| 863 |
+
const py = y - bobY; // feet position base
|
| 864 |
+
|
| 865 |
+
ctx.save();
|
| 866 |
+
ctx.translate(x, py);
|
| 867 |
+
|
| 868 |
+
// Shadow
|
| 869 |
+
ctx.fillStyle = 'rgba(0,0,0,0.25)';
|
| 870 |
+
ctx.beginPath();
|
| 871 |
+
ctx.ellipse(0, 12, 10, 4, 0, 0, Math.PI * 2);
|
| 872 |
+
ctx.fill();
|
| 873 |
+
|
| 874 |
+
// Legs
|
| 875 |
+
const legSpread = walking ? walkCycle * 5 : 0;
|
| 876 |
+
ctx.strokeStyle = '#1e3a5f';
|
| 877 |
+
ctx.lineWidth = 3;
|
| 878 |
+
ctx.lineCap = 'round';
|
| 879 |
+
// Left leg
|
| 880 |
+
ctx.beginPath();
|
| 881 |
+
ctx.moveTo(-3, 4);
|
| 882 |
+
ctx.lineTo(-3 + legSpread, 12);
|
| 883 |
+
ctx.stroke();
|
| 884 |
+
// Right leg
|
| 885 |
+
ctx.beginPath();
|
| 886 |
+
ctx.moveTo(3, 4);
|
| 887 |
+
ctx.lineTo(3 - legSpread, 12);
|
| 888 |
+
ctx.stroke();
|
| 889 |
+
// Shoes
|
| 890 |
+
ctx.fillStyle = '#1e293b';
|
| 891 |
+
ctx.beginPath(); ctx.arc(-3 + legSpread, 12, 2.5, 0, Math.PI * 2); ctx.fill();
|
| 892 |
+
ctx.beginPath(); ctx.arc(3 - legSpread, 12, 2.5, 0, Math.PI * 2); ctx.fill();
|
| 893 |
+
|
| 894 |
+
// Body / lab coat
|
| 895 |
+
ctx.fillStyle = '#e2e8f0'; // white lab coat
|
| 896 |
+
ctx.beginPath();
|
| 897 |
+
ctx.moveTo(-7, -4);
|
| 898 |
+
ctx.lineTo(-6, 6);
|
| 899 |
+
ctx.lineTo(6, 6);
|
| 900 |
+
ctx.lineTo(7, -4);
|
| 901 |
+
ctx.quadraticCurveTo(7, -10, 0, -10);
|
| 902 |
+
ctx.quadraticCurveTo(-7, -10, -7, -4);
|
| 903 |
+
ctx.fill();
|
| 904 |
+
// Coat outline
|
| 905 |
+
ctx.strokeStyle = '#94a3b8';
|
| 906 |
+
ctx.lineWidth = 0.5;
|
| 907 |
+
ctx.stroke();
|
| 908 |
+
// Coat split at bottom
|
| 909 |
+
ctx.beginPath();
|
| 910 |
+
ctx.moveTo(0, 1);
|
| 911 |
+
ctx.lineTo(0, 6);
|
| 912 |
+
ctx.strokeStyle = '#cbd5e1';
|
| 913 |
+
ctx.lineWidth = 0.5;
|
| 914 |
+
ctx.stroke();
|
| 915 |
+
// Pocket
|
| 916 |
+
ctx.strokeStyle = '#94a3b8';
|
| 917 |
+
ctx.lineWidth = 0.5;
|
| 918 |
+
ctx.strokeRect(f > 0 ? 1 : -5, -1, 4, 3);
|
| 919 |
+
|
| 920 |
+
// Arms
|
| 921 |
+
ctx.strokeStyle = '#e2e8f0';
|
| 922 |
+
ctx.lineWidth = 3.5;
|
| 923 |
+
ctx.lineCap = 'round';
|
| 924 |
+
// Back arm
|
| 925 |
+
const backArmSwing = walking ? -walkCycle * 4 : 0;
|
| 926 |
+
ctx.beginPath();
|
| 927 |
+
ctx.moveTo(-f * 6, -6);
|
| 928 |
+
ctx.lineTo(-f * 6 + backArmSwing, 2);
|
| 929 |
+
ctx.stroke();
|
| 930 |
+
// Front arm (active arm)
|
| 931 |
+
if (working) {
|
| 932 |
+
// Arm reaching forward/up for work
|
| 933 |
+
ctx.beginPath();
|
| 934 |
+
ctx.moveTo(f * 6, -6);
|
| 935 |
+
ctx.lineTo(f * 10 + workArm * 5, -8 + workArm * 3);
|
| 936 |
+
ctx.stroke();
|
| 937 |
+
// Hand/tool
|
| 938 |
+
ctx.fillStyle = '#fde68a';
|
| 939 |
+
ctx.beginPath();
|
| 940 |
+
ctx.arc(f * 10 + workArm * 5, -8 + workArm * 3, 2, 0, Math.PI * 2);
|
| 941 |
+
ctx.fill();
|
| 942 |
+
} else {
|
| 943 |
+
const frontArmSwing = walking ? walkCycle * 4 : 0;
|
| 944 |
+
ctx.beginPath();
|
| 945 |
+
ctx.moveTo(f * 6, -6);
|
| 946 |
+
ctx.lineTo(f * 6 + frontArmSwing, 2);
|
| 947 |
+
ctx.stroke();
|
| 948 |
+
}
|
| 949 |
+
// Skin for hands
|
| 950 |
+
ctx.fillStyle = '#fde68a';
|
| 951 |
+
ctx.beginPath(); ctx.arc(-f * 6 + backArmSwing, 2, 1.8, 0, Math.PI * 2); ctx.fill();
|
| 952 |
+
if (!working) {
|
| 953 |
+
const fs = walking ? walkCycle * 4 : 0;
|
| 954 |
+
ctx.beginPath(); ctx.arc(f * 6 + fs, 2, 1.8, 0, Math.PI * 2); ctx.fill();
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
// Head
|
| 958 |
+
ctx.fillStyle = '#fde68a'; // skin
|
| 959 |
+
ctx.beginPath();
|
| 960 |
+
ctx.arc(0, -15, 7, 0, Math.PI * 2);
|
| 961 |
+
ctx.fill();
|
| 962 |
+
// Hair
|
| 963 |
+
ctx.fillStyle = '#1e293b';
|
| 964 |
+
ctx.beginPath();
|
| 965 |
+
ctx.arc(0, -17, 7, Math.PI, 0);
|
| 966 |
+
ctx.fill();
|
| 967 |
+
// Face details
|
| 968 |
+
ctx.fillStyle = '#1e293b';
|
| 969 |
+
// Eyes
|
| 970 |
+
ctx.beginPath();
|
| 971 |
+
ctx.arc(f * 2.5, -15.5, 1, 0, Math.PI * 2);
|
| 972 |
+
ctx.fill();
|
| 973 |
+
ctx.beginPath();
|
| 974 |
+
ctx.arc(f * -1.5, -15.5, 1, 0, Math.PI * 2);
|
| 975 |
+
ctx.fill();
|
| 976 |
+
// Glasses
|
| 977 |
+
ctx.strokeStyle = '#475569';
|
| 978 |
+
ctx.lineWidth = 0.7;
|
| 979 |
+
ctx.beginPath();
|
| 980 |
+
ctx.arc(f * 2.5, -15.5, 2.5, 0, Math.PI * 2);
|
| 981 |
+
ctx.stroke();
|
| 982 |
+
ctx.beginPath();
|
| 983 |
+
ctx.arc(f * -1.5, -15.5, 2.5, 0, Math.PI * 2);
|
| 984 |
+
ctx.stroke();
|
| 985 |
+
ctx.beginPath();
|
| 986 |
+
ctx.moveTo(f * 0.5, -15.5);
|
| 987 |
+
ctx.lineTo(f * -0.5, -15.5);
|
| 988 |
+
ctx.stroke();
|
| 989 |
+
// Mouth
|
| 990 |
+
if (working) {
|
| 991 |
+
ctx.fillStyle = '#1e293b';
|
| 992 |
+
ctx.beginPath();
|
| 993 |
+
ctx.arc(f * 0.5, -12.5, 1, 0, Math.PI);
|
| 994 |
+
ctx.fill();
|
| 995 |
+
}
|
| 996 |
+
|
| 997 |
+
// ID Badge
|
| 998 |
+
ctx.fillStyle = '#38bdf8';
|
| 999 |
+
ctx.fillRect(f > 0 ? -6 : 2, -3, 4, 5);
|
| 1000 |
+
ctx.fillStyle = '#fff';
|
| 1001 |
+
ctx.font = 'bold 3px Inter, sans-serif';
|
| 1002 |
+
ctx.textAlign = 'center';
|
| 1003 |
+
ctx.fillText('AI', f > 0 ? -4 : 4, 0.5);
|
| 1004 |
+
|
| 1005 |
+
ctx.restore();
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
// ---- Draw lab equipment ----
|
| 1009 |
+
function drawEquipment(stationKey, cx, cy, color, active) {
|
| 1010 |
+
ctx.save();
|
| 1011 |
+
|
| 1012 |
+
switch (stationKey) {
|
| 1013 |
+
case 'idle':
|
| 1014 |
+
// Door frame
|
| 1015 |
+
ctx.strokeStyle = '#334155';
|
| 1016 |
+
ctx.lineWidth = 2;
|
| 1017 |
+
ctx.strokeRect(cx - 12, cy - 30, 24, 40);
|
| 1018 |
+
ctx.fillStyle = '#1a2332';
|
| 1019 |
+
ctx.fillRect(cx - 10, cy - 28, 20, 36);
|
| 1020 |
+
ctx.fillStyle = '#475569';
|
| 1021 |
+
ctx.beginPath(); ctx.arc(cx + 6, cy - 10, 2, 0, Math.PI * 2); ctx.fill();
|
| 1022 |
+
break;
|
| 1023 |
+
|
| 1024 |
+
case 'sample':
|
| 1025 |
+
// Lab bench with sample tubes
|
| 1026 |
+
// Bench surface
|
| 1027 |
+
ctx.fillStyle = '#1a2332';
|
| 1028 |
+
ctx.fillRect(cx - 30, cy - 8, 60, 6);
|
| 1029 |
+
// Bench legs
|
| 1030 |
+
ctx.fillStyle = '#253040';
|
| 1031 |
+
ctx.fillRect(cx - 28, cy - 2, 4, 20);
|
| 1032 |
+
ctx.fillRect(cx + 24, cy - 2, 4, 20);
|
| 1033 |
+
// Tube rack
|
| 1034 |
+
ctx.fillStyle = '#253040';
|
| 1035 |
+
ctx.fillRect(cx - 18, cy - 18, 36, 10);
|
| 1036 |
+
// Test tubes
|
| 1037 |
+
const tubeColors = ['#34d399', '#22d3ee', '#fbbf24', '#f472b6', '#34d399', '#22d3ee'];
|
| 1038 |
+
for (let i = 0; i < 6; i++) {
|
| 1039 |
+
const tx = cx - 14 + i * 6;
|
| 1040 |
+
ctx.fillStyle = active ? tubeColors[i] : '#334155';
|
| 1041 |
+
ctx.globalAlpha = active ? 0.7 : 0.4;
|
| 1042 |
+
ctx.fillRect(tx, cy - 28, 4, 12);
|
| 1043 |
+
// Tube caps
|
| 1044 |
+
ctx.globalAlpha = 1;
|
| 1045 |
+
ctx.fillStyle = active ? tubeColors[i] : '#475569';
|
| 1046 |
+
ctx.fillRect(tx - 0.5, cy - 29, 5, 2);
|
| 1047 |
+
}
|
| 1048 |
+
ctx.globalAlpha = 1;
|
| 1049 |
+
// Pipette if active
|
| 1050 |
+
if (active) {
|
| 1051 |
+
const pipY = cy - 32 + Math.sin(frameCount * 0.08) * 4;
|
| 1052 |
+
ctx.strokeStyle = '#94a3b8';
|
| 1053 |
+
ctx.lineWidth = 2;
|
| 1054 |
+
ctx.beginPath();
|
| 1055 |
+
ctx.moveTo(cx + 5, pipY);
|
| 1056 |
+
ctx.lineTo(cx + 5, pipY - 14);
|
| 1057 |
+
ctx.stroke();
|
| 1058 |
+
ctx.fillStyle = '#64748b';
|
| 1059 |
+
ctx.fillRect(cx + 3, pipY - 18, 5, 6);
|
| 1060 |
+
// Droplet
|
| 1061 |
+
if (frameCount % 60 < 20) {
|
| 1062 |
+
ctx.fillStyle = '#34d399';
|
| 1063 |
+
ctx.globalAlpha = 0.6;
|
| 1064 |
+
ctx.beginPath();
|
| 1065 |
+
ctx.arc(cx + 5, pipY + 3, 1.5, 0, Math.PI * 2);
|
| 1066 |
+
ctx.fill();
|
| 1067 |
+
ctx.globalAlpha = 1;
|
| 1068 |
+
}
|
| 1069 |
+
}
|
| 1070 |
+
break;
|
| 1071 |
+
|
| 1072 |
+
case 'cohort':
|
| 1073 |
+
// Filing cabinet / patient records
|
| 1074 |
+
ctx.fillStyle = '#1a2332';
|
| 1075 |
+
ctx.fillRect(cx - 20, cy - 22, 40, 40);
|
| 1076 |
+
ctx.strokeStyle = '#253040';
|
| 1077 |
+
ctx.lineWidth = 1;
|
| 1078 |
+
for (let i = 0; i < 3; i++) {
|
| 1079 |
+
const dy = cy - 18 + i * 13;
|
| 1080 |
+
ctx.strokeRect(cx - 18, dy, 36, 11);
|
| 1081 |
+
ctx.fillStyle = active ? '#475569' : '#253040';
|
| 1082 |
+
ctx.fillRect(cx - 4, dy + 4, 8, 3);
|
| 1083 |
+
}
|
| 1084 |
+
// Clipboard
|
| 1085 |
+
ctx.fillStyle = '#253040';
|
| 1086 |
+
ctx.fillRect(cx + 24, cy - 16, 14, 20);
|
| 1087 |
+
ctx.strokeStyle = '#475569';
|
| 1088 |
+
ctx.lineWidth = 0.5;
|
| 1089 |
+
for (let i = 0; i < 4; i++) {
|
| 1090 |
+
ctx.beginPath();
|
| 1091 |
+
ctx.moveTo(cx + 27, cy - 12 + i * 4);
|
| 1092 |
+
ctx.lineTo(cx + 35, cy - 12 + i * 4);
|
| 1093 |
+
ctx.stroke();
|
| 1094 |
+
}
|
| 1095 |
+
if (active) {
|
| 1096 |
+
ctx.fillStyle = color;
|
| 1097 |
+
ctx.globalAlpha = 0.5;
|
| 1098 |
+
ctx.beginPath(); ctx.arc(cx + 31, cy - 14, 2, 0, Math.PI * 2); ctx.fill();
|
| 1099 |
+
ctx.globalAlpha = 1;
|
| 1100 |
+
}
|
| 1101 |
+
break;
|
| 1102 |
+
|
| 1103 |
+
case 'prep':
|
| 1104 |
+
// Library prep station - PCR machine + bench
|
| 1105 |
+
// Bench
|
| 1106 |
+
ctx.fillStyle = '#1a2332';
|
| 1107 |
+
ctx.fillRect(cx - 28, cy - 6, 56, 6);
|
| 1108 |
+
ctx.fillStyle = '#253040';
|
| 1109 |
+
ctx.fillRect(cx - 26, cy, 4, 18);
|
| 1110 |
+
ctx.fillRect(cx + 22, cy, 4, 18);
|
| 1111 |
+
// PCR/thermocycler machine
|
| 1112 |
+
ctx.fillStyle = active ? '#192535' : '#172030';
|
| 1113 |
+
ctx.strokeStyle = active ? color : '#253040';
|
| 1114 |
+
ctx.lineWidth = 1;
|
| 1115 |
+
roundRect(ctx, cx - 18, cy - 26, 36, 20, 3);
|
| 1116 |
+
ctx.fill(); ctx.stroke();
|
| 1117 |
+
// Display on machine
|
| 1118 |
+
ctx.fillStyle = active ? 'rgba(45,212,191,0.15)' : 'rgba(30,41,59,0.3)';
|
| 1119 |
+
ctx.fillRect(cx - 14, cy - 22, 16, 8);
|
| 1120 |
+
if (active) {
|
| 1121 |
+
ctx.fillStyle = color;
|
| 1122 |
+
ctx.font = '500 6px JetBrains Mono, monospace';
|
| 1123 |
+
ctx.textAlign = 'left';
|
| 1124 |
+
ctx.fillText('72.0°C', cx - 12, cy - 16);
|
| 1125 |
+
// LED
|
| 1126 |
+
ctx.fillStyle = color;
|
| 1127 |
+
ctx.beginPath(); ctx.arc(cx + 12, cy - 18, 2, 0, Math.PI * 2); ctx.fill();
|
| 1128 |
+
}
|
| 1129 |
+
// Microplate
|
| 1130 |
+
ctx.fillStyle = '#1e293b';
|
| 1131 |
+
ctx.fillRect(cx - 20, cy - 3, 18, 12);
|
| 1132 |
+
ctx.strokeStyle = '#334155';
|
| 1133 |
+
ctx.lineWidth = 0.3;
|
| 1134 |
+
for (let r = 0; r < 3; r++) {
|
| 1135 |
+
for (let c = 0; c < 4; c++) {
|
| 1136 |
+
ctx.beginPath();
|
| 1137 |
+
ctx.arc(cx - 17 + c * 4.5, cy + 1 + r * 3.5, 1.2, 0, Math.PI * 2);
|
| 1138 |
+
ctx.stroke();
|
| 1139 |
+
}
|
| 1140 |
+
}
|
| 1141 |
+
break;
|
| 1142 |
+
|
| 1143 |
+
case 'sequencer':
|
| 1144 |
+
// Big sequencing machine (NovaSeq-like)
|
| 1145 |
+
// Machine body
|
| 1146 |
+
ctx.fillStyle = '#172030';
|
| 1147 |
+
ctx.strokeStyle = active ? color : '#253040';
|
| 1148 |
+
ctx.lineWidth = active ? 1.5 : 1;
|
| 1149 |
+
roundRect(ctx, cx - 24, cy - 28, 48, 44, 4);
|
| 1150 |
+
ctx.fill(); ctx.stroke();
|
| 1151 |
+
// Front panel / screen
|
| 1152 |
+
ctx.fillStyle = active ? 'rgba(34,211,238,0.1)' : 'rgba(30,41,59,0.3)';
|
| 1153 |
+
roundRect(ctx, cx - 18, cy - 22, 36, 18, 2);
|
| 1154 |
+
ctx.fill();
|
| 1155 |
+
if (active) {
|
| 1156 |
+
// Progress bar on screen
|
| 1157 |
+
ctx.fillStyle = 'rgba(34,211,238,0.2)';
|
| 1158 |
+
ctx.fillRect(cx - 14, cy - 12, 28, 4);
|
| 1159 |
+
const progress = (frameCount % 120) / 120;
|
| 1160 |
+
ctx.fillStyle = color;
|
| 1161 |
+
ctx.fillRect(cx - 14, cy - 12, 28 * progress, 4);
|
| 1162 |
+
ctx.fillStyle = color;
|
| 1163 |
+
ctx.font = '500 6px JetBrains Mono, monospace';
|
| 1164 |
+
ctx.textAlign = 'center';
|
| 1165 |
+
ctx.fillText('SEQUENCING', cx, cy - 16);
|
| 1166 |
+
}
|
| 1167 |
+
// Slot
|
| 1168 |
+
ctx.fillStyle = '#0f1520';
|
| 1169 |
+
ctx.fillRect(cx - 10, cy, 20, 4);
|
| 1170 |
+
// Status LEDs
|
| 1171 |
+
ctx.fillStyle = active ? '#34d399' : '#334155';
|
| 1172 |
+
ctx.beginPath(); ctx.arc(cx - 14, cy + 10, 2, 0, Math.PI * 2); ctx.fill();
|
| 1173 |
+
if (active && frameCount % 30 < 15) {
|
| 1174 |
+
ctx.fillStyle = '#fbbf24';
|
| 1175 |
+
} else {
|
| 1176 |
+
ctx.fillStyle = '#334155';
|
| 1177 |
+
}
|
| 1178 |
+
ctx.beginPath(); ctx.arc(cx - 8, cy + 10, 2, 0, Math.PI * 2); ctx.fill();
|
| 1179 |
+
break;
|
| 1180 |
+
|
| 1181 |
+
case 'computer':
|
| 1182 |
+
// Computer desk with dual monitors
|
| 1183 |
+
// Desk
|
| 1184 |
+
ctx.fillStyle = '#1a2332';
|
| 1185 |
+
ctx.fillRect(cx - 36, cy + 2, 72, 5);
|
| 1186 |
+
ctx.fillStyle = '#253040';
|
| 1187 |
+
ctx.fillRect(cx - 32, cy + 7, 4, 16);
|
| 1188 |
+
ctx.fillRect(cx + 28, cy + 7, 4, 16);
|
| 1189 |
+
// Chair
|
| 1190 |
+
ctx.fillStyle = '#1e293b';
|
| 1191 |
+
ctx.beginPath();
|
| 1192 |
+
ctx.arc(cx, cy + 28, 8, 0, Math.PI * 2);
|
| 1193 |
+
ctx.fill();
|
| 1194 |
+
ctx.fillStyle = '#253040';
|
| 1195 |
+
ctx.fillRect(cx - 1, cy + 20, 2, 8);
|
| 1196 |
+
// Monitor 1 (main)
|
| 1197 |
+
ctx.fillStyle = active ? '#0c1219' : '#131c28';
|
| 1198 |
+
ctx.strokeStyle = active ? 'rgba(56,189,248,0.4)' : '#253040';
|
| 1199 |
+
ctx.lineWidth = 1;
|
| 1200 |
+
roundRect(ctx, cx - 30, cy - 28, 32, 24, 2);
|
| 1201 |
+
ctx.fill(); ctx.stroke();
|
| 1202 |
+
// Monitor stand
|
| 1203 |
+
ctx.fillStyle = '#334155';
|
| 1204 |
+
ctx.fillRect(cx - 16, cy - 4, 4, 6);
|
| 1205 |
+
ctx.fillRect(cx - 20, cy + 1, 12, 2);
|
| 1206 |
+
// Monitor 2
|
| 1207 |
+
ctx.fillStyle = active ? '#0c1219' : '#131c28';
|
| 1208 |
+
ctx.strokeStyle = active ? 'rgba(56,189,248,0.3)' : '#253040';
|
| 1209 |
+
roundRect(ctx, cx + 2, cy - 24, 26, 20, 2);
|
| 1210 |
+
ctx.fill(); ctx.stroke();
|
| 1211 |
+
ctx.fillStyle = '#334155';
|
| 1212 |
+
ctx.fillRect(cx + 13, cy - 4, 4, 6);
|
| 1213 |
+
ctx.fillRect(cx + 9, cy + 1, 12, 2);
|
| 1214 |
+
// Screen content
|
| 1215 |
+
if (active) {
|
| 1216 |
+
ctx.fillStyle = 'rgba(56,189,248,0.08)';
|
| 1217 |
+
ctx.fillRect(cx - 28, cy - 26, 28, 20);
|
| 1218 |
+
// Code lines
|
| 1219 |
+
for (let i = 0; i < 5; i++) {
|
| 1220 |
+
ctx.fillStyle = `rgba(56,189,248,${0.15 + i * 0.06})`;
|
| 1221 |
+
const w = 8 + Math.sin(i * 2.3 + frameCount * 0.02) * 6;
|
| 1222 |
+
ctx.fillRect(cx - 26, cy - 24 + i * 4, w, 2);
|
| 1223 |
+
}
|
| 1224 |
+
// Second screen - graph
|
| 1225 |
+
ctx.fillStyle = 'rgba(56,189,248,0.06)';
|
| 1226 |
+
ctx.fillRect(cx + 4, cy - 22, 22, 16);
|
| 1227 |
+
ctx.strokeStyle = 'rgba(34,211,238,0.3)';
|
| 1228 |
+
ctx.lineWidth = 1;
|
| 1229 |
+
ctx.beginPath();
|
| 1230 |
+
ctx.moveTo(cx + 6, cy - 8);
|
| 1231 |
+
for (let i = 0; i < 8; i++) {
|
| 1232 |
+
ctx.lineTo(cx + 6 + i * 2.5, cy - 10 - Math.sin(i * 0.8 + frameCount * 0.03) * 5);
|
| 1233 |
+
}
|
| 1234 |
+
ctx.stroke();
|
| 1235 |
+
}
|
| 1236 |
+
// Keyboard
|
| 1237 |
+
ctx.fillStyle = '#1e293b';
|
| 1238 |
+
ctx.fillRect(cx - 14, cy + 4, 28, 6);
|
| 1239 |
+
// Typing effect
|
| 1240 |
+
if (active && agent.working) {
|
| 1241 |
+
const keyX = cx - 12 + (frameCount % 20) * 1.2;
|
| 1242 |
+
ctx.fillStyle = 'rgba(56,189,248,0.4)';
|
| 1243 |
+
ctx.fillRect(keyX, cy + 5, 3, 4);
|
| 1244 |
+
}
|
| 1245 |
+
break;
|
| 1246 |
+
|
| 1247 |
+
case 'whiteboard':
|
| 1248 |
+
// Whiteboard on wall + standing desk
|
| 1249 |
+
// Board on wall
|
| 1250 |
+
ctx.fillStyle = '#1e293b';
|
| 1251 |
+
ctx.strokeStyle = '#334155';
|
| 1252 |
+
ctx.lineWidth = 1;
|
| 1253 |
+
ctx.fillRect(cx - 28, cy - 34, 56, 32);
|
| 1254 |
+
ctx.strokeRect(cx - 28, cy - 34, 56, 32);
|
| 1255 |
+
// Board content
|
| 1256 |
+
if (active) {
|
| 1257 |
+
ctx.fillStyle = 'rgba(167,139,250,0.1)';
|
| 1258 |
+
ctx.fillRect(cx - 26, cy - 32, 52, 28);
|
| 1259 |
+
// Diagram elements
|
| 1260 |
+
ctx.strokeStyle = 'rgba(167,139,250,0.4)';
|
| 1261 |
+
ctx.lineWidth = 0.8;
|
| 1262 |
+
// Boxes
|
| 1263 |
+
ctx.strokeRect(cx - 20, cy - 28, 14, 8);
|
| 1264 |
+
ctx.strokeRect(cx + 6, cy - 28, 14, 8);
|
| 1265 |
+
ctx.strokeRect(cx - 8, cy - 16, 16, 8);
|
| 1266 |
+
// Arrows
|
| 1267 |
+
ctx.beginPath();
|
| 1268 |
+
ctx.moveTo(cx - 6, cy - 24); ctx.lineTo(cx + 6, cy - 24); ctx.stroke();
|
| 1269 |
+
ctx.beginPath();
|
| 1270 |
+
ctx.moveTo(cx, cy - 20); ctx.lineTo(cx, cy - 16); ctx.stroke();
|
| 1271 |
+
// Checkmark
|
| 1272 |
+
ctx.strokeStyle = '#34d399';
|
| 1273 |
+
ctx.lineWidth = 1.5;
|
| 1274 |
+
ctx.beginPath();
|
| 1275 |
+
ctx.moveTo(cx - 4, cy - 12);
|
| 1276 |
+
ctx.lineTo(cx - 1, cy - 9);
|
| 1277 |
+
ctx.lineTo(cx + 5, cy - 15);
|
| 1278 |
+
ctx.stroke();
|
| 1279 |
+
} else {
|
| 1280 |
+
// Faint lines
|
| 1281 |
+
ctx.strokeStyle = '#253040';
|
| 1282 |
+
ctx.lineWidth = 0.5;
|
| 1283 |
+
for (let i = 0; i < 4; i++) {
|
| 1284 |
+
ctx.beginPath();
|
| 1285 |
+
ctx.moveTo(cx - 22, cy - 28 + i * 7);
|
| 1286 |
+
ctx.lineTo(cx + 22, cy - 28 + i * 7);
|
| 1287 |
+
ctx.stroke();
|
| 1288 |
+
}
|
| 1289 |
+
}
|
| 1290 |
+
// Standing desk
|
| 1291 |
+
ctx.fillStyle = '#1a2332';
|
| 1292 |
+
ctx.fillRect(cx - 16, cy + 2, 32, 4);
|
| 1293 |
+
ctx.fillStyle = '#253040';
|
| 1294 |
+
ctx.fillRect(cx - 2, cy + 6, 4, 14);
|
| 1295 |
+
break;
|
| 1296 |
+
}
|
| 1297 |
+
|
| 1298 |
+
ctx.restore();
|
| 1299 |
+
}
|
| 1300 |
+
|
| 1301 |
+
function roundRect(ctx, x, y, w, h, r) {
|
| 1302 |
+
ctx.beginPath();
|
| 1303 |
+
ctx.moveTo(x + r, y);
|
| 1304 |
+
ctx.lineTo(x + w - r, y);
|
| 1305 |
+
ctx.quadraticCurveTo(x + w, y, x + w, y + r);
|
| 1306 |
+
ctx.lineTo(x + w, y + h - r);
|
| 1307 |
+
ctx.quadraticCurveTo(x + w, y + h, x + w - r, y + h);
|
| 1308 |
+
ctx.lineTo(x + r, y + h);
|
| 1309 |
+
ctx.quadraticCurveTo(x, y + h, x, y + h - r);
|
| 1310 |
+
ctx.lineTo(x, y + r);
|
| 1311 |
+
ctx.quadraticCurveTo(x, y, x + r, y);
|
| 1312 |
+
ctx.closePath();
|
| 1313 |
+
}
|
| 1314 |
+
|
| 1315 |
+
drawLab();
|
| 1316 |
+
|
| 1317 |
+
// =====================================================
|
| 1318 |
+
// EPISODE DATA + APP LOGIC
|
| 1319 |
+
// =====================================================
|
| 1320 |
+
const EPISODE = [
|
| 1321 |
+
{
|
| 1322 |
+
action: 'collect_sample', params: 'n_samples=8, tissue="lung"', category: 'wet',
|
| 1323 |
+
budget: 92400, budgetPct: 92.4, time: 165, timePct: 91.7,
|
| 1324 |
+
output: ['Collected 8 lung tissue samples (4 IPF, 4 control)','Tissue quality: excellent | Storage: -80C'],
|
| 1325 |
+
reward: { validity: 0.90, ordering: 1.00, info_gain: 0.10, efficiency: 0.72, novelty: 1.00, penalty: 0.0 },
|
| 1326 |
+
total: 0.45,
|
| 1327 |
+
},
|
| 1328 |
+
{
|
| 1329 |
+
action: 'select_cohort', params: 'criteria="age_matched, sex_balanced"', category: 'wet',
|
| 1330 |
+
budget: 91800, budgetPct: 91.8, time: 162, timePct: 90.0,
|
| 1331 |
+
output: ['Cohort selected: 4 IPF patients (2M/2F, age 58-67)','Controls matched: 4 healthy donors (2M/2F, age 55-65)'],
|
| 1332 |
+
reward: { validity: 0.85, ordering: 0.90, info_gain: 0.15, efficiency: 0.80, novelty: 0.90, penalty: 0.0 },
|
| 1333 |
+
total: 0.38,
|
| 1334 |
+
},
|
| 1335 |
+
{
|
| 1336 |
+
action: 'prepare_library', params: 'protocol="10x_chromium_v3"', category: 'wet',
|
| 1337 |
+
budget: 84200, budgetPct: 84.2, time: 155, timePct: 86.1,
|
| 1338 |
+
output: ['Library prep complete using 10x Chromium v3','Estimated cell capture: ~12,000 cells','cDNA yield: 42ng (good)'],
|
| 1339 |
+
reward: { validity: 0.95, ordering: 1.00, info_gain: 0.20, efficiency: 0.70, novelty: 0.95, penalty: 0.0 },
|
| 1340 |
+
total: 0.52,
|
| 1341 |
+
},
|
| 1342 |
+
{
|
| 1343 |
+
action: 'sequence_cells', params: 'depth="standard", platform="NovaSeq"', category: 'wet',
|
| 1344 |
+
budget: 68500, budgetPct: 68.5, time: 142, timePct: 78.9,
|
| 1345 |
+
output: ['11,847 cells sequenced | 22,438 genes detected','Median reads/cell: 45,200 | Median genes/cell: 3,842','Sequencing saturation: 78.3%'],
|
| 1346 |
+
reward: { validity: 0.95, ordering: 1.00, info_gain: 0.55, efficiency: 0.60, novelty: 0.90, penalty: 0.0 },
|
| 1347 |
+
total: 0.68,
|
| 1348 |
+
},
|
| 1349 |
+
{
|
| 1350 |
+
action: 'run_qc', params: 'tool="scanpy", min_genes=200', category: 'comp',
|
| 1351 |
+
budget: 68100, budgetPct: 68.1, time: 141, timePct: 78.3,
|
| 1352 |
+
output: ['QC complete: 10,234 / 11,847 cells passed (86.4%)','Removed: 382 doublets (3.2%), 1,231 low-quality cells','Mitochondrial threshold: 20% (flagged 847 cells)'],
|
| 1353 |
+
reward: { validity: 0.95, ordering: 1.00, info_gain: 0.35, efficiency: 0.85, novelty: 0.80, penalty: 0.0 },
|
| 1354 |
+
total: 0.55,
|
| 1355 |
+
},
|
| 1356 |
+
{
|
| 1357 |
+
action: 'normalize_data', params: 'method="scran", log_transform=true', category: 'comp',
|
| 1358 |
+
budget: 67900, budgetPct: 67.9, time: 140, timePct: 77.8,
|
| 1359 |
+
output: ['Size-factor normalization (scran) applied','Log1p transform complete | HVG selection: 3,000 genes'],
|
| 1360 |
+
reward: { validity: 0.90, ordering: 1.00, info_gain: 0.25, efficiency: 0.90, novelty: 0.70, penalty: 0.0 },
|
| 1361 |
+
total: 0.42,
|
| 1362 |
+
},
|
| 1363 |
+
{
|
| 1364 |
+
action: 'cluster_cells', params: 'algorithm="leiden", resolution=0.8', category: 'comp',
|
| 1365 |
+
budget: 67500, budgetPct: 67.5, time: 139, timePct: 77.2,
|
| 1366 |
+
output: ['Leiden clustering: 14 clusters identified','AT1 (8.2%), AT2 (12.1%), Fibroblast (15.7%), Macrophage (18.3%)','Endothelial (9.4%), Basal (6.1%), Ciliated (5.8%), NK/T (7.2%)','Smooth Muscle (4.1%), Mast (2.9%), B cell (3.4%), pDC (2.0%)','Mesothelial (2.6%), Aberrant Basaloid (2.2%)'],
|
| 1367 |
+
reward: { validity: 0.95, ordering: 1.00, info_gain: 0.65, efficiency: 0.85, novelty: 0.85, penalty: 0.0 },
|
| 1368 |
+
total: 0.72,
|
| 1369 |
+
discovery: { title: '14 cell populations identified', detail: 'Including Aberrant Basaloid cells (IPF-associated)', color: 'var(--cyan)', bg: 'var(--cyan-dim)' },
|
| 1370 |
+
},
|
| 1371 |
+
{
|
| 1372 |
+
action: 'differential_expression', params: 'method="DESeq2", contrast="IPF_vs_Ctrl"', category: 'comp',
|
| 1373 |
+
budget: 67000, budgetPct: 67.0, time: 137, timePct: 76.1,
|
| 1374 |
+
output: ['1,847 DE genes (|log2FC| > 1, padj < 0.05)','Top upregulated in IPF:',' SPP1 log2FC=3.42 padj=1.2e-18',' MMP7 log2FC=2.89 padj=3.4e-15',' COL1A1 log2FC=2.67 padj=8.7e-14',' TGFB1 log2FC=1.95 padj=2.1e-09','Top downregulated: AGER (-3.1), SFTPC (-2.8), HOPX (-2.3)'],
|
| 1375 |
+
reward: { validity: 0.95, ordering: 1.00, info_gain: 0.78, efficiency: 0.80, novelty: 0.88, penalty: 0.0 },
|
| 1376 |
+
total: 0.82,
|
| 1377 |
+
discovery: { title: 'SPP1 strongly upregulated in IPF', detail: 'log2FC=3.42, padj=1.2e-18', color: 'var(--pink)', bg: 'rgba(244,114,182,0.10)' },
|
| 1378 |
+
},
|
| 1379 |
+
{
|
| 1380 |
+
action: 'pathway_enrichment', params: 'tool="gseapy", gene_sets="KEGG,Reactome"', category: 'comp',
|
| 1381 |
+
budget: 66600, budgetPct: 66.6, time: 136, timePct: 75.6,
|
| 1382 |
+
output: ['Top enriched pathways (IPF vs Control):',' ECM-receptor interaction padj=4.2e-12',' TGF-beta signaling padj=1.8e-09',' PI3K-Akt signaling padj=3.1e-07',' Focal adhesion padj=8.9e-07','SPP1 participates in 3/4 top pathways'],
|
| 1383 |
+
reward: { validity: 0.90, ordering: 1.00, info_gain: 0.60, efficiency: 0.85, novelty: 0.75, penalty: 0.0 },
|
| 1384 |
+
total: 0.58,
|
| 1385 |
+
discovery: { title: 'SPP1 in ECM/TGF-beta/PI3K pathways', detail: 'Core fibrosis signaling axis confirmed', color: 'var(--purple)', bg: 'rgba(167,139,250,0.10)' },
|
| 1386 |
+
},
|
| 1387 |
+
{
|
| 1388 |
+
action: 'marker_selection', params: 'candidates=["SPP1","MMP7","COL1A1"]', category: 'comp',
|
| 1389 |
+
budget: 66200, budgetPct: 66.2, time: 135, timePct: 75.0,
|
| 1390 |
+
output: ['Marker ranking by discriminative power:',' 1. SPP1 - AUROC: 0.94, specificity: 0.89',' 2. MMP7 - AUROC: 0.87, specificity: 0.82',' 3. COL1A1 - AUROC: 0.81, specificity: 0.76','SPP1 selected as primary biomarker candidate'],
|
| 1391 |
+
reward: { validity: 0.90, ordering: 1.00, info_gain: 0.50, efficiency: 0.88, novelty: 0.70, penalty: 0.0 },
|
| 1392 |
+
total: 0.55,
|
| 1393 |
+
},
|
| 1394 |
+
{
|
| 1395 |
+
action: 'validate_marker', params: 'gene="SPP1", method="cross_validation"', category: 'comp',
|
| 1396 |
+
budget: 65200, budgetPct: 65.2, time: 130, timePct: 72.2,
|
| 1397 |
+
output: ['SPP1 Biomarker Validation Report:',' 5-fold CV AUROC: 0.91 (+/- 0.03)',' Sensitivity: 0.88',' Specificity: 0.87',' Positive LR: 6.77',' Expression in Aberrant Basaloid: 94.2% of cells',' Status: VALIDATED as IPF biomarker'],
|
| 1398 |
+
reward: { validity: 0.95, ordering: 1.00, info_gain: 0.72, efficiency: 0.82, novelty: 0.85, penalty: 0.0 },
|
| 1399 |
+
total: 0.76,
|
| 1400 |
+
discovery: { title: 'SPP1 validated as IPF biomarker', detail: 'AUROC=0.91, specificity=0.87', color: 'var(--green)', bg: 'var(--green-dim)' },
|
| 1401 |
+
},
|
| 1402 |
+
{
|
| 1403 |
+
action: 'synthesize_conclusion', params: 'confidence=0.85', category: 'meta',
|
| 1404 |
+
budget: 65000, budgetPct: 65.0, time: 129, timePct: 71.7,
|
| 1405 |
+
output: ['CONCLUSION (confidence: 0.85):','','SPP1 is a validated biomarker for IPF with strong','discriminative power (AUROC=0.91). It is upregulated','3.42-fold in IPF lungs, concentrated in Aberrant Basaloid','cells (94.2%), and participates in ECM-receptor, TGF-beta,','and PI3K-Akt signaling pathways.','','Literature match: 4/5 expected findings confirmed','Calibration: Well-calibrated (no overconfidence penalty)'],
|
| 1406 |
+
reward: { validity: 1.00, ordering: 1.00, info_gain: 0.40, efficiency: 0.90, novelty: 0.50, penalty: 0.0 },
|
| 1407 |
+
total: 0.91, terminal: true,
|
| 1408 |
+
},
|
| 1409 |
+
];
|
| 1410 |
+
|
| 1411 |
+
// State
|
| 1412 |
+
let running = false;
|
| 1413 |
+
let cumReward = 0;
|
| 1414 |
+
|
| 1415 |
+
// DOM refs
|
| 1416 |
+
const terminalEl = document.getElementById('terminal');
|
| 1417 |
+
const statusDot = document.getElementById('statusDot');
|
| 1418 |
+
const statusText = document.getElementById('statusText');
|
| 1419 |
+
const runBtn = document.getElementById('runBtn');
|
| 1420 |
+
const labActionLabel = document.getElementById('labActionLabel');
|
| 1421 |
+
|
| 1422 |
+
// Helpers
|
| 1423 |
+
function addLine(html) {
|
| 1424 |
+
const div = document.createElement('div');
|
| 1425 |
+
div.className = 't-line';
|
| 1426 |
+
div.innerHTML = html || ' ';
|
| 1427 |
+
terminalEl.appendChild(div);
|
| 1428 |
+
terminalEl.scrollTop = terminalEl.scrollHeight;
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
function setGauge(id, value, pct, color) {
|
| 1432 |
+
document.getElementById(id + 'Val').textContent = value;
|
| 1433 |
+
const fill = document.getElementById(id + 'Fill');
|
| 1434 |
+
fill.style.width = pct + '%';
|
| 1435 |
+
if (color) fill.style.background = color;
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
function setRewardBars(r) {
|
| 1439 |
+
for (const key of ['validity','ordering','info_gain','efficiency','novelty','penalty']) {
|
| 1440 |
+
const el = document.getElementById('rw-' + key);
|
| 1441 |
+
el.style.width = (r[key] * 100) + '%';
|
| 1442 |
+
el.textContent = r[key] > 0.01 ? r[key].toFixed(2) : '';
|
| 1443 |
+
}
|
| 1444 |
+
}
|
| 1445 |
+
|
| 1446 |
+
function clearRewardBars() {
|
| 1447 |
+
for (const key of ['validity','ordering','info_gain','efficiency','novelty','penalty']) {
|
| 1448 |
+
const el = document.getElementById('rw-' + key);
|
| 1449 |
+
el.style.width = '0%';
|
| 1450 |
+
el.textContent = '';
|
| 1451 |
+
}
|
| 1452 |
+
}
|
| 1453 |
+
|
| 1454 |
+
function addPipeStep(step, index) {
|
| 1455 |
+
const el = document.createElement('div');
|
| 1456 |
+
el.className = 'pipe-step';
|
| 1457 |
+
el.id = 'pipe-' + index;
|
| 1458 |
+
const catColor = step.category === 'wet' ? 'var(--green)' : step.category === 'comp' ? 'var(--accent)' : 'var(--pink)';
|
| 1459 |
+
el.innerHTML = `<div class="step-icon" style="color:${catColor};border-color:${catColor};">${index + 1}</div><span>${step.action}</span>`;
|
| 1460 |
+
document.getElementById('pipelineSteps').appendChild(el);
|
| 1461 |
+
requestAnimationFrame(() => el.classList.add('visible'));
|
| 1462 |
+
return el;
|
| 1463 |
+
}
|
| 1464 |
+
|
| 1465 |
+
function addDiscovery(d) {
|
| 1466 |
+
const c = document.getElementById('discoveries');
|
| 1467 |
+
if (c.querySelector('.empty-state')) c.innerHTML = '';
|
| 1468 |
+
const el = document.createElement('div');
|
| 1469 |
+
el.className = 'discovery';
|
| 1470 |
+
el.innerHTML = `<div class="disc-icon" style="background:${d.bg};color:${d.color};">◆</div><div class="disc-body"><div class="disc-title">${d.title}</div><div class="disc-detail">${d.detail}</div></div>`;
|
| 1471 |
+
c.appendChild(el);
|
| 1472 |
+
requestAnimationFrame(() => el.classList.add('visible'));
|
| 1473 |
+
}
|
| 1474 |
+
|
| 1475 |
+
function addRewardHistory(step, index) {
|
| 1476 |
+
const c = document.getElementById('rewardHistory');
|
| 1477 |
+
if (c.querySelector('.empty-state')) c.innerHTML = '';
|
| 1478 |
+
const el = document.createElement('div');
|
| 1479 |
+
el.className = 'step-reward-mini';
|
| 1480 |
+
el.innerHTML = `<span class="srm-name">${index + 1}. ${step.action}</span><span class="srm-val ${step.total >= 0 ? 'pos' : 'neg'}">${step.total >= 0 ? '+' : ''}${step.total.toFixed(2)}</span>`;
|
| 1481 |
+
c.appendChild(el);
|
| 1482 |
+
requestAnimationFrame(() => el.classList.add('visible'));
|
| 1483 |
+
}
|
| 1484 |
+
|
| 1485 |
+
function selectScenario(el) {
|
| 1486 |
+
if (running) return;
|
| 1487 |
+
document.querySelectorAll('.scenario-opt').forEach(e => e.classList.remove('active'));
|
| 1488 |
+
el.classList.add('active');
|
| 1489 |
+
}
|
| 1490 |
+
|
| 1491 |
+
function wait(ms) { return new Promise(r => setTimeout(r, ms)); }
|
| 1492 |
+
|
| 1493 |
+
// ---- Run ----
|
| 1494 |
+
async function startDemo() {
|
| 1495 |
+
if (running) return;
|
| 1496 |
+
running = true;
|
| 1497 |
+
runBtn.disabled = true;
|
| 1498 |
+
runBtn.textContent = 'Running...';
|
| 1499 |
+
statusDot.classList.add('live');
|
| 1500 |
+
statusText.textContent = 'Running';
|
| 1501 |
+
terminalEl.innerHTML = '';
|
| 1502 |
+
cumReward = 0;
|
| 1503 |
+
document.getElementById('pipelineSteps').innerHTML = '';
|
| 1504 |
+
document.getElementById('discoveries').innerHTML = '<div class="empty-state">No discoveries yet</div>';
|
| 1505 |
+
document.getElementById('rewardHistory').innerHTML = '<div class="empty-state">No steps yet</div>';
|
| 1506 |
+
document.getElementById('violations').innerHTML = '<div class="empty-state">No violations</div>';
|
| 1507 |
+
clearRewardBars();
|
| 1508 |
+
document.getElementById('cumReward').textContent = '0.00';
|
| 1509 |
+
document.getElementById('stepRewardLabel').textContent = '--';
|
| 1510 |
+
initAgent();
|
| 1511 |
+
|
| 1512 |
+
addLine('<span class="t-label">[BioEnv]</span> <span class="t-dim">Initializing environment...</span>');
|
| 1513 |
+
await wait(500);
|
| 1514 |
+
addLine('<span class="t-label">[BioEnv]</span> Scenario: <span class="t-str">biomarker_validation_lung</span> (Hard)');
|
| 1515 |
+
await wait(200);
|
| 1516 |
+
addLine('<span class="t-label">[BioEnv]</span> Organism: <span class="t-str">Homo sapiens</span> | Tissue: <span class="t-str">Lung</span>');
|
| 1517 |
+
await wait(200);
|
| 1518 |
+
addLine('<span class="t-label">[BioEnv]</span> Budget: <span class="t-num">$100,000</span> | Time: <span class="t-num">180 days</span> | Max steps: <span class="t-num">30</span>');
|
| 1519 |
+
await wait(200);
|
| 1520 |
+
addLine('<span class="t-label">[BioEnv]</span> Task: Validate <span class="t-kw">SPP1</span> as biomarker for idiopathic pulmonary fibrosis');
|
| 1521 |
+
await wait(400);
|
| 1522 |
+
addLine('');
|
| 1523 |
+
|
| 1524 |
+
for (let i = 0; i < EPISODE.length; i++) {
|
| 1525 |
+
await runStep(i);
|
| 1526 |
+
await wait(500);
|
| 1527 |
+
}
|
| 1528 |
+
|
| 1529 |
+
// Done
|
| 1530 |
+
moveAgentTo('idle');
|
| 1531 |
+
labActionLabel.classList.remove('visible');
|
| 1532 |
+
addLine('');
|
| 1533 |
+
addLine('<span class="t-label">[BioEnv]</span> <span class="t-ok">Episode complete!</span>');
|
| 1534 |
+
addLine('<span class="t-label">[BioEnv]</span> Total reward: <span class="t-ok">+' + cumReward.toFixed(2) + '</span> | Steps: <span class="t-num">' + EPISODE.length + '</span> | Budget remaining: <span class="t-num">$65,000</span>');
|
| 1535 |
+
addLine('<span class="t-label">[BioEnv]</span> Literature match: <span class="t-ok">4/5 expected findings confirmed</span>');
|
| 1536 |
+
addLine('<span class="t-label">[BioEnv]</span> Calibration: <span class="t-ok">Well-calibrated</span> (no overconfidence penalty)');
|
| 1537 |
+
|
| 1538 |
+
statusDot.classList.remove('live');
|
| 1539 |
+
statusText.textContent = 'Complete';
|
| 1540 |
+
runBtn.textContent = 'Run Episode';
|
| 1541 |
+
runBtn.disabled = false;
|
| 1542 |
+
running = false;
|
| 1543 |
+
}
|
| 1544 |
+
|
| 1545 |
+
async function runStep(i) {
|
| 1546 |
+
const step = EPISODE[i];
|
| 1547 |
+
const station = ACTION_STATION[step.action] || 'computer';
|
| 1548 |
+
|
| 1549 |
+
// Move agent in lab
|
| 1550 |
+
moveAgentTo(station);
|
| 1551 |
+
labActionLabel.textContent = step.action + '()';
|
| 1552 |
+
labActionLabel.classList.add('visible');
|
| 1553 |
+
await wait(800); // wait for agent to travel
|
| 1554 |
+
|
| 1555 |
+
// Start working animation
|
| 1556 |
+
setAgentWorking(step.action);
|
| 1557 |
+
spawnParticles(agent.targetX, agent.targetY, STATIONS[station].color);
|
| 1558 |
+
|
| 1559 |
+
// Pipeline sidebar
|
| 1560 |
+
const pipeEl = addPipeStep(step, i);
|
| 1561 |
+
if (i > 0) {
|
| 1562 |
+
const prev = document.getElementById('pipe-' + (i - 1));
|
| 1563 |
+
prev.classList.remove('active');
|
| 1564 |
+
prev.classList.add('done');
|
| 1565 |
+
prev.querySelector('.step-icon').innerHTML = '✓';
|
| 1566 |
+
}
|
| 1567 |
+
pipeEl.classList.add('active');
|
| 1568 |
+
|
| 1569 |
+
// Gauges
|
| 1570 |
+
setGauge('budget', '$' + step.budget.toLocaleString(), step.budgetPct,
|
| 1571 |
+
step.budgetPct > 50 ? 'var(--green)' : step.budgetPct > 25 ? 'var(--amber)' : 'var(--red)');
|
| 1572 |
+
setGauge('time', step.time + ' / 180 days', step.timePct, 'var(--cyan)');
|
| 1573 |
+
setGauge('step', (i + 1) + ' / 30', ((i + 1) / 30 * 100), 'var(--accent)');
|
| 1574 |
+
|
| 1575 |
+
// Terminal output
|
| 1576 |
+
const catTag = step.category === 'wet' ? '<span class="t-ok">WET</span>'
|
| 1577 |
+
: step.category === 'comp' ? '<span class="t-label">CMP</span>'
|
| 1578 |
+
: '<span class="t-kw">META</span>';
|
| 1579 |
+
addLine(`<span class="t-dim">Step ${i + 1}</span> ${catTag} <span class="t-fn">${step.action}</span>(<span class="t-str">${step.params}</span>)`);
|
| 1580 |
+
await wait(300);
|
| 1581 |
+
|
| 1582 |
+
for (const line of step.output) {
|
| 1583 |
+
addLine(' <span class="t-sub">' + line + '</span>');
|
| 1584 |
+
await wait(80);
|
| 1585 |
+
}
|
| 1586 |
+
|
| 1587 |
+
// Reward
|
| 1588 |
+
cumReward += step.total;
|
| 1589 |
+
document.getElementById('stepRewardLabel').textContent = 'Step ' + (i + 1) + ': ' + step.action;
|
| 1590 |
+
setRewardBars(step.reward);
|
| 1591 |
+
document.getElementById('cumReward').textContent = cumReward.toFixed(2);
|
| 1592 |
+
addRewardHistory(step, i);
|
| 1593 |
+
|
| 1594 |
+
const rewardStr = step.total >= 0
|
| 1595 |
+
? '<span class="t-ok">+' + step.total.toFixed(2) + '</span>'
|
| 1596 |
+
: '<span class="t-err">' + step.total.toFixed(2) + '</span>';
|
| 1597 |
+
addLine(` <span class="t-dim">reward: ${rewardStr} <span class="t-dim">(cumulative: ${cumReward.toFixed(2)})</span></span>`);
|
| 1598 |
+
addLine('');
|
| 1599 |
+
|
| 1600 |
+
if (step.discovery) addDiscovery(step.discovery);
|
| 1601 |
+
|
| 1602 |
+
// Done working
|
| 1603 |
+
agent.working = false;
|
| 1604 |
+
spawnParticles(agent.targetX, agent.targetY, '#34d399', 6);
|
| 1605 |
+
|
| 1606 |
+
if (step.terminal) {
|
| 1607 |
+
pipeEl.classList.remove('active');
|
| 1608 |
+
pipeEl.classList.add('done');
|
| 1609 |
+
pipeEl.querySelector('.step-icon').innerHTML = '✓';
|
| 1610 |
+
}
|
| 1611 |
+
}
|
| 1612 |
+
|
| 1613 |
+
function resetDemo() {
|
| 1614 |
+
if (running) return;
|
| 1615 |
+
terminalEl.innerHTML = '';
|
| 1616 |
+
cumReward = 0;
|
| 1617 |
+
document.getElementById('pipelineSteps').innerHTML = '';
|
| 1618 |
+
document.getElementById('discoveries').innerHTML = '<div class="empty-state">No discoveries yet</div>';
|
| 1619 |
+
document.getElementById('rewardHistory').innerHTML = '<div class="empty-state">No steps yet</div>';
|
| 1620 |
+
document.getElementById('violations').innerHTML = '<div class="empty-state">No violations</div>';
|
| 1621 |
+
clearRewardBars();
|
| 1622 |
+
document.getElementById('cumReward').textContent = '0.00';
|
| 1623 |
+
document.getElementById('stepRewardLabel').textContent = '--';
|
| 1624 |
+
setGauge('budget', '$100,000', 100, 'var(--green)');
|
| 1625 |
+
setGauge('time', '180 / 180 days', 100, 'var(--cyan)');
|
| 1626 |
+
setGauge('step', '0 / 30', 0, 'var(--accent)');
|
| 1627 |
+
statusDot.classList.remove('live');
|
| 1628 |
+
statusText.textContent = 'Ready';
|
| 1629 |
+
labActionLabel.classList.remove('visible');
|
| 1630 |
+
initAgent();
|
| 1631 |
+
addLine('<span class="t-dim">Environment reset. Click "Run Episode" to start.</span>');
|
| 1632 |
+
}
|
| 1633 |
+
|
| 1634 |
+
// Init
|
| 1635 |
+
addLine('<span class="t-dim">BioEnv v1.0 | biomarker_validation_lung</span>');
|
| 1636 |
+
addLine('<span class="t-dim">Click "Run Episode" to start the demo.</span>');
|
| 1637 |
+
</script>
|
| 1638 |
+
</body>
|
| 1639 |
+
</html>
|
eval_compare.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compare base vs trained model on the same prompts."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
from typing import Dict, List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
+
|
| 13 |
+
from training_script import (
|
| 14 |
+
SYSTEM_PROMPT,
|
| 15 |
+
OpenEnvReward,
|
| 16 |
+
build_prompt_examples,
|
| 17 |
+
completion_to_text,
|
| 18 |
+
parse_action_completion,
|
| 19 |
+
selected_scenarios,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def generate_completions(
|
| 24 |
+
model,
|
| 25 |
+
tokenizer,
|
| 26 |
+
prompts: List[str],
|
| 27 |
+
max_new_tokens: int = 220,
|
| 28 |
+
) -> List[str]:
|
| 29 |
+
completions = []
|
| 30 |
+
for prompt in prompts:
|
| 31 |
+
messages = [
|
| 32 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 33 |
+
{"role": "user", "content": prompt},
|
| 34 |
+
]
|
| 35 |
+
input_text = tokenizer.apply_chat_template(
|
| 36 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 37 |
+
)
|
| 38 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
output = model.generate(
|
| 41 |
+
**inputs,
|
| 42 |
+
max_new_tokens=max_new_tokens,
|
| 43 |
+
do_sample=True,
|
| 44 |
+
temperature=0.7,
|
| 45 |
+
top_p=0.9,
|
| 46 |
+
)
|
| 47 |
+
generated = output[0][inputs["input_ids"].shape[1]:]
|
| 48 |
+
completions.append(tokenizer.decode(generated, skip_special_tokens=True))
|
| 49 |
+
return completions
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def evaluate_model(
|
| 53 |
+
model,
|
| 54 |
+
tokenizer,
|
| 55 |
+
examples: List[Dict[str, str]],
|
| 56 |
+
reward_fn: OpenEnvReward,
|
| 57 |
+
label: str,
|
| 58 |
+
) -> Dict[str, float]:
|
| 59 |
+
prompts = [ex["prompt"] for ex in examples]
|
| 60 |
+
completions = generate_completions(model, tokenizer, prompts)
|
| 61 |
+
|
| 62 |
+
rewards = []
|
| 63 |
+
valid_actions = 0
|
| 64 |
+
for comp, ex in zip(completions, examples):
|
| 65 |
+
reward = reward_fn(
|
| 66 |
+
completions=[comp],
|
| 67 |
+
scenario_name=[ex.get("scenario_name")],
|
| 68 |
+
history_actions=[ex.get("history_actions")],
|
| 69 |
+
)[0]
|
| 70 |
+
rewards.append(reward)
|
| 71 |
+
if parse_action_completion(comp) is not None:
|
| 72 |
+
valid_actions += 1
|
| 73 |
+
|
| 74 |
+
avg_reward = sum(rewards) / len(rewards) if rewards else 0
|
| 75 |
+
valid_pct = valid_actions / len(completions) * 100 if completions else 0
|
| 76 |
+
|
| 77 |
+
print(f"\n{'='*50}")
|
| 78 |
+
print(f" {label}")
|
| 79 |
+
print(f"{'='*50}")
|
| 80 |
+
print(f" Samples: {len(completions)}")
|
| 81 |
+
print(f" Avg reward: {avg_reward:.4f}")
|
| 82 |
+
print(f" Min reward: {min(rewards):.4f}")
|
| 83 |
+
print(f" Max reward: {max(rewards):.4f}")
|
| 84 |
+
print(f" Valid actions: {valid_actions}/{len(completions)} ({valid_pct:.1f}%)")
|
| 85 |
+
print()
|
| 86 |
+
|
| 87 |
+
# Show a few example completions
|
| 88 |
+
for i, (comp, r) in enumerate(zip(completions[:3], rewards[:3])):
|
| 89 |
+
print(f" Example {i+1} (reward={r:.2f}):")
|
| 90 |
+
print(f" {comp[:200]}")
|
| 91 |
+
print()
|
| 92 |
+
|
| 93 |
+
return {"avg_reward": avg_reward, "valid_pct": valid_pct, "rewards": rewards}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def main():
|
| 97 |
+
parser = argparse.ArgumentParser(description="Compare base vs trained model")
|
| 98 |
+
parser.add_argument("--base-model", default="Qwen/Qwen3.5-0.8B",
|
| 99 |
+
help="Base model ID from HuggingFace")
|
| 100 |
+
parser.add_argument("--trained-model", default="./grpo-output",
|
| 101 |
+
help="Path to trained model (local dir or HF repo)")
|
| 102 |
+
parser.add_argument("--num-samples", type=int, default=16,
|
| 103 |
+
help="Number of eval prompts")
|
| 104 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 105 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
|
| 108 |
+
random.seed(args.seed)
|
| 109 |
+
|
| 110 |
+
# Build eval prompts
|
| 111 |
+
scenarios = selected_scenarios(None)
|
| 112 |
+
examples = build_prompt_examples(
|
| 113 |
+
dataset_episodes=args.num_samples,
|
| 114 |
+
rollout_steps=1, # one prompt per episode
|
| 115 |
+
collection_policy="heuristic",
|
| 116 |
+
scenario_names=scenarios,
|
| 117 |
+
seed=args.seed,
|
| 118 |
+
domain_randomise=False,
|
| 119 |
+
)
|
| 120 |
+
print(f"Built {len(examples)} eval prompts across {len(scenarios)} scenarios")
|
| 121 |
+
|
| 122 |
+
reward_fn = OpenEnvReward(reward_backend="local", base_url="")
|
| 123 |
+
|
| 124 |
+
# Evaluate base model
|
| 125 |
+
print(f"\nLoading base model: {args.base_model}")
|
| 126 |
+
base_tokenizer = AutoTokenizer.from_pretrained(
|
| 127 |
+
args.base_model, trust_remote_code=args.trust_remote_code
|
| 128 |
+
)
|
| 129 |
+
if base_tokenizer.pad_token is None:
|
| 130 |
+
base_tokenizer.pad_token = base_tokenizer.eos_token
|
| 131 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 132 |
+
args.base_model,
|
| 133 |
+
trust_remote_code=args.trust_remote_code,
|
| 134 |
+
torch_dtype=torch.bfloat16,
|
| 135 |
+
device_map="auto",
|
| 136 |
+
)
|
| 137 |
+
base_results = evaluate_model(
|
| 138 |
+
base_model, base_tokenizer, examples, reward_fn, "BASE MODEL"
|
| 139 |
+
)
|
| 140 |
+
del base_model
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
|
| 143 |
+
# Evaluate trained model
|
| 144 |
+
print(f"\nLoading trained model: {args.trained_model}")
|
| 145 |
+
trained_tokenizer = AutoTokenizer.from_pretrained(
|
| 146 |
+
args.trained_model, trust_remote_code=args.trust_remote_code
|
| 147 |
+
)
|
| 148 |
+
if trained_tokenizer.pad_token is None:
|
| 149 |
+
trained_tokenizer.pad_token = trained_tokenizer.eos_token
|
| 150 |
+
trained_model = AutoModelForCausalLM.from_pretrained(
|
| 151 |
+
args.trained_model,
|
| 152 |
+
trust_remote_code=args.trust_remote_code,
|
| 153 |
+
torch_dtype=torch.bfloat16,
|
| 154 |
+
device_map="auto",
|
| 155 |
+
)
|
| 156 |
+
trained_results = evaluate_model(
|
| 157 |
+
trained_model, trained_tokenizer, examples, reward_fn, "TRAINED MODEL"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Summary
|
| 161 |
+
delta = trained_results["avg_reward"] - base_results["avg_reward"]
|
| 162 |
+
print(f"{'='*50}")
|
| 163 |
+
print(f" COMPARISON SUMMARY")
|
| 164 |
+
print(f"{'='*50}")
|
| 165 |
+
print(f" Base avg reward: {base_results['avg_reward']:.4f}")
|
| 166 |
+
print(f" Trained avg reward: {trained_results['avg_reward']:.4f}")
|
| 167 |
+
print(f" Delta: {delta:+.4f}")
|
| 168 |
+
print(f" Base valid actions: {base_results['valid_pct']:.1f}%")
|
| 169 |
+
print(f" Trained valid: {trained_results['valid_pct']:.1f}%")
|
| 170 |
+
print()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
main()
|
models.py
CHANGED
|
@@ -73,6 +73,1025 @@ META_ACTIONS = frozenset({
|
|
| 73 |
})
|
| 74 |
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
class SubagentType(str, Enum):
|
| 77 |
WET_LAB_PLANNER = "wet_lab_planner"
|
| 78 |
COMPUTATIONAL_ANALYST = "computational_analyst"
|
|
@@ -96,29 +1115,61 @@ class ExperimentAction(Action):
|
|
| 96 |
"""
|
| 97 |
|
| 98 |
action_type: ActionType = Field(
|
| 99 |
-
...,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
input_targets: List[str] = Field(
|
| 102 |
default_factory=list,
|
| 103 |
-
description=
|
|
|
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
method: Optional[str] = Field(
|
| 106 |
-
None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
)
|
| 108 |
parameters: Dict[str, Any] = Field(
|
| 109 |
-
default_factory=dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
expected_output_type: Optional[str] = Field(
|
| 112 |
-
None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
)
|
| 114 |
justification: Optional[str] = Field(
|
| 115 |
-
None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
invoked_subagent: Optional[SubagentType] = Field(
|
| 118 |
None, description="Sub-agent to delegate to, if any"
|
| 119 |
)
|
| 120 |
tool_call_spec: Optional[Dict[str, Any]] = Field(
|
| 121 |
-
None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
confidence: float = Field(
|
| 124 |
0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
|
|
@@ -216,14 +1267,22 @@ class TaskSpec(BaseModel):
|
|
| 216 |
organism: str = "human"
|
| 217 |
tissue: str = "blood"
|
| 218 |
conditions: List[str] = Field(default_factory=list)
|
| 219 |
-
available_assays: List[str] = Field(
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
budget_limit: float = 100_000.0
|
| 228 |
time_limit_days: float = 180.0
|
| 229 |
prior_observations: List[str] = Field(default_factory=list)
|
|
@@ -234,7 +1293,10 @@ class TaskSpec(BaseModel):
|
|
| 234 |
|
| 235 |
|
| 236 |
class ConclusionClaim(BaseModel):
|
| 237 |
-
claim: str
|
|
|
|
|
|
|
|
|
|
| 238 |
evidence_steps: List[int] = Field(default_factory=list)
|
| 239 |
confidence: float = Field(0.5, ge=0.0, le=1.0)
|
| 240 |
claim_type: str = "correlational"
|
|
@@ -254,9 +1316,26 @@ class ExperimentObservation(Observation):
|
|
| 254 |
task: TaskSpec = Field(default_factory=TaskSpec)
|
| 255 |
step_index: int = 0
|
| 256 |
pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
|
| 257 |
-
available_assays: List[str] = Field(
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
latest_output: Optional[IntermediateOutput] = None
|
| 261 |
all_outputs: List[IntermediateOutput] = Field(default_factory=list)
|
| 262 |
discovered_markers: List[str] = Field(default_factory=list)
|
|
@@ -266,3 +1345,313 @@ class ExperimentObservation(Observation):
|
|
| 266 |
conclusions: List[ConclusionClaim] = Field(default_factory=list)
|
| 267 |
rule_violations: List[str] = Field(default_factory=list)
|
| 268 |
step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
})
|
| 74 |
|
| 75 |
|
| 76 |
+
# ── Tool, Assay & Modality Registries ──────────────────────────────────────
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ToolCategory(str, Enum):
|
| 80 |
+
ALIGNMENT = "alignment"
|
| 81 |
+
PREPROCESSING = "preprocessing"
|
| 82 |
+
NORMALIZATION = "normalization"
|
| 83 |
+
DIMENSIONALITY_REDUCTION = "dimensionality_reduction"
|
| 84 |
+
CLUSTERING = "clustering"
|
| 85 |
+
DIFFERENTIAL_EXPRESSION = "differential_expression"
|
| 86 |
+
TRAJECTORY = "trajectory"
|
| 87 |
+
GENE_REGULATORY_NETWORK = "gene_regulatory_network"
|
| 88 |
+
CELL_COMMUNICATION = "cell_communication"
|
| 89 |
+
SPATIAL = "spatial"
|
| 90 |
+
MULTIMODAL_INTEGRATION = "multimodal_integration"
|
| 91 |
+
GENE_SET_ANALYSIS = "gene_set_analysis"
|
| 92 |
+
VARIANT_CALLING = "variant_calling"
|
| 93 |
+
PEAK_CALLING = "peak_calling"
|
| 94 |
+
IMPUTATION = "imputation"
|
| 95 |
+
BATCH_CORRECTION = "batch_correction"
|
| 96 |
+
CELL_TYPE_ANNOTATION = "cell_type_annotation"
|
| 97 |
+
SIMULATION = "simulation"
|
| 98 |
+
VISUALIZATION = "visualization"
|
| 99 |
+
QUALITY_CONTROL = "quality_control"
|
| 100 |
+
PERTURBATION_ANALYSIS = "perturbation_analysis"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ToolSpec(BaseModel):
|
| 104 |
+
"""Registry entry describing a bioinformatics tool."""
|
| 105 |
+
|
| 106 |
+
name: str
|
| 107 |
+
category: ToolCategory
|
| 108 |
+
modalities: List[str] = Field(default_factory=list)
|
| 109 |
+
description: str = ""
|
| 110 |
+
input_types: List[str] = Field(default_factory=list)
|
| 111 |
+
output_types: List[str] = Field(default_factory=list)
|
| 112 |
+
typical_runtime_hours: float = 0.1
|
| 113 |
+
typical_cost_usd: float = 0.0
|
| 114 |
+
requires_gpu: bool = False
|
| 115 |
+
open_source: bool = True
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
TOOL_REGISTRY: Dict[str, ToolSpec] = {
|
| 119 |
+
# ── Alignment & quantification ──
|
| 120 |
+
"CellRanger": ToolSpec(
|
| 121 |
+
name="CellRanger",
|
| 122 |
+
category=ToolCategory.ALIGNMENT,
|
| 123 |
+
modalities=["scRNA-seq", "scATAC-seq", "CITE-seq", "scMultiome"],
|
| 124 |
+
description="10x Genomics pipeline for alignment, barcode demux, and counting",
|
| 125 |
+
input_types=["fastq"],
|
| 126 |
+
output_types=["count_matrix", "bam"],
|
| 127 |
+
typical_runtime_hours=4.0,
|
| 128 |
+
open_source=False,
|
| 129 |
+
),
|
| 130 |
+
"STARsolo": ToolSpec(
|
| 131 |
+
name="STARsolo",
|
| 132 |
+
category=ToolCategory.ALIGNMENT,
|
| 133 |
+
modalities=["scRNA-seq", "scATAC-seq"],
|
| 134 |
+
description="Drop-seq / 10x-compatible aligner built into STAR",
|
| 135 |
+
input_types=["fastq"],
|
| 136 |
+
output_types=["count_matrix", "bam"],
|
| 137 |
+
typical_runtime_hours=3.0,
|
| 138 |
+
),
|
| 139 |
+
"kallisto_bustools": ToolSpec(
|
| 140 |
+
name="kallisto_bustools",
|
| 141 |
+
category=ToolCategory.ALIGNMENT,
|
| 142 |
+
modalities=["scRNA-seq"],
|
| 143 |
+
description="Pseudoalignment-based lightweight quantification",
|
| 144 |
+
input_types=["fastq"],
|
| 145 |
+
output_types=["count_matrix"],
|
| 146 |
+
typical_runtime_hours=1.0,
|
| 147 |
+
),
|
| 148 |
+
"Salmon_alevin": ToolSpec(
|
| 149 |
+
name="Salmon_alevin",
|
| 150 |
+
category=ToolCategory.ALIGNMENT,
|
| 151 |
+
modalities=["scRNA-seq"],
|
| 152 |
+
description="Quasi-mapping quantification for single-cell RNA-seq",
|
| 153 |
+
input_types=["fastq"],
|
| 154 |
+
output_types=["count_matrix"],
|
| 155 |
+
typical_runtime_hours=1.5,
|
| 156 |
+
),
|
| 157 |
+
"spaceranger": ToolSpec(
|
| 158 |
+
name="spaceranger",
|
| 159 |
+
category=ToolCategory.ALIGNMENT,
|
| 160 |
+
modalities=["spatial_transcriptomics"],
|
| 161 |
+
description="10x Visium spatial alignment and quantification",
|
| 162 |
+
input_types=["fastq", "image"],
|
| 163 |
+
output_types=["count_matrix", "spatial_coords"],
|
| 164 |
+
typical_runtime_hours=3.0,
|
| 165 |
+
open_source=False,
|
| 166 |
+
),
|
| 167 |
+
# ── Preprocessing / analysis frameworks ──
|
| 168 |
+
"Scanpy": ToolSpec(
|
| 169 |
+
name="Scanpy",
|
| 170 |
+
category=ToolCategory.PREPROCESSING,
|
| 171 |
+
modalities=["scRNA-seq", "scATAC-seq", "spatial_transcriptomics"],
|
| 172 |
+
description="Python single-cell analysis framework",
|
| 173 |
+
input_types=["count_matrix", "h5ad"],
|
| 174 |
+
output_types=["h5ad", "embedding", "cluster_result"],
|
| 175 |
+
typical_runtime_hours=0.5,
|
| 176 |
+
),
|
| 177 |
+
"Seurat": ToolSpec(
|
| 178 |
+
name="Seurat",
|
| 179 |
+
category=ToolCategory.PREPROCESSING,
|
| 180 |
+
modalities=["scRNA-seq", "CITE-seq", "spatial_transcriptomics", "scATAC-seq"],
|
| 181 |
+
description="R single-cell analysis toolkit with multimodal support",
|
| 182 |
+
input_types=["count_matrix", "h5seurat"],
|
| 183 |
+
output_types=["h5seurat", "embedding", "cluster_result"],
|
| 184 |
+
typical_runtime_hours=0.5,
|
| 185 |
+
),
|
| 186 |
+
"Bioconductor_SingleCellExperiment": ToolSpec(
|
| 187 |
+
name="Bioconductor_SingleCellExperiment",
|
| 188 |
+
category=ToolCategory.PREPROCESSING,
|
| 189 |
+
modalities=["scRNA-seq"],
|
| 190 |
+
description="R/Bioconductor framework for single-cell experiments",
|
| 191 |
+
input_types=["count_matrix"],
|
| 192 |
+
output_types=["sce_object"],
|
| 193 |
+
typical_runtime_hours=0.3,
|
| 194 |
+
),
|
| 195 |
+
# ── Normalization ──
|
| 196 |
+
"scran": ToolSpec(
|
| 197 |
+
name="scran",
|
| 198 |
+
category=ToolCategory.NORMALIZATION,
|
| 199 |
+
modalities=["scRNA-seq"],
|
| 200 |
+
description="Pool-based size-factor normalization",
|
| 201 |
+
input_types=["count_matrix"],
|
| 202 |
+
output_types=["normalized_matrix"],
|
| 203 |
+
),
|
| 204 |
+
"sctransform": ToolSpec(
|
| 205 |
+
name="sctransform",
|
| 206 |
+
category=ToolCategory.NORMALIZATION,
|
| 207 |
+
modalities=["scRNA-seq"],
|
| 208 |
+
description="Variance-stabilizing transformation via regularized NB regression",
|
| 209 |
+
input_types=["count_matrix"],
|
| 210 |
+
output_types=["normalized_matrix"],
|
| 211 |
+
),
|
| 212 |
+
# ── Dimensionality reduction ──
|
| 213 |
+
"scVI": ToolSpec(
|
| 214 |
+
name="scVI",
|
| 215 |
+
category=ToolCategory.DIMENSIONALITY_REDUCTION,
|
| 216 |
+
modalities=["scRNA-seq", "CITE-seq", "scATAC-seq"],
|
| 217 |
+
description="Deep generative model for scRNA-seq (variational inference)",
|
| 218 |
+
input_types=["count_matrix"],
|
| 219 |
+
output_types=["latent_embedding"],
|
| 220 |
+
requires_gpu=True,
|
| 221 |
+
),
|
| 222 |
+
"UMAP": ToolSpec(
|
| 223 |
+
name="UMAP",
|
| 224 |
+
category=ToolCategory.DIMENSIONALITY_REDUCTION,
|
| 225 |
+
modalities=["scRNA-seq", "scATAC-seq", "CITE-seq", "spatial_transcriptomics"],
|
| 226 |
+
description="Uniform manifold approximation for 2D/3D visualization",
|
| 227 |
+
input_types=["pca_embedding", "latent_embedding"],
|
| 228 |
+
output_types=["2d_embedding"],
|
| 229 |
+
),
|
| 230 |
+
# ── Clustering ──
|
| 231 |
+
"Leiden": ToolSpec(
|
| 232 |
+
name="Leiden",
|
| 233 |
+
category=ToolCategory.CLUSTERING,
|
| 234 |
+
modalities=["scRNA-seq", "scATAC-seq", "CITE-seq"],
|
| 235 |
+
description="Community detection via the Leiden algorithm",
|
| 236 |
+
input_types=["knn_graph"],
|
| 237 |
+
output_types=["cluster_result"],
|
| 238 |
+
),
|
| 239 |
+
"Louvain": ToolSpec(
|
| 240 |
+
name="Louvain",
|
| 241 |
+
category=ToolCategory.CLUSTERING,
|
| 242 |
+
modalities=["scRNA-seq", "scATAC-seq"],
|
| 243 |
+
description="Community detection via Louvain modularity optimization",
|
| 244 |
+
input_types=["knn_graph"],
|
| 245 |
+
output_types=["cluster_result"],
|
| 246 |
+
),
|
| 247 |
+
# ── Differential expression ──
|
| 248 |
+
"DESeq2": ToolSpec(
|
| 249 |
+
name="DESeq2",
|
| 250 |
+
category=ToolCategory.DIFFERENTIAL_EXPRESSION,
|
| 251 |
+
modalities=["bulk_rna_seq", "scRNA-seq"],
|
| 252 |
+
description="Negative binomial GLM-based differential expression",
|
| 253 |
+
input_types=["count_matrix"],
|
| 254 |
+
output_types=["de_result"],
|
| 255 |
+
),
|
| 256 |
+
"MAST": ToolSpec(
|
| 257 |
+
name="MAST",
|
| 258 |
+
category=ToolCategory.DIFFERENTIAL_EXPRESSION,
|
| 259 |
+
modalities=["scRNA-seq"],
|
| 260 |
+
description="Two-part hurdle model for scRNA-seq DE testing",
|
| 261 |
+
input_types=["count_matrix"],
|
| 262 |
+
output_types=["de_result"],
|
| 263 |
+
),
|
| 264 |
+
"edgeR": ToolSpec(
|
| 265 |
+
name="edgeR",
|
| 266 |
+
category=ToolCategory.DIFFERENTIAL_EXPRESSION,
|
| 267 |
+
modalities=["bulk_rna_seq", "scRNA-seq"],
|
| 268 |
+
description="Empirical Bayes quasi-likelihood DE testing",
|
| 269 |
+
input_types=["count_matrix"],
|
| 270 |
+
output_types=["de_result"],
|
| 271 |
+
),
|
| 272 |
+
"Wilcoxon": ToolSpec(
|
| 273 |
+
name="Wilcoxon",
|
| 274 |
+
category=ToolCategory.DIFFERENTIAL_EXPRESSION,
|
| 275 |
+
modalities=["scRNA-seq"],
|
| 276 |
+
description="Rank-sum test for marker gene detection",
|
| 277 |
+
input_types=["count_matrix"],
|
| 278 |
+
output_types=["de_result"],
|
| 279 |
+
),
|
| 280 |
+
# ── Trajectory & RNA velocity ──
|
| 281 |
+
"Monocle3": ToolSpec(
|
| 282 |
+
name="Monocle3",
|
| 283 |
+
category=ToolCategory.TRAJECTORY,
|
| 284 |
+
modalities=["scRNA-seq"],
|
| 285 |
+
description="Reversed graph embedding for pseudotime trajectories",
|
| 286 |
+
input_types=["count_matrix", "embedding"],
|
| 287 |
+
output_types=["trajectory_result", "pseudotime"],
|
| 288 |
+
),
|
| 289 |
+
"scVelo": ToolSpec(
|
| 290 |
+
name="scVelo",
|
| 291 |
+
category=ToolCategory.TRAJECTORY,
|
| 292 |
+
modalities=["scRNA-seq"],
|
| 293 |
+
description="RNA velocity estimation via spliced/unspliced dynamics",
|
| 294 |
+
input_types=["count_matrix"],
|
| 295 |
+
output_types=["velocity_result"],
|
| 296 |
+
),
|
| 297 |
+
"CellRank": ToolSpec(
|
| 298 |
+
name="CellRank",
|
| 299 |
+
category=ToolCategory.TRAJECTORY,
|
| 300 |
+
modalities=["scRNA-seq"],
|
| 301 |
+
description="Fate probability estimation combining velocity and transcriptomics",
|
| 302 |
+
input_types=["velocity_result", "count_matrix"],
|
| 303 |
+
output_types=["fate_probabilities"],
|
| 304 |
+
),
|
| 305 |
+
"Slingshot": ToolSpec(
|
| 306 |
+
name="Slingshot",
|
| 307 |
+
category=ToolCategory.TRAJECTORY,
|
| 308 |
+
modalities=["scRNA-seq"],
|
| 309 |
+
description="Minimum spanning tree-based trajectory inference",
|
| 310 |
+
input_types=["embedding", "cluster_result"],
|
| 311 |
+
output_types=["trajectory_result", "pseudotime"],
|
| 312 |
+
),
|
| 313 |
+
"PAGA": ToolSpec(
|
| 314 |
+
name="PAGA",
|
| 315 |
+
category=ToolCategory.TRAJECTORY,
|
| 316 |
+
modalities=["scRNA-seq"],
|
| 317 |
+
description="Partition-based graph abstraction for topology estimation",
|
| 318 |
+
input_types=["knn_graph", "cluster_result"],
|
| 319 |
+
output_types=["trajectory_result"],
|
| 320 |
+
),
|
| 321 |
+
# ── Gene regulatory networks ──
|
| 322 |
+
"SCENIC": ToolSpec(
|
| 323 |
+
name="SCENIC",
|
| 324 |
+
category=ToolCategory.GENE_REGULATORY_NETWORK,
|
| 325 |
+
modalities=["scRNA-seq"],
|
| 326 |
+
description="Single-cell regulatory network inference and clustering",
|
| 327 |
+
input_types=["count_matrix"],
|
| 328 |
+
output_types=["regulon_result", "network_result"],
|
| 329 |
+
typical_runtime_hours=6.0,
|
| 330 |
+
),
|
| 331 |
+
"CellOracle": ToolSpec(
|
| 332 |
+
name="CellOracle",
|
| 333 |
+
category=ToolCategory.GENE_REGULATORY_NETWORK,
|
| 334 |
+
modalities=["scRNA-seq", "scATAC-seq", "scMultiome"],
|
| 335 |
+
description="GRN-based in-silico perturbation prediction",
|
| 336 |
+
input_types=["count_matrix", "peak_matrix"],
|
| 337 |
+
output_types=["network_result", "perturbation_prediction"],
|
| 338 |
+
typical_runtime_hours=4.0,
|
| 339 |
+
),
|
| 340 |
+
# ── Cell-cell communication ──
|
| 341 |
+
"CellChat": ToolSpec(
|
| 342 |
+
name="CellChat",
|
| 343 |
+
category=ToolCategory.CELL_COMMUNICATION,
|
| 344 |
+
modalities=["scRNA-seq", "spatial_transcriptomics"],
|
| 345 |
+
description="Ligand-receptor interaction inference with communication patterns",
|
| 346 |
+
input_types=["count_matrix", "cluster_result"],
|
| 347 |
+
output_types=["communication_result"],
|
| 348 |
+
),
|
| 349 |
+
"NicheNet": ToolSpec(
|
| 350 |
+
name="NicheNet",
|
| 351 |
+
category=ToolCategory.CELL_COMMUNICATION,
|
| 352 |
+
modalities=["scRNA-seq"],
|
| 353 |
+
description="Ligand-target link prediction using prior knowledge",
|
| 354 |
+
input_types=["count_matrix", "de_result"],
|
| 355 |
+
output_types=["communication_result"],
|
| 356 |
+
),
|
| 357 |
+
"LIANA": ToolSpec(
|
| 358 |
+
name="LIANA",
|
| 359 |
+
category=ToolCategory.CELL_COMMUNICATION,
|
| 360 |
+
modalities=["scRNA-seq", "spatial_transcriptomics"],
|
| 361 |
+
description="Framework unifying multiple ligand-receptor methods",
|
| 362 |
+
input_types=["count_matrix", "cluster_result"],
|
| 363 |
+
output_types=["communication_result"],
|
| 364 |
+
),
|
| 365 |
+
# ── Spatial analysis ──
|
| 366 |
+
"squidpy": ToolSpec(
|
| 367 |
+
name="squidpy",
|
| 368 |
+
category=ToolCategory.SPATIAL,
|
| 369 |
+
modalities=["spatial_transcriptomics"],
|
| 370 |
+
description="Spatial omics analysis (neighborhood, co-occurrence, image features)",
|
| 371 |
+
input_types=["count_matrix", "spatial_coords"],
|
| 372 |
+
output_types=["spatial_result"],
|
| 373 |
+
),
|
| 374 |
+
"cell2location": ToolSpec(
|
| 375 |
+
name="cell2location",
|
| 376 |
+
category=ToolCategory.SPATIAL,
|
| 377 |
+
modalities=["spatial_transcriptomics"],
|
| 378 |
+
description="Spatial deconvolution mapping cell types to tissue locations",
|
| 379 |
+
input_types=["count_matrix", "spatial_coords", "reference_h5ad"],
|
| 380 |
+
output_types=["deconvolution_result"],
|
| 381 |
+
requires_gpu=True,
|
| 382 |
+
),
|
| 383 |
+
"BANKSY": ToolSpec(
|
| 384 |
+
name="BANKSY",
|
| 385 |
+
category=ToolCategory.SPATIAL,
|
| 386 |
+
modalities=["spatial_transcriptomics"],
|
| 387 |
+
description="Spatially-aware clustering combining cell and neighbor features",
|
| 388 |
+
input_types=["count_matrix", "spatial_coords"],
|
| 389 |
+
output_types=["cluster_result"],
|
| 390 |
+
),
|
| 391 |
+
# ── Multimodal integration ──
|
| 392 |
+
"Harmony": ToolSpec(
|
| 393 |
+
name="Harmony",
|
| 394 |
+
category=ToolCategory.BATCH_CORRECTION,
|
| 395 |
+
modalities=["scRNA-seq", "scATAC-seq", "CITE-seq"],
|
| 396 |
+
description="Fast iterative batch correction on PCA embeddings",
|
| 397 |
+
input_types=["pca_embedding"],
|
| 398 |
+
output_types=["corrected_embedding"],
|
| 399 |
+
),
|
| 400 |
+
"scanorama": ToolSpec(
|
| 401 |
+
name="scanorama",
|
| 402 |
+
category=ToolCategory.BATCH_CORRECTION,
|
| 403 |
+
modalities=["scRNA-seq"],
|
| 404 |
+
description="Panoramic stitching of scRNA-seq batches",
|
| 405 |
+
input_types=["count_matrix"],
|
| 406 |
+
output_types=["corrected_embedding", "corrected_matrix"],
|
| 407 |
+
),
|
| 408 |
+
"BBKNN": ToolSpec(
|
| 409 |
+
name="BBKNN",
|
| 410 |
+
category=ToolCategory.BATCH_CORRECTION,
|
| 411 |
+
modalities=["scRNA-seq"],
|
| 412 |
+
description="Batch-balanced KNN graph construction",
|
| 413 |
+
input_types=["pca_embedding"],
|
| 414 |
+
output_types=["knn_graph"],
|
| 415 |
+
),
|
| 416 |
+
"WNN": ToolSpec(
|
| 417 |
+
name="WNN",
|
| 418 |
+
category=ToolCategory.MULTIMODAL_INTEGRATION,
|
| 419 |
+
modalities=["CITE-seq", "scMultiome"],
|
| 420 |
+
description="Weighted nearest neighbors for multimodal integration (Seurat v4+)",
|
| 421 |
+
input_types=["rna_embedding", "protein_embedding"],
|
| 422 |
+
output_types=["multimodal_embedding"],
|
| 423 |
+
),
|
| 424 |
+
"MOFA+": ToolSpec(
|
| 425 |
+
name="MOFA+",
|
| 426 |
+
category=ToolCategory.MULTIMODAL_INTEGRATION,
|
| 427 |
+
modalities=["scMultiome", "CITE-seq"],
|
| 428 |
+
description="Multi-omics factor analysis for unsupervised integration",
|
| 429 |
+
input_types=["count_matrix", "peak_matrix"],
|
| 430 |
+
output_types=["factor_result"],
|
| 431 |
+
),
|
| 432 |
+
"ArchR": ToolSpec(
|
| 433 |
+
name="ArchR",
|
| 434 |
+
category=ToolCategory.PREPROCESSING,
|
| 435 |
+
modalities=["scATAC-seq", "scMultiome"],
|
| 436 |
+
description="Full-featured scATAC-seq analysis framework in R",
|
| 437 |
+
input_types=["fragments", "bam"],
|
| 438 |
+
output_types=["peak_matrix", "gene_activity_matrix"],
|
| 439 |
+
typical_runtime_hours=2.0,
|
| 440 |
+
),
|
| 441 |
+
"Signac": ToolSpec(
|
| 442 |
+
name="Signac",
|
| 443 |
+
category=ToolCategory.PREPROCESSING,
|
| 444 |
+
modalities=["scATAC-seq", "scMultiome"],
|
| 445 |
+
description="Seurat extension for chromatin accessibility analysis",
|
| 446 |
+
input_types=["fragments", "peak_matrix"],
|
| 447 |
+
output_types=["peak_matrix", "motif_result"],
|
| 448 |
+
),
|
| 449 |
+
"chromVAR": ToolSpec(
|
| 450 |
+
name="chromVAR",
|
| 451 |
+
category=ToolCategory.PEAK_CALLING,
|
| 452 |
+
modalities=["scATAC-seq", "scMultiome"],
|
| 453 |
+
description="TF motif accessibility deviation scoring",
|
| 454 |
+
input_types=["peak_matrix"],
|
| 455 |
+
output_types=["motif_deviation_scores"],
|
| 456 |
+
),
|
| 457 |
+
# ── Gene set / pathway analysis ──
|
| 458 |
+
"GSEA": ToolSpec(
|
| 459 |
+
name="GSEA",
|
| 460 |
+
category=ToolCategory.GENE_SET_ANALYSIS,
|
| 461 |
+
modalities=["bulk_rna_seq", "scRNA-seq"],
|
| 462 |
+
description="Gene Set Enrichment Analysis (preranked or phenotype-based)",
|
| 463 |
+
input_types=["de_result", "ranked_gene_list"],
|
| 464 |
+
output_types=["pathway_result"],
|
| 465 |
+
),
|
| 466 |
+
"clusterProfiler": ToolSpec(
|
| 467 |
+
name="clusterProfiler",
|
| 468 |
+
category=ToolCategory.GENE_SET_ANALYSIS,
|
| 469 |
+
modalities=["bulk_rna_seq", "scRNA-seq"],
|
| 470 |
+
description="ORA & GSEA with GO, KEGG, Reactome, and custom gene sets",
|
| 471 |
+
input_types=["de_result", "gene_list"],
|
| 472 |
+
output_types=["pathway_result"],
|
| 473 |
+
),
|
| 474 |
+
"decoupleR": ToolSpec(
|
| 475 |
+
name="decoupleR",
|
| 476 |
+
category=ToolCategory.GENE_SET_ANALYSIS,
|
| 477 |
+
modalities=["scRNA-seq", "bulk_rna_seq", "spatial_transcriptomics"],
|
| 478 |
+
description="Unified framework for functional activity inference (TF, pathway)",
|
| 479 |
+
input_types=["count_matrix", "de_result"],
|
| 480 |
+
output_types=["activity_scores"],
|
| 481 |
+
),
|
| 482 |
+
# ── Cell type annotation ──
|
| 483 |
+
"celltypist": ToolSpec(
|
| 484 |
+
name="celltypist",
|
| 485 |
+
category=ToolCategory.CELL_TYPE_ANNOTATION,
|
| 486 |
+
modalities=["scRNA-seq"],
|
| 487 |
+
description="Automated cell type classification with pre-trained models",
|
| 488 |
+
input_types=["count_matrix"],
|
| 489 |
+
output_types=["annotation_result"],
|
| 490 |
+
),
|
| 491 |
+
"SingleR": ToolSpec(
|
| 492 |
+
name="SingleR",
|
| 493 |
+
category=ToolCategory.CELL_TYPE_ANNOTATION,
|
| 494 |
+
modalities=["scRNA-seq"],
|
| 495 |
+
description="Reference-based cell type annotation using correlation",
|
| 496 |
+
input_types=["count_matrix", "reference_dataset"],
|
| 497 |
+
output_types=["annotation_result"],
|
| 498 |
+
),
|
| 499 |
+
"scArches": ToolSpec(
|
| 500 |
+
name="scArches",
|
| 501 |
+
category=ToolCategory.CELL_TYPE_ANNOTATION,
|
| 502 |
+
modalities=["scRNA-seq", "scATAC-seq", "CITE-seq"],
|
| 503 |
+
description="Reference mapping and label transfer via deep learning",
|
| 504 |
+
input_types=["count_matrix", "reference_model"],
|
| 505 |
+
output_types=["annotation_result", "latent_embedding"],
|
| 506 |
+
requires_gpu=True,
|
| 507 |
+
),
|
| 508 |
+
# ── Imputation ──
|
| 509 |
+
"MAGIC": ToolSpec(
|
| 510 |
+
name="MAGIC",
|
| 511 |
+
category=ToolCategory.IMPUTATION,
|
| 512 |
+
modalities=["scRNA-seq"],
|
| 513 |
+
description="Markov affinity-based graph imputation of dropout zeros",
|
| 514 |
+
input_types=["count_matrix"],
|
| 515 |
+
output_types=["imputed_matrix"],
|
| 516 |
+
),
|
| 517 |
+
# ── Perturbation analysis ──
|
| 518 |
+
"MILO": ToolSpec(
|
| 519 |
+
name="MILO",
|
| 520 |
+
category=ToolCategory.PERTURBATION_ANALYSIS,
|
| 521 |
+
modalities=["scRNA-seq"],
|
| 522 |
+
description="Differential abundance testing on KNN graph neighborhoods",
|
| 523 |
+
input_types=["count_matrix", "knn_graph"],
|
| 524 |
+
output_types=["da_result"],
|
| 525 |
+
),
|
| 526 |
+
"Mixscape": ToolSpec(
|
| 527 |
+
name="Mixscape",
|
| 528 |
+
category=ToolCategory.PERTURBATION_ANALYSIS,
|
| 529 |
+
modalities=["Perturb-seq", "CROP-seq"],
|
| 530 |
+
description="Seurat extension for CRISPR screen perturbation analysis",
|
| 531 |
+
input_types=["count_matrix", "guide_assignments"],
|
| 532 |
+
output_types=["perturbation_result"],
|
| 533 |
+
),
|
| 534 |
+
"MIMOSCA": ToolSpec(
|
| 535 |
+
name="MIMOSCA",
|
| 536 |
+
category=ToolCategory.PERTURBATION_ANALYSIS,
|
| 537 |
+
modalities=["Perturb-seq", "CROP-seq"],
|
| 538 |
+
description="Multi-input multi-output single-cell analysis for screens",
|
| 539 |
+
input_types=["count_matrix", "guide_assignments"],
|
| 540 |
+
output_types=["perturbation_result"],
|
| 541 |
+
),
|
| 542 |
+
# ── Quality control ──
|
| 543 |
+
"scrublet": ToolSpec(
|
| 544 |
+
name="scrublet",
|
| 545 |
+
category=ToolCategory.QUALITY_CONTROL,
|
| 546 |
+
modalities=["scRNA-seq"],
|
| 547 |
+
description="Computational doublet detection via synthetic doublets",
|
| 548 |
+
input_types=["count_matrix"],
|
| 549 |
+
output_types=["doublet_scores"],
|
| 550 |
+
),
|
| 551 |
+
"DoubletFinder": ToolSpec(
|
| 552 |
+
name="DoubletFinder",
|
| 553 |
+
category=ToolCategory.QUALITY_CONTROL,
|
| 554 |
+
modalities=["scRNA-seq"],
|
| 555 |
+
description="Artificial nearest-neighbor doublet detection",
|
| 556 |
+
input_types=["count_matrix"],
|
| 557 |
+
output_types=["doublet_scores"],
|
| 558 |
+
),
|
| 559 |
+
"SoupX": ToolSpec(
|
| 560 |
+
name="SoupX",
|
| 561 |
+
category=ToolCategory.QUALITY_CONTROL,
|
| 562 |
+
modalities=["scRNA-seq"],
|
| 563 |
+
description="Ambient RNA contamination estimation and removal",
|
| 564 |
+
input_types=["count_matrix", "raw_count_matrix"],
|
| 565 |
+
output_types=["corrected_matrix"],
|
| 566 |
+
),
|
| 567 |
+
"DecontX": ToolSpec(
|
| 568 |
+
name="DecontX",
|
| 569 |
+
category=ToolCategory.QUALITY_CONTROL,
|
| 570 |
+
modalities=["scRNA-seq"],
|
| 571 |
+
description="Bayesian ambient RNA decontamination",
|
| 572 |
+
input_types=["count_matrix"],
|
| 573 |
+
output_types=["corrected_matrix"],
|
| 574 |
+
),
|
| 575 |
+
# ── Simulation ──
|
| 576 |
+
"Splatter": ToolSpec(
|
| 577 |
+
name="Splatter",
|
| 578 |
+
category=ToolCategory.SIMULATION,
|
| 579 |
+
modalities=["scRNA-seq"],
|
| 580 |
+
description="Flexible scRNA-seq data simulation framework",
|
| 581 |
+
input_types=["simulation_params"],
|
| 582 |
+
output_types=["simulated_count_matrix"],
|
| 583 |
+
),
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class Modality(str, Enum):
|
| 588 |
+
SCRNA_SEQ = "scRNA-seq"
|
| 589 |
+
SCATAC_SEQ = "scATAC-seq"
|
| 590 |
+
CITE_SEQ = "CITE-seq"
|
| 591 |
+
SPATIAL_TRANSCRIPTOMICS = "spatial_transcriptomics"
|
| 592 |
+
BULK_RNA_SEQ = "bulk_rna_seq"
|
| 593 |
+
SCRNA_MULTIOME = "scMultiome"
|
| 594 |
+
PERTURB_SEQ = "Perturb-seq"
|
| 595 |
+
CROP_SEQ = "CROP-seq"
|
| 596 |
+
SMART_SEQ2 = "Smart-seq2"
|
| 597 |
+
SLIDE_SEQ = "Slide-seq"
|
| 598 |
+
MERFISH = "MERFISH"
|
| 599 |
+
SEQFISH = "seqFISH"
|
| 600 |
+
PATCH_SEQ = "Patch-seq"
|
| 601 |
+
SHARE_SEQ = "SHARE-seq"
|
| 602 |
+
SNARE_SEQ = "SNARE-seq"
|
| 603 |
+
SC_HI_C = "scHi-C"
|
| 604 |
+
SCBS_SEQ = "scBS-seq"
|
| 605 |
+
SCNMT_SEQ = "scNMT-seq"
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
class ModalitySpec(BaseModel):
|
| 609 |
+
"""Registry entry for a single-cell or bulk assay modality."""
|
| 610 |
+
|
| 611 |
+
name: str
|
| 612 |
+
modality: Modality
|
| 613 |
+
measurement: str = ""
|
| 614 |
+
resolution: str = "single-cell"
|
| 615 |
+
multiplexable: bool = False
|
| 616 |
+
typical_cells: str = "1k-20k"
|
| 617 |
+
typical_cost_per_sample_usd: float = 5000.0
|
| 618 |
+
compatible_tools: List[str] = Field(default_factory=list)
|
| 619 |
+
description: str = ""
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
MODALITY_REGISTRY: Dict[str, ModalitySpec] = {
|
| 623 |
+
"scRNA-seq": ModalitySpec(
|
| 624 |
+
name="scRNA-seq",
|
| 625 |
+
modality=Modality.SCRNA_SEQ,
|
| 626 |
+
measurement="mRNA transcripts",
|
| 627 |
+
typical_cells="5k-20k",
|
| 628 |
+
typical_cost_per_sample_usd=5000.0,
|
| 629 |
+
compatible_tools=[
|
| 630 |
+
"CellRanger", "STARsolo", "kallisto_bustools", "Scanpy", "Seurat",
|
| 631 |
+
"scVI", "Leiden", "DESeq2", "MAST", "Monocle3", "scVelo", "SCENIC",
|
| 632 |
+
"CellChat", "GSEA", "celltypist", "scrublet",
|
| 633 |
+
],
|
| 634 |
+
description="Droplet-based single-cell RNA sequencing (e.g. 10x Chromium)",
|
| 635 |
+
),
|
| 636 |
+
"scATAC-seq": ModalitySpec(
|
| 637 |
+
name="scATAC-seq",
|
| 638 |
+
modality=Modality.SCATAC_SEQ,
|
| 639 |
+
measurement="open chromatin regions",
|
| 640 |
+
typical_cells="5k-15k",
|
| 641 |
+
typical_cost_per_sample_usd=6000.0,
|
| 642 |
+
compatible_tools=[
|
| 643 |
+
"CellRanger", "ArchR", "Signac", "chromVAR", "Scanpy", "Leiden",
|
| 644 |
+
],
|
| 645 |
+
description="Single-cell Assay for Transposase-Accessible Chromatin",
|
| 646 |
+
),
|
| 647 |
+
"CITE-seq": ModalitySpec(
|
| 648 |
+
name="CITE-seq",
|
| 649 |
+
modality=Modality.CITE_SEQ,
|
| 650 |
+
measurement="mRNA + surface proteins (ADT)",
|
| 651 |
+
multiplexable=True,
|
| 652 |
+
typical_cells="5k-20k",
|
| 653 |
+
typical_cost_per_sample_usd=8000.0,
|
| 654 |
+
compatible_tools=[
|
| 655 |
+
"CellRanger", "Seurat", "WNN", "MOFA+", "Scanpy", "Leiden",
|
| 656 |
+
],
|
| 657 |
+
description="Cellular Indexing of Transcriptomes and Epitopes by Sequencing",
|
| 658 |
+
),
|
| 659 |
+
"spatial_transcriptomics": ModalitySpec(
|
| 660 |
+
name="spatial_transcriptomics",
|
| 661 |
+
modality=Modality.SPATIAL_TRANSCRIPTOMICS,
|
| 662 |
+
measurement="spatially resolved transcripts",
|
| 663 |
+
resolution="spot (55µm) or subcellular",
|
| 664 |
+
typical_cells="1k-10k spots",
|
| 665 |
+
typical_cost_per_sample_usd=7000.0,
|
| 666 |
+
compatible_tools=[
|
| 667 |
+
"spaceranger", "squidpy", "cell2location", "BANKSY", "Scanpy", "Seurat",
|
| 668 |
+
],
|
| 669 |
+
description="Spatially resolved transcriptomics (Visium, MERFISH, Slide-seq, etc.)",
|
| 670 |
+
),
|
| 671 |
+
"bulk_rna_seq": ModalitySpec(
|
| 672 |
+
name="bulk_rna_seq",
|
| 673 |
+
modality=Modality.BULK_RNA_SEQ,
|
| 674 |
+
measurement="aggregate mRNA across cells",
|
| 675 |
+
resolution="bulk",
|
| 676 |
+
typical_cells="N/A",
|
| 677 |
+
typical_cost_per_sample_usd=500.0,
|
| 678 |
+
compatible_tools=["DESeq2", "edgeR", "GSEA", "clusterProfiler"],
|
| 679 |
+
description="Standard bulk RNA sequencing",
|
| 680 |
+
),
|
| 681 |
+
"scMultiome": ModalitySpec(
|
| 682 |
+
name="scMultiome",
|
| 683 |
+
modality=Modality.SCRNA_MULTIOME,
|
| 684 |
+
measurement="mRNA + open chromatin (joint)",
|
| 685 |
+
typical_cells="5k-15k",
|
| 686 |
+
typical_cost_per_sample_usd=10000.0,
|
| 687 |
+
compatible_tools=[
|
| 688 |
+
"CellRanger", "ArchR", "Signac", "Seurat", "MOFA+", "CellOracle",
|
| 689 |
+
],
|
| 690 |
+
description="10x Multiome (joint scRNA + scATAC from same cell)",
|
| 691 |
+
),
|
| 692 |
+
"Perturb-seq": ModalitySpec(
|
| 693 |
+
name="Perturb-seq",
|
| 694 |
+
modality=Modality.PERTURB_SEQ,
|
| 695 |
+
measurement="mRNA + CRISPR guide assignment",
|
| 696 |
+
multiplexable=True,
|
| 697 |
+
typical_cells="10k-100k",
|
| 698 |
+
typical_cost_per_sample_usd=15000.0,
|
| 699 |
+
compatible_tools=[
|
| 700 |
+
"CellRanger", "Scanpy", "Seurat", "Mixscape", "MIMOSCA",
|
| 701 |
+
],
|
| 702 |
+
description="Pooled CRISPR screens with single-cell RNA readout",
|
| 703 |
+
),
|
| 704 |
+
"CROP-seq": ModalitySpec(
|
| 705 |
+
name="CROP-seq",
|
| 706 |
+
modality=Modality.CROP_SEQ,
|
| 707 |
+
measurement="mRNA + CRISPR guide assignment",
|
| 708 |
+
multiplexable=True,
|
| 709 |
+
typical_cells="10k-50k",
|
| 710 |
+
typical_cost_per_sample_usd=12000.0,
|
| 711 |
+
compatible_tools=[
|
| 712 |
+
"CellRanger", "Scanpy", "Seurat", "Mixscape", "MIMOSCA",
|
| 713 |
+
],
|
| 714 |
+
description="CRISPR dropout screen with single-cell RNA readout",
|
| 715 |
+
),
|
| 716 |
+
"Smart-seq2": ModalitySpec(
|
| 717 |
+
name="Smart-seq2",
|
| 718 |
+
modality=Modality.SMART_SEQ2,
|
| 719 |
+
measurement="full-length mRNA transcripts",
|
| 720 |
+
typical_cells="100-1000",
|
| 721 |
+
typical_cost_per_sample_usd=10000.0,
|
| 722 |
+
compatible_tools=["Scanpy", "Seurat", "DESeq2", "MAST", "Monocle3"],
|
| 723 |
+
description="Plate-based full-length scRNA-seq with high sensitivity",
|
| 724 |
+
),
|
| 725 |
+
"MERFISH": ModalitySpec(
|
| 726 |
+
name="MERFISH",
|
| 727 |
+
modality=Modality.MERFISH,
|
| 728 |
+
measurement="in situ mRNA (imaging-based)",
|
| 729 |
+
resolution="subcellular",
|
| 730 |
+
typical_cells="10k-1M",
|
| 731 |
+
typical_cost_per_sample_usd=20000.0,
|
| 732 |
+
compatible_tools=["squidpy", "Scanpy", "BANKSY"],
|
| 733 |
+
description="Multiplexed Error-Robust FISH for spatial transcriptomics",
|
| 734 |
+
),
|
| 735 |
+
"Slide-seq": ModalitySpec(
|
| 736 |
+
name="Slide-seq",
|
| 737 |
+
modality=Modality.SLIDE_SEQ,
|
| 738 |
+
measurement="spatially resolved mRNA (bead array)",
|
| 739 |
+
resolution="10µm",
|
| 740 |
+
typical_cells="10k-50k beads",
|
| 741 |
+
typical_cost_per_sample_usd=8000.0,
|
| 742 |
+
compatible_tools=["squidpy", "cell2location", "Scanpy"],
|
| 743 |
+
description="Near-cellular spatial transcriptomics on bead arrays",
|
| 744 |
+
),
|
| 745 |
+
"Patch-seq": ModalitySpec(
|
| 746 |
+
name="Patch-seq",
|
| 747 |
+
modality=Modality.PATCH_SEQ,
|
| 748 |
+
measurement="mRNA + electrophysiology + morphology",
|
| 749 |
+
typical_cells="10-500",
|
| 750 |
+
typical_cost_per_sample_usd=50000.0,
|
| 751 |
+
compatible_tools=["Scanpy", "Seurat"],
|
| 752 |
+
description="Combined patch-clamp electrophysiology and scRNA-seq",
|
| 753 |
+
),
|
| 754 |
+
"scHi-C": ModalitySpec(
|
| 755 |
+
name="scHi-C",
|
| 756 |
+
modality=Modality.SC_HI_C,
|
| 757 |
+
measurement="3D chromatin contacts",
|
| 758 |
+
typical_cells="1k-10k",
|
| 759 |
+
typical_cost_per_sample_usd=15000.0,
|
| 760 |
+
compatible_tools=["Scanpy"],
|
| 761 |
+
description="Single-cell chromosome conformation capture",
|
| 762 |
+
),
|
| 763 |
+
"scBS-seq": ModalitySpec(
|
| 764 |
+
name="scBS-seq",
|
| 765 |
+
modality=Modality.SCBS_SEQ,
|
| 766 |
+
measurement="DNA methylation (CpG)",
|
| 767 |
+
typical_cells="100-5k",
|
| 768 |
+
typical_cost_per_sample_usd=12000.0,
|
| 769 |
+
compatible_tools=["Scanpy"],
|
| 770 |
+
description="Single-cell bisulfite sequencing for DNA methylation",
|
| 771 |
+
),
|
| 772 |
+
"scNMT-seq": ModalitySpec(
|
| 773 |
+
name="scNMT-seq",
|
| 774 |
+
modality=Modality.SCNMT_SEQ,
|
| 775 |
+
measurement="nucleosome + methylation + transcription (joint)",
|
| 776 |
+
typical_cells="100-1k",
|
| 777 |
+
typical_cost_per_sample_usd=25000.0,
|
| 778 |
+
compatible_tools=["MOFA+", "Scanpy"],
|
| 779 |
+
description="Joint single-cell nucleosome, methylation, and transcription",
|
| 780 |
+
),
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
class AssayCategory(str, Enum):
|
| 785 |
+
SEQUENCING = "sequencing"
|
| 786 |
+
IMAGING = "imaging"
|
| 787 |
+
PERTURBATION = "perturbation"
|
| 788 |
+
FUNCTIONAL = "functional"
|
| 789 |
+
EPIGENOMICS = "epigenomics"
|
| 790 |
+
PROTEOMICS = "proteomics"
|
| 791 |
+
METABOLOMICS = "metabolomics"
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
class AssaySpec(BaseModel):
|
| 795 |
+
"""Registry entry for a laboratory assay or protocol."""
|
| 796 |
+
|
| 797 |
+
name: str
|
| 798 |
+
category: AssayCategory
|
| 799 |
+
modalities: List[str] = Field(default_factory=list)
|
| 800 |
+
description: str = ""
|
| 801 |
+
typical_duration_days: float = 1.0
|
| 802 |
+
typical_cost_usd: float = 1000.0
|
| 803 |
+
requires_live_cells: bool = False
|
| 804 |
+
requires_fresh_tissue: bool = False
|
| 805 |
+
throughput: str = "medium"
|
| 806 |
+
outputs: List[str] = Field(default_factory=list)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
ASSAY_REGISTRY: Dict[str, AssaySpec] = {
|
| 810 |
+
"10x_chromium": AssaySpec(
|
| 811 |
+
name="10x_chromium",
|
| 812 |
+
category=AssayCategory.SEQUENCING,
|
| 813 |
+
modalities=["scRNA-seq", "scATAC-seq", "CITE-seq", "scMultiome"],
|
| 814 |
+
description="10x Genomics Chromium droplet-based single-cell partitioning",
|
| 815 |
+
typical_duration_days=2.0,
|
| 816 |
+
typical_cost_usd=5000.0,
|
| 817 |
+
requires_live_cells=True,
|
| 818 |
+
throughput="high (500-20k cells)",
|
| 819 |
+
outputs=["fastq", "count_matrix"],
|
| 820 |
+
),
|
| 821 |
+
"smart-seq2": AssaySpec(
|
| 822 |
+
name="smart-seq2",
|
| 823 |
+
category=AssayCategory.SEQUENCING,
|
| 824 |
+
modalities=["Smart-seq2"],
|
| 825 |
+
description="Plate-based full-length cDNA scRNA-seq",
|
| 826 |
+
typical_duration_days=3.0,
|
| 827 |
+
typical_cost_usd=10000.0,
|
| 828 |
+
requires_live_cells=True,
|
| 829 |
+
throughput="low (96-384 cells)",
|
| 830 |
+
outputs=["fastq", "count_matrix"],
|
| 831 |
+
),
|
| 832 |
+
"smart-seq3": AssaySpec(
|
| 833 |
+
name="smart-seq3",
|
| 834 |
+
category=AssayCategory.SEQUENCING,
|
| 835 |
+
modalities=["Smart-seq2"],
|
| 836 |
+
description="Improved full-length scRNA-seq with UMIs",
|
| 837 |
+
typical_duration_days=3.0,
|
| 838 |
+
typical_cost_usd=10000.0,
|
| 839 |
+
requires_live_cells=True,
|
| 840 |
+
throughput="low (96-384 cells)",
|
| 841 |
+
outputs=["fastq", "count_matrix"],
|
| 842 |
+
),
|
| 843 |
+
"bulk_rna_seq": AssaySpec(
|
| 844 |
+
name="bulk_rna_seq",
|
| 845 |
+
category=AssayCategory.SEQUENCING,
|
| 846 |
+
modalities=["bulk_rna_seq"],
|
| 847 |
+
description="Standard bulk RNA sequencing with poly-A or ribo-depletion",
|
| 848 |
+
typical_duration_days=3.0,
|
| 849 |
+
typical_cost_usd=500.0,
|
| 850 |
+
throughput="high",
|
| 851 |
+
outputs=["fastq", "count_matrix"],
|
| 852 |
+
),
|
| 853 |
+
"atac-seq": AssaySpec(
|
| 854 |
+
name="atac-seq",
|
| 855 |
+
category=AssayCategory.EPIGENOMICS,
|
| 856 |
+
modalities=["scATAC-seq"],
|
| 857 |
+
description="Assay for Transposase-Accessible Chromatin using sequencing",
|
| 858 |
+
typical_duration_days=2.0,
|
| 859 |
+
typical_cost_usd=6000.0,
|
| 860 |
+
requires_live_cells=True,
|
| 861 |
+
outputs=["fastq", "fragments", "peak_matrix"],
|
| 862 |
+
),
|
| 863 |
+
"cite-seq": AssaySpec(
|
| 864 |
+
name="cite-seq",
|
| 865 |
+
category=AssayCategory.PROTEOMICS,
|
| 866 |
+
modalities=["CITE-seq"],
|
| 867 |
+
description="Simultaneous RNA + surface protein via DNA-barcoded antibodies",
|
| 868 |
+
typical_duration_days=2.0,
|
| 869 |
+
typical_cost_usd=8000.0,
|
| 870 |
+
requires_live_cells=True,
|
| 871 |
+
throughput="high (5k-20k cells)",
|
| 872 |
+
outputs=["fastq", "count_matrix", "adt_matrix"],
|
| 873 |
+
),
|
| 874 |
+
"10x_multiome": AssaySpec(
|
| 875 |
+
name="10x_multiome",
|
| 876 |
+
category=AssayCategory.SEQUENCING,
|
| 877 |
+
modalities=["scMultiome"],
|
| 878 |
+
description="Joint scRNA-seq + scATAC-seq from the same cell",
|
| 879 |
+
typical_duration_days=2.0,
|
| 880 |
+
typical_cost_usd=10000.0,
|
| 881 |
+
requires_live_cells=True,
|
| 882 |
+
throughput="high (5k-15k cells)",
|
| 883 |
+
outputs=["fastq", "count_matrix", "fragments"],
|
| 884 |
+
),
|
| 885 |
+
"visium": AssaySpec(
|
| 886 |
+
name="visium",
|
| 887 |
+
category=AssayCategory.SEQUENCING,
|
| 888 |
+
modalities=["spatial_transcriptomics"],
|
| 889 |
+
description="10x Visium spatially barcoded capture on tissue sections",
|
| 890 |
+
typical_duration_days=3.0,
|
| 891 |
+
typical_cost_usd=7000.0,
|
| 892 |
+
requires_fresh_tissue=True,
|
| 893 |
+
throughput="medium (1k-5k spots)",
|
| 894 |
+
outputs=["fastq", "count_matrix", "spatial_coords", "image"],
|
| 895 |
+
),
|
| 896 |
+
"visium_hd": AssaySpec(
|
| 897 |
+
name="visium_hd",
|
| 898 |
+
category=AssayCategory.SEQUENCING,
|
| 899 |
+
modalities=["spatial_transcriptomics"],
|
| 900 |
+
description="High-definition Visium with 2µm bin resolution",
|
| 901 |
+
typical_duration_days=3.0,
|
| 902 |
+
typical_cost_usd=10000.0,
|
| 903 |
+
requires_fresh_tissue=True,
|
| 904 |
+
throughput="high",
|
| 905 |
+
outputs=["fastq", "count_matrix", "spatial_coords", "image"],
|
| 906 |
+
),
|
| 907 |
+
"merfish": AssaySpec(
|
| 908 |
+
name="merfish",
|
| 909 |
+
category=AssayCategory.IMAGING,
|
| 910 |
+
modalities=["MERFISH"],
|
| 911 |
+
description="Multiplexed Error-Robust FISH imaging-based spatial",
|
| 912 |
+
typical_duration_days=5.0,
|
| 913 |
+
typical_cost_usd=20000.0,
|
| 914 |
+
requires_fresh_tissue=True,
|
| 915 |
+
throughput="high (100-1000 genes, millions of transcripts)",
|
| 916 |
+
outputs=["transcript_coords", "cell_segmentation"],
|
| 917 |
+
),
|
| 918 |
+
"seqfish_plus": AssaySpec(
|
| 919 |
+
name="seqfish_plus",
|
| 920 |
+
category=AssayCategory.IMAGING,
|
| 921 |
+
modalities=["seqFISH"],
|
| 922 |
+
description="Sequential FISH for imaging-based spatial transcriptomics",
|
| 923 |
+
typical_duration_days=5.0,
|
| 924 |
+
typical_cost_usd=15000.0,
|
| 925 |
+
requires_fresh_tissue=True,
|
| 926 |
+
outputs=["transcript_coords"],
|
| 927 |
+
),
|
| 928 |
+
"slide-seq": AssaySpec(
|
| 929 |
+
name="slide-seq",
|
| 930 |
+
category=AssayCategory.SEQUENCING,
|
| 931 |
+
modalities=["Slide-seq"],
|
| 932 |
+
description="Near-cellular spatial transcriptomics on bead arrays",
|
| 933 |
+
typical_duration_days=3.0,
|
| 934 |
+
typical_cost_usd=8000.0,
|
| 935 |
+
requires_fresh_tissue=True,
|
| 936 |
+
outputs=["count_matrix", "spatial_coords"],
|
| 937 |
+
),
|
| 938 |
+
"perturb-seq": AssaySpec(
|
| 939 |
+
name="perturb-seq",
|
| 940 |
+
category=AssayCategory.PERTURBATION,
|
| 941 |
+
modalities=["Perturb-seq"],
|
| 942 |
+
description="Pooled CRISPR screen + scRNA-seq readout",
|
| 943 |
+
typical_duration_days=14.0,
|
| 944 |
+
typical_cost_usd=15000.0,
|
| 945 |
+
requires_live_cells=True,
|
| 946 |
+
throughput="high (10k-100k cells)",
|
| 947 |
+
outputs=["fastq", "count_matrix", "guide_assignments"],
|
| 948 |
+
),
|
| 949 |
+
"crop-seq": AssaySpec(
|
| 950 |
+
name="crop-seq",
|
| 951 |
+
category=AssayCategory.PERTURBATION,
|
| 952 |
+
modalities=["CROP-seq"],
|
| 953 |
+
description="CRISPR dropout screening with scRNA-seq readout",
|
| 954 |
+
typical_duration_days=14.0,
|
| 955 |
+
typical_cost_usd=12000.0,
|
| 956 |
+
requires_live_cells=True,
|
| 957 |
+
throughput="high (10k-50k cells)",
|
| 958 |
+
outputs=["fastq", "count_matrix", "guide_assignments"],
|
| 959 |
+
),
|
| 960 |
+
"patch-seq": AssaySpec(
|
| 961 |
+
name="patch-seq",
|
| 962 |
+
category=AssayCategory.FUNCTIONAL,
|
| 963 |
+
modalities=["Patch-seq"],
|
| 964 |
+
description="Patch-clamp electrophysiology + scRNA-seq on same neuron",
|
| 965 |
+
typical_duration_days=7.0,
|
| 966 |
+
typical_cost_usd=50000.0,
|
| 967 |
+
requires_live_cells=True,
|
| 968 |
+
throughput="very low (10-100 cells)",
|
| 969 |
+
outputs=["fastq", "count_matrix", "ephys_trace", "morphology"],
|
| 970 |
+
),
|
| 971 |
+
"sc_hi_c": AssaySpec(
|
| 972 |
+
name="sc_hi_c",
|
| 973 |
+
category=AssayCategory.EPIGENOMICS,
|
| 974 |
+
modalities=["scHi-C"],
|
| 975 |
+
description="Single-cell chromosome conformation capture",
|
| 976 |
+
typical_duration_days=5.0,
|
| 977 |
+
typical_cost_usd=15000.0,
|
| 978 |
+
outputs=["contact_matrix"],
|
| 979 |
+
),
|
| 980 |
+
"sc_bisulfite": AssaySpec(
|
| 981 |
+
name="sc_bisulfite",
|
| 982 |
+
category=AssayCategory.EPIGENOMICS,
|
| 983 |
+
modalities=["scBS-seq"],
|
| 984 |
+
description="Single-cell bisulfite sequencing for DNA methylation profiling",
|
| 985 |
+
typical_duration_days=5.0,
|
| 986 |
+
typical_cost_usd=12000.0,
|
| 987 |
+
outputs=["methylation_matrix"],
|
| 988 |
+
),
|
| 989 |
+
"sc_nmt_seq": AssaySpec(
|
| 990 |
+
name="sc_nmt_seq",
|
| 991 |
+
category=AssayCategory.EPIGENOMICS,
|
| 992 |
+
modalities=["scNMT-seq"],
|
| 993 |
+
description="Joint nucleosome occupancy, methylation, and transcription",
|
| 994 |
+
typical_duration_days=7.0,
|
| 995 |
+
typical_cost_usd=25000.0,
|
| 996 |
+
requires_live_cells=True,
|
| 997 |
+
throughput="low (100-1k cells)",
|
| 998 |
+
outputs=["count_matrix", "methylation_matrix", "accessibility_matrix"],
|
| 999 |
+
),
|
| 1000 |
+
"flow_cytometry": AssaySpec(
|
| 1001 |
+
name="flow_cytometry",
|
| 1002 |
+
category=AssayCategory.FUNCTIONAL,
|
| 1003 |
+
modalities=[],
|
| 1004 |
+
description="Fluorescence-based cell sorting and phenotyping",
|
| 1005 |
+
typical_duration_days=1.0,
|
| 1006 |
+
typical_cost_usd=500.0,
|
| 1007 |
+
requires_live_cells=True,
|
| 1008 |
+
throughput="very high (millions of cells)",
|
| 1009 |
+
outputs=["cell_counts", "sorted_cells"],
|
| 1010 |
+
),
|
| 1011 |
+
"mass_cytometry_CyTOF": AssaySpec(
|
| 1012 |
+
name="mass_cytometry_CyTOF",
|
| 1013 |
+
category=AssayCategory.PROTEOMICS,
|
| 1014 |
+
modalities=[],
|
| 1015 |
+
description="Mass-tag cytometry for 40+ protein markers per cell",
|
| 1016 |
+
typical_duration_days=2.0,
|
| 1017 |
+
typical_cost_usd=3000.0,
|
| 1018 |
+
requires_live_cells=True,
|
| 1019 |
+
throughput="high (100k-1M cells)",
|
| 1020 |
+
outputs=["protein_expression_matrix"],
|
| 1021 |
+
),
|
| 1022 |
+
"western_blot": AssaySpec(
|
| 1023 |
+
name="western_blot",
|
| 1024 |
+
category=AssayCategory.PROTEOMICS,
|
| 1025 |
+
modalities=[],
|
| 1026 |
+
description="Protein detection and semi-quantification by size separation",
|
| 1027 |
+
typical_duration_days=2.0,
|
| 1028 |
+
typical_cost_usd=200.0,
|
| 1029 |
+
outputs=["band_image", "relative_quantification"],
|
| 1030 |
+
),
|
| 1031 |
+
"qPCR": AssaySpec(
|
| 1032 |
+
name="qPCR",
|
| 1033 |
+
category=AssayCategory.FUNCTIONAL,
|
| 1034 |
+
modalities=[],
|
| 1035 |
+
description="Quantitative PCR for targeted gene expression validation",
|
| 1036 |
+
typical_duration_days=1.0,
|
| 1037 |
+
typical_cost_usd=100.0,
|
| 1038 |
+
throughput="low (target genes)",
|
| 1039 |
+
outputs=["ct_values", "fold_change"],
|
| 1040 |
+
),
|
| 1041 |
+
"immunofluorescence": AssaySpec(
|
| 1042 |
+
name="immunofluorescence",
|
| 1043 |
+
category=AssayCategory.IMAGING,
|
| 1044 |
+
modalities=[],
|
| 1045 |
+
description="Antibody-based fluorescence imaging of proteins in situ",
|
| 1046 |
+
typical_duration_days=2.0,
|
| 1047 |
+
typical_cost_usd=500.0,
|
| 1048 |
+
outputs=["fluorescence_image"],
|
| 1049 |
+
),
|
| 1050 |
+
"elisa": AssaySpec(
|
| 1051 |
+
name="elisa",
|
| 1052 |
+
category=AssayCategory.PROTEOMICS,
|
| 1053 |
+
modalities=[],
|
| 1054 |
+
description="Enzyme-linked immunosorbent assay for secreted protein quantification",
|
| 1055 |
+
typical_duration_days=1.0,
|
| 1056 |
+
typical_cost_usd=300.0,
|
| 1057 |
+
throughput="medium (96-384 well)",
|
| 1058 |
+
outputs=["protein_concentration"],
|
| 1059 |
+
),
|
| 1060 |
+
"cell_viability_assay": AssaySpec(
|
| 1061 |
+
name="cell_viability_assay",
|
| 1062 |
+
category=AssayCategory.FUNCTIONAL,
|
| 1063 |
+
modalities=[],
|
| 1064 |
+
description="MTT/CellTiter-Glo viability and proliferation measurement",
|
| 1065 |
+
typical_duration_days=1.0,
|
| 1066 |
+
typical_cost_usd=200.0,
|
| 1067 |
+
requires_live_cells=True,
|
| 1068 |
+
throughput="high (96-384 well)",
|
| 1069 |
+
outputs=["viability_scores"],
|
| 1070 |
+
),
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
# ── Registry helper functions ──────────────────────────────────────────────
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
def tools_for_modality(modality: str) -> List[ToolSpec]:
|
| 1078 |
+
"""Return all registered tools compatible with a given modality."""
|
| 1079 |
+
return [t for t in TOOL_REGISTRY.values() if modality in t.modalities]
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
def assays_for_modality(modality: str) -> List[AssaySpec]:
|
| 1083 |
+
"""Return all registered assays that produce a given modality."""
|
| 1084 |
+
return [a for a in ASSAY_REGISTRY.values() if modality in a.modalities]
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
def tools_by_category(category: ToolCategory) -> List[ToolSpec]:
|
| 1088 |
+
"""Return all registered tools in a given category."""
|
| 1089 |
+
return [t for t in TOOL_REGISTRY.values() if t.category == category]
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
# ── Sub-agents ───────────────────────────────────────────────────���─────────
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
class SubagentType(str, Enum):
|
| 1096 |
WET_LAB_PLANNER = "wet_lab_planner"
|
| 1097 |
COMPUTATIONAL_ANALYST = "computational_analyst"
|
|
|
|
| 1115 |
"""
|
| 1116 |
|
| 1117 |
action_type: ActionType = Field(
|
| 1118 |
+
...,
|
| 1119 |
+
description=(
|
| 1120 |
+
"Discrete simulator step type. The environment enforces scientific "
|
| 1121 |
+
"prerequisites between steps, so actions should follow a valid "
|
| 1122 |
+
"pipeline order."
|
| 1123 |
+
),
|
| 1124 |
)
|
| 1125 |
input_targets: List[str] = Field(
|
| 1126 |
default_factory=list,
|
| 1127 |
+
description=(
|
| 1128 |
+
"Optional references to prior samples, outputs, or artifacts that "
|
| 1129 |
+
"this step consumes."
|
| 1130 |
+
),
|
| 1131 |
)
|
| 1132 |
method: Optional[str] = Field(
|
| 1133 |
+
None,
|
| 1134 |
+
description=(
|
| 1135 |
+
"Optional named tool or protocol (for example 'Seurat' or "
|
| 1136 |
+
"'CellRanger'). Prefer methods compatible with the current "
|
| 1137 |
+
"modality and available tool list because tool choice can change "
|
| 1138 |
+
"runtime, cost, and scientific fit."
|
| 1139 |
+
),
|
| 1140 |
)
|
| 1141 |
parameters: Dict[str, Any] = Field(
|
| 1142 |
+
default_factory=dict,
|
| 1143 |
+
description=(
|
| 1144 |
+
"Action-specific settings such as comparison labels, perturbation "
|
| 1145 |
+
"targets, or analysis options. Use only parameters that materially "
|
| 1146 |
+
"change the scientific step."
|
| 1147 |
+
),
|
| 1148 |
)
|
| 1149 |
expected_output_type: Optional[str] = Field(
|
| 1150 |
+
None,
|
| 1151 |
+
description=(
|
| 1152 |
+
"Optional expected artifact or summary that should result from the "
|
| 1153 |
+
"step, such as a count matrix, QC report, DE table, or validation "
|
| 1154 |
+
"result."
|
| 1155 |
+
),
|
| 1156 |
)
|
| 1157 |
justification: Optional[str] = Field(
|
| 1158 |
+
None,
|
| 1159 |
+
description=(
|
| 1160 |
+
"Short scientific rationale explaining why this is the right next "
|
| 1161 |
+
"step in the current environment state."
|
| 1162 |
+
),
|
| 1163 |
)
|
| 1164 |
invoked_subagent: Optional[SubagentType] = Field(
|
| 1165 |
None, description="Sub-agent to delegate to, if any"
|
| 1166 |
)
|
| 1167 |
tool_call_spec: Optional[Dict[str, Any]] = Field(
|
| 1168 |
+
None,
|
| 1169 |
+
description=(
|
| 1170 |
+
"Optional structured tool invocation payload when the action needs "
|
| 1171 |
+
"a more explicit tool execution plan."
|
| 1172 |
+
),
|
| 1173 |
)
|
| 1174 |
confidence: float = Field(
|
| 1175 |
0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
|
|
|
|
| 1267 |
organism: str = "human"
|
| 1268 |
tissue: str = "blood"
|
| 1269 |
conditions: List[str] = Field(default_factory=list)
|
| 1270 |
+
available_assays: List[str] = Field(
|
| 1271 |
+
default_factory=lambda: list(ASSAY_REGISTRY.keys()),
|
| 1272 |
+
description=(
|
| 1273 |
+
"Assays that are scientifically compatible with this task's "
|
| 1274 |
+
"modality. These are the relevant assay choices for the episode, "
|
| 1275 |
+
"not an unrestricted catalog."
|
| 1276 |
+
),
|
| 1277 |
+
)
|
| 1278 |
+
available_tools: List[str] = Field(
|
| 1279 |
+
default_factory=lambda: list(TOOL_REGISTRY.keys()),
|
| 1280 |
+
description=(
|
| 1281 |
+
"Tools filtered to those compatible with the current task "
|
| 1282 |
+
"modality. The agent should treat this list as the preferred tool "
|
| 1283 |
+
"set for the episode."
|
| 1284 |
+
),
|
| 1285 |
+
)
|
| 1286 |
budget_limit: float = 100_000.0
|
| 1287 |
time_limit_days: float = 180.0
|
| 1288 |
prior_observations: List[str] = Field(default_factory=list)
|
|
|
|
| 1293 |
|
| 1294 |
|
| 1295 |
class ConclusionClaim(BaseModel):
|
| 1296 |
+
claim: str = ""
|
| 1297 |
+
top_markers: List[str] = Field(default_factory=list)
|
| 1298 |
+
causal_mechanisms: List[str] = Field(default_factory=list)
|
| 1299 |
+
predicted_pathways: Dict[str, float] = Field(default_factory=dict)
|
| 1300 |
evidence_steps: List[int] = Field(default_factory=list)
|
| 1301 |
confidence: float = Field(0.5, ge=0.0, le=1.0)
|
| 1302 |
claim_type: str = "correlational"
|
|
|
|
| 1316 |
task: TaskSpec = Field(default_factory=TaskSpec)
|
| 1317 |
step_index: int = 0
|
| 1318 |
pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
|
| 1319 |
+
available_assays: List[str] = Field(
|
| 1320 |
+
default_factory=list,
|
| 1321 |
+
description=(
|
| 1322 |
+
"Episode-specific assay choices already filtered to the current "
|
| 1323 |
+
"modality and task context."
|
| 1324 |
+
),
|
| 1325 |
+
)
|
| 1326 |
+
available_tools: List[str] = Field(
|
| 1327 |
+
default_factory=list,
|
| 1328 |
+
description=(
|
| 1329 |
+
"Episode-specific compatible tools. These are the methods the "
|
| 1330 |
+
"agent should prefer instead of inventing incompatible tools."
|
| 1331 |
+
),
|
| 1332 |
+
)
|
| 1333 |
+
resource_usage: ResourceUsage = Field(
|
| 1334 |
+
default_factory=ResourceUsage,
|
| 1335 |
+
description=(
|
| 1336 |
+
"Running budget, time, and compute usage after previous actions."
|
| 1337 |
+
),
|
| 1338 |
+
)
|
| 1339 |
latest_output: Optional[IntermediateOutput] = None
|
| 1340 |
all_outputs: List[IntermediateOutput] = Field(default_factory=list)
|
| 1341 |
discovered_markers: List[str] = Field(default_factory=list)
|
|
|
|
| 1345 |
conclusions: List[ConclusionClaim] = Field(default_factory=list)
|
| 1346 |
rule_violations: List[str] = Field(default_factory=list)
|
| 1347 |
step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
|
| 1348 |
+
|
| 1349 |
+
|
| 1350 |
+
AGENT_ACTION_GUIDANCE: Dict[ActionType, str] = {
|
| 1351 |
+
ActionType.COLLECT_SAMPLE: (
|
| 1352 |
+
"Wet-lab entry point. One successful collection usually provides enough "
|
| 1353 |
+
"material to continue unless the output shows poor yield or quality."
|
| 1354 |
+
),
|
| 1355 |
+
ActionType.SELECT_COHORT: (
|
| 1356 |
+
"Use when subject stratification is part of the scientific question "
|
| 1357 |
+
"before downstream experimental work."
|
| 1358 |
+
),
|
| 1359 |
+
ActionType.PREPARE_LIBRARY: (
|
| 1360 |
+
"Requires collected samples and converts biological material into "
|
| 1361 |
+
"sequence-ready libraries."
|
| 1362 |
+
),
|
| 1363 |
+
ActionType.CULTURE_CELLS: (
|
| 1364 |
+
"Requires collected samples and adds substantial time; use only when "
|
| 1365 |
+
"live-cell expansion or later perturbation is needed."
|
| 1366 |
+
),
|
| 1367 |
+
ActionType.PERTURB_GENE: (
|
| 1368 |
+
"Requires samples. Use for causal tests, not as a default discovery "
|
| 1369 |
+
"step."
|
| 1370 |
+
),
|
| 1371 |
+
ActionType.PERTURB_COMPOUND: (
|
| 1372 |
+
"Requires samples. Best for mechanistic follow-up or treatment "
|
| 1373 |
+
"response questions."
|
| 1374 |
+
),
|
| 1375 |
+
ActionType.SEQUENCE_CELLS: (
|
| 1376 |
+
"Requires prepared libraries and produces the raw sequencing-derived "
|
| 1377 |
+
"artifacts used by downstream QC and analysis."
|
| 1378 |
+
),
|
| 1379 |
+
ActionType.RUN_QC: (
|
| 1380 |
+
"Requires sequencing and returns summarized quality metrics such as "
|
| 1381 |
+
"doublets, mitochondrial fraction, and ambient RNA."
|
| 1382 |
+
),
|
| 1383 |
+
ActionType.FILTER_DATA: (
|
| 1384 |
+
"Requires QC and removes poor-quality cells, changing downstream cell "
|
| 1385 |
+
"counts and data retention."
|
| 1386 |
+
),
|
| 1387 |
+
ActionType.NORMALIZE_DATA: (
|
| 1388 |
+
"Requires filtered data and unlocks clustering, differential "
|
| 1389 |
+
"expression, trajectory, and network analyses."
|
| 1390 |
+
),
|
| 1391 |
+
ActionType.INTEGRATE_BATCHES: (
|
| 1392 |
+
"Requires normalized data. Use when batch effects are likely to "
|
| 1393 |
+
"confound interpretation; it is not always necessary."
|
| 1394 |
+
),
|
| 1395 |
+
ActionType.CLUSTER_CELLS: (
|
| 1396 |
+
"Requires normalized data and identifies cell populations or states "
|
| 1397 |
+
"for downstream interpretation."
|
| 1398 |
+
),
|
| 1399 |
+
ActionType.DIFFERENTIAL_EXPRESSION: (
|
| 1400 |
+
"Requires normalized data and is the main route to candidate genes "
|
| 1401 |
+
"for pathway analysis and marker selection."
|
| 1402 |
+
),
|
| 1403 |
+
ActionType.TRAJECTORY_ANALYSIS: (
|
| 1404 |
+
"Requires normalized data and is most useful when lineage progression "
|
| 1405 |
+
"or pseudotime is central to the task."
|
| 1406 |
+
),
|
| 1407 |
+
ActionType.PATHWAY_ENRICHMENT: (
|
| 1408 |
+
"Requires differential expression. Results are less reliable without a "
|
| 1409 |
+
"strong DE gene list."
|
| 1410 |
+
),
|
| 1411 |
+
ActionType.REGULATORY_NETWORK_INFERENCE: (
|
| 1412 |
+
"Requires normalized data and is most helpful once cell states or "
|
| 1413 |
+
"trajectories are already characterized."
|
| 1414 |
+
),
|
| 1415 |
+
ActionType.MARKER_SELECTION: (
|
| 1416 |
+
"Requires differential expression and turns candidate genes into a "
|
| 1417 |
+
"short list for validation."
|
| 1418 |
+
),
|
| 1419 |
+
ActionType.VALIDATE_MARKER: (
|
| 1420 |
+
"Requires discovered markers and is an expensive wet-lab confirmation "
|
| 1421 |
+
"step that should follow strong computational evidence."
|
| 1422 |
+
),
|
| 1423 |
+
ActionType.DESIGN_FOLLOWUP: (
|
| 1424 |
+
"Use to propose targeted next experiments once remaining uncertainty "
|
| 1425 |
+
"is clear."
|
| 1426 |
+
),
|
| 1427 |
+
ActionType.REQUEST_SUBAGENT_REVIEW: (
|
| 1428 |
+
"Use for critique or planning support, not as a substitute for "
|
| 1429 |
+
"missing experimental evidence."
|
| 1430 |
+
),
|
| 1431 |
+
ActionType.SYNTHESIZE_CONCLUSION: (
|
| 1432 |
+
"Use once the evidence is sufficient. Do not spend budget on redundant "
|
| 1433 |
+
"steps just because more actions are possible."
|
| 1434 |
+
),
|
| 1435 |
+
}
|
| 1436 |
+
|
| 1437 |
+
AGENT_ENVIRONMENT_RULES: List[str] = [
|
| 1438 |
+
(
|
| 1439 |
+
"Each successful action already returns summarized scientific evidence, "
|
| 1440 |
+
"so repeated sampling or repeated analysis is not the default."
|
| 1441 |
+
),
|
| 1442 |
+
(
|
| 1443 |
+
"Repeat a step only when the task demands it or when prior outputs show "
|
| 1444 |
+
"poor quality, insufficient yield, unresolved batch effects, or another "
|
| 1445 |
+
"clear failure mode."
|
| 1446 |
+
),
|
| 1447 |
+
(
|
| 1448 |
+
"The available tool and assay lists are already filtered to the current "
|
| 1449 |
+
"task modality, so prefer them over inventing incompatible methods."
|
| 1450 |
+
),
|
| 1451 |
+
(
|
| 1452 |
+
"Hard scientific prerequisites are enforced by the environment, so "
|
| 1453 |
+
"invalid pipeline orderings will be blocked."
|
| 1454 |
+
),
|
| 1455 |
+
]
|
| 1456 |
+
|
| 1457 |
+
_TOOL_CATEGORY_AGENT_NOTES: Dict[ToolCategory, str] = {
|
| 1458 |
+
ToolCategory.ALIGNMENT: (
|
| 1459 |
+
"Best immediately after sequencing to turn FASTQ-like inputs into "
|
| 1460 |
+
"count-style matrices for downstream analysis."
|
| 1461 |
+
),
|
| 1462 |
+
ToolCategory.PREPROCESSING: (
|
| 1463 |
+
"Useful for general single-cell data handling before specialized "
|
| 1464 |
+
"downstream analyses."
|
| 1465 |
+
),
|
| 1466 |
+
ToolCategory.NORMALIZATION: (
|
| 1467 |
+
"Applies after filtering to produce normalized matrices for downstream "
|
| 1468 |
+
"modeling."
|
| 1469 |
+
),
|
| 1470 |
+
ToolCategory.DIMENSIONALITY_REDUCTION: (
|
| 1471 |
+
"Builds latent embeddings that support clustering or trajectory work."
|
| 1472 |
+
),
|
| 1473 |
+
ToolCategory.CLUSTERING: (
|
| 1474 |
+
"Best once data are normalized and the goal is to resolve cell states "
|
| 1475 |
+
"or populations."
|
| 1476 |
+
),
|
| 1477 |
+
ToolCategory.DIFFERENTIAL_EXPRESSION: (
|
| 1478 |
+
"Tests contrasts and produces ranked genes for biological "
|
| 1479 |
+
"interpretation."
|
| 1480 |
+
),
|
| 1481 |
+
ToolCategory.TRAJECTORY: (
|
| 1482 |
+
"Useful when the task asks about developmental progression, state "
|
| 1483 |
+
"transitions, or pseudotime."
|
| 1484 |
+
),
|
| 1485 |
+
ToolCategory.GENE_REGULATORY_NETWORK: (
|
| 1486 |
+
"Most useful after normalized data and some cell-state structure are "
|
| 1487 |
+
"already established."
|
| 1488 |
+
),
|
| 1489 |
+
ToolCategory.GENE_SET_ANALYSIS: (
|
| 1490 |
+
"Best after differential expression to interpret gene lists at the "
|
| 1491 |
+
"pathway level."
|
| 1492 |
+
),
|
| 1493 |
+
ToolCategory.BATCH_CORRECTION: (
|
| 1494 |
+
"Use when batch effects would confound interpretation; unnecessary use "
|
| 1495 |
+
"adds extra steps."
|
| 1496 |
+
),
|
| 1497 |
+
ToolCategory.MULTIMODAL_INTEGRATION: (
|
| 1498 |
+
"Useful only when combining modalities or batches is part of the "
|
| 1499 |
+
"scientific question."
|
| 1500 |
+
),
|
| 1501 |
+
ToolCategory.QUALITY_CONTROL: (
|
| 1502 |
+
"Helps identify low-quality cells or technical artifacts before "
|
| 1503 |
+
"filtering."
|
| 1504 |
+
),
|
| 1505 |
+
ToolCategory.CELL_TYPE_ANNOTATION: (
|
| 1506 |
+
"Best after clustering when assigning biological identities to groups."
|
| 1507 |
+
),
|
| 1508 |
+
ToolCategory.PERTURBATION_ANALYSIS: (
|
| 1509 |
+
"Use when perturbations were actually applied and the goal is to model "
|
| 1510 |
+
"their transcriptional effects."
|
| 1511 |
+
),
|
| 1512 |
+
ToolCategory.SPATIAL: (
|
| 1513 |
+
"Only useful when the modality includes spatial coordinates or tissue "
|
| 1514 |
+
"context."
|
| 1515 |
+
),
|
| 1516 |
+
}
|
| 1517 |
+
|
| 1518 |
+
|
| 1519 |
+
def _format_currency(value: float) -> str:
|
| 1520 |
+
return f"${value:,.0f}"
|
| 1521 |
+
|
| 1522 |
+
|
| 1523 |
+
def _format_runtime_hours(hours: float) -> str:
|
| 1524 |
+
if hours < 1.0:
|
| 1525 |
+
return f"{int(round(hours * 60))}m"
|
| 1526 |
+
if float(hours).is_integer():
|
| 1527 |
+
return f"{int(hours)}h"
|
| 1528 |
+
return f"{hours:.1f}h"
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
def describe_tool_for_agent(tool_name: str) -> str:
|
| 1532 |
+
"""Return a compact environment-aware tool description for prompts."""
|
| 1533 |
+
tool = TOOL_REGISTRY.get(tool_name)
|
| 1534 |
+
if tool is None:
|
| 1535 |
+
return tool_name
|
| 1536 |
+
|
| 1537 |
+
parts = [f"{tool.name}: {tool.description}."]
|
| 1538 |
+
if tool.input_types or tool.output_types:
|
| 1539 |
+
inputs = ", ".join(tool.input_types) or "upstream artifacts"
|
| 1540 |
+
outputs = ", ".join(tool.output_types) or "analysis artifacts"
|
| 1541 |
+
parts.append(f"Consumes {inputs}; yields {outputs}.")
|
| 1542 |
+
|
| 1543 |
+
category_note = _TOOL_CATEGORY_AGENT_NOTES.get(tool.category)
|
| 1544 |
+
if category_note:
|
| 1545 |
+
parts.append(category_note)
|
| 1546 |
+
|
| 1547 |
+
resource_bits: List[str] = []
|
| 1548 |
+
if tool.typical_cost_usd > 0:
|
| 1549 |
+
resource_bits.append(_format_currency(tool.typical_cost_usd))
|
| 1550 |
+
if tool.typical_runtime_hours > 0:
|
| 1551 |
+
resource_bits.append(_format_runtime_hours(tool.typical_runtime_hours))
|
| 1552 |
+
if tool.requires_gpu:
|
| 1553 |
+
resource_bits.append("GPU")
|
| 1554 |
+
if resource_bits:
|
| 1555 |
+
parts.append(f"Typical resources: {', '.join(resource_bits)}.")
|
| 1556 |
+
|
| 1557 |
+
return " ".join(parts)
|
| 1558 |
+
|
| 1559 |
+
|
| 1560 |
+
def describe_assay_for_agent(assay_name: str) -> str:
|
| 1561 |
+
"""Return a compact environment-aware assay description for prompts."""
|
| 1562 |
+
assay = ASSAY_REGISTRY.get(assay_name)
|
| 1563 |
+
if assay is None:
|
| 1564 |
+
return assay_name
|
| 1565 |
+
|
| 1566 |
+
parts = [f"{assay.name}: {assay.description}."]
|
| 1567 |
+
if assay.outputs:
|
| 1568 |
+
parts.append(f"Produces {', '.join(assay.outputs)}.")
|
| 1569 |
+
|
| 1570 |
+
requirements: List[str] = []
|
| 1571 |
+
if assay.requires_live_cells:
|
| 1572 |
+
requirements.append("live cells")
|
| 1573 |
+
if assay.requires_fresh_tissue:
|
| 1574 |
+
requirements.append("fresh tissue")
|
| 1575 |
+
if requirements:
|
| 1576 |
+
parts.append(f"Requires {' and '.join(requirements)}.")
|
| 1577 |
+
|
| 1578 |
+
parts.append(
|
| 1579 |
+
"Typical resources: "
|
| 1580 |
+
f"{_format_currency(assay.typical_cost_usd)}, "
|
| 1581 |
+
f"{assay.typical_duration_days:.1f}d."
|
| 1582 |
+
)
|
| 1583 |
+
return " ".join(parts)
|
| 1584 |
+
|
| 1585 |
+
|
| 1586 |
+
def build_agent_system_prompt() -> str:
|
| 1587 |
+
"""Build the shared agent system prompt for training and inference."""
|
| 1588 |
+
lines = [
|
| 1589 |
+
"You are an expert biologist planning a single-cell experiment pipeline.",
|
| 1590 |
+
"",
|
| 1591 |
+
"At each turn you see the experiment state and must pick the next scientifically justified step.",
|
| 1592 |
+
"",
|
| 1593 |
+
"Environment-specific reasoning rules:",
|
| 1594 |
+
]
|
| 1595 |
+
lines.extend(f" - {rule}" for rule in AGENT_ENVIRONMENT_RULES)
|
| 1596 |
+
lines.append("")
|
| 1597 |
+
lines.append("Action guidance:")
|
| 1598 |
+
lines.extend(
|
| 1599 |
+
f" - {action_type.value}: {AGENT_ACTION_GUIDANCE[action_type]}"
|
| 1600 |
+
for action_type in ActionType
|
| 1601 |
+
)
|
| 1602 |
+
lines.extend([
|
| 1603 |
+
"",
|
| 1604 |
+
"Respond with ONLY valid JSON, nothing else:",
|
| 1605 |
+
'{"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}',
|
| 1606 |
+
"",
|
| 1607 |
+
"For synthesize_conclusion, use structured claims:",
|
| 1608 |
+
'{"action_type": "synthesize_conclusion", "parameters": {"claims": [{"top_markers": ["GENE1", "GENE2"], "causal_mechanisms": ["mechanism description"], "predicted_pathways": {"pathway_name": 0.8}, "confidence": 0.8, "claim_type": "causal", "claim": "optional free text"}]}, "justification": "...", "confidence": 0.8}',
|
| 1609 |
+
])
|
| 1610 |
+
return "\n".join(lines)
|
| 1611 |
+
|
| 1612 |
+
|
| 1613 |
+
def build_agent_observation_context(
|
| 1614 |
+
obs: ExperimentObservation,
|
| 1615 |
+
*,
|
| 1616 |
+
max_tools: int = 6,
|
| 1617 |
+
max_assays: int = 3,
|
| 1618 |
+
) -> str:
|
| 1619 |
+
"""Summarize modality-specific tool and assay context for the agent."""
|
| 1620 |
+
sections: List[str] = []
|
| 1621 |
+
|
| 1622 |
+
modality_spec = MODALITY_REGISTRY.get(obs.task.modality)
|
| 1623 |
+
if modality_spec is not None:
|
| 1624 |
+
sections.append(
|
| 1625 |
+
"Modality context: "
|
| 1626 |
+
f"{modality_spec.name} measures {modality_spec.measurement} at "
|
| 1627 |
+
f"{modality_spec.resolution} resolution; typical scale "
|
| 1628 |
+
f"{modality_spec.typical_cells}."
|
| 1629 |
+
)
|
| 1630 |
+
else:
|
| 1631 |
+
sections.append(f"Modality context: {obs.task.modality}.")
|
| 1632 |
+
|
| 1633 |
+
tool_names = list(dict.fromkeys(obs.available_tools or obs.task.available_tools))
|
| 1634 |
+
if tool_names:
|
| 1635 |
+
sections.append("Available tools (already filtered to this modality):")
|
| 1636 |
+
for tool_name in tool_names[:max_tools]:
|
| 1637 |
+
sections.append(f" - {describe_tool_for_agent(tool_name)}")
|
| 1638 |
+
if len(tool_names) > max_tools:
|
| 1639 |
+
remainder = ", ".join(tool_names[max_tools:max_tools + 6])
|
| 1640 |
+
sections.append(
|
| 1641 |
+
" - Additional compatible tools not shown in full: "
|
| 1642 |
+
f"{remainder}"
|
| 1643 |
+
)
|
| 1644 |
+
|
| 1645 |
+
assay_names = list(dict.fromkeys(obs.available_assays or obs.task.available_assays))
|
| 1646 |
+
if assay_names:
|
| 1647 |
+
sections.append("Available assays:")
|
| 1648 |
+
for assay_name in assay_names[:max_assays]:
|
| 1649 |
+
sections.append(f" - {describe_assay_for_agent(assay_name)}")
|
| 1650 |
+
if len(assay_names) > max_assays:
|
| 1651 |
+
remainder = ", ".join(assay_names[max_assays:max_assays + 4])
|
| 1652 |
+
sections.append(
|
| 1653 |
+
" - Additional compatible assays not shown in full: "
|
| 1654 |
+
f"{remainder}"
|
| 1655 |
+
)
|
| 1656 |
+
|
| 1657 |
+
return "\n".join(sections)
|
my_env/README.md
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: My Env Environment Server
|
| 3 |
+
emoji: 🎤
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# My Env Environment
|
| 15 |
+
|
| 16 |
+
A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
The simplest way to use the My Env environment is through the `MyEnv` class:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from my_env import MyAction, MyEnv
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Create environment from Docker image
|
| 27 |
+
my_envenv = MyEnv.from_docker_image("my_env-env:latest")
|
| 28 |
+
|
| 29 |
+
# Reset
|
| 30 |
+
result = my_envenv.reset()
|
| 31 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 32 |
+
|
| 33 |
+
# Send multiple messages
|
| 34 |
+
messages = ["Hello, World!", "Testing echo", "Final message"]
|
| 35 |
+
|
| 36 |
+
for msg in messages:
|
| 37 |
+
result = my_envenv.step(MyAction(message=msg))
|
| 38 |
+
print(f"Sent: '{msg}'")
|
| 39 |
+
print(f" → Echoed: '{result.observation.echoed_message}'")
|
| 40 |
+
print(f" → Length: {result.observation.message_length}")
|
| 41 |
+
print(f" → Reward: {result.reward}")
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
# Always clean up
|
| 45 |
+
my_envenv.close()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
That's it! The `MyEnv.from_docker_image()` method handles:
|
| 49 |
+
- Starting the Docker container
|
| 50 |
+
- Waiting for the server to be ready
|
| 51 |
+
- Connecting to the environment
|
| 52 |
+
- Container cleanup when you call `close()`
|
| 53 |
+
|
| 54 |
+
## Building the Docker Image
|
| 55 |
+
|
| 56 |
+
Before using the environment, you need to build the Docker image:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# From project root
|
| 60 |
+
docker build -t my_env-env:latest -f server/Dockerfile .
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Deploying to Hugging Face Spaces
|
| 64 |
+
|
| 65 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# From the environment directory (where openenv.yaml is located)
|
| 69 |
+
openenv push
|
| 70 |
+
|
| 71 |
+
# Or specify options
|
| 72 |
+
openenv push --namespace my-org --private
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `openenv push` command will:
|
| 76 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 77 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 78 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 79 |
+
|
| 80 |
+
### Prerequisites
|
| 81 |
+
|
| 82 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 83 |
+
|
| 84 |
+
### Options
|
| 85 |
+
|
| 86 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 87 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 88 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 89 |
+
- `--private`: Deploy the space as private (default: public)
|
| 90 |
+
|
| 91 |
+
### Examples
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 95 |
+
openenv push
|
| 96 |
+
|
| 97 |
+
# Push to a specific repository
|
| 98 |
+
openenv push --repo-id my-org/my-env
|
| 99 |
+
|
| 100 |
+
# Push with a custom base image
|
| 101 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 102 |
+
|
| 103 |
+
# Push as a private space
|
| 104 |
+
openenv push --private
|
| 105 |
+
|
| 106 |
+
# Combine options
|
| 107 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
After deployment, your space will be available at:
|
| 111 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 112 |
+
|
| 113 |
+
The deployed space includes:
|
| 114 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 115 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 116 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 117 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 118 |
+
|
| 119 |
+
## Environment Details
|
| 120 |
+
|
| 121 |
+
### Action
|
| 122 |
+
**MyAction**: Contains a single field
|
| 123 |
+
- `message` (str) - The message to echo back
|
| 124 |
+
|
| 125 |
+
### Observation
|
| 126 |
+
**MyObservation**: Contains the echo response and metadata
|
| 127 |
+
- `echoed_message` (str) - The message echoed back
|
| 128 |
+
- `message_length` (int) - Length of the message
|
| 129 |
+
- `reward` (float) - Reward based on message length (length × 0.1)
|
| 130 |
+
- `done` (bool) - Always False for echo environment
|
| 131 |
+
- `metadata` (dict) - Additional info like step count
|
| 132 |
+
|
| 133 |
+
### Reward
|
| 134 |
+
The reward is calculated as: `message_length × 0.1`
|
| 135 |
+
- "Hi" → reward: 0.2
|
| 136 |
+
- "Hello, World!" → reward: 1.3
|
| 137 |
+
- Empty message → reward: 0.0
|
| 138 |
+
|
| 139 |
+
## Advanced Usage
|
| 140 |
+
|
| 141 |
+
### Connecting to an Existing Server
|
| 142 |
+
|
| 143 |
+
If you already have a My Env environment server running, you can connect directly:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from my_env import MyEnv
|
| 147 |
+
|
| 148 |
+
# Connect to existing server
|
| 149 |
+
my_envenv = MyEnv(base_url="<ENV_HTTP_URL_HERE>")
|
| 150 |
+
|
| 151 |
+
# Use as normal
|
| 152 |
+
result = my_envenv.reset()
|
| 153 |
+
result = my_envenv.step(MyAction(message="Hello!"))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note: When connecting to an existing server, `my_envenv.close()` will NOT stop the server.
|
| 157 |
+
|
| 158 |
+
### Using the Context Manager
|
| 159 |
+
|
| 160 |
+
The client supports context manager usage for automatic connection management:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from my_env import MyAction, MyEnv
|
| 164 |
+
|
| 165 |
+
# Connect with context manager (auto-connects and closes)
|
| 166 |
+
with MyEnv(base_url="http://localhost:8000") as env:
|
| 167 |
+
result = env.reset()
|
| 168 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 169 |
+
# Multiple steps with low latency
|
| 170 |
+
for msg in ["Hello", "World", "!"]:
|
| 171 |
+
result = env.step(MyAction(message=msg))
|
| 172 |
+
print(f"Echoed: {result.observation.echoed_message}")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
The client uses WebSocket connections for:
|
| 176 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 177 |
+
- **Persistent session**: Server maintains your environment state
|
| 178 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 179 |
+
|
| 180 |
+
### Concurrent WebSocket Sessions
|
| 181 |
+
|
| 182 |
+
The server supports multiple concurrent WebSocket connections. To enable this,
|
| 183 |
+
modify `server/app.py` to use factory mode:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
# In server/app.py - use factory mode for concurrent sessions
|
| 187 |
+
app = create_app(
|
| 188 |
+
MyEnvironment, # Pass class, not instance
|
| 189 |
+
MyAction,
|
| 190 |
+
MyObservation,
|
| 191 |
+
max_concurrent_envs=4, # Allow 4 concurrent sessions
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Then multiple clients can connect simultaneously:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from my_env import MyAction, MyEnv
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 200 |
+
|
| 201 |
+
def run_episode(client_id: int):
|
| 202 |
+
with MyEnv(base_url="http://localhost:8000") as env:
|
| 203 |
+
result = env.reset()
|
| 204 |
+
for i in range(10):
|
| 205 |
+
result = env.step(MyAction(message=f"Client {client_id}, step {i}"))
|
| 206 |
+
return client_id, result.observation.message_length
|
| 207 |
+
|
| 208 |
+
# Run 4 episodes concurrently
|
| 209 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 210 |
+
results = list(executor.map(run_episode, range(4)))
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Development & Testing
|
| 214 |
+
|
| 215 |
+
### Direct Environment Testing
|
| 216 |
+
|
| 217 |
+
Test the environment logic directly without starting the HTTP server:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
# From the server directory
|
| 221 |
+
python3 server/my_env_environment.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
This verifies that:
|
| 225 |
+
- Environment resets correctly
|
| 226 |
+
- Step executes actions properly
|
| 227 |
+
- State tracking works
|
| 228 |
+
- Rewards are calculated correctly
|
| 229 |
+
|
| 230 |
+
### Running Locally
|
| 231 |
+
|
| 232 |
+
Run the server locally for development:
|
| 233 |
+
|
| 234 |
+
```bash
|
| 235 |
+
uvicorn server.app:app --reload
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Project Structure
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
my_env/
|
| 242 |
+
├── .dockerignore # Docker build exclusions
|
| 243 |
+
├── __init__.py # Module exports
|
| 244 |
+
├── README.md # This file
|
| 245 |
+
├── openenv.yaml # OpenEnv manifest
|
| 246 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 247 |
+
├── uv.lock # Locked dependencies (generated)
|
| 248 |
+
├── client.py # MyEnv client
|
| 249 |
+
├── models.py # Action and Observation models
|
| 250 |
+
└── server/
|
| 251 |
+
├── __init__.py # Server module exports
|
| 252 |
+
├── my_env_environment.py # Core environment logic
|
| 253 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 254 |
+
└── Dockerfile # Container image definition
|
| 255 |
+
```
|
my_env/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""My Env Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import MyEnv
|
| 10 |
+
from .models import MyAction, MyObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"MyAction",
|
| 14 |
+
"MyObservation",
|
| 15 |
+
"MyEnv",
|
| 16 |
+
]
|
my_env/client.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""My Env Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core.client_types import StepResult
|
| 12 |
+
from openenv.core.env_server.types import State
|
| 13 |
+
from openenv.core import EnvClient
|
| 14 |
+
|
| 15 |
+
from .models import MyAction, MyObservation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MyEnv(
|
| 19 |
+
EnvClient[MyAction, MyObservation]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the My Env Environment.
|
| 23 |
+
|
| 24 |
+
This client maintains a persistent WebSocket connection to the environment server,
|
| 25 |
+
enabling efficient multi-step interactions with lower latency.
|
| 26 |
+
Each client instance has its own dedicated environment session on the server.
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> # Connect to a running server
|
| 30 |
+
>>> with MyEnv(base_url="http://localhost:8000") as client:
|
| 31 |
+
... result = client.reset()
|
| 32 |
+
... print(result.observation.echoed_message)
|
| 33 |
+
...
|
| 34 |
+
... result = client.step(MyAction(message="Hello!"))
|
| 35 |
+
... print(result.observation.echoed_message)
|
| 36 |
+
|
| 37 |
+
Example with Docker:
|
| 38 |
+
>>> # Automatically start container and connect
|
| 39 |
+
>>> client = MyEnv.from_docker_image("my_env-env:latest")
|
| 40 |
+
>>> try:
|
| 41 |
+
... result = client.reset()
|
| 42 |
+
... result = client.step(MyAction(message="Test"))
|
| 43 |
+
... finally:
|
| 44 |
+
... client.close()
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def _step_payload(self, action: MyAction) -> Dict:
|
| 48 |
+
"""
|
| 49 |
+
Convert MyAction to JSON payload for step message.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
action: MyAction instance
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Dictionary representation suitable for JSON encoding
|
| 56 |
+
"""
|
| 57 |
+
return {
|
| 58 |
+
"message": action.message,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def _parse_result(self, payload: Dict) -> StepResult[MyObservation]:
|
| 62 |
+
"""
|
| 63 |
+
Parse server response into StepResult[MyObservation].
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
payload: JSON response data from server
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
StepResult with MyObservation
|
| 70 |
+
"""
|
| 71 |
+
obs_data = payload.get("observation", {})
|
| 72 |
+
observation = MyObservation(
|
| 73 |
+
echoed_message=obs_data.get("echoed_message", ""),
|
| 74 |
+
message_length=obs_data.get("message_length", 0),
|
| 75 |
+
done=payload.get("done", False),
|
| 76 |
+
reward=payload.get("reward"),
|
| 77 |
+
metadata=obs_data.get("metadata", {}),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return StepResult(
|
| 81 |
+
observation=observation,
|
| 82 |
+
reward=payload.get("reward"),
|
| 83 |
+
done=payload.get("done", False),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 87 |
+
"""
|
| 88 |
+
Parse server response into State object.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
payload: JSON response from state request
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
State object with episode_id and step_count
|
| 95 |
+
"""
|
| 96 |
+
return State(
|
| 97 |
+
episode_id=payload.get("episode_id"),
|
| 98 |
+
step_count=payload.get("step_count", 0),
|
| 99 |
+
)
|
my_env/models.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the My Env Environment.
|
| 9 |
+
|
| 10 |
+
The my_env environment is a simple test environment that echoes back messages.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from pydantic import Field
|
| 14 |
+
|
| 15 |
+
from openenv.core.env_server.types import Action, Observation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MyAction(Action):
|
| 19 |
+
"""Action for the My Env environment - just a message to echo."""
|
| 20 |
+
|
| 21 |
+
message: str = Field(..., description="Message to echo back")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MyObservation(Observation):
|
| 25 |
+
"""Observation from the My Env environment - the echoed message."""
|
| 26 |
+
|
| 27 |
+
echoed_message: str = Field(default="", description="The echoed message")
|
| 28 |
+
message_length: int = Field(default=0, description="Length of the echoed message")
|
my_env/openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: my_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
my_env/pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-my_env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "My Env environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.0",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m my_env.server.app
|
| 40 |
+
server = "my_env.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["my_env", "my_env.server"]
|
| 45 |
+
package-dir = { "my_env" = ".", "my_env.server" = "server" }
|
my_env/server/Dockerfile
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=my_env
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
my_env/server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""My Env environment server components."""
|
| 8 |
+
|
| 9 |
+
from .my_env_environment import MyEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["MyEnvironment"]
|
my_env/server/app.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the My Env Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes the MyEnvironment
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Endpoints:
|
| 14 |
+
- POST /reset: Reset the environment
|
| 15 |
+
- POST /step: Execute an action
|
| 16 |
+
- GET /state: Get current environment state
|
| 17 |
+
- GET /schema: Get action/observation schemas
|
| 18 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Development (with auto-reload):
|
| 22 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
+
|
| 24 |
+
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 26 |
+
|
| 27 |
+
# Or run directly:
|
| 28 |
+
python -m server.app
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from openenv.core.env_server.http_server import create_app
|
| 33 |
+
except Exception as e: # pragma: no cover
|
| 34 |
+
raise ImportError(
|
| 35 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 36 |
+
) from e
|
| 37 |
+
|
| 38 |
+
# Import from local models.py (PYTHONPATH includes /app/env in Docker)
|
| 39 |
+
from models import MyAction, MyObservation
|
| 40 |
+
from .my_env_environment import MyEnvironment
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Create the app with web interface and README integration
|
| 44 |
+
app = create_app(
|
| 45 |
+
MyEnvironment,
|
| 46 |
+
MyAction,
|
| 47 |
+
MyObservation,
|
| 48 |
+
env_name="my_env",
|
| 49 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 54 |
+
"""
|
| 55 |
+
Entry point for direct execution via uv run or python -m.
|
| 56 |
+
|
| 57 |
+
This function enables running the server without Docker:
|
| 58 |
+
uv run --project . server
|
| 59 |
+
uv run --project . server --port 8001
|
| 60 |
+
python -m my_env.server.app
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 64 |
+
port: Port number to listen on (default: 8000)
|
| 65 |
+
|
| 66 |
+
For production deployments, consider using uvicorn directly with
|
| 67 |
+
multiple workers:
|
| 68 |
+
uvicorn my_env.server.app:app --workers 4
|
| 69 |
+
"""
|
| 70 |
+
import uvicorn
|
| 71 |
+
|
| 72 |
+
uvicorn.run(app, host=host, port=port)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
import argparse
|
| 77 |
+
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
main(port=args.port)
|
my_env/server/my_env_environment.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
My Env Environment Implementation.
|
| 9 |
+
|
| 10 |
+
A simple test environment that echoes back messages sent to it.
|
| 11 |
+
Perfect for testing HTTP server infrastructure.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from uuid import uuid4
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server.interfaces import Environment
|
| 17 |
+
from openenv.core.env_server.types import State
|
| 18 |
+
|
| 19 |
+
from models import MyAction, MyObservation
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MyEnvironment(Environment):
|
| 23 |
+
"""
|
| 24 |
+
A simple echo environment that echoes back messages.
|
| 25 |
+
|
| 26 |
+
This environment is designed for testing the HTTP server infrastructure.
|
| 27 |
+
It maintains minimal state and simply echoes back whatever message it receives.
|
| 28 |
+
|
| 29 |
+
Example:
|
| 30 |
+
>>> env = MyEnvironment()
|
| 31 |
+
>>> obs = env.reset()
|
| 32 |
+
>>> print(obs.echoed_message) # "My Env environment ready!"
|
| 33 |
+
>>>
|
| 34 |
+
>>> obs = env.step(MyAction(message="Hello"))
|
| 35 |
+
>>> print(obs.echoed_message) # "Hello"
|
| 36 |
+
>>> print(obs.message_length) # 5
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# Enable concurrent WebSocket sessions.
|
| 40 |
+
# Set to True if your environment isolates state between instances.
|
| 41 |
+
# When True, multiple WebSocket clients can connect simultaneously, each
|
| 42 |
+
# getting their own environment instance (when using factory mode in app.py).
|
| 43 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
"""Initialize the my_env environment."""
|
| 47 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 48 |
+
self._reset_count = 0
|
| 49 |
+
|
| 50 |
+
def reset(self) -> MyObservation:
|
| 51 |
+
"""
|
| 52 |
+
Reset the environment.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
MyObservation with a ready message
|
| 56 |
+
"""
|
| 57 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 58 |
+
self._reset_count += 1
|
| 59 |
+
|
| 60 |
+
return MyObservation(
|
| 61 |
+
echoed_message="My Env environment ready!",
|
| 62 |
+
message_length=0,
|
| 63 |
+
done=False,
|
| 64 |
+
reward=0.0,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def step(self, action: MyAction) -> MyObservation: # type: ignore[override]
|
| 68 |
+
"""
|
| 69 |
+
Execute a step in the environment by echoing the message.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
action: MyAction containing the message to echo
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
MyObservation with the echoed message and its length
|
| 76 |
+
"""
|
| 77 |
+
self._state.step_count += 1
|
| 78 |
+
|
| 79 |
+
message = action.message
|
| 80 |
+
length = len(message)
|
| 81 |
+
|
| 82 |
+
# Simple reward: longer messages get higher rewards
|
| 83 |
+
reward = length * 0.1
|
| 84 |
+
|
| 85 |
+
return MyObservation(
|
| 86 |
+
echoed_message=message,
|
| 87 |
+
message_length=length,
|
| 88 |
+
done=False,
|
| 89 |
+
reward=reward,
|
| 90 |
+
metadata={"original_message": message, "step": self._state.step_count},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def state(self) -> State:
|
| 95 |
+
"""
|
| 96 |
+
Get the current environment state.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Current State with episode_id and step_count
|
| 100 |
+
"""
|
| 101 |
+
return self._state
|
my_env/server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
pyproject.toml
CHANGED
|
@@ -21,9 +21,6 @@ dependencies = [
|
|
| 21 |
]
|
| 22 |
|
| 23 |
[project.optional-dependencies]
|
| 24 |
-
train = [
|
| 25 |
-
"gymnasium>=0.29.0",
|
| 26 |
-
]
|
| 27 |
bio = [
|
| 28 |
"biopython>=1.84",
|
| 29 |
"gseapy>=1.1.3",
|
|
@@ -32,7 +29,16 @@ bio = [
|
|
| 32 |
dev = [
|
| 33 |
"pytest>=8.0.0",
|
| 34 |
"pytest-cov>=4.0.0",
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
]
|
| 37 |
|
| 38 |
[project.scripts]
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
[project.optional-dependencies]
|
|
|
|
|
|
|
|
|
|
| 24 |
bio = [
|
| 25 |
"biopython>=1.84",
|
| 26 |
"gseapy>=1.1.3",
|
|
|
|
| 29 |
dev = [
|
| 30 |
"pytest>=8.0.0",
|
| 31 |
"pytest-cov>=4.0.0",
|
| 32 |
+
]
|
| 33 |
+
train = [
|
| 34 |
+
"accelerate>=1.13.0",
|
| 35 |
+
"bitsandbytes>=0.45.0",
|
| 36 |
+
"datasets>=4.6.1",
|
| 37 |
+
"matplotlib>=3.10.8",
|
| 38 |
+
"peft>=0.15.0",
|
| 39 |
+
"torch>=2.10.0",
|
| 40 |
+
"transformers>=5.3.0",
|
| 41 |
+
"trl>=0.29.0",
|
| 42 |
]
|
| 43 |
|
| 44 |
[project.scripts]
|
run_agent.py
CHANGED
|
@@ -1,292 +1,978 @@
|
|
| 1 |
-
"""Run the bio-experiment environment with Qwen3.5-
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import json
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import time
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
from
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
{
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
if
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
if
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the bio-experiment environment with Qwen3.5-0.8B as the planning agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 14 |
+
|
| 15 |
+
from models import (
|
| 16 |
+
ActionType,
|
| 17 |
+
ExperimentAction,
|
| 18 |
+
ExperimentObservation,
|
| 19 |
+
build_agent_observation_context,
|
| 20 |
+
build_agent_system_prompt,
|
| 21 |
+
)
|
| 22 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 23 |
+
|
| 24 |
+
DASHBOARD_STATE_PATH = Path(__file__).parent / "_dashboard_state.json"
|
| 25 |
+
DASHBOARD_CMD_PATH = Path(__file__).parent / "_dashboard_cmd.json"
|
| 26 |
+
|
| 27 |
+
USE_PIPELINE = os.getenv("RUN_AGENT_USE_PIPELINE", "0").strip().lower() not in {"0", "false", "off"}
|
| 28 |
+
|
| 29 |
+
def _parse_thinking_flag() -> bool:
|
| 30 |
+
import sys
|
| 31 |
+
if "--no-thinking" in sys.argv:
|
| 32 |
+
return False
|
| 33 |
+
if "--thinking" in sys.argv:
|
| 34 |
+
return True
|
| 35 |
+
return os.getenv("RUN_AGENT_ENABLE_THINKING", "1").strip().lower() not in {"0", "false", "off"}
|
| 36 |
+
|
| 37 |
+
ENABLE_THINKING = _parse_thinking_flag()
|
| 38 |
+
|
| 39 |
+
MODEL_ID = "Qwen/Qwen3.5-2B"
|
| 40 |
+
MAX_EPISODE_STEPS = int(os.getenv("RUN_AGENT_MAX_EPISODE_STEPS", "20"))
|
| 41 |
+
PIPELINE_TASK = "text-generation"
|
| 42 |
+
|
| 43 |
+
ACTION_TYPES = [a.value for a in ActionType]
|
| 44 |
+
ACTION_TYPE_ALIASES = {
|
| 45 |
+
"collect_samples": ActionType.COLLECT_SAMPLE.value,
|
| 46 |
+
"collect_sample_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
|
| 47 |
+
"collect_samples_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
|
| 48 |
+
"prepare_sc_library": ActionType.PREPARE_LIBRARY.value,
|
| 49 |
+
"sequence_single_cells": ActionType.SEQUENCE_CELLS.value,
|
| 50 |
+
"qc": ActionType.RUN_QC.value,
|
| 51 |
+
"run_quality_control": ActionType.RUN_QC.value,
|
| 52 |
+
"cluster": ActionType.CLUSTER_CELLS.value,
|
| 53 |
+
"de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
|
| 54 |
+
"differential_expression_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
|
| 55 |
+
"trajectory_inference": ActionType.TRAJECTORY_ANALYSIS.value,
|
| 56 |
+
"infer_trajectory": ActionType.TRAJECTORY_ANALYSIS.value,
|
| 57 |
+
"network_inference": ActionType.REGULATORY_NETWORK_INFERENCE.value,
|
| 58 |
+
"select_markers": ActionType.MARKER_SELECTION.value,
|
| 59 |
+
"final_conclusion": ActionType.SYNTHESIZE_CONCLUSION.value,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
SYSTEM_PROMPT = build_agent_system_prompt()
|
| 63 |
+
|
| 64 |
+
STANDARD_PIPELINE_ORDER = [
|
| 65 |
+
ActionType.COLLECT_SAMPLE,
|
| 66 |
+
ActionType.SELECT_COHORT,
|
| 67 |
+
ActionType.PREPARE_LIBRARY,
|
| 68 |
+
ActionType.SEQUENCE_CELLS,
|
| 69 |
+
ActionType.RUN_QC,
|
| 70 |
+
ActionType.FILTER_DATA,
|
| 71 |
+
ActionType.NORMALIZE_DATA,
|
| 72 |
+
ActionType.INTEGRATE_BATCHES,
|
| 73 |
+
ActionType.CLUSTER_CELLS,
|
| 74 |
+
ActionType.DIFFERENTIAL_EXPRESSION,
|
| 75 |
+
ActionType.PATHWAY_ENRICHMENT,
|
| 76 |
+
ActionType.MARKER_SELECTION,
|
| 77 |
+
ActionType.TRAJECTORY_ANALYSIS,
|
| 78 |
+
ActionType.REGULATORY_NETWORK_INFERENCE,
|
| 79 |
+
ActionType.SYNTHESIZE_CONCLUSION,
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
MODEL_RESPONSE_PREVIEW_CHARS = int(
|
| 83 |
+
os.getenv("RUN_AGENT_MODEL_RESPONSE_PREVIEW_CHARS", "240")
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compact_preview(value: Any, max_chars: int = 160) -> str:
|
| 88 |
+
try:
|
| 89 |
+
text = json.dumps(value, ensure_ascii=True, sort_keys=True)
|
| 90 |
+
except TypeError:
|
| 91 |
+
text = str(value)
|
| 92 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 93 |
+
if len(text) <= max_chars:
|
| 94 |
+
return text
|
| 95 |
+
return text[: max_chars - 3] + "..."
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def format_observation(obs: ExperimentObservation) -> str:
|
| 99 |
+
parts = [
|
| 100 |
+
f"TASK: {obs.task.problem_statement}",
|
| 101 |
+
f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
|
| 102 |
+
f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
|
| 103 |
+
f"Step: {obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d",
|
| 104 |
+
]
|
| 105 |
+
context = build_agent_observation_context(obs, max_tools=5, max_assays=2)
|
| 106 |
+
if context:
|
| 107 |
+
parts.append(context)
|
| 108 |
+
if obs.pipeline_history:
|
| 109 |
+
last5 = obs.pipeline_history[-5:]
|
| 110 |
+
parts.append("Recent history:")
|
| 111 |
+
for h in last5:
|
| 112 |
+
tag = "OK" if h.success else "FAIL"
|
| 113 |
+
line = f" [{tag}] {h.action_type.value}"
|
| 114 |
+
if h.method:
|
| 115 |
+
line += f" ({h.method})"
|
| 116 |
+
line += f": {h.output_summary[:80]}"
|
| 117 |
+
parts.append(line)
|
| 118 |
+
|
| 119 |
+
completed = {h.action_type for h in obs.pipeline_history if h.success}
|
| 120 |
+
if completed:
|
| 121 |
+
parts.append(f"Completed steps (do NOT repeat): {', '.join(sorted(a.value for a in completed))}")
|
| 122 |
+
remaining = [a.value for a in STANDARD_PIPELINE_ORDER if a not in completed]
|
| 123 |
+
if remaining:
|
| 124 |
+
parts.append(f"Remaining steps (choose one): {', '.join(remaining)}")
|
| 125 |
+
|
| 126 |
+
if obs.latest_output and obs.latest_output.data:
|
| 127 |
+
parts.append(
|
| 128 |
+
f"Latest data: {compact_preview(obs.latest_output.data, 200)}"
|
| 129 |
+
)
|
| 130 |
+
if obs.rule_violations:
|
| 131 |
+
parts.append(f"VIOLATIONS: {obs.rule_violations}")
|
| 132 |
+
if obs.discovered_markers:
|
| 133 |
+
parts.append(f"Markers found so far: {obs.discovered_markers[:5]}")
|
| 134 |
+
|
| 135 |
+
parts.append(
|
| 136 |
+
'Output ONLY a single JSON object with these exact keys, no comments, no extra text:\n'
|
| 137 |
+
'{"action_type": "<one of the remaining steps>", "method": null, "parameters": {}, "justification": "<why>", "confidence": 0.8}'
|
| 138 |
+
)
|
| 139 |
+
return "\n".join(parts)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _repair_truncated_json(text: str) -> Optional[str]:
|
| 143 |
+
"""Try to repair JSON truncated mid-value (common with small LLMs)."""
|
| 144 |
+
s = text.strip()
|
| 145 |
+
if not s.startswith("{"):
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
# Drop dangling partial keys or empty key/value stubs at the tail.
|
| 149 |
+
s = re.sub(r',\s*"[^"\n]*$', '', s)
|
| 150 |
+
s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
|
| 151 |
+
|
| 152 |
+
in_string = False
|
| 153 |
+
escape = False
|
| 154 |
+
for ch in s:
|
| 155 |
+
if escape:
|
| 156 |
+
escape = False
|
| 157 |
+
continue
|
| 158 |
+
if ch == "\\":
|
| 159 |
+
escape = True
|
| 160 |
+
continue
|
| 161 |
+
if ch == '"':
|
| 162 |
+
in_string = not in_string
|
| 163 |
+
|
| 164 |
+
if in_string:
|
| 165 |
+
s += '"'
|
| 166 |
+
|
| 167 |
+
open_braces = s.count("{") - s.count("}")
|
| 168 |
+
open_brackets = s.count("[") - s.count("]")
|
| 169 |
+
s += "]" * max(0, open_brackets)
|
| 170 |
+
s += "}" * max(0, open_braces)
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
obj = json.loads(s)
|
| 174 |
+
if isinstance(obj, dict):
|
| 175 |
+
return s
|
| 176 |
+
except json.JSONDecodeError:
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
s = re.sub(r',\s*([}\]])', r'\1', s)
|
| 180 |
+
try:
|
| 181 |
+
obj = json.loads(s)
|
| 182 |
+
if isinstance(obj, dict):
|
| 183 |
+
return s
|
| 184 |
+
except json.JSONDecodeError:
|
| 185 |
+
pass
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _normalize_jsonish_text(text: str) -> str:
|
| 190 |
+
"""Normalize common near-JSON artifacts emitted by small local models."""
|
| 191 |
+
text = _strip_js_comments(text)
|
| 192 |
+
text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
|
| 193 |
+
text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
|
| 194 |
+
text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
|
| 195 |
+
text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
|
| 196 |
+
return text
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _strip_js_comments(text: str) -> str:
|
| 200 |
+
"""Remove // and /* */ comments that small LLMs inject into JSON."""
|
| 201 |
+
text = re.sub(r'//[^\n]*', '', text)
|
| 202 |
+
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
| 203 |
+
return text
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
|
| 207 |
+
stripped = _normalize_jsonish_text(text).strip()
|
| 208 |
+
fence_prefix = "```"
|
| 209 |
+
if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
|
| 210 |
+
lines = stripped.splitlines()
|
| 211 |
+
if len(lines) >= 3:
|
| 212 |
+
stripped = "\n".join(lines[1:-1]).strip()
|
| 213 |
+
|
| 214 |
+
candidates: List[str] = [stripped]
|
| 215 |
+
start = stripped.find("{")
|
| 216 |
+
while start != -1:
|
| 217 |
+
depth = 0
|
| 218 |
+
for idx in range(start, len(stripped)):
|
| 219 |
+
char = stripped[idx]
|
| 220 |
+
if char == "{":
|
| 221 |
+
depth += 1
|
| 222 |
+
elif char == "}":
|
| 223 |
+
depth -= 1
|
| 224 |
+
if depth == 0:
|
| 225 |
+
candidates.append(stripped[start:idx + 1])
|
| 226 |
+
break
|
| 227 |
+
start = stripped.find("{", start + 1)
|
| 228 |
+
|
| 229 |
+
first_brace = stripped.find("{")
|
| 230 |
+
if first_brace != -1:
|
| 231 |
+
repaired = _repair_truncated_json(stripped[first_brace:])
|
| 232 |
+
if repaired is not None:
|
| 233 |
+
candidates.append(repaired)
|
| 234 |
+
|
| 235 |
+
candidates.sort(key=len, reverse=True)
|
| 236 |
+
|
| 237 |
+
for candidate in candidates:
|
| 238 |
+
try:
|
| 239 |
+
parsed = json.loads(candidate)
|
| 240 |
+
except json.JSONDecodeError:
|
| 241 |
+
continue
|
| 242 |
+
if isinstance(parsed, dict):
|
| 243 |
+
return parsed
|
| 244 |
+
|
| 245 |
+
return None
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _edit_distance(a: str, b: str) -> int:
|
| 249 |
+
if len(a) < len(b):
|
| 250 |
+
return _edit_distance(b, a)
|
| 251 |
+
if not b:
|
| 252 |
+
return len(a)
|
| 253 |
+
prev = list(range(len(b) + 1))
|
| 254 |
+
for i, ca in enumerate(a):
|
| 255 |
+
curr = [i + 1]
|
| 256 |
+
for j, cb in enumerate(b):
|
| 257 |
+
curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
|
| 258 |
+
prev = curr
|
| 259 |
+
return prev[-1]
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
|
| 263 |
+
for name in names:
|
| 264 |
+
if name in payload:
|
| 265 |
+
return payload[name]
|
| 266 |
+
|
| 267 |
+
lowered = {
|
| 268 |
+
str(key).lower(): value
|
| 269 |
+
for key, value in payload.items()
|
| 270 |
+
}
|
| 271 |
+
for name in names:
|
| 272 |
+
if name.lower() in lowered:
|
| 273 |
+
return lowered[name.lower()]
|
| 274 |
+
|
| 275 |
+
for key, value in lowered.items():
|
| 276 |
+
for name in names:
|
| 277 |
+
threshold = max(2, len(name) // 3)
|
| 278 |
+
if _edit_distance(key, name.lower()) <= threshold:
|
| 279 |
+
return value
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def normalize_optional_string(value: Any) -> Optional[str]:
|
| 284 |
+
if value is None or isinstance(value, bool):
|
| 285 |
+
return None
|
| 286 |
+
if isinstance(value, str):
|
| 287 |
+
value = value.strip()
|
| 288 |
+
return value or None
|
| 289 |
+
if isinstance(value, (int, float)):
|
| 290 |
+
return str(value)
|
| 291 |
+
return compact_preview(value, 80)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def normalize_action_type(raw_action_type: Any) -> Optional[str]:
|
| 295 |
+
if not isinstance(raw_action_type, str):
|
| 296 |
+
return None
|
| 297 |
+
|
| 298 |
+
candidate = raw_action_type.strip().lower()
|
| 299 |
+
if candidate in ACTION_TYPES:
|
| 300 |
+
return candidate
|
| 301 |
+
if candidate in ACTION_TYPE_ALIASES:
|
| 302 |
+
return ACTION_TYPE_ALIASES[candidate]
|
| 303 |
+
|
| 304 |
+
candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
|
| 305 |
+
if candidate in ACTION_TYPES:
|
| 306 |
+
return candidate
|
| 307 |
+
if candidate in ACTION_TYPE_ALIASES:
|
| 308 |
+
return ACTION_TYPE_ALIASES[candidate]
|
| 309 |
+
|
| 310 |
+
heuristics = [
|
| 311 |
+
(("collect", "sample"), ActionType.COLLECT_SAMPLE.value),
|
| 312 |
+
(("library",), ActionType.PREPARE_LIBRARY.value),
|
| 313 |
+
(("sequence",), ActionType.SEQUENCE_CELLS.value),
|
| 314 |
+
(("qc",), ActionType.RUN_QC.value),
|
| 315 |
+
(("quality", "control"), ActionType.RUN_QC.value),
|
| 316 |
+
(("filter",), ActionType.FILTER_DATA.value),
|
| 317 |
+
(("normal",), ActionType.NORMALIZE_DATA.value),
|
| 318 |
+
(("integrat", "batch"), ActionType.INTEGRATE_BATCHES.value),
|
| 319 |
+
(("cluster",), ActionType.CLUSTER_CELLS.value),
|
| 320 |
+
(("differential", "expression"), ActionType.DIFFERENTIAL_EXPRESSION.value),
|
| 321 |
+
(("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
|
| 322 |
+
(("trajectory",), ActionType.TRAJECTORY_ANALYSIS.value),
|
| 323 |
+
(("network",), ActionType.REGULATORY_NETWORK_INFERENCE.value),
|
| 324 |
+
(("marker",), ActionType.MARKER_SELECTION.value),
|
| 325 |
+
(("validat", "marker"), ActionType.VALIDATE_MARKER.value),
|
| 326 |
+
(("followup",), ActionType.DESIGN_FOLLOWUP.value),
|
| 327 |
+
(("review",), ActionType.REQUEST_SUBAGENT_REVIEW.value),
|
| 328 |
+
(("conclusion",), ActionType.SYNTHESIZE_CONCLUSION.value),
|
| 329 |
+
]
|
| 330 |
+
for fragments, normalized in heuristics:
|
| 331 |
+
if all(fragment in candidate for fragment in fragments):
|
| 332 |
+
return normalized
|
| 333 |
+
return None
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def should_block_failed_reattempt(
|
| 337 |
+
history: List[Any], action_type: ActionType
|
| 338 |
+
) -> bool:
|
| 339 |
+
last_failed_idx = None
|
| 340 |
+
last_success_idx = None
|
| 341 |
+
|
| 342 |
+
for idx, record in enumerate(history):
|
| 343 |
+
if record.action_type != action_type:
|
| 344 |
+
continue
|
| 345 |
+
if record.success:
|
| 346 |
+
last_success_idx = idx
|
| 347 |
+
else:
|
| 348 |
+
last_failed_idx = idx
|
| 349 |
+
|
| 350 |
+
if last_failed_idx is None:
|
| 351 |
+
return False
|
| 352 |
+
|
| 353 |
+
# Allow retry after the same action has already succeeded once, or after the
|
| 354 |
+
# pipeline made progress with a different successful step since the failure.
|
| 355 |
+
if last_success_idx is not None and last_success_idx > last_failed_idx:
|
| 356 |
+
return False
|
| 357 |
+
for record in history[last_failed_idx + 1:]:
|
| 358 |
+
if record.success and record.action_type != action_type:
|
| 359 |
+
return False
|
| 360 |
+
return True
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def parse_action(text: str) -> Optional[ExperimentAction]:
|
| 364 |
+
d = extract_json_object(text)
|
| 365 |
+
if d is not None:
|
| 366 |
+
action_type = normalize_action_type(get_payload_value(d, "action_type"))
|
| 367 |
+
if action_type is None:
|
| 368 |
+
return None
|
| 369 |
+
|
| 370 |
+
parameters = get_payload_value(d, "parameters", "params") or {}
|
| 371 |
+
if not isinstance(parameters, dict):
|
| 372 |
+
parameters = {}
|
| 373 |
+
|
| 374 |
+
confidence = get_payload_value(d, "confidence")
|
| 375 |
+
if confidence is None:
|
| 376 |
+
confidence = 0.5
|
| 377 |
+
try:
|
| 378 |
+
confidence = float(confidence)
|
| 379 |
+
except (TypeError, ValueError):
|
| 380 |
+
confidence = 0.5
|
| 381 |
+
|
| 382 |
+
justification = get_payload_value(
|
| 383 |
+
d, "justification", "reasoning", "rationale", "reason"
|
| 384 |
+
)
|
| 385 |
+
if justification is not None and not isinstance(justification, str):
|
| 386 |
+
justification = compact_preview(justification, 200)
|
| 387 |
+
method = normalize_optional_string(get_payload_value(d, "method"))
|
| 388 |
+
|
| 389 |
+
return ExperimentAction(
|
| 390 |
+
action_type=ActionType(action_type),
|
| 391 |
+
method=method,
|
| 392 |
+
parameters=parameters,
|
| 393 |
+
justification=justification,
|
| 394 |
+
confidence=min(1.0, max(0.0, confidence)),
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
action_match = re.search(
|
| 398 |
+
r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
|
| 399 |
+
text,
|
| 400 |
+
re.IGNORECASE,
|
| 401 |
+
)
|
| 402 |
+
if not action_match:
|
| 403 |
+
return None
|
| 404 |
+
|
| 405 |
+
action_type = normalize_action_type(action_match.group(1))
|
| 406 |
+
if action_type is None:
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
method_match = re.search(
|
| 410 |
+
r'["\']method["\']\s*:\s*("((?:[^"\\]|\\.)*)"|null|none|true|false|-?\d+(?:\.\d+)?)',
|
| 411 |
+
text,
|
| 412 |
+
re.IGNORECASE,
|
| 413 |
+
)
|
| 414 |
+
confidence_match = re.search(
|
| 415 |
+
r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)',
|
| 416 |
+
text,
|
| 417 |
+
re.IGNORECASE,
|
| 418 |
+
)
|
| 419 |
+
justification_match = re.search(
|
| 420 |
+
r'["\'](?:justif\w*|reasoning|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)',
|
| 421 |
+
text,
|
| 422 |
+
re.DOTALL | re.IGNORECASE,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
confidence = 0.5
|
| 426 |
+
if confidence_match:
|
| 427 |
+
try:
|
| 428 |
+
confidence = float(confidence_match.group(1))
|
| 429 |
+
except ValueError:
|
| 430 |
+
confidence = 0.5
|
| 431 |
+
|
| 432 |
+
justification = None
|
| 433 |
+
if justification_match:
|
| 434 |
+
try:
|
| 435 |
+
justification = json.loads(f'"{justification_match.group(1)}"')
|
| 436 |
+
except json.JSONDecodeError:
|
| 437 |
+
justification = justification_match.group(1)
|
| 438 |
+
|
| 439 |
+
method = None
|
| 440 |
+
if method_match:
|
| 441 |
+
raw_method = method_match.group(1)
|
| 442 |
+
if raw_method.startswith('"') and raw_method.endswith('"'):
|
| 443 |
+
try:
|
| 444 |
+
method = json.loads(raw_method)
|
| 445 |
+
except json.JSONDecodeError:
|
| 446 |
+
method = raw_method.strip('"')
|
| 447 |
+
elif raw_method.lower() not in {"null", "none", "true", "false"}:
|
| 448 |
+
method = raw_method
|
| 449 |
+
method = normalize_optional_string(method)
|
| 450 |
+
|
| 451 |
+
return ExperimentAction(
|
| 452 |
+
action_type=ActionType(action_type),
|
| 453 |
+
method=method,
|
| 454 |
+
parameters={},
|
| 455 |
+
justification=justification,
|
| 456 |
+
confidence=min(1.0, max(0.0, confidence)),
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def should_force_terminal_conclusion(
|
| 461 |
+
action: ExperimentAction,
|
| 462 |
+
completed_types: set[ActionType],
|
| 463 |
+
) -> bool:
|
| 464 |
+
meta_repeatables = {
|
| 465 |
+
ActionType.DESIGN_FOLLOWUP,
|
| 466 |
+
ActionType.REQUEST_SUBAGENT_REVIEW,
|
| 467 |
+
}
|
| 468 |
+
return (
|
| 469 |
+
action.action_type in meta_repeatables
|
| 470 |
+
and action.action_type in completed_types
|
| 471 |
+
and ActionType.SYNTHESIZE_CONCLUSION not in completed_types
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def write_dashboard_state(
|
| 477 |
+
env: BioExperimentEnvironment,
|
| 478 |
+
obs: ExperimentObservation,
|
| 479 |
+
*,
|
| 480 |
+
step: int,
|
| 481 |
+
cumulative_reward: float,
|
| 482 |
+
model_response: str = "",
|
| 483 |
+
model_thinking: str = "",
|
| 484 |
+
action: Optional[ExperimentAction] = None,
|
| 485 |
+
gen_time: float = 0.0,
|
| 486 |
+
episode_done: bool = False,
|
| 487 |
+
) -> None:
|
| 488 |
+
"""Serialise the full world state (observable + latent) for the dashboard."""
|
| 489 |
+
latent = env._latent
|
| 490 |
+
snapshot: Dict[str, Any] = {
|
| 491 |
+
"timestamp": time.time(),
|
| 492 |
+
"step": step,
|
| 493 |
+
"episode_done": episode_done,
|
| 494 |
+
"cumulative_reward": cumulative_reward,
|
| 495 |
+
"gen_time_s": round(gen_time, 2),
|
| 496 |
+
"model_response_raw": model_response[:600],
|
| 497 |
+
"model_thinking": model_thinking[:800],
|
| 498 |
+
"thinking_enabled": ENABLE_THINKING,
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
snapshot["task"] = {
|
| 502 |
+
"problem_statement": obs.task.problem_statement,
|
| 503 |
+
"organism": obs.task.organism,
|
| 504 |
+
"tissue": obs.task.tissue,
|
| 505 |
+
"modality": obs.task.modality,
|
| 506 |
+
"conditions": obs.task.conditions,
|
| 507 |
+
"budget_limit": obs.task.budget_limit,
|
| 508 |
+
"time_limit_days": obs.task.time_limit_days,
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
snapshot["resources"] = {
|
| 512 |
+
"budget_used": round(obs.resource_usage.budget_used, 2),
|
| 513 |
+
"budget_remaining": round(obs.resource_usage.budget_remaining, 2),
|
| 514 |
+
"time_used_days": round(obs.resource_usage.time_used_days, 1),
|
| 515 |
+
"time_remaining_days": round(obs.resource_usage.time_remaining_days, 1),
|
| 516 |
+
"samples_consumed": obs.resource_usage.samples_consumed,
|
| 517 |
+
"compute_hours_used": round(obs.resource_usage.compute_hours_used, 2),
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
snapshot["pipeline_history"] = [
|
| 521 |
+
{
|
| 522 |
+
"step_index": h.step_index,
|
| 523 |
+
"action_type": h.action_type.value,
|
| 524 |
+
"method": h.method,
|
| 525 |
+
"output_summary": h.output_summary[:120],
|
| 526 |
+
"success": h.success,
|
| 527 |
+
"quality_score": round(h.quality_score, 3),
|
| 528 |
+
"resource_cost": round(h.resource_cost, 2),
|
| 529 |
+
"time_cost_days": round(h.time_cost_days, 1),
|
| 530 |
+
}
|
| 531 |
+
for h in obs.pipeline_history
|
| 532 |
+
]
|
| 533 |
+
|
| 534 |
+
if action:
|
| 535 |
+
snapshot["current_action"] = {
|
| 536 |
+
"action_type": action.action_type.value,
|
| 537 |
+
"method": action.method,
|
| 538 |
+
"parameters": action.parameters,
|
| 539 |
+
"justification": action.justification,
|
| 540 |
+
"confidence": action.confidence,
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
if obs.latest_output:
|
| 544 |
+
lo = obs.latest_output
|
| 545 |
+
snapshot["latest_output"] = {
|
| 546 |
+
"summary": lo.summary,
|
| 547 |
+
"success": lo.success,
|
| 548 |
+
"quality_score": round(lo.quality_score, 3),
|
| 549 |
+
"uncertainty": round(lo.uncertainty, 3),
|
| 550 |
+
"warnings": lo.warnings,
|
| 551 |
+
"data_preview": compact_preview(lo.data, 300) if lo.data else None,
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
snapshot["discovered_markers"] = obs.discovered_markers[:20]
|
| 555 |
+
snapshot["candidate_mechanisms"] = obs.candidate_mechanisms[:20]
|
| 556 |
+
snapshot["rule_violations"] = obs.rule_violations
|
| 557 |
+
snapshot["uncertainty_summary"] = {
|
| 558 |
+
k: round(v, 3) for k, v in obs.uncertainty_summary.items()
|
| 559 |
+
}
|
| 560 |
+
snapshot["reward_breakdown"] = {
|
| 561 |
+
k: round(v, 4) for k, v in obs.step_reward_breakdown.items()
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
if obs.conclusions:
|
| 565 |
+
snapshot["conclusions"] = [
|
| 566 |
+
{
|
| 567 |
+
"claim": c.claim,
|
| 568 |
+
"claim_type": c.claim_type,
|
| 569 |
+
"confidence": c.confidence,
|
| 570 |
+
"top_markers": c.top_markers,
|
| 571 |
+
"causal_mechanisms": c.causal_mechanisms,
|
| 572 |
+
"predicted_pathways": c.predicted_pathways,
|
| 573 |
+
}
|
| 574 |
+
for c in obs.conclusions
|
| 575 |
+
]
|
| 576 |
+
|
| 577 |
+
if latent:
|
| 578 |
+
bio = latent.biology
|
| 579 |
+
snapshot["latent"] = {
|
| 580 |
+
"cell_populations": [
|
| 581 |
+
{
|
| 582 |
+
"name": cp.name,
|
| 583 |
+
"proportion": round(cp.proportion, 3),
|
| 584 |
+
"marker_genes": cp.marker_genes[:8],
|
| 585 |
+
"state": cp.state,
|
| 586 |
+
}
|
| 587 |
+
for cp in bio.cell_populations
|
| 588 |
+
],
|
| 589 |
+
"true_markers": bio.true_markers,
|
| 590 |
+
"causal_mechanisms": bio.causal_mechanisms,
|
| 591 |
+
"true_pathways": {
|
| 592 |
+
k: round(v, 3) for k, v in list(bio.true_pathways.items())[:15]
|
| 593 |
+
},
|
| 594 |
+
"true_de_genes_count": sum(
|
| 595 |
+
len(genes) for genes in bio.true_de_genes.values()
|
| 596 |
+
),
|
| 597 |
+
"true_regulatory_network_size": sum(
|
| 598 |
+
len(targets) for targets in bio.true_regulatory_network.values()
|
| 599 |
+
),
|
| 600 |
+
"confounders": bio.confounders,
|
| 601 |
+
"n_true_cells": bio.n_true_cells,
|
| 602 |
+
"technical": {
|
| 603 |
+
"ambient_rna_fraction": latent.technical.ambient_rna_fraction,
|
| 604 |
+
"doublet_rate": latent.technical.doublet_rate,
|
| 605 |
+
"dropout_rate": latent.technical.dropout_rate,
|
| 606 |
+
"sample_quality": latent.technical.sample_quality,
|
| 607 |
+
"library_complexity": latent.technical.library_complexity,
|
| 608 |
+
"capture_efficiency": latent.technical.capture_efficiency,
|
| 609 |
+
},
|
| 610 |
+
"progress": latent.progress.model_dump(),
|
| 611 |
+
"hidden_failure_conditions": latent.hidden_failure_conditions,
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
try:
|
| 615 |
+
DASHBOARD_STATE_PATH.write_text(
|
| 616 |
+
json.dumps(snapshot, indent=2, default=str), encoding="utf-8"
|
| 617 |
+
)
|
| 618 |
+
except Exception:
|
| 619 |
+
pass
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def log(msg: str) -> None:
|
| 623 |
+
print(msg, flush=True)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def build_observation_prompt(obs: ExperimentObservation) -> str:
|
| 627 |
+
return format_observation(obs)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def run_with_pipeline(pipe, prompt: str) -> str:
|
| 631 |
+
try:
|
| 632 |
+
_pipe_max = 2048 if ENABLE_THINKING else 300
|
| 633 |
+
result = pipe(prompt, max_new_tokens=_pipe_max, return_full_text=False)
|
| 634 |
+
except Exception:
|
| 635 |
+
return ""
|
| 636 |
+
|
| 637 |
+
if isinstance(result, list) and result:
|
| 638 |
+
result = result[0]
|
| 639 |
+
if isinstance(result, dict):
|
| 640 |
+
text = result.get("generated_text") or result.get("text") or result.get("answer")
|
| 641 |
+
elif isinstance(result, str):
|
| 642 |
+
text = result
|
| 643 |
+
else:
|
| 644 |
+
text = ""
|
| 645 |
+
return text.strip() if isinstance(text, str) else ""
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def resolve_torch_runtime() -> Dict[str, Any]:
|
| 649 |
+
use_cuda = torch.cuda.is_available()
|
| 650 |
+
bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
|
| 651 |
+
dtype = torch.bfloat16 if bf16 else (
|
| 652 |
+
torch.float16 if use_cuda else torch.float32
|
| 653 |
+
)
|
| 654 |
+
return {
|
| 655 |
+
"use_cuda": use_cuda,
|
| 656 |
+
"device": "cuda:0" if use_cuda else "cpu",
|
| 657 |
+
"dtype": dtype,
|
| 658 |
+
"device_map": "auto" if use_cuda else None,
|
| 659 |
+
"device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def main():
|
| 664 |
+
tokenizer = None
|
| 665 |
+
model = None
|
| 666 |
+
eos_ids: List[int] = []
|
| 667 |
+
active_pipeline = None
|
| 668 |
+
|
| 669 |
+
runtime = resolve_torch_runtime()
|
| 670 |
+
log(
|
| 671 |
+
f"Using local model runtime: device={runtime['device']} "
|
| 672 |
+
f"name={runtime['device_name']} dtype={runtime['dtype']}"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
if USE_PIPELINE:
|
| 676 |
+
log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...")
|
| 677 |
+
try:
|
| 678 |
+
active_pipeline = pipeline(
|
| 679 |
+
PIPELINE_TASK,
|
| 680 |
+
model=MODEL_ID,
|
| 681 |
+
trust_remote_code=True,
|
| 682 |
+
dtype=runtime["dtype"],
|
| 683 |
+
device=0 if runtime["use_cuda"] else -1,
|
| 684 |
+
)
|
| 685 |
+
log("Pipeline loaded.")
|
| 686 |
+
except Exception as exc:
|
| 687 |
+
log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.")
|
| 688 |
+
|
| 689 |
+
if active_pipeline is None:
|
| 690 |
+
log(f"Loading tokenizer for {MODEL_ID} ...")
|
| 691 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 692 |
+
MODEL_ID, trust_remote_code=True,
|
| 693 |
+
)
|
| 694 |
+
log("Tokenizer loaded. Loading model (this may download files on first run) ...")
|
| 695 |
+
|
| 696 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 697 |
+
MODEL_ID,
|
| 698 |
+
dtype=runtime["dtype"],
|
| 699 |
+
device_map=runtime["device_map"],
|
| 700 |
+
trust_remote_code=True,
|
| 701 |
+
)
|
| 702 |
+
log(f"Model loaded. Device: {model.device}")
|
| 703 |
+
|
| 704 |
+
if tokenizer.eos_token_id is not None:
|
| 705 |
+
eos_ids.append(tokenizer.eos_token_id)
|
| 706 |
+
extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
|
| 707 |
+
for tid in extra:
|
| 708 |
+
if isinstance(tid, int) and tid not in eos_ids:
|
| 709 |
+
eos_ids.append(tid)
|
| 710 |
+
log(f"EOS token ids: {eos_ids}")
|
| 711 |
+
|
| 712 |
+
def check_dashboard_command() -> Optional[Dict[str, Any]]:
|
| 713 |
+
"""Read and consume a command file written by the dashboard."""
|
| 714 |
+
try:
|
| 715 |
+
raw = DASHBOARD_CMD_PATH.read_text(encoding="utf-8")
|
| 716 |
+
DASHBOARD_CMD_PATH.unlink(missing_ok=True)
|
| 717 |
+
return json.loads(raw)
|
| 718 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
| 719 |
+
return None
|
| 720 |
+
|
| 721 |
+
def run_episode(
|
| 722 |
+
scenario_name: Optional[str] = None,
|
| 723 |
+
custom_ground_truth: Optional[Dict[str, Any]] = None,
|
| 724 |
+
):
|
| 725 |
+
env = BioExperimentEnvironment(scenario_name=scenario_name)
|
| 726 |
+
obs = env.reset()
|
| 727 |
+
|
| 728 |
+
if custom_ground_truth and env._latent:
|
| 729 |
+
gt = custom_ground_truth
|
| 730 |
+
bio = env._latent.biology
|
| 731 |
+
if gt.get("true_markers"):
|
| 732 |
+
bio.true_markers = gt["true_markers"]
|
| 733 |
+
if gt.get("causal_mechanisms"):
|
| 734 |
+
bio.causal_mechanisms = gt["causal_mechanisms"]
|
| 735 |
+
if gt.get("true_pathways"):
|
| 736 |
+
bio.true_pathways = {
|
| 737 |
+
k: float(v) for k, v in gt["true_pathways"].items()
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
log("\n" + "=" * 70)
|
| 741 |
+
log(f"TASK: {obs.task.problem_statement}")
|
| 742 |
+
log(f"Conditions: {obs.task.conditions}")
|
| 743 |
+
log(f"Budget: ${obs.task.budget_limit:,.0f} | Time: {obs.task.time_limit_days:.0f} days")
|
| 744 |
+
if ENABLE_THINKING:
|
| 745 |
+
log("Reasoning mode: ENABLED")
|
| 746 |
+
log("=" * 70)
|
| 747 |
+
|
| 748 |
+
cumulative_reward = 0.0
|
| 749 |
+
write_dashboard_state(env, obs, step=0, cumulative_reward=0.0)
|
| 750 |
+
|
| 751 |
+
for step in range(MAX_EPISODE_STEPS):
|
| 752 |
+
cmd = check_dashboard_command()
|
| 753 |
+
if cmd and cmd.get("action") == "restart":
|
| 754 |
+
log("\n[DASHBOARD] Restart requested — ending episode early.")
|
| 755 |
+
break
|
| 756 |
+
|
| 757 |
+
user_msg = build_observation_prompt(obs)
|
| 758 |
+
|
| 759 |
+
messages = [
|
| 760 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 761 |
+
{"role": "user", "content": user_msg},
|
| 762 |
+
]
|
| 763 |
+
|
| 764 |
+
if active_pipeline is not None:
|
| 765 |
+
prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}"
|
| 766 |
+
else:
|
| 767 |
+
try:
|
| 768 |
+
prompt = tokenizer.apply_chat_template(
|
| 769 |
+
messages,
|
| 770 |
+
tokenize=False,
|
| 771 |
+
add_generation_prompt=True,
|
| 772 |
+
enable_thinking=ENABLE_THINKING,
|
| 773 |
+
)
|
| 774 |
+
except TypeError:
|
| 775 |
+
prompt = tokenizer.apply_chat_template(
|
| 776 |
+
messages,
|
| 777 |
+
tokenize=False,
|
| 778 |
+
add_generation_prompt=True,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
t0 = time.time()
|
| 782 |
+
if active_pipeline is not None:
|
| 783 |
+
response = run_with_pipeline(active_pipeline, prompt)
|
| 784 |
+
if not response:
|
| 785 |
+
response = format_observation(obs)
|
| 786 |
+
else:
|
| 787 |
+
assert tokenizer is not None and model is not None
|
| 788 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 789 |
+
n_input = inputs["input_ids"].shape[1]
|
| 790 |
+
max_new = 2048 if ENABLE_THINKING else 300
|
| 791 |
+
with torch.no_grad():
|
| 792 |
+
output_ids = model.generate(
|
| 793 |
+
**inputs,
|
| 794 |
+
max_new_tokens=max_new,
|
| 795 |
+
do_sample=True,
|
| 796 |
+
temperature=0.7,
|
| 797 |
+
top_p=0.8,
|
| 798 |
+
top_k=20,
|
| 799 |
+
repetition_penalty=1.3,
|
| 800 |
+
eos_token_id=eos_ids if eos_ids else None,
|
| 801 |
+
)
|
| 802 |
+
new_tokens = output_ids[0][n_input:]
|
| 803 |
+
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 804 |
+
gen_time = time.time() - t0
|
| 805 |
+
|
| 806 |
+
thinking = ""
|
| 807 |
+
if ENABLE_THINKING:
|
| 808 |
+
think_match = re.search(
|
| 809 |
+
r"<think>(.*?)</think>", response, re.DOTALL
|
| 810 |
+
)
|
| 811 |
+
if think_match:
|
| 812 |
+
thinking = think_match.group(1).strip()
|
| 813 |
+
response = response[think_match.end():].strip()
|
| 814 |
+
elif response.startswith("<think>"):
|
| 815 |
+
parts = response.split("</think>", 1)
|
| 816 |
+
if len(parts) == 2:
|
| 817 |
+
thinking = parts[0].replace("<think>", "").strip()
|
| 818 |
+
response = parts[1].strip()
|
| 819 |
+
|
| 820 |
+
is_last_step = (step == MAX_EPISODE_STEPS - 1)
|
| 821 |
+
|
| 822 |
+
action = parse_action(response)
|
| 823 |
+
if action is None:
|
| 824 |
+
if is_last_step:
|
| 825 |
+
log(f"\n [!] Parse failed on final step — forcing synthesize_conclusion.")
|
| 826 |
+
action = ExperimentAction(
|
| 827 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 828 |
+
justification="forced terminal conclusion",
|
| 829 |
+
confidence=0.5,
|
| 830 |
+
)
|
| 831 |
+
else:
|
| 832 |
+
log(f"\n [!] Parse failed, skipping step. Raw: {response[:150]}")
|
| 833 |
+
continue
|
| 834 |
+
|
| 835 |
+
completed_types = {
|
| 836 |
+
r.action_type for r in obs.pipeline_history if r.success
|
| 837 |
+
}
|
| 838 |
+
failed_types = {
|
| 839 |
+
r.action_type
|
| 840 |
+
for r in obs.pipeline_history
|
| 841 |
+
if not r.success
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
if should_force_terminal_conclusion(action, completed_types):
|
| 845 |
+
log(
|
| 846 |
+
f"\n [!] repeated completed meta step {action.action_type.value} "
|
| 847 |
+
f"— forcing synthesize_conclusion."
|
| 848 |
+
)
|
| 849 |
+
action = ExperimentAction(
|
| 850 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 851 |
+
justification="repeated completed meta step forced terminal conclusion",
|
| 852 |
+
confidence=action.confidence,
|
| 853 |
+
)
|
| 854 |
+
completed_types = {
|
| 855 |
+
r.action_type for r in obs.pipeline_history if r.success
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
skip_reason = None
|
| 859 |
+
if action.action_type in completed_types:
|
| 860 |
+
skip_reason = (
|
| 861 |
+
f"blocked repeat of completed step {action.action_type.value}"
|
| 862 |
+
)
|
| 863 |
+
elif action.action_type in failed_types:
|
| 864 |
+
if should_block_failed_reattempt(
|
| 865 |
+
obs.pipeline_history, action.action_type
|
| 866 |
+
):
|
| 867 |
+
skip_reason = (
|
| 868 |
+
f"blocked re-attempt of failed step {action.action_type.value}"
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
if skip_reason:
|
| 872 |
+
if is_last_step:
|
| 873 |
+
log(f"\n [!] {skip_reason} on final step — forcing synthesize_conclusion.")
|
| 874 |
+
action = ExperimentAction(
|
| 875 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 876 |
+
justification="forced terminal conclusion",
|
| 877 |
+
confidence=0.5,
|
| 878 |
+
)
|
| 879 |
+
else:
|
| 880 |
+
log(f"\n [!] {skip_reason}, skipping step.")
|
| 881 |
+
continue
|
| 882 |
+
|
| 883 |
+
if is_last_step and action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
|
| 884 |
+
log(f"\n [!] Final step — overriding {action.action_type.value} with synthesize_conclusion.")
|
| 885 |
+
action = ExperimentAction(
|
| 886 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 887 |
+
justification="forced terminal conclusion",
|
| 888 |
+
confidence=action.confidence,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
|
| 892 |
+
if thinking:
|
| 893 |
+
log(f" Thinking: {thinking[:200]}")
|
| 894 |
+
if action.justification:
|
| 895 |
+
log(f" Rationale: {action.justification}")
|
| 896 |
+
else:
|
| 897 |
+
log(" Rationale: [model did not provide one]")
|
| 898 |
+
if action.parameters:
|
| 899 |
+
log(f" Parameters: {compact_preview(action.parameters, 200)}")
|
| 900 |
+
elif not action.justification and response:
|
| 901 |
+
log(
|
| 902 |
+
f" Model response: "
|
| 903 |
+
f"{compact_preview(response, MODEL_RESPONSE_PREVIEW_CHARS)}"
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
obs = env.step(action)
|
| 907 |
+
|
| 908 |
+
if obs.latest_output:
|
| 909 |
+
lo = obs.latest_output
|
| 910 |
+
status = "OK" if lo.success else "FAIL"
|
| 911 |
+
log(f" [{status}] {lo.summary}")
|
| 912 |
+
if lo.warnings:
|
| 913 |
+
log(f" Warnings: {lo.warnings}")
|
| 914 |
+
|
| 915 |
+
step_reward = obs.reward
|
| 916 |
+
cumulative_reward += step_reward
|
| 917 |
+
log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
|
| 918 |
+
log(f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d")
|
| 919 |
+
|
| 920 |
+
write_dashboard_state(
|
| 921 |
+
env, obs,
|
| 922 |
+
step=step + 1,
|
| 923 |
+
cumulative_reward=cumulative_reward,
|
| 924 |
+
model_response=response,
|
| 925 |
+
model_thinking=thinking,
|
| 926 |
+
action=action,
|
| 927 |
+
gen_time=gen_time,
|
| 928 |
+
episode_done=obs.done,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
if obs.rule_violations:
|
| 932 |
+
log(f" Violations: {obs.rule_violations}")
|
| 933 |
+
|
| 934 |
+
if obs.done:
|
| 935 |
+
break
|
| 936 |
+
|
| 937 |
+
log(f"\n{'=' * 70}")
|
| 938 |
+
log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
|
| 939 |
+
log(f" Steps: {obs.step_index}")
|
| 940 |
+
log(f" Total reward: {cumulative_reward:+.3f}")
|
| 941 |
+
log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
|
| 942 |
+
log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
|
| 943 |
+
if obs.conclusions:
|
| 944 |
+
log(" Conclusions:")
|
| 945 |
+
for c in obs.conclusions:
|
| 946 |
+
log(f" [{c.claim_type}, conf={c.confidence:.2f}] {c.claim}")
|
| 947 |
+
if c.top_markers:
|
| 948 |
+
log(f" Markers: {c.top_markers}")
|
| 949 |
+
if c.causal_mechanisms:
|
| 950 |
+
log(f" Mechanisms: {c.causal_mechanisms}")
|
| 951 |
+
if c.predicted_pathways:
|
| 952 |
+
log(f" Pathways: {c.predicted_pathways}")
|
| 953 |
+
log("=" * 70)
|
| 954 |
+
|
| 955 |
+
DASHBOARD_CMD_PATH.unlink(missing_ok=True)
|
| 956 |
+
run_episode()
|
| 957 |
+
|
| 958 |
+
while True:
|
| 959 |
+
log("\nWaiting for dashboard command (restart / new task) ...")
|
| 960 |
+
while True:
|
| 961 |
+
cmd = check_dashboard_command()
|
| 962 |
+
if cmd:
|
| 963 |
+
break
|
| 964 |
+
time.sleep(1.0)
|
| 965 |
+
|
| 966 |
+
action_type = cmd.get("action", "restart")
|
| 967 |
+
if action_type == "quit":
|
| 968 |
+
log("Quit requested.")
|
| 969 |
+
break
|
| 970 |
+
|
| 971 |
+
scenario = cmd.get("scenario_name")
|
| 972 |
+
ground_truth = cmd.get("ground_truth")
|
| 973 |
+
log(f"\n[DASHBOARD] {action_type} — scenario={scenario}")
|
| 974 |
+
run_episode(scenario_name=scenario, custom_ground_truth=ground_truth)
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
if __name__ == "__main__":
|
| 978 |
+
main()
|
server/app.py
CHANGED
|
@@ -6,8 +6,12 @@ Endpoints:
|
|
| 6 |
- GET /state: Get current environment state
|
| 7 |
- GET /schema: Get action/observation schemas
|
| 8 |
- WS /ws: WebSocket endpoint for persistent sessions
|
|
|
|
| 9 |
"""
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
try:
|
| 12 |
from openenv.core.env_server.http_server import create_app
|
| 13 |
except Exception as e: # pragma: no cover
|
|
@@ -16,6 +20,7 @@ except Exception as e: # pragma: no cover
|
|
| 16 |
"Install dependencies with 'uv sync'"
|
| 17 |
) from e
|
| 18 |
|
|
|
|
| 19 |
from models import ExperimentAction, ExperimentObservation
|
| 20 |
from .hackathon_environment import BioExperimentEnvironment
|
| 21 |
|
|
@@ -24,12 +29,24 @@ app = create_app(
|
|
| 24 |
ExperimentAction,
|
| 25 |
ExperimentObservation,
|
| 26 |
env_name="bio_experiment",
|
| 27 |
-
max_concurrent_envs=
|
| 28 |
)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
def main(host: str = "0.0.0.0", port: int =
|
| 32 |
import uvicorn
|
|
|
|
|
|
|
| 33 |
uvicorn.run(app, host=host, port=port)
|
| 34 |
|
| 35 |
|
|
@@ -37,9 +54,6 @@ if __name__ == "__main__":
|
|
| 37 |
import argparse
|
| 38 |
parser = argparse.ArgumentParser()
|
| 39 |
parser.add_argument("--host", default="0.0.0.0")
|
| 40 |
-
parser.add_argument("--port", type=int, default=
|
| 41 |
args = parser.parse_args()
|
| 42 |
-
|
| 43 |
-
main()
|
| 44 |
-
else:
|
| 45 |
-
main(host=args.host, port=args.port)
|
|
|
|
| 6 |
- GET /state: Get current environment state
|
| 7 |
- GET /schema: Get action/observation schemas
|
| 8 |
- WS /ws: WebSocket endpoint for persistent sessions
|
| 9 |
+
- GET / Demo UI
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
try:
|
| 16 |
from openenv.core.env_server.http_server import create_app
|
| 17 |
except Exception as e: # pragma: no cover
|
|
|
|
| 20 |
"Install dependencies with 'uv sync'"
|
| 21 |
) from e
|
| 22 |
|
| 23 |
+
from fastapi.responses import HTMLResponse
|
| 24 |
from models import ExperimentAction, ExperimentObservation
|
| 25 |
from .hackathon_environment import BioExperimentEnvironment
|
| 26 |
|
|
|
|
| 29 |
ExperimentAction,
|
| 30 |
ExperimentObservation,
|
| 31 |
env_name="bio_experiment",
|
| 32 |
+
max_concurrent_envs=int(os.environ.get("MAX_ENVS", "4")),
|
| 33 |
)
|
| 34 |
|
| 35 |
+
# Serve demo UI at root
|
| 36 |
+
DEMO_HTML = Path(__file__).resolve().parent.parent / "demo.html"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@app.get("/", response_class=HTMLResponse)
|
| 40 |
+
async def demo_ui():
|
| 41 |
+
if DEMO_HTML.exists():
|
| 42 |
+
return HTMLResponse(content=DEMO_HTML.read_text(), status_code=200)
|
| 43 |
+
return HTMLResponse(content="<h1>BioEnv API</h1><p>Visit /docs for API documentation.</p>", status_code=200)
|
| 44 |
+
|
| 45 |
|
| 46 |
+
def main(host: str = "0.0.0.0", port: int = None):
|
| 47 |
import uvicorn
|
| 48 |
+
if port is None:
|
| 49 |
+
port = int(os.environ.get("PORT", "8000"))
|
| 50 |
uvicorn.run(app, host=host, port=port)
|
| 51 |
|
| 52 |
|
|
|
|
| 54 |
import argparse
|
| 55 |
parser = argparse.ArgumentParser()
|
| 56 |
parser.add_argument("--host", default="0.0.0.0")
|
| 57 |
+
parser.add_argument("--port", type=int, default=None)
|
| 58 |
args = parser.parse_args()
|
| 59 |
+
main(host=args.host, port=args.port)
|
|
|
|
|
|
|
|
|
server/biology/__init__.py
ADDED
|
File without changes
|
server/biology/gene_index.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pathway-aware gene similarity index for structured reward scoring.
|
| 2 |
+
|
| 3 |
+
Uses gseapy pathway libraries (KEGG + Reactome) to build binary pathway
|
| 4 |
+
membership vectors per gene, enabling cosine-similarity-based set scoring
|
| 5 |
+
instead of substring matching.
|
| 6 |
+
|
| 7 |
+
Mechanism comparison uses sentence-transformers for semantic similarity.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from functools import lru_cache
|
| 14 |
+
from typing import Dict, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
_PATHWAY_SETS: Optional[Dict[str, List[str]]] = None
|
| 21 |
+
_PATHWAY_NAMES: Optional[List[str]] = None
|
| 22 |
+
_GENE_TO_PATHWAY_IDX: Optional[Dict[str, List[int]]] = None
|
| 23 |
+
_N_PATHWAYS: int = 0
|
| 24 |
+
|
| 25 |
+
_SENTENCE_MODEL = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _ensure_pathway_index() -> None:
|
| 29 |
+
"""Lazily build the inverted gene→pathway index on first use."""
|
| 30 |
+
global _PATHWAY_SETS, _PATHWAY_NAMES, _GENE_TO_PATHWAY_IDX, _N_PATHWAYS
|
| 31 |
+
|
| 32 |
+
if _PATHWAY_NAMES is not None:
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import gseapy as gp
|
| 37 |
+
except ImportError:
|
| 38 |
+
logger.warning("gseapy not installed; pathway scoring will use fallback.")
|
| 39 |
+
_PATHWAY_SETS = {}
|
| 40 |
+
_PATHWAY_NAMES = []
|
| 41 |
+
_GENE_TO_PATHWAY_IDX = {}
|
| 42 |
+
_N_PATHWAYS = 0
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
combined: Dict[str, List[str]] = {}
|
| 46 |
+
for lib_name in ("KEGG_2021_Human", "Reactome_2022"):
|
| 47 |
+
try:
|
| 48 |
+
combined.update(gp.get_library(lib_name))
|
| 49 |
+
except Exception as exc:
|
| 50 |
+
logger.warning("Failed to load %s: %s", lib_name, exc)
|
| 51 |
+
|
| 52 |
+
_PATHWAY_SETS = combined
|
| 53 |
+
_PATHWAY_NAMES = sorted(combined.keys())
|
| 54 |
+
_N_PATHWAYS = len(_PATHWAY_NAMES)
|
| 55 |
+
|
| 56 |
+
inv: Dict[str, List[int]] = {}
|
| 57 |
+
for idx, pw_name in enumerate(_PATHWAY_NAMES):
|
| 58 |
+
for gene in combined[pw_name]:
|
| 59 |
+
gene_upper = gene.upper().strip()
|
| 60 |
+
inv.setdefault(gene_upper, []).append(idx)
|
| 61 |
+
|
| 62 |
+
_GENE_TO_PATHWAY_IDX = inv
|
| 63 |
+
logger.info(
|
| 64 |
+
"Pathway index built: %d pathways, %d genes indexed.",
|
| 65 |
+
_N_PATHWAYS, len(inv),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _ensure_sentence_model():
|
| 70 |
+
"""Lazily load the sentence-transformer model."""
|
| 71 |
+
global _SENTENCE_MODEL
|
| 72 |
+
if _SENTENCE_MODEL is not None:
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from sentence_transformers import SentenceTransformer
|
| 77 |
+
_SENTENCE_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
| 78 |
+
except ImportError:
|
| 79 |
+
logger.warning(
|
| 80 |
+
"sentence-transformers not installed; mechanism scoring will use fallback."
|
| 81 |
+
)
|
| 82 |
+
_SENTENCE_MODEL = None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def gene_vector(gene: str) -> np.ndarray:
|
| 86 |
+
"""L2-normalised binary pathway membership vector for *gene*."""
|
| 87 |
+
_ensure_pathway_index()
|
| 88 |
+
vec = np.zeros(_N_PATHWAYS, dtype=np.float32)
|
| 89 |
+
indices = _GENE_TO_PATHWAY_IDX.get(gene.upper().strip(), [])
|
| 90 |
+
if indices:
|
| 91 |
+
vec[indices] = 1.0
|
| 92 |
+
norm = np.linalg.norm(vec)
|
| 93 |
+
if norm > 0:
|
| 94 |
+
vec /= norm
|
| 95 |
+
return vec
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def pathway_similarity(g1: str, g2: str) -> float:
|
| 99 |
+
"""Cosine similarity between two genes in pathway space."""
|
| 100 |
+
v1 = gene_vector(g1)
|
| 101 |
+
v2 = gene_vector(g2)
|
| 102 |
+
dot = float(np.dot(v1, v2))
|
| 103 |
+
return max(0.0, min(1.0, dot))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def marker_set_score(
|
| 107 |
+
predicted: List[str],
|
| 108 |
+
truth: List[str],
|
| 109 |
+
sigma: float = 0.3,
|
| 110 |
+
) -> float:
|
| 111 |
+
"""Pathway-weighted Gaussian set similarity for marker genes.
|
| 112 |
+
|
| 113 |
+
For each true marker, finds the best-matching predicted gene by
|
| 114 |
+
pathway cosine similarity, then applies a Gaussian kernel:
|
| 115 |
+
score_i = exp(-d^2 / (2 * sigma^2)) where d = 1 - sim
|
| 116 |
+
Returns the mean score over all true markers.
|
| 117 |
+
"""
|
| 118 |
+
if not truth:
|
| 119 |
+
return 0.0
|
| 120 |
+
if not predicted:
|
| 121 |
+
return 0.0
|
| 122 |
+
|
| 123 |
+
_ensure_pathway_index()
|
| 124 |
+
|
| 125 |
+
if _N_PATHWAYS == 0:
|
| 126 |
+
return _fallback_marker_score(predicted, truth)
|
| 127 |
+
|
| 128 |
+
pred_vecs = [gene_vector(g) for g in predicted]
|
| 129 |
+
scores: List[float] = []
|
| 130 |
+
|
| 131 |
+
for true_gene in truth:
|
| 132 |
+
tv = gene_vector(true_gene)
|
| 133 |
+
best_sim = 0.0
|
| 134 |
+
for pv in pred_vecs:
|
| 135 |
+
sim = float(np.dot(tv, pv))
|
| 136 |
+
if sim > best_sim:
|
| 137 |
+
best_sim = sim
|
| 138 |
+
d = 1.0 - best_sim
|
| 139 |
+
scores.append(float(np.exp(-(d ** 2) / (2.0 * sigma ** 2))))
|
| 140 |
+
|
| 141 |
+
return sum(scores) / len(scores)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _fallback_marker_score(predicted: List[str], truth: List[str]) -> float:
|
| 145 |
+
"""Exact-match fallback when pathway data is unavailable."""
|
| 146 |
+
pred_set = {g.upper().strip() for g in predicted}
|
| 147 |
+
hits = sum(1 for g in truth if g.upper().strip() in pred_set)
|
| 148 |
+
return hits / len(truth) if truth else 0.0
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def mechanism_set_score(predicted: List[str], truth: List[str]) -> float:
|
| 152 |
+
"""Sentence-transformer semantic similarity for mechanism strings.
|
| 153 |
+
|
| 154 |
+
For each truth mechanism, finds the best-matching predicted mechanism
|
| 155 |
+
by cosine similarity and returns the mean of best matches.
|
| 156 |
+
"""
|
| 157 |
+
if not truth:
|
| 158 |
+
return 0.0
|
| 159 |
+
if not predicted:
|
| 160 |
+
return 0.0
|
| 161 |
+
|
| 162 |
+
_ensure_sentence_model()
|
| 163 |
+
|
| 164 |
+
if _SENTENCE_MODEL is None:
|
| 165 |
+
return _fallback_mechanism_score(predicted, truth)
|
| 166 |
+
|
| 167 |
+
pred_embs = _SENTENCE_MODEL.encode(predicted, convert_to_numpy=True)
|
| 168 |
+
truth_embs = _SENTENCE_MODEL.encode(truth, convert_to_numpy=True)
|
| 169 |
+
|
| 170 |
+
pred_norms = pred_embs / (
|
| 171 |
+
np.linalg.norm(pred_embs, axis=1, keepdims=True) + 1e-9
|
| 172 |
+
)
|
| 173 |
+
truth_norms = truth_embs / (
|
| 174 |
+
np.linalg.norm(truth_embs, axis=1, keepdims=True) + 1e-9
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
sim_matrix = truth_norms @ pred_norms.T
|
| 178 |
+
best_per_truth = sim_matrix.max(axis=1)
|
| 179 |
+
return float(np.mean(np.clip(best_per_truth, 0.0, 1.0)))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _fallback_mechanism_score(predicted: List[str], truth: List[str]) -> float:
|
| 183 |
+
"""Token-overlap fallback when sentence-transformers is unavailable."""
|
| 184 |
+
scores: List[float] = []
|
| 185 |
+
for t in truth:
|
| 186 |
+
t_tokens = set(t.lower().split())
|
| 187 |
+
best = 0.0
|
| 188 |
+
for p in predicted:
|
| 189 |
+
p_tokens = set(p.lower().split())
|
| 190 |
+
union = t_tokens | p_tokens
|
| 191 |
+
if union:
|
| 192 |
+
overlap = len(t_tokens & p_tokens) / len(union)
|
| 193 |
+
best = max(best, overlap)
|
| 194 |
+
scores.append(best)
|
| 195 |
+
return sum(scores) / len(scores) if scores else 0.0
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def score_pathways(
|
| 199 |
+
predicted: Dict[str, float],
|
| 200 |
+
truth: Dict[str, float],
|
| 201 |
+
) -> float:
|
| 202 |
+
"""Score predicted pathway activations against ground truth.
|
| 203 |
+
|
| 204 |
+
Uses normalised key matching with activity-level weighting.
|
| 205 |
+
"""
|
| 206 |
+
if not truth:
|
| 207 |
+
return 0.0
|
| 208 |
+
if not predicted:
|
| 209 |
+
return 0.0
|
| 210 |
+
|
| 211 |
+
pred_norm = {k.lower().strip(): v for k, v in predicted.items()}
|
| 212 |
+
total_weight = 0.0
|
| 213 |
+
weighted_score = 0.0
|
| 214 |
+
|
| 215 |
+
for pw, true_activity in truth.items():
|
| 216 |
+
pw_key = pw.lower().strip()
|
| 217 |
+
weight = true_activity
|
| 218 |
+
total_weight += weight
|
| 219 |
+
if pw_key in pred_norm:
|
| 220 |
+
pred_activity = pred_norm[pw_key]
|
| 221 |
+
diff = abs(pred_activity - true_activity)
|
| 222 |
+
match_score = max(0.0, 1.0 - diff)
|
| 223 |
+
weighted_score += weight * match_score
|
| 224 |
+
|
| 225 |
+
return weighted_score / total_weight if total_weight > 0 else 0.0
|
server/hackathon_environment.py
CHANGED
|
@@ -28,7 +28,7 @@ from server.rules.engine import RuleEngine
|
|
| 28 |
from server.rewards.reward import RewardBreakdown, RewardComputer
|
| 29 |
from server.simulator.latent_state import FullLatentState
|
| 30 |
from server.simulator.noise import NoiseModel
|
| 31 |
-
from server.simulator.transition import ACTION_COSTS, TransitionEngine
|
| 32 |
from server.tasks.generator import TaskGenerator
|
| 33 |
|
| 34 |
|
|
@@ -70,8 +70,8 @@ class BioExperimentEnvironment(Environment):
|
|
| 70 |
|
| 71 |
# ── Environment interface ───────────────────────────────────────────
|
| 72 |
|
| 73 |
-
def reset(self) -> ExperimentObservation:
|
| 74 |
-
seed = hash(uuid4()) % (2**31)
|
| 75 |
self._noise.reseed(seed)
|
| 76 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 77 |
|
|
@@ -116,7 +116,7 @@ class BioExperimentEnvironment(Environment):
|
|
| 116 |
action, prev_state, self._latent, result.output, hard_v, soft_v,
|
| 117 |
)
|
| 118 |
|
| 119 |
-
cost_budget, cost_time =
|
| 120 |
self._history.append(PipelineStepRecord(
|
| 121 |
step_index=self._state.step_count,
|
| 122 |
action_type=action.action_type,
|
|
@@ -143,7 +143,11 @@ class BioExperimentEnvironment(Environment):
|
|
| 143 |
terminal_rb = RewardBreakdown()
|
| 144 |
if done:
|
| 145 |
terminal_rb = self._rewards.terminal_reward(
|
| 146 |
-
self._latent,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
total_reward = step_rb.total + terminal_rb.total
|
|
@@ -158,6 +162,7 @@ class BioExperimentEnvironment(Environment):
|
|
| 158 |
latest_output=result.output,
|
| 159 |
rule_violations=hard_v + soft_v,
|
| 160 |
reward_breakdown=breakdown,
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
@property
|
|
@@ -179,10 +184,18 @@ class BioExperimentEnvironment(Environment):
|
|
| 179 |
latest_output: Optional[IntermediateOutput] = None,
|
| 180 |
rule_violations: Optional[List[str]] = None,
|
| 181 |
reward_breakdown: Optional[Dict[str, float]] = None,
|
|
|
|
| 182 |
) -> ExperimentObservation:
|
| 183 |
assert self._task is not None
|
| 184 |
assert self._latent is not None
|
| 185 |
res = self._latent.resources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
return ExperimentObservation(
|
| 187 |
task=self._task,
|
| 188 |
step_index=self._state.step_count,
|
|
@@ -205,14 +218,10 @@ class BioExperimentEnvironment(Environment):
|
|
| 205 |
subagent_outputs=list(self._subagent_outputs),
|
| 206 |
conclusions=list(self._conclusions),
|
| 207 |
rule_violations=rule_violations or [],
|
| 208 |
-
step_reward_breakdown=
|
| 209 |
done=done,
|
| 210 |
reward=reward,
|
| 211 |
-
metadata=
|
| 212 |
-
"episode_id": self._state.episode_id,
|
| 213 |
-
"step": self._state.step_count,
|
| 214 |
-
"cumulative_reward": self._cumulative_reward,
|
| 215 |
-
},
|
| 216 |
)
|
| 217 |
|
| 218 |
def _compute_uncertainty_summary(self) -> Dict[str, float]:
|
|
@@ -228,12 +237,22 @@ class BioExperimentEnvironment(Environment):
|
|
| 228 |
) -> None:
|
| 229 |
if action.action_type == ActionType.MARKER_SELECTION:
|
| 230 |
markers = output.data.get("markers", [])
|
| 231 |
-
self._discovered_markers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| 233 |
regs = output.data.get("top_regulators", [])
|
| 234 |
-
self._candidate_mechanisms
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
|
| 236 |
pathways = output.data.get("top_pathways", [])
|
| 237 |
-
self._candidate_mechanisms
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
| 28 |
from server.rewards.reward import RewardBreakdown, RewardComputer
|
| 29 |
from server.simulator.latent_state import FullLatentState
|
| 30 |
from server.simulator.noise import NoiseModel
|
| 31 |
+
from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost
|
| 32 |
from server.tasks.generator import TaskGenerator
|
| 33 |
|
| 34 |
|
|
|
|
| 70 |
|
| 71 |
# ── Environment interface ───────────────────────────────────────────
|
| 72 |
|
| 73 |
+
def reset(self, seed: Optional[int] = None) -> ExperimentObservation:
|
| 74 |
+
seed = seed if seed is not None else hash(uuid4()) % (2**31)
|
| 75 |
self._noise.reseed(seed)
|
| 76 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 77 |
|
|
|
|
| 116 |
action, prev_state, self._latent, result.output, hard_v, soft_v,
|
| 117 |
)
|
| 118 |
|
| 119 |
+
cost_budget, cost_time = compute_action_cost(action)
|
| 120 |
self._history.append(PipelineStepRecord(
|
| 121 |
step_index=self._state.step_count,
|
| 122 |
action_type=action.action_type,
|
|
|
|
| 143 |
terminal_rb = RewardBreakdown()
|
| 144 |
if done:
|
| 145 |
terminal_rb = self._rewards.terminal_reward(
|
| 146 |
+
self._latent,
|
| 147 |
+
self._conclusions,
|
| 148 |
+
self._task.success_criteria,
|
| 149 |
+
discovered_markers=self._discovered_markers,
|
| 150 |
+
candidate_mechanisms=self._candidate_mechanisms,
|
| 151 |
)
|
| 152 |
|
| 153 |
total_reward = step_rb.total + terminal_rb.total
|
|
|
|
| 162 |
latest_output=result.output,
|
| 163 |
rule_violations=hard_v + soft_v,
|
| 164 |
reward_breakdown=breakdown,
|
| 165 |
+
metadata_extra={"reward_breakdown": breakdown},
|
| 166 |
)
|
| 167 |
|
| 168 |
@property
|
|
|
|
| 184 |
latest_output: Optional[IntermediateOutput] = None,
|
| 185 |
rule_violations: Optional[List[str]] = None,
|
| 186 |
reward_breakdown: Optional[Dict[str, float]] = None,
|
| 187 |
+
metadata_extra: Optional[Dict[str, Any]] = None,
|
| 188 |
) -> ExperimentObservation:
|
| 189 |
assert self._task is not None
|
| 190 |
assert self._latent is not None
|
| 191 |
res = self._latent.resources
|
| 192 |
+
meta: Dict[str, Any] = {
|
| 193 |
+
"episode_id": self._state.episode_id,
|
| 194 |
+
"step": self._state.step_count,
|
| 195 |
+
"cumulative_reward": self._cumulative_reward,
|
| 196 |
+
}
|
| 197 |
+
if metadata_extra:
|
| 198 |
+
meta.update(metadata_extra)
|
| 199 |
return ExperimentObservation(
|
| 200 |
task=self._task,
|
| 201 |
step_index=self._state.step_count,
|
|
|
|
| 218 |
subagent_outputs=list(self._subagent_outputs),
|
| 219 |
conclusions=list(self._conclusions),
|
| 220 |
rule_violations=rule_violations or [],
|
| 221 |
+
step_reward_breakdown={},
|
| 222 |
done=done,
|
| 223 |
reward=reward,
|
| 224 |
+
metadata=meta,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
)
|
| 226 |
|
| 227 |
def _compute_uncertainty_summary(self) -> Dict[str, float]:
|
|
|
|
| 237 |
) -> None:
|
| 238 |
if action.action_type == ActionType.MARKER_SELECTION:
|
| 239 |
markers = output.data.get("markers", [])
|
| 240 |
+
existing = set(self._discovered_markers)
|
| 241 |
+
for m in markers:
|
| 242 |
+
if m not in existing:
|
| 243 |
+
self._discovered_markers.append(m)
|
| 244 |
+
existing.add(m)
|
| 245 |
if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| 246 |
regs = output.data.get("top_regulators", [])
|
| 247 |
+
existing = set(self._candidate_mechanisms)
|
| 248 |
+
for r in regs:
|
| 249 |
+
if r not in existing:
|
| 250 |
+
self._candidate_mechanisms.append(r)
|
| 251 |
+
existing.add(r)
|
| 252 |
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
|
| 253 |
pathways = output.data.get("top_pathways", [])
|
| 254 |
+
existing = set(self._candidate_mechanisms)
|
| 255 |
+
for p in pathways:
|
| 256 |
+
if isinstance(p, dict) and p["pathway"] not in existing:
|
| 257 |
+
self._candidate_mechanisms.append(p["pathway"])
|
| 258 |
+
existing.add(p["pathway"])
|
server/requirements.txt
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
-
openenv[core]>=0.2.0
|
| 2 |
fastapi>=0.115.0
|
| 3 |
uvicorn>=0.24.0
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
fastapi>=0.115.0
|
| 3 |
uvicorn>=0.24.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
scipy>=1.10.0
|
| 6 |
+
pydantic>=2.0.0
|
| 7 |
+
gseapy>=1.1.0
|
| 8 |
+
sentence-transformers>=3.0.0
|
| 9 |
+
scikit-learn>=1.4.0
|
server/rewards/reward.py
CHANGED
|
@@ -15,7 +15,7 @@ Potential-based shaping
|
|
| 15 |
|
| 16 |
The final step reward is:
|
| 17 |
R_t = r_validity + r_ordering + r_info_gain + r_efficiency
|
| 18 |
-
+ r_novelty + r_penalty +
|
| 19 |
|
| 20 |
The terminal reward adds:
|
| 21 |
R_T += r_terminal
|
|
@@ -32,9 +32,15 @@ from models import (
|
|
| 32 |
ExperimentAction,
|
| 33 |
IntermediateOutput,
|
| 34 |
META_ACTIONS,
|
|
|
|
| 35 |
WET_LAB_ACTIONS,
|
| 36 |
)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
from server.simulator.latent_state import FullLatentState
|
| 39 |
|
| 40 |
|
|
@@ -84,20 +90,16 @@ class RewardComputer:
|
|
| 84 |
|
| 85 |
Parameters
|
| 86 |
----------
|
| 87 |
-
gamma : float
|
| 88 |
-
Discount factor for potential-based shaping (default 0.99).
|
| 89 |
efficiency_weight : float
|
| 90 |
Relative importance of resource efficiency.
|
| 91 |
"""
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
| 95 |
-
gamma: float = 0.99,
|
| 96 |
efficiency_weight: float = 0.3,
|
| 97 |
info_gain_weight: float = 0.4,
|
| 98 |
validity_weight: float = 0.3,
|
| 99 |
):
|
| 100 |
-
self.gamma = gamma
|
| 101 |
self.w_eff = efficiency_weight
|
| 102 |
self.w_ig = info_gain_weight
|
| 103 |
self.w_val = validity_weight
|
|
@@ -124,11 +126,19 @@ class RewardComputer:
|
|
| 124 |
|
| 125 |
rb.validity = self.w_val * (1.0 if output.success else 0.0)
|
| 126 |
|
| 127 |
-
|
| 128 |
-
rb.ordering = 0.2 *
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# information gain proxy: quality × (1 - uncertainty)
|
| 131 |
rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
# efficiency: normalised cost relative to budget
|
| 134 |
budget_frac = (
|
|
@@ -141,13 +151,25 @@ class RewardComputer:
|
|
| 141 |
if not soft_violations:
|
| 142 |
rb.novelty = 0.1
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# penalties
|
| 145 |
rb.penalty = -0.15 * len(soft_violations)
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
phi_prev = self._potential(prev_state)
|
| 149 |
phi_next = self._potential(next_state)
|
| 150 |
-
rb.shaping =
|
| 151 |
|
| 152 |
return rb
|
| 153 |
|
|
@@ -158,8 +180,12 @@ class RewardComputer:
|
|
| 158 |
state: FullLatentState,
|
| 159 |
conclusions: List[ConclusionClaim],
|
| 160 |
task_success_criteria: List[str],
|
|
|
|
|
|
|
| 161 |
) -> RewardBreakdown:
|
| 162 |
rb = RewardBreakdown()
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# pipeline completeness (0-1)
|
| 165 |
completeness = self._completeness(state)
|
|
@@ -183,11 +209,22 @@ class RewardComputer:
|
|
| 183 |
overconf = self._overconfidence_penalty(state, conclusions)
|
| 184 |
rb.components["overconfidence_penalty"] = overconf
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
rb.terminal = (
|
| 187 |
3.0 * completeness
|
| 188 |
+ 4.0 * calibration
|
| 189 |
-
+ 1.0 *
|
| 190 |
+ overconf
|
|
|
|
| 191 |
)
|
| 192 |
return rb
|
| 193 |
|
|
@@ -196,7 +233,7 @@ class RewardComputer:
|
|
| 196 |
def _ordering_score(
|
| 197 |
self, action: ExperimentAction, s: FullLatentState
|
| 198 |
) -> float:
|
| 199 |
-
"""Heuristic: 1.0 if
|
| 200 |
at = action.action_type
|
| 201 |
p = s.progress
|
| 202 |
NATURAL_NEXT = {
|
|
@@ -215,10 +252,26 @@ class RewardComputer:
|
|
| 215 |
p.de_performed or p.cells_clustered
|
| 216 |
) and not p.conclusion_reached,
|
| 217 |
}
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
def _potential(self, s: FullLatentState) -> float:
|
| 221 |
-
"""Progress potential φ(s) — counts completed milestones.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
p = s.progress
|
| 223 |
milestones = [
|
| 224 |
p.samples_collected,
|
|
@@ -252,9 +305,38 @@ class RewardComputer:
|
|
| 252 |
def _calibration(
|
| 253 |
self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 254 |
) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
if not conclusions:
|
| 256 |
return 0.0
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
true_mechanisms = set(s.biology.causal_mechanisms)
|
| 259 |
true_markers = set(s.biology.true_markers)
|
| 260 |
score = 0.0
|
|
@@ -270,16 +352,121 @@ class RewardComputer:
|
|
| 270 |
score -= 0.3
|
| 271 |
return max(0.0, min(1.0, score / max(n, 1)))
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
def _overconfidence_penalty(
|
| 274 |
self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 275 |
) -> float:
|
| 276 |
-
"""Penalise high-confidence claims that disagree with ground truth.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
penalty = 0.0
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
| 281 |
for c in conclusions:
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
penalty -= 0.5 * c.confidence
|
|
|
|
| 285 |
return penalty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
The final step reward is:
|
| 17 |
R_t = r_validity + r_ordering + r_info_gain + r_efficiency
|
| 18 |
+
+ r_novelty + r_penalty + [φ(s_{t+1}) − φ(s_t)]
|
| 19 |
|
| 20 |
The terminal reward adds:
|
| 21 |
R_T += r_terminal
|
|
|
|
| 32 |
ExperimentAction,
|
| 33 |
IntermediateOutput,
|
| 34 |
META_ACTIONS,
|
| 35 |
+
TOOL_REGISTRY,
|
| 36 |
WET_LAB_ACTIONS,
|
| 37 |
)
|
| 38 |
|
| 39 |
+
from server.biology.gene_index import (
|
| 40 |
+
marker_set_score,
|
| 41 |
+
mechanism_set_score,
|
| 42 |
+
score_pathways,
|
| 43 |
+
)
|
| 44 |
from server.simulator.latent_state import FullLatentState
|
| 45 |
|
| 46 |
|
|
|
|
| 90 |
|
| 91 |
Parameters
|
| 92 |
----------
|
|
|
|
|
|
|
| 93 |
efficiency_weight : float
|
| 94 |
Relative importance of resource efficiency.
|
| 95 |
"""
|
| 96 |
|
| 97 |
def __init__(
|
| 98 |
self,
|
|
|
|
| 99 |
efficiency_weight: float = 0.3,
|
| 100 |
info_gain_weight: float = 0.4,
|
| 101 |
validity_weight: float = 0.3,
|
| 102 |
):
|
|
|
|
| 103 |
self.w_eff = efficiency_weight
|
| 104 |
self.w_ig = info_gain_weight
|
| 105 |
self.w_val = validity_weight
|
|
|
|
| 126 |
|
| 127 |
rb.validity = self.w_val * (1.0 if output.success else 0.0)
|
| 128 |
|
| 129 |
+
ordering_score = self._ordering_score(action, prev_state)
|
| 130 |
+
rb.ordering = 0.2 * ordering_score
|
| 131 |
+
if ordering_score < 0:
|
| 132 |
+
rb.penalty += ordering_score * 0.3
|
| 133 |
|
| 134 |
# information gain proxy: quality × (1 - uncertainty)
|
| 135 |
rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
|
| 136 |
+
if action.action_type in META_ACTIONS and not (
|
| 137 |
+
prev_state.progress.de_performed
|
| 138 |
+
or prev_state.progress.cells_clustered
|
| 139 |
+
):
|
| 140 |
+
# Meta actions before substantive analysis should not dominate reward.
|
| 141 |
+
rb.info_gain *= 0.2
|
| 142 |
|
| 143 |
# efficiency: normalised cost relative to budget
|
| 144 |
budget_frac = (
|
|
|
|
| 151 |
if not soft_violations:
|
| 152 |
rb.novelty = 0.1
|
| 153 |
|
| 154 |
+
# tool-modality fit bonus/penalty
|
| 155 |
+
tool_fit = self._tool_fit_score(action, prev_state)
|
| 156 |
+
rb.components["tool_fit"] = tool_fit
|
| 157 |
+
rb.validity += 0.15 * tool_fit
|
| 158 |
+
|
| 159 |
# penalties
|
| 160 |
rb.penalty = -0.15 * len(soft_violations)
|
| 161 |
+
if action.action_type in META_ACTIONS and not (
|
| 162 |
+
prev_state.progress.de_performed
|
| 163 |
+
or prev_state.progress.cells_clustered
|
| 164 |
+
):
|
| 165 |
+
rb.penalty -= 0.25
|
| 166 |
+
rb.components["premature_meta_action_penalty"] = -0.25
|
| 167 |
+
|
| 168 |
+
# potential-based shaping (γ=1 so it doesn't depend on the
|
| 169 |
+
# training algorithm's discount factor)
|
| 170 |
phi_prev = self._potential(prev_state)
|
| 171 |
phi_next = self._potential(next_state)
|
| 172 |
+
rb.shaping = phi_next - phi_prev
|
| 173 |
|
| 174 |
return rb
|
| 175 |
|
|
|
|
| 180 |
state: FullLatentState,
|
| 181 |
conclusions: List[ConclusionClaim],
|
| 182 |
task_success_criteria: List[str],
|
| 183 |
+
discovered_markers: Optional[List[str]] = None,
|
| 184 |
+
candidate_mechanisms: Optional[List[str]] = None,
|
| 185 |
) -> RewardBreakdown:
|
| 186 |
rb = RewardBreakdown()
|
| 187 |
+
discovered_markers = discovered_markers or []
|
| 188 |
+
candidate_mechanisms = candidate_mechanisms or []
|
| 189 |
|
| 190 |
# pipeline completeness (0-1)
|
| 191 |
completeness = self._completeness(state)
|
|
|
|
| 209 |
overconf = self._overconfidence_penalty(state, conclusions)
|
| 210 |
rb.components["overconfidence_penalty"] = overconf
|
| 211 |
|
| 212 |
+
discovery_alignment = self._discovery_alignment(
|
| 213 |
+
state,
|
| 214 |
+
discovered_markers,
|
| 215 |
+
candidate_mechanisms,
|
| 216 |
+
)
|
| 217 |
+
discovery_error_penalty = -2.5 * (1.0 - discovery_alignment)
|
| 218 |
+
rb.components["discovery_alignment"] = discovery_alignment
|
| 219 |
+
rb.components["discovery_error_penalty"] = discovery_error_penalty
|
| 220 |
+
|
| 221 |
+
eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
|
| 222 |
rb.terminal = (
|
| 223 |
3.0 * completeness
|
| 224 |
+ 4.0 * calibration
|
| 225 |
+
+ 1.0 * eff_bonus
|
| 226 |
+ overconf
|
| 227 |
+
+ discovery_error_penalty
|
| 228 |
)
|
| 229 |
return rb
|
| 230 |
|
|
|
|
| 233 |
def _ordering_score(
|
| 234 |
self, action: ExperimentAction, s: FullLatentState
|
| 235 |
) -> float:
|
| 236 |
+
"""Heuristic: 1.0 if natural next, 0.3 if acceptable, -1.0 if premature."""
|
| 237 |
at = action.action_type
|
| 238 |
p = s.progress
|
| 239 |
NATURAL_NEXT = {
|
|
|
|
| 252 |
p.de_performed or p.cells_clustered
|
| 253 |
) and not p.conclusion_reached,
|
| 254 |
}
|
| 255 |
+
if NATURAL_NEXT.get(at, False):
|
| 256 |
+
return 1.0
|
| 257 |
+
|
| 258 |
+
has_evidence = any([
|
| 259 |
+
p.cells_clustered, p.de_performed, p.trajectories_inferred,
|
| 260 |
+
p.pathways_analyzed, p.networks_inferred, p.markers_discovered,
|
| 261 |
+
])
|
| 262 |
+
if at in META_ACTIONS and not has_evidence:
|
| 263 |
+
return -1.0
|
| 264 |
+
|
| 265 |
+
return 0.3
|
| 266 |
|
| 267 |
def _potential(self, s: FullLatentState) -> float:
|
| 268 |
+
"""Progress potential φ(s) — counts completed milestones.
|
| 269 |
+
|
| 270 |
+
Returns 0.0 at terminal states so that the shaping signal
|
| 271 |
+
telescopes correctly over the episode.
|
| 272 |
+
"""
|
| 273 |
+
if s.progress.conclusion_reached:
|
| 274 |
+
return 0.0
|
| 275 |
p = s.progress
|
| 276 |
milestones = [
|
| 277 |
p.samples_collected,
|
|
|
|
| 305 |
def _calibration(
|
| 306 |
self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 307 |
) -> float:
|
| 308 |
+
"""Structured set-similarity calibration against hidden ground truth.
|
| 309 |
+
|
| 310 |
+
Uses pathway-weighted Gaussian similarity for markers, semantic
|
| 311 |
+
similarity for mechanisms, and activity-weighted matching for pathways.
|
| 312 |
+
Falls back to legacy substring matching when structured fields are empty.
|
| 313 |
+
"""
|
| 314 |
if not conclusions:
|
| 315 |
return 0.0
|
| 316 |
|
| 317 |
+
pred_markers = [g for c in conclusions for g in c.top_markers]
|
| 318 |
+
pred_mechs = [m for c in conclusions for m in c.causal_mechanisms]
|
| 319 |
+
pred_pathways = {
|
| 320 |
+
p: v for c in conclusions for p, v in c.predicted_pathways.items()
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
has_structured = bool(pred_markers or pred_mechs or pred_pathways)
|
| 324 |
+
|
| 325 |
+
if has_structured:
|
| 326 |
+
m_score = marker_set_score(pred_markers, s.biology.true_markers)
|
| 327 |
+
mech_score = mechanism_set_score(
|
| 328 |
+
pred_mechs, s.biology.causal_mechanisms
|
| 329 |
+
)
|
| 330 |
+
pw_score = score_pathways(pred_pathways, s.biology.true_pathways)
|
| 331 |
+
return 0.50 * m_score + 0.35 * mech_score + 0.15 * pw_score
|
| 332 |
+
|
| 333 |
+
return self._legacy_calibration(s, conclusions)
|
| 334 |
+
|
| 335 |
+
@staticmethod
|
| 336 |
+
def _legacy_calibration(
|
| 337 |
+
s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 338 |
+
) -> float:
|
| 339 |
+
"""Substring-based calibration kept for backward compatibility."""
|
| 340 |
true_mechanisms = set(s.biology.causal_mechanisms)
|
| 341 |
true_markers = set(s.biology.true_markers)
|
| 342 |
score = 0.0
|
|
|
|
| 352 |
score -= 0.3
|
| 353 |
return max(0.0, min(1.0, score / max(n, 1)))
|
| 354 |
|
| 355 |
+
_METHOD_TO_TOOL: Dict[str, str] = {
|
| 356 |
+
"scanpy.pp.calculate_qc_metrics": "Scanpy",
|
| 357 |
+
"scanpy.pp.filter_cells": "Scanpy",
|
| 358 |
+
"scanpy.pp.filter_genes": "Scanpy",
|
| 359 |
+
"scanpy.pp.normalize_total": "Scanpy",
|
| 360 |
+
"scanpy.pp.log1p": "Scanpy",
|
| 361 |
+
"scanpy.pp.highly_variable_genes": "Scanpy",
|
| 362 |
+
"scanpy.pp.neighbors": "Scanpy",
|
| 363 |
+
"scanpy.tl.leiden": "Leiden",
|
| 364 |
+
"scanpy.tl.louvain": "Louvain",
|
| 365 |
+
"scanpy.tl.rank_genes_groups": "Scanpy",
|
| 366 |
+
"scanpy.tl.paga": "PAGA",
|
| 367 |
+
"scanpy.tl.umap": "UMAP",
|
| 368 |
+
"gseapy.prerank": "Scanpy",
|
| 369 |
+
"gseapy.gsea": "Scanpy",
|
| 370 |
+
"10x_chromium": "CellRanger",
|
| 371 |
+
"NovaSeq": "CellRanger",
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
@staticmethod
|
| 375 |
+
def _tool_fit_score(
|
| 376 |
+
action: ExperimentAction, s: FullLatentState
|
| 377 |
+
) -> float:
|
| 378 |
+
"""Score how well the chosen tool matches the task modality.
|
| 379 |
+
|
| 380 |
+
Returns +1.0 for a perfect match, 0.0 if no tool specified,
|
| 381 |
+
-1.0 for a known tool used on an incompatible modality.
|
| 382 |
+
"""
|
| 383 |
+
method = action.method
|
| 384 |
+
if not method:
|
| 385 |
+
return 0.0
|
| 386 |
+
resolved = RewardComputer._METHOD_TO_TOOL.get(method, method)
|
| 387 |
+
tool_spec = TOOL_REGISTRY.get(resolved)
|
| 388 |
+
if tool_spec is None:
|
| 389 |
+
return -0.5
|
| 390 |
+
modality = getattr(s, "task_modality", None)
|
| 391 |
+
if not modality or not tool_spec.modalities:
|
| 392 |
+
return 0.0
|
| 393 |
+
if modality in tool_spec.modalities:
|
| 394 |
+
return 1.0
|
| 395 |
+
return -1.0
|
| 396 |
+
|
| 397 |
def _overconfidence_penalty(
|
| 398 |
self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 399 |
) -> float:
|
| 400 |
+
"""Penalise high-confidence claims that disagree with ground truth.
|
| 401 |
+
|
| 402 |
+
Checks structured fields (top_markers, causal_mechanisms) first;
|
| 403 |
+
falls back to claim substring matching for backward compatibility.
|
| 404 |
+
"""
|
| 405 |
penalty = 0.0
|
| 406 |
+
true_markers_lower = {m.lower() for m in s.biology.true_markers}
|
| 407 |
+
true_mechs_lower = {m.lower() for m in s.biology.causal_mechanisms}
|
| 408 |
+
true_set = true_markers_lower | true_mechs_lower
|
| 409 |
+
|
| 410 |
for c in conclusions:
|
| 411 |
+
if c.confidence <= 0.8:
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
has_structured = bool(c.top_markers or c.causal_mechanisms)
|
| 415 |
+
if has_structured:
|
| 416 |
+
marker_hit = any(
|
| 417 |
+
g.upper().strip() in {m.upper() for m in s.biology.true_markers}
|
| 418 |
+
for g in c.top_markers
|
| 419 |
+
)
|
| 420 |
+
mech_hit = any(
|
| 421 |
+
any(kw in m.lower() for kw in t.lower().split())
|
| 422 |
+
for m in c.causal_mechanisms
|
| 423 |
+
for t in s.biology.causal_mechanisms
|
| 424 |
+
)
|
| 425 |
+
is_correct = marker_hit or mech_hit
|
| 426 |
+
else:
|
| 427 |
+
is_correct = any(t in c.claim.lower() for t in true_set)
|
| 428 |
+
|
| 429 |
+
if not is_correct:
|
| 430 |
penalty -= 0.5 * c.confidence
|
| 431 |
+
|
| 432 |
return penalty
|
| 433 |
+
|
| 434 |
+
def _discovery_alignment(
|
| 435 |
+
self,
|
| 436 |
+
s: FullLatentState,
|
| 437 |
+
discovered_markers: List[str],
|
| 438 |
+
candidate_mechanisms: List[str],
|
| 439 |
+
) -> float:
|
| 440 |
+
"""Symmetric end-of-episode similarity for discovered biology.
|
| 441 |
+
|
| 442 |
+
Forward scoring measures recall against hidden truth. Reverse scoring
|
| 443 |
+
measures how well the agent's discoveries map back onto real biology,
|
| 444 |
+
which penalizes extra hallucinated markers or mechanisms.
|
| 445 |
+
"""
|
| 446 |
+
components: List[float] = []
|
| 447 |
+
|
| 448 |
+
if s.biology.true_markers or discovered_markers:
|
| 449 |
+
marker_recall = marker_set_score(
|
| 450 |
+
discovered_markers,
|
| 451 |
+
s.biology.true_markers,
|
| 452 |
+
)
|
| 453 |
+
marker_precision = marker_set_score(
|
| 454 |
+
s.biology.true_markers,
|
| 455 |
+
discovered_markers,
|
| 456 |
+
)
|
| 457 |
+
components.append((marker_recall + marker_precision) / 2.0)
|
| 458 |
+
|
| 459 |
+
if s.biology.causal_mechanisms or candidate_mechanisms:
|
| 460 |
+
mechanism_recall = mechanism_set_score(
|
| 461 |
+
candidate_mechanisms,
|
| 462 |
+
s.biology.causal_mechanisms,
|
| 463 |
+
)
|
| 464 |
+
mechanism_precision = mechanism_set_score(
|
| 465 |
+
s.biology.causal_mechanisms,
|
| 466 |
+
candidate_mechanisms,
|
| 467 |
+
)
|
| 468 |
+
components.append((mechanism_recall + mechanism_precision) / 2.0)
|
| 469 |
+
|
| 470 |
+
if not components:
|
| 471 |
+
return 1.0
|
| 472 |
+
return sum(components) / len(components)
|
server/rules/engine.py
CHANGED
|
@@ -10,7 +10,7 @@ from dataclasses import dataclass
|
|
| 10 |
from enum import Enum
|
| 11 |
from typing import List
|
| 12 |
|
| 13 |
-
from models import ActionType, ExperimentAction
|
| 14 |
|
| 15 |
from server.simulator.latent_state import FullLatentState
|
| 16 |
|
|
@@ -32,6 +32,19 @@ class RuleEngine:
|
|
| 32 |
latent state before each action is applied.
|
| 33 |
"""
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def check(
|
| 36 |
self, action: ExperimentAction, state: FullLatentState
|
| 37 |
) -> List[RuleViolation]:
|
|
@@ -40,6 +53,7 @@ class RuleEngine:
|
|
| 40 |
violations.extend(self._check_resource_constraints(action, state))
|
| 41 |
violations.extend(self._check_redundancy(action, state))
|
| 42 |
violations.extend(self._check_causal_validity(action, state))
|
|
|
|
| 43 |
return violations
|
| 44 |
|
| 45 |
def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
|
|
@@ -106,6 +120,9 @@ class RuleEngine:
|
|
| 106 |
ActionType.CULTURE_CELLS: [
|
| 107 |
("samples_collected", "Cannot culture without samples"),
|
| 108 |
],
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
|
| 111 |
for flag, msg in REQUIRES.get(at, []):
|
|
@@ -127,22 +144,22 @@ class RuleEngine:
|
|
| 127 |
vs.append(RuleViolation(
|
| 128 |
rule_id="budget_exhausted",
|
| 129 |
severity=Severity.HARD,
|
| 130 |
-
message="Budget exhausted
|
| 131 |
))
|
| 132 |
if s.resources.time_exhausted:
|
| 133 |
vs.append(RuleViolation(
|
| 134 |
rule_id="time_exhausted",
|
| 135 |
severity=Severity.HARD,
|
| 136 |
-
message="Time limit reached
|
| 137 |
))
|
| 138 |
|
| 139 |
remaining = s.resources.budget_remaining
|
| 140 |
-
from server.simulator.transition import
|
| 141 |
-
cost, _ =
|
| 142 |
if cost > remaining and remaining > 0:
|
| 143 |
vs.append(RuleViolation(
|
| 144 |
rule_id="budget_insufficient",
|
| 145 |
-
severity=Severity.
|
| 146 |
message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
|
| 147 |
))
|
| 148 |
return vs
|
|
@@ -163,13 +180,23 @@ class RuleEngine:
|
|
| 163 |
ActionType.RUN_QC: "qc_performed",
|
| 164 |
ActionType.FILTER_DATA: "data_filtered",
|
| 165 |
ActionType.NORMALIZE_DATA: "data_normalized",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
}
|
| 167 |
flag = REDUNDANT.get(at)
|
| 168 |
if flag and getattr(p, flag, False):
|
| 169 |
vs.append(RuleViolation(
|
| 170 |
rule_id=f"redundant_{at.value}",
|
| 171 |
-
severity=Severity.
|
| 172 |
-
message=f"Step '{at.value}' already completed — redundant action",
|
| 173 |
))
|
| 174 |
return vs
|
| 175 |
|
|
@@ -179,12 +206,36 @@ class RuleEngine:
|
|
| 179 |
self, action: ExperimentAction, s: FullLatentState
|
| 180 |
) -> List[RuleViolation]:
|
| 181 |
vs: List[RuleViolation] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 183 |
if not s.progress.de_performed and not s.progress.cells_clustered:
|
| 184 |
vs.append(RuleViolation(
|
| 185 |
rule_id="premature_conclusion",
|
| 186 |
-
severity=Severity.
|
| 187 |
-
message="
|
| 188 |
))
|
| 189 |
|
| 190 |
claims = action.parameters.get("claims", [])
|
|
@@ -206,3 +257,72 @@ class RuleEngine:
|
|
| 206 |
message="Pathway enrichment without DE may yield unreliable results",
|
| 207 |
))
|
| 208 |
return vs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from enum import Enum
|
| 11 |
from typing import List
|
| 12 |
|
| 13 |
+
from models import ActionType, ExperimentAction, TOOL_REGISTRY
|
| 14 |
|
| 15 |
from server.simulator.latent_state import FullLatentState
|
| 16 |
|
|
|
|
| 32 |
latent state before each action is applied.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def _has_analysis_evidence(s: FullLatentState) -> bool:
|
| 37 |
+
p = s.progress
|
| 38 |
+
return any([
|
| 39 |
+
p.cells_clustered,
|
| 40 |
+
p.de_performed,
|
| 41 |
+
p.trajectories_inferred,
|
| 42 |
+
p.pathways_analyzed,
|
| 43 |
+
p.networks_inferred,
|
| 44 |
+
p.markers_discovered,
|
| 45 |
+
p.markers_validated,
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
def check(
|
| 49 |
self, action: ExperimentAction, state: FullLatentState
|
| 50 |
) -> List[RuleViolation]:
|
|
|
|
| 53 |
violations.extend(self._check_resource_constraints(action, state))
|
| 54 |
violations.extend(self._check_redundancy(action, state))
|
| 55 |
violations.extend(self._check_causal_validity(action, state))
|
| 56 |
+
violations.extend(self._check_tool_compatibility(action, state))
|
| 57 |
return violations
|
| 58 |
|
| 59 |
def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
|
|
|
|
| 120 |
ActionType.CULTURE_CELLS: [
|
| 121 |
("samples_collected", "Cannot culture without samples"),
|
| 122 |
],
|
| 123 |
+
ActionType.SYNTHESIZE_CONCLUSION: [
|
| 124 |
+
("data_normalized", "Cannot synthesize conclusions before data normalization"),
|
| 125 |
+
],
|
| 126 |
}
|
| 127 |
|
| 128 |
for flag, msg in REQUIRES.get(at, []):
|
|
|
|
| 144 |
vs.append(RuleViolation(
|
| 145 |
rule_id="budget_exhausted",
|
| 146 |
severity=Severity.HARD,
|
| 147 |
+
message="Budget exhausted - no further actions possible",
|
| 148 |
))
|
| 149 |
if s.resources.time_exhausted:
|
| 150 |
vs.append(RuleViolation(
|
| 151 |
rule_id="time_exhausted",
|
| 152 |
severity=Severity.HARD,
|
| 153 |
+
message="Time limit reached - no further actions possible",
|
| 154 |
))
|
| 155 |
|
| 156 |
remaining = s.resources.budget_remaining
|
| 157 |
+
from server.simulator.transition import compute_action_cost
|
| 158 |
+
cost, _ = compute_action_cost(action)
|
| 159 |
if cost > remaining and remaining > 0:
|
| 160 |
vs.append(RuleViolation(
|
| 161 |
rule_id="budget_insufficient",
|
| 162 |
+
severity=Severity.HARD,
|
| 163 |
message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
|
| 164 |
))
|
| 165 |
return vs
|
|
|
|
| 180 |
ActionType.RUN_QC: "qc_performed",
|
| 181 |
ActionType.FILTER_DATA: "data_filtered",
|
| 182 |
ActionType.NORMALIZE_DATA: "data_normalized",
|
| 183 |
+
ActionType.CLUSTER_CELLS: "cells_clustered",
|
| 184 |
+
ActionType.DIFFERENTIAL_EXPRESSION: "de_performed",
|
| 185 |
+
ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred",
|
| 186 |
+
ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed",
|
| 187 |
+
ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
|
| 188 |
+
ActionType.MARKER_SELECTION: "markers_discovered",
|
| 189 |
+
ActionType.VALIDATE_MARKER: "markers_validated",
|
| 190 |
+
ActionType.DESIGN_FOLLOWUP: "followup_designed",
|
| 191 |
+
ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested",
|
| 192 |
+
ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
|
| 193 |
}
|
| 194 |
flag = REDUNDANT.get(at)
|
| 195 |
if flag and getattr(p, flag, False):
|
| 196 |
vs.append(RuleViolation(
|
| 197 |
rule_id=f"redundant_{at.value}",
|
| 198 |
+
severity=Severity.HARD,
|
| 199 |
+
message=f"Step '{at.value}' already completed — redundant action blocked",
|
| 200 |
))
|
| 201 |
return vs
|
| 202 |
|
|
|
|
| 206 |
self, action: ExperimentAction, s: FullLatentState
|
| 207 |
) -> List[RuleViolation]:
|
| 208 |
vs: List[RuleViolation] = []
|
| 209 |
+
has_analysis_evidence = self._has_analysis_evidence(s)
|
| 210 |
+
|
| 211 |
+
if action.action_type == ActionType.DESIGN_FOLLOWUP:
|
| 212 |
+
if not has_analysis_evidence:
|
| 213 |
+
vs.append(RuleViolation(
|
| 214 |
+
rule_id="premature_followup_design",
|
| 215 |
+
severity=Severity.HARD,
|
| 216 |
+
message=(
|
| 217 |
+
"Follow-up design without prior analysis is blocked; "
|
| 218 |
+
"complete wet-lab and computational steps first"
|
| 219 |
+
),
|
| 220 |
+
))
|
| 221 |
+
|
| 222 |
+
if action.action_type == ActionType.REQUEST_SUBAGENT_REVIEW:
|
| 223 |
+
if not has_analysis_evidence:
|
| 224 |
+
vs.append(RuleViolation(
|
| 225 |
+
rule_id="premature_subagent_review",
|
| 226 |
+
severity=Severity.HARD,
|
| 227 |
+
message=(
|
| 228 |
+
"Subagent review without prior analysis is blocked; "
|
| 229 |
+
"generate evidence first"
|
| 230 |
+
),
|
| 231 |
+
))
|
| 232 |
+
|
| 233 |
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 234 |
if not s.progress.de_performed and not s.progress.cells_clustered:
|
| 235 |
vs.append(RuleViolation(
|
| 236 |
rule_id="premature_conclusion",
|
| 237 |
+
severity=Severity.HARD,
|
| 238 |
+
message="Cannot synthesise conclusion without substantive analysis",
|
| 239 |
))
|
| 240 |
|
| 241 |
claims = action.parameters.get("claims", [])
|
|
|
|
| 257 |
message="Pathway enrichment without DE may yield unreliable results",
|
| 258 |
))
|
| 259 |
return vs
|
| 260 |
+
|
| 261 |
+
# ── tool / modality compatibility ────────────────────────────────────
|
| 262 |
+
|
| 263 |
+
_KNOWN_METHODS = {
|
| 264 |
+
"scanpy.pp.calculate_qc_metrics", "scanpy.pp.filter_cells",
|
| 265 |
+
"scanpy.pp.filter_genes", "scanpy.pp.normalize_total",
|
| 266 |
+
"scanpy.pp.log1p", "scanpy.pp.highly_variable_genes",
|
| 267 |
+
"scanpy.pp.neighbors", "scanpy.tl.leiden", "scanpy.tl.louvain",
|
| 268 |
+
"scanpy.tl.rank_genes_groups", "scanpy.tl.paga", "scanpy.tl.umap",
|
| 269 |
+
"gseapy.prerank", "gseapy.gsea", "10x_chromium", "NovaSeq",
|
| 270 |
+
}
|
| 271 |
+
_METHOD_TO_TOOL = {
|
| 272 |
+
"scanpy.pp.calculate_qc_metrics": "Scanpy",
|
| 273 |
+
"scanpy.pp.filter_cells": "Scanpy",
|
| 274 |
+
"scanpy.pp.filter_genes": "Scanpy",
|
| 275 |
+
"scanpy.pp.normalize_total": "Scanpy",
|
| 276 |
+
"scanpy.pp.log1p": "Scanpy",
|
| 277 |
+
"scanpy.pp.highly_variable_genes": "Scanpy",
|
| 278 |
+
"scanpy.pp.neighbors": "Scanpy",
|
| 279 |
+
"scanpy.tl.leiden": "Leiden",
|
| 280 |
+
"scanpy.tl.louvain": "Louvain",
|
| 281 |
+
"scanpy.tl.rank_genes_groups": "Scanpy",
|
| 282 |
+
"scanpy.tl.paga": "PAGA",
|
| 283 |
+
"scanpy.tl.umap": "UMAP",
|
| 284 |
+
"gseapy.prerank": "Scanpy",
|
| 285 |
+
"gseapy.gsea": "Scanpy",
|
| 286 |
+
"10x_chromium": "CellRanger",
|
| 287 |
+
"NovaSeq": "CellRanger",
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def _check_tool_compatibility(
|
| 291 |
+
self, action: ExperimentAction, s: FullLatentState
|
| 292 |
+
) -> List[RuleViolation]:
|
| 293 |
+
"""Warn when the chosen tool is incompatible with the task modality."""
|
| 294 |
+
vs: List[RuleViolation] = []
|
| 295 |
+
method = action.method
|
| 296 |
+
if not method:
|
| 297 |
+
return vs
|
| 298 |
+
|
| 299 |
+
resolved = self._METHOD_TO_TOOL.get(method, method)
|
| 300 |
+
tool_spec = TOOL_REGISTRY.get(resolved)
|
| 301 |
+
if tool_spec is None and method not in self._KNOWN_METHODS:
|
| 302 |
+
vs.append(RuleViolation(
|
| 303 |
+
rule_id="unknown_tool",
|
| 304 |
+
severity=Severity.SOFT,
|
| 305 |
+
message=f"Tool '{method}' is not in the registry — results may be unreliable",
|
| 306 |
+
))
|
| 307 |
+
return vs
|
| 308 |
+
if tool_spec is None:
|
| 309 |
+
return vs
|
| 310 |
+
|
| 311 |
+
# Check modality compatibility (modality lives on the task, which is
|
| 312 |
+
# stored in the latent state's associated TaskSpec — but the latent
|
| 313 |
+
# state doesn't carry the TaskSpec directly. We can still check via
|
| 314 |
+
# the action's own context or fall back gracefully).
|
| 315 |
+
task_modality = getattr(s, "task_modality", None)
|
| 316 |
+
if task_modality and tool_spec.modalities:
|
| 317 |
+
if task_modality not in tool_spec.modalities:
|
| 318 |
+
vs.append(RuleViolation(
|
| 319 |
+
rule_id="tool_modality_mismatch",
|
| 320 |
+
severity=Severity.SOFT,
|
| 321 |
+
message=(
|
| 322 |
+
f"Tool '{method}' is designed for "
|
| 323 |
+
f"{', '.join(tool_spec.modalities)} but task modality "
|
| 324 |
+
f"is '{task_modality}'"
|
| 325 |
+
),
|
| 326 |
+
))
|
| 327 |
+
|
| 328 |
+
return vs
|
server/simulator/latent_state.py
CHANGED
|
@@ -88,8 +88,11 @@ class ExperimentProgress(BaseModel):
|
|
| 88 |
networks_inferred: bool = False
|
| 89 |
markers_discovered: bool = False
|
| 90 |
markers_validated: bool = False
|
|
|
|
|
|
|
| 91 |
conclusion_reached: bool = False
|
| 92 |
|
|
|
|
| 93 |
n_cells_after_filter: Optional[int] = None
|
| 94 |
n_clusters_found: Optional[int] = None
|
| 95 |
n_de_genes_found: Optional[int] = None
|
|
@@ -139,5 +142,12 @@ class FullLatentState(BaseModel):
|
|
| 139 |
mechanism_confidence: Dict[str, float] = Field(default_factory=dict)
|
| 140 |
discovered_de_genes: List[str] = Field(default_factory=list)
|
| 141 |
discovered_clusters: List[str] = Field(default_factory=list)
|
|
|
|
| 142 |
step_count: int = 0
|
| 143 |
rng_seed: int = 42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
networks_inferred: bool = False
|
| 89 |
markers_discovered: bool = False
|
| 90 |
markers_validated: bool = False
|
| 91 |
+
followup_designed: bool = False
|
| 92 |
+
subagent_review_requested: bool = False
|
| 93 |
conclusion_reached: bool = False
|
| 94 |
|
| 95 |
+
n_cells_sequenced: Optional[int] = None
|
| 96 |
n_cells_after_filter: Optional[int] = None
|
| 97 |
n_clusters_found: Optional[int] = None
|
| 98 |
n_de_genes_found: Optional[int] = None
|
|
|
|
| 142 |
mechanism_confidence: Dict[str, float] = Field(default_factory=dict)
|
| 143 |
discovered_de_genes: List[str] = Field(default_factory=list)
|
| 144 |
discovered_clusters: List[str] = Field(default_factory=list)
|
| 145 |
+
task_modality: str = "scRNA-seq"
|
| 146 |
step_count: int = 0
|
| 147 |
rng_seed: int = 42
|
| 148 |
+
|
| 149 |
+
# Transient fields for passing sampled values from the transition engine
|
| 150 |
+
# to the output generator within a single step (not serialized).
|
| 151 |
+
last_retain_frac: Optional[float] = Field(None, exclude=True)
|
| 152 |
+
last_n_clusters: Optional[int] = Field(None, exclude=True)
|
| 153 |
+
last_perturbation_efficiency: Optional[float] = Field(None, exclude=True)
|
server/simulator/noise.py
CHANGED
|
@@ -30,7 +30,11 @@ class NoiseModel:
|
|
| 30 |
) -> Dict[str, float]:
|
| 31 |
noisy: Dict[str, float] = {}
|
| 32 |
for gene, value in true_values.items():
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
noisy[gene] = 0.0
|
| 35 |
else:
|
| 36 |
sigma = noise_level * abs(value) + 0.1
|
|
|
|
| 30 |
) -> Dict[str, float]:
|
| 31 |
noisy: Dict[str, float] = {}
|
| 32 |
for gene, value in true_values.items():
|
| 33 |
+
# Dropout probability is inversely proportional to expression
|
| 34 |
+
# magnitude: lowly expressed genes drop out much more readily,
|
| 35 |
+
# matching the zero-inflation pattern in real scRNA-seq data.
|
| 36 |
+
p_drop = dropout_rate / (1.0 + abs(value))
|
| 37 |
+
if self.rng.random() < p_drop:
|
| 38 |
noisy[gene] = 0.0
|
| 39 |
else:
|
| 40 |
sigma = noise_level * abs(value) + 0.1
|
server/simulator/output_generator.py
CHANGED
|
@@ -14,6 +14,15 @@ from models import (
|
|
| 14 |
from .latent_state import FullLatentState
|
| 15 |
from .noise import NoiseModel
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
class OutputGenerator:
|
| 19 |
"""Creates structured ``IntermediateOutput`` objects conditioned on the
|
|
@@ -91,7 +100,13 @@ class OutputGenerator:
|
|
| 91 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 92 |
) -> IntermediateOutput:
|
| 93 |
days = action.parameters.get("days", 7)
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return IntermediateOutput(
|
| 96 |
output_type=OutputType.CULTURE_RESULT,
|
| 97 |
step_index=idx,
|
|
@@ -101,20 +116,54 @@ class OutputGenerator:
|
|
| 101 |
artifacts_available=["cultured_cells"],
|
| 102 |
)
|
| 103 |
|
| 104 |
-
def
|
| 105 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 106 |
) -> IntermediateOutput:
|
|
|
|
|
|
|
| 107 |
target = action.parameters.get("target", "unknown")
|
| 108 |
-
efficiency = self.noise.sample_qc_metric(0.
|
|
|
|
| 109 |
return IntermediateOutput(
|
| 110 |
output_type=OutputType.PERTURBATION_RESULT,
|
| 111 |
step_index=idx,
|
| 112 |
quality_score=efficiency,
|
| 113 |
-
summary=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
data={
|
| 115 |
"target": target,
|
| 116 |
"efficiency": efficiency,
|
| 117 |
"type": action.action_type.value,
|
|
|
|
|
|
|
| 118 |
},
|
| 119 |
artifacts_available=["perturbed_cells"],
|
| 120 |
)
|
|
@@ -122,11 +171,18 @@ class OutputGenerator:
|
|
| 122 |
def _sequence_cells(
|
| 123 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 124 |
) -> IntermediateOutput:
|
|
|
|
| 125 |
depth = s.technical.sequencing_depth_factor
|
| 126 |
-
n_cells = self.noise.sample_count(
|
| 127 |
s.biology.n_true_cells * s.technical.capture_efficiency
|
| 128 |
)
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
median_umi = self.noise.sample_count(int(3000 * depth))
|
| 131 |
quality = self.noise.quality_degradation(
|
| 132 |
s.technical.sample_quality,
|
|
@@ -157,7 +213,18 @@ class OutputGenerator:
|
|
| 157 |
doublet_frac = self.noise.sample_qc_metric(
|
| 158 |
s.technical.doublet_rate, 0.01, 0.0, 0.2
|
| 159 |
)
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
ambient_frac = self.noise.sample_qc_metric(
|
| 162 |
s.technical.ambient_rna_fraction, 0.01, 0.0, 0.2
|
| 163 |
)
|
|
@@ -186,9 +253,9 @@ class OutputGenerator:
|
|
| 186 |
def _filter_data(
|
| 187 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 188 |
) -> IntermediateOutput:
|
| 189 |
-
retain_frac = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| 190 |
-
n_before = s.biology.n_true_cells
|
| 191 |
-
n_after = max(100, int(n_before * retain_frac))
|
| 192 |
return IntermediateOutput(
|
| 193 |
output_type=OutputType.COUNT_MATRIX_SUMMARY,
|
| 194 |
step_index=idx,
|
|
@@ -238,14 +305,15 @@ class OutputGenerator:
|
|
| 238 |
) -> IntermediateOutput:
|
| 239 |
n_true = len(s.biology.cell_populations) or 5
|
| 240 |
quality = self.noise.quality_degradation(0.8, [0.95])
|
| 241 |
-
n_clusters = self.noise.sample_cluster_count(n_true, quality)
|
| 242 |
cluster_names = [f"cluster_{i}" for i in range(n_clusters)]
|
| 243 |
-
|
|
|
|
| 244 |
return IntermediateOutput(
|
| 245 |
output_type=OutputType.CLUSTER_RESULT,
|
| 246 |
step_index=idx,
|
| 247 |
quality_score=quality,
|
| 248 |
-
summary=f"Found {n_clusters} clusters
|
| 249 |
data={
|
| 250 |
"n_clusters": n_clusters,
|
| 251 |
"cluster_names": cluster_names,
|
|
@@ -260,10 +328,22 @@ class OutputGenerator:
|
|
| 260 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 261 |
) -> IntermediateOutput:
|
| 262 |
comparison = action.parameters.get("comparison", "disease_vs_healthy")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
true_effects = s.biology.true_de_genes.get(comparison, {})
|
| 264 |
|
| 265 |
n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
observed = self.noise.sample_effect_sizes(true_effects, n_cells, noise_level)
|
| 268 |
|
| 269 |
fp_genes = self.noise.generate_false_positives(5000, 0.002 + noise_level * 0.01)
|
|
@@ -299,10 +379,16 @@ class OutputGenerator:
|
|
| 299 |
quality = self.noise.quality_degradation(0.7 if has_trajectory else 0.3, [0.9])
|
| 300 |
summary_data: Dict[str, Any] = {"method": action.method or "monocle3"}
|
| 301 |
if has_trajectory:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
summary_data.update({
|
| 303 |
-
"n_lineages":
|
| 304 |
"pseudotime_range": [0.0, 1.0],
|
| 305 |
-
"branching_detected":
|
| 306 |
})
|
| 307 |
else:
|
| 308 |
summary_data["n_lineages"] = self.noise.sample_count(1) + 1
|
|
@@ -323,24 +409,38 @@ class OutputGenerator:
|
|
| 323 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 324 |
) -> IntermediateOutput:
|
| 325 |
true_pathways = s.biology.true_pathways
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
observed: Dict[str, float] = {}
|
| 328 |
for pw, activity in true_pathways.items():
|
| 329 |
observed[pw] = activity + float(self.noise.rng.normal(0, noise_level))
|
| 330 |
|
| 331 |
-
for i in range(self.noise.sample_count(
|
| 332 |
observed[f"FP_PATHWAY_{i}"] = float(self.noise.rng.uniform(0.3, 0.6))
|
| 333 |
|
| 334 |
top = sorted(observed.items(), key=lambda kv: kv[1], reverse=True)[:15]
|
|
|
|
| 335 |
return IntermediateOutput(
|
| 336 |
output_type=OutputType.PATHWAY_RESULT,
|
| 337 |
step_index=idx,
|
| 338 |
-
quality_score=self.noise.quality_degradation(
|
| 339 |
summary=f"Pathway enrichment: {len(top)} significant pathways",
|
| 340 |
data={
|
| 341 |
"method": action.method or "GSEA",
|
| 342 |
"top_pathways": [
|
| 343 |
-
{"pathway": p, "score": round(
|
| 344 |
],
|
| 345 |
},
|
| 346 |
uncertainty=noise_level,
|
|
@@ -353,6 +453,25 @@ class OutputGenerator:
|
|
| 353 |
true_net = s.biology.true_regulatory_network
|
| 354 |
n_edges_true = sum(len(v) for v in true_net.values())
|
| 355 |
noise_edges = self.noise.sample_count(max(5, int(n_edges_true * 0.3)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
return IntermediateOutput(
|
| 357 |
output_type=OutputType.NETWORK_RESULT,
|
| 358 |
step_index=idx,
|
|
@@ -362,7 +481,7 @@ class OutputGenerator:
|
|
| 362 |
"method": action.method or "SCENIC",
|
| 363 |
"n_regulons": len(true_net) + self.noise.sample_count(3),
|
| 364 |
"n_edges": n_edges_true + noise_edges,
|
| 365 |
-
"top_regulators":
|
| 366 |
},
|
| 367 |
uncertainty=0.35,
|
| 368 |
artifacts_available=["regulon_table", "grn_adjacency"],
|
|
@@ -407,8 +526,11 @@ class OutputGenerator:
|
|
| 407 |
"marker": marker,
|
| 408 |
"validated": validated,
|
| 409 |
"assay": action.method or "qPCR",
|
|
|
|
|
|
|
|
|
|
| 410 |
"effect_size": self.noise.sample_qc_metric(
|
| 411 |
-
|
| 412 |
),
|
| 413 |
},
|
| 414 |
artifacts_available=["validation_data"],
|
|
@@ -417,22 +539,54 @@ class OutputGenerator:
|
|
| 417 |
def _design_followup(
|
| 418 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 419 |
) -> IntermediateOutput:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
return IntermediateOutput(
|
| 421 |
output_type=OutputType.FOLLOWUP_DESIGN,
|
| 422 |
step_index=idx,
|
| 423 |
-
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
artifacts_available=["followup_proposal"],
|
| 426 |
)
|
| 427 |
|
| 428 |
def _subagent_review(
|
| 429 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 430 |
) -> IntermediateOutput:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
return IntermediateOutput(
|
| 432 |
output_type=OutputType.SUBAGENT_REPORT,
|
| 433 |
step_index=idx,
|
|
|
|
| 434 |
summary=f"Subagent review ({action.invoked_subagent or 'general'})",
|
| 435 |
-
data={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
artifacts_available=["subagent_report"],
|
| 437 |
)
|
| 438 |
|
|
@@ -469,14 +623,46 @@ class OutputGenerator:
|
|
| 469 |
sizes[0] += diff
|
| 470 |
return sizes
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
_HANDLERS = {
|
| 474 |
ActionType.COLLECT_SAMPLE: OutputGenerator._collect_sample,
|
| 475 |
ActionType.SELECT_COHORT: OutputGenerator._select_cohort,
|
| 476 |
ActionType.PREPARE_LIBRARY: OutputGenerator._prepare_library,
|
| 477 |
ActionType.CULTURE_CELLS: OutputGenerator._culture_cells,
|
| 478 |
-
ActionType.PERTURB_GENE: OutputGenerator.
|
| 479 |
-
ActionType.PERTURB_COMPOUND: OutputGenerator.
|
| 480 |
ActionType.SEQUENCE_CELLS: OutputGenerator._sequence_cells,
|
| 481 |
ActionType.RUN_QC: OutputGenerator._run_qc,
|
| 482 |
ActionType.FILTER_DATA: OutputGenerator._filter_data,
|
|
|
|
| 14 |
from .latent_state import FullLatentState
|
| 15 |
from .noise import NoiseModel
|
| 16 |
|
| 17 |
+
# Pool of common transcription factors used to generate realistic false-positive
|
| 18 |
+
# regulators, so the agent cannot trivially distinguish true vs. false hits by
|
| 19 |
+
# gene-name format alone.
|
| 20 |
+
_NOISE_TFS: List[str] = [
|
| 21 |
+
"NR3C1", "KLF4", "EGR1", "IRF1", "FOSL2", "JUN", "FOS", "ATF3",
|
| 22 |
+
"NFKB1", "RELA", "SP1", "MYC", "MAX", "E2F1", "CTCF", "YY1",
|
| 23 |
+
"TP53", "STAT5A", "SMAD3", "TCF7L2", "NFE2L2", "HIF1A", "CREB1",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
|
| 27 |
class OutputGenerator:
|
| 28 |
"""Creates structured ``IntermediateOutput`` objects conditioned on the
|
|
|
|
| 100 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 101 |
) -> IntermediateOutput:
|
| 102 |
days = action.parameters.get("days", 7)
|
| 103 |
+
# Viability decays with culture duration: each day adds ~0.5%
|
| 104 |
+
# cumulative stress, reflecting senescence, media depletion, and
|
| 105 |
+
# passaging artefacts common in primary cell cultures.
|
| 106 |
+
decay = 0.005 * days
|
| 107 |
+
viability = self.noise.sample_qc_metric(
|
| 108 |
+
max(0.50, 0.95 - decay), 0.05, 0.30, 1.0
|
| 109 |
+
)
|
| 110 |
return IntermediateOutput(
|
| 111 |
output_type=OutputType.CULTURE_RESULT,
|
| 112 |
step_index=idx,
|
|
|
|
| 116 |
artifacts_available=["cultured_cells"],
|
| 117 |
)
|
| 118 |
|
| 119 |
+
def _perturb_gene(
|
| 120 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 121 |
) -> IntermediateOutput:
|
| 122 |
+
"""Genetic perturbation (CRISPR/RNAi): high on-target efficiency,
|
| 123 |
+
binary effect, non-trivial off-target risk."""
|
| 124 |
target = action.parameters.get("target", "unknown")
|
| 125 |
+
efficiency = s.last_perturbation_efficiency if s.last_perturbation_efficiency is not None else self.noise.sample_qc_metric(0.80, 0.12, 0.0, 1.0)
|
| 126 |
+
off_target_risk = self.noise.sample_qc_metric(0.10, 0.05, 0.0, 0.5)
|
| 127 |
return IntermediateOutput(
|
| 128 |
output_type=OutputType.PERTURBATION_RESULT,
|
| 129 |
step_index=idx,
|
| 130 |
quality_score=efficiency,
|
| 131 |
+
summary=(
|
| 132 |
+
f"Genetic perturbation of {target} "
|
| 133 |
+
f"(efficiency={efficiency:.2f}, off-target risk={off_target_risk:.2f})"
|
| 134 |
+
),
|
| 135 |
+
data={
|
| 136 |
+
"target": target,
|
| 137 |
+
"efficiency": efficiency,
|
| 138 |
+
"type": action.action_type.value,
|
| 139 |
+
"off_target_risk": off_target_risk,
|
| 140 |
+
},
|
| 141 |
+
artifacts_available=["perturbed_cells"],
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def _perturb_compound(
|
| 145 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 146 |
+
) -> IntermediateOutput:
|
| 147 |
+
"""Small-molecule perturbation: dose-dependent, partial on-target
|
| 148 |
+
activity, systemic effects possible."""
|
| 149 |
+
target = action.parameters.get("target", "unknown")
|
| 150 |
+
dose_um = action.parameters.get("dose_uM", 1.0)
|
| 151 |
+
efficiency = s.last_perturbation_efficiency if s.last_perturbation_efficiency is not None else self.noise.sample_qc_metric(0.70, 0.15, 0.0, 1.0)
|
| 152 |
+
on_target_frac = self.noise.sample_qc_metric(0.75, 0.10, 0.0, 1.0)
|
| 153 |
+
return IntermediateOutput(
|
| 154 |
+
output_type=OutputType.PERTURBATION_RESULT,
|
| 155 |
+
step_index=idx,
|
| 156 |
+
quality_score=efficiency * on_target_frac,
|
| 157 |
+
summary=(
|
| 158 |
+
f"Compound perturbation targeting {target} at {dose_um} µM "
|
| 159 |
+
f"(efficiency={efficiency:.2f}, on-target={on_target_frac:.2f})"
|
| 160 |
+
),
|
| 161 |
data={
|
| 162 |
"target": target,
|
| 163 |
"efficiency": efficiency,
|
| 164 |
"type": action.action_type.value,
|
| 165 |
+
"dose_uM": dose_um,
|
| 166 |
+
"on_target_fraction": on_target_frac,
|
| 167 |
},
|
| 168 |
artifacts_available=["perturbed_cells"],
|
| 169 |
)
|
|
|
|
| 171 |
def _sequence_cells(
|
| 172 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 173 |
) -> IntermediateOutput:
|
| 174 |
+
import math
|
| 175 |
depth = s.technical.sequencing_depth_factor
|
| 176 |
+
n_cells = s.progress.n_cells_sequenced or self.noise.sample_count(
|
| 177 |
s.biology.n_true_cells * s.technical.capture_efficiency
|
| 178 |
)
|
| 179 |
+
# Gene detection saturates with sequencing depth: follows a
|
| 180 |
+
# 1 - exp(-k) saturation curve, scaled by library complexity.
|
| 181 |
+
max_genes = 20_000
|
| 182 |
+
saturation_arg = depth * s.technical.library_complexity * 0.8
|
| 183 |
+
n_genes = self.noise.sample_count(
|
| 184 |
+
int(max_genes * (1.0 - math.exp(-saturation_arg)))
|
| 185 |
+
)
|
| 186 |
median_umi = self.noise.sample_count(int(3000 * depth))
|
| 187 |
quality = self.noise.quality_degradation(
|
| 188 |
s.technical.sample_quality,
|
|
|
|
| 213 |
doublet_frac = self.noise.sample_qc_metric(
|
| 214 |
s.technical.doublet_rate, 0.01, 0.0, 0.2
|
| 215 |
)
|
| 216 |
+
# Mitochondrial fraction reflects cellular stress: activated,
|
| 217 |
+
# inflammatory, or pro-fibrotic populations have elevated mito
|
| 218 |
+
# transcription compared to quiescent/resting cells.
|
| 219 |
+
_stressed_states = {"activated", "stressed", "pro-fibrotic", "inflammatory"}
|
| 220 |
+
has_stressed_cells = any(
|
| 221 |
+
p.state in _stressed_states for p in s.biology.cell_populations
|
| 222 |
+
)
|
| 223 |
+
# Means are kept close (0.09 vs 0.06) with a wider SD (0.03) so the
|
| 224 |
+
# mito fraction is informative but not a near-perfect oracle for
|
| 225 |
+
# stressed-cell presence.
|
| 226 |
+
mito_mean = 0.09 if has_stressed_cells else 0.06
|
| 227 |
+
mito_frac = self.noise.sample_qc_metric(mito_mean, 0.03, 0.0, 0.3)
|
| 228 |
ambient_frac = self.noise.sample_qc_metric(
|
| 229 |
s.technical.ambient_rna_fraction, 0.01, 0.0, 0.2
|
| 230 |
)
|
|
|
|
| 253 |
def _filter_data(
|
| 254 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 255 |
) -> IntermediateOutput:
|
| 256 |
+
retain_frac = s.last_retain_frac if s.last_retain_frac is not None else self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| 257 |
+
n_before = s.progress.n_cells_sequenced or s.biology.n_true_cells
|
| 258 |
+
n_after = s.progress.n_cells_after_filter or max(100, int(n_before * retain_frac))
|
| 259 |
return IntermediateOutput(
|
| 260 |
output_type=OutputType.COUNT_MATRIX_SUMMARY,
|
| 261 |
step_index=idx,
|
|
|
|
| 305 |
) -> IntermediateOutput:
|
| 306 |
n_true = len(s.biology.cell_populations) or 5
|
| 307 |
quality = self.noise.quality_degradation(0.8, [0.95])
|
| 308 |
+
n_clusters = s.last_n_clusters if s.last_n_clusters is not None else self.noise.sample_cluster_count(n_true, quality)
|
| 309 |
cluster_names = [f"cluster_{i}" for i in range(n_clusters)]
|
| 310 |
+
n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
|
| 311 |
+
sizes = self._partition_by_population(n_cells, n_clusters, s.biology.cell_populations)
|
| 312 |
return IntermediateOutput(
|
| 313 |
output_type=OutputType.CLUSTER_RESULT,
|
| 314 |
step_index=idx,
|
| 315 |
quality_score=quality,
|
| 316 |
+
summary=f"Found {n_clusters} clusters",
|
| 317 |
data={
|
| 318 |
"n_clusters": n_clusters,
|
| 319 |
"cluster_names": cluster_names,
|
|
|
|
| 328 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 329 |
) -> IntermediateOutput:
|
| 330 |
comparison = action.parameters.get("comparison", "disease_vs_healthy")
|
| 331 |
+
# Fall back to the first available comparison key if the requested one
|
| 332 |
+
# is absent, rather than silently returning an empty effect dict.
|
| 333 |
+
if comparison not in s.biology.true_de_genes and s.biology.true_de_genes:
|
| 334 |
+
comparison = next(iter(s.biology.true_de_genes))
|
| 335 |
true_effects = s.biology.true_de_genes.get(comparison, {})
|
| 336 |
|
| 337 |
n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
|
| 338 |
+
batch_noise = (
|
| 339 |
+
sum(s.technical.batch_effects.values())
|
| 340 |
+
/ max(len(s.technical.batch_effects), 1)
|
| 341 |
+
)
|
| 342 |
+
noise_level = (
|
| 343 |
+
s.technical.dropout_rate
|
| 344 |
+
+ 0.1 * (1.0 - s.technical.sample_quality)
|
| 345 |
+
+ 0.5 * batch_noise
|
| 346 |
+
)
|
| 347 |
observed = self.noise.sample_effect_sizes(true_effects, n_cells, noise_level)
|
| 348 |
|
| 349 |
fp_genes = self.noise.generate_false_positives(5000, 0.002 + noise_level * 0.01)
|
|
|
|
| 379 |
quality = self.noise.quality_degradation(0.7 if has_trajectory else 0.3, [0.9])
|
| 380 |
summary_data: Dict[str, Any] = {"method": action.method or "monocle3"}
|
| 381 |
if has_trajectory:
|
| 382 |
+
true_n_lineages = s.biology.true_trajectory.get("n_lineages", 1)
|
| 383 |
+
true_branching = s.biology.true_trajectory.get("branching", False)
|
| 384 |
+
# Perturb lineage count by ±1 and flip the branching flag with 20%
|
| 385 |
+
# probability so the output is informative but not an exact oracle.
|
| 386 |
+
noisy_n_lineages = max(1, true_n_lineages + int(self.noise.rng.choice([-1, 0, 0, 1])))
|
| 387 |
+
noisy_branching = true_branching if not self.noise.coin_flip(0.20) else not true_branching
|
| 388 |
summary_data.update({
|
| 389 |
+
"n_lineages": noisy_n_lineages,
|
| 390 |
"pseudotime_range": [0.0, 1.0],
|
| 391 |
+
"branching_detected": noisy_branching,
|
| 392 |
})
|
| 393 |
else:
|
| 394 |
summary_data["n_lineages"] = self.noise.sample_count(1) + 1
|
|
|
|
| 409 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 410 |
) -> IntermediateOutput:
|
| 411 |
true_pathways = s.biology.true_pathways
|
| 412 |
+
# Pathway enrichment quality is tightly coupled to the quality of the
|
| 413 |
+
# preceding DE step: more DE genes found → better gene-set coverage →
|
| 414 |
+
# lower noise and fewer spurious pathway hits.
|
| 415 |
+
de_genes_found = s.progress.n_de_genes_found or 0
|
| 416 |
+
de_was_run = s.progress.de_performed
|
| 417 |
+
if de_was_run and de_genes_found > 0:
|
| 418 |
+
# Noise shrinks as the DE gene list grows (more signal in input).
|
| 419 |
+
noise_level = max(0.05, 0.25 - 0.001 * min(de_genes_found, 200))
|
| 420 |
+
n_fp_mean = max(1, int(5 - de_genes_found / 50))
|
| 421 |
+
else:
|
| 422 |
+
# Without a DE step, enrichment is unreliable.
|
| 423 |
+
noise_level = 0.40
|
| 424 |
+
n_fp_mean = 8
|
| 425 |
+
|
| 426 |
observed: Dict[str, float] = {}
|
| 427 |
for pw, activity in true_pathways.items():
|
| 428 |
observed[pw] = activity + float(self.noise.rng.normal(0, noise_level))
|
| 429 |
|
| 430 |
+
for i in range(self.noise.sample_count(n_fp_mean)):
|
| 431 |
observed[f"FP_PATHWAY_{i}"] = float(self.noise.rng.uniform(0.3, 0.6))
|
| 432 |
|
| 433 |
top = sorted(observed.items(), key=lambda kv: kv[1], reverse=True)[:15]
|
| 434 |
+
base_quality = 0.80 if de_was_run else 0.45
|
| 435 |
return IntermediateOutput(
|
| 436 |
output_type=OutputType.PATHWAY_RESULT,
|
| 437 |
step_index=idx,
|
| 438 |
+
quality_score=self.noise.quality_degradation(base_quality, [0.95]),
|
| 439 |
summary=f"Pathway enrichment: {len(top)} significant pathways",
|
| 440 |
data={
|
| 441 |
"method": action.method or "GSEA",
|
| 442 |
"top_pathways": [
|
| 443 |
+
{"pathway": p, "score": round(sc, 3)} for p, sc in top
|
| 444 |
],
|
| 445 |
},
|
| 446 |
uncertainty=noise_level,
|
|
|
|
| 453 |
true_net = s.biology.true_regulatory_network
|
| 454 |
n_edges_true = sum(len(v) for v in true_net.values())
|
| 455 |
noise_edges = self.noise.sample_count(max(5, int(n_edges_true * 0.3)))
|
| 456 |
+
|
| 457 |
+
true_tfs = list(true_net.keys())
|
| 458 |
+
# Drop ~25% of true regulators (false-negative rate).
|
| 459 |
+
fn_set = set(self.noise.generate_false_negatives(true_tfs, 0.25))
|
| 460 |
+
observed_tfs = [tf for tf in true_tfs if tf not in fn_set]
|
| 461 |
+
# Inject realistic false-positive TFs drawn from a background pool so
|
| 462 |
+
# the agent cannot distinguish true from false hits by name format.
|
| 463 |
+
fp_candidates = [t for t in _NOISE_TFS if t not in set(true_tfs)]
|
| 464 |
+
n_fp = self.noise.sample_count(max(2, int(len(true_tfs) * 0.5) + 2))
|
| 465 |
+
if fp_candidates and n_fp > 0:
|
| 466 |
+
chosen = self.noise.rng.choice(
|
| 467 |
+
fp_candidates,
|
| 468 |
+
size=min(n_fp, len(fp_candidates)),
|
| 469 |
+
replace=False,
|
| 470 |
+
)
|
| 471 |
+
observed_tfs.extend(chosen.tolist())
|
| 472 |
+
# Shuffle so rank order does not reveal true-vs-false identity.
|
| 473 |
+
observed_tfs = self.noise.shuffle_ranking(observed_tfs, 0.5)
|
| 474 |
+
|
| 475 |
return IntermediateOutput(
|
| 476 |
output_type=OutputType.NETWORK_RESULT,
|
| 477 |
step_index=idx,
|
|
|
|
| 481 |
"method": action.method or "SCENIC",
|
| 482 |
"n_regulons": len(true_net) + self.noise.sample_count(3),
|
| 483 |
"n_edges": n_edges_true + noise_edges,
|
| 484 |
+
"top_regulators": observed_tfs[:10],
|
| 485 |
},
|
| 486 |
uncertainty=0.35,
|
| 487 |
artifacts_available=["regulon_table", "grn_adjacency"],
|
|
|
|
| 526 |
"marker": marker,
|
| 527 |
"validated": validated,
|
| 528 |
"assay": action.method or "qPCR",
|
| 529 |
+
# Means are kept close (0.85 vs 0.45) with a wide SD (0.4)
|
| 530 |
+
# so the effect size is correlated with, but not a near-perfect
|
| 531 |
+
# oracle for, true marker membership.
|
| 532 |
"effect_size": self.noise.sample_qc_metric(
|
| 533 |
+
0.85 if is_true else 0.45, 0.4, -0.5, 5.0
|
| 534 |
),
|
| 535 |
},
|
| 536 |
artifacts_available=["validation_data"],
|
|
|
|
| 539 |
def _design_followup(
|
| 540 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 541 |
) -> IntermediateOutput:
|
| 542 |
+
evidence_signals = sum([
|
| 543 |
+
int(s.progress.cells_clustered),
|
| 544 |
+
int(s.progress.de_performed),
|
| 545 |
+
int(s.progress.trajectories_inferred),
|
| 546 |
+
int(s.progress.pathways_analyzed),
|
| 547 |
+
int(s.progress.networks_inferred),
|
| 548 |
+
int(s.progress.markers_discovered),
|
| 549 |
+
int(s.progress.markers_validated),
|
| 550 |
+
])
|
| 551 |
return IntermediateOutput(
|
| 552 |
output_type=OutputType.FOLLOWUP_DESIGN,
|
| 553 |
step_index=idx,
|
| 554 |
+
quality_score=min(0.75, 0.2 + 0.08 * evidence_signals),
|
| 555 |
+
summary=(
|
| 556 |
+
f"Follow-up experiment design proposed "
|
| 557 |
+
f"(evidence_signals={evidence_signals})"
|
| 558 |
+
),
|
| 559 |
+
data={
|
| 560 |
+
"proposal": action.parameters,
|
| 561 |
+
"evidence_signals": evidence_signals,
|
| 562 |
+
},
|
| 563 |
+
uncertainty=max(0.25, 0.8 - 0.08 * evidence_signals),
|
| 564 |
artifacts_available=["followup_proposal"],
|
| 565 |
)
|
| 566 |
|
| 567 |
def _subagent_review(
|
| 568 |
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 569 |
) -> IntermediateOutput:
|
| 570 |
+
evidence_signals = sum([
|
| 571 |
+
int(s.progress.cells_clustered),
|
| 572 |
+
int(s.progress.de_performed),
|
| 573 |
+
int(s.progress.trajectories_inferred),
|
| 574 |
+
int(s.progress.pathways_analyzed),
|
| 575 |
+
int(s.progress.networks_inferred),
|
| 576 |
+
int(s.progress.markers_discovered),
|
| 577 |
+
int(s.progress.markers_validated),
|
| 578 |
+
])
|
| 579 |
return IntermediateOutput(
|
| 580 |
output_type=OutputType.SUBAGENT_REPORT,
|
| 581 |
step_index=idx,
|
| 582 |
+
quality_score=min(0.7, 0.15 + 0.07 * evidence_signals),
|
| 583 |
summary=f"Subagent review ({action.invoked_subagent or 'general'})",
|
| 584 |
+
data={
|
| 585 |
+
"subagent": action.invoked_subagent,
|
| 586 |
+
"notes": "Review complete.",
|
| 587 |
+
"evidence_signals": evidence_signals,
|
| 588 |
+
},
|
| 589 |
+
uncertainty=max(0.3, 0.85 - 0.08 * evidence_signals),
|
| 590 |
artifacts_available=["subagent_report"],
|
| 591 |
)
|
| 592 |
|
|
|
|
| 623 |
sizes[0] += diff
|
| 624 |
return sizes
|
| 625 |
|
| 626 |
+
def _partition_by_population(
|
| 627 |
+
self,
|
| 628 |
+
total: int,
|
| 629 |
+
k: int,
|
| 630 |
+
populations: list,
|
| 631 |
+
) -> List[int]:
|
| 632 |
+
"""Partition cells into k clusters using true population proportions
|
| 633 |
+
as Dirichlet concentration parameters, so majority cell types produce
|
| 634 |
+
larger clusters rather than uniformly random sizes."""
|
| 635 |
+
if k <= 0:
|
| 636 |
+
return []
|
| 637 |
+
if populations:
|
| 638 |
+
# Use true proportions as Dirichlet alpha — larger proportions
|
| 639 |
+
# concentrate probability mass, yielding realistic size ratios.
|
| 640 |
+
raw = [max(p.proportion, 1e-3) for p in populations]
|
| 641 |
+
# Align alpha length to k: repeat/truncate as needed.
|
| 642 |
+
if len(raw) >= k:
|
| 643 |
+
alpha = raw[:k]
|
| 644 |
+
else:
|
| 645 |
+
alpha = raw + [sum(raw) / len(raw)] * (k - len(raw))
|
| 646 |
+
# Scale alpha so the total magnitude is proportional to k,
|
| 647 |
+
# giving reasonable Dirichlet variance.
|
| 648 |
+
scale = k / max(sum(alpha), 1e-6)
|
| 649 |
+
alpha = [a * scale for a in alpha]
|
| 650 |
+
else:
|
| 651 |
+
alpha = [1.0] * k
|
| 652 |
+
fracs = self.noise.rng.dirichlet(alpha=alpha)
|
| 653 |
+
sizes = [max(1, int(total * f)) for f in fracs]
|
| 654 |
+
diff = total - sum(sizes)
|
| 655 |
+
sizes[0] += diff
|
| 656 |
+
return sizes
|
| 657 |
+
|
| 658 |
|
| 659 |
_HANDLERS = {
|
| 660 |
ActionType.COLLECT_SAMPLE: OutputGenerator._collect_sample,
|
| 661 |
ActionType.SELECT_COHORT: OutputGenerator._select_cohort,
|
| 662 |
ActionType.PREPARE_LIBRARY: OutputGenerator._prepare_library,
|
| 663 |
ActionType.CULTURE_CELLS: OutputGenerator._culture_cells,
|
| 664 |
+
ActionType.PERTURB_GENE: OutputGenerator._perturb_gene,
|
| 665 |
+
ActionType.PERTURB_COMPOUND: OutputGenerator._perturb_compound,
|
| 666 |
ActionType.SEQUENCE_CELLS: OutputGenerator._sequence_cells,
|
| 667 |
ActionType.RUN_QC: OutputGenerator._run_qc,
|
| 668 |
ActionType.FILTER_DATA: OutputGenerator._filter_data,
|
server/simulator/transition.py
CHANGED
|
@@ -15,6 +15,7 @@ from models import (
|
|
| 15 |
ExperimentAction,
|
| 16 |
IntermediateOutput,
|
| 17 |
OutputType,
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
from .latent_state import FullLatentState
|
|
@@ -22,7 +23,8 @@ from .noise import NoiseModel
|
|
| 22 |
from .output_generator import OutputGenerator
|
| 23 |
|
| 24 |
|
| 25 |
-
|
|
|
|
| 26 |
ActionType.COLLECT_SAMPLE: (5_000, 7.0),
|
| 27 |
ActionType.SELECT_COHORT: ( 500, 1.0),
|
| 28 |
ActionType.PREPARE_LIBRARY: (8_000, 3.0),
|
|
@@ -41,11 +43,30 @@ ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
|
|
| 41 |
ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
|
| 42 |
ActionType.MARKER_SELECTION: ( 100, 0.5),
|
| 43 |
ActionType.VALIDATE_MARKER: (5_000, 14.0),
|
| 44 |
-
ActionType.DESIGN_FOLLOWUP: (
|
| 45 |
-
ActionType.REQUEST_SUBAGENT_REVIEW: (
|
| 46 |
ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
@dataclass
|
| 51 |
class TransitionResult:
|
|
@@ -138,9 +159,7 @@ class TransitionEngine:
|
|
| 138 |
def _apply_resource_cost(
|
| 139 |
self, s: FullLatentState, action: ExperimentAction
|
| 140 |
) -> None:
|
| 141 |
-
budget_cost, time_cost =
|
| 142 |
-
action.action_type, (0.0, 0.0)
|
| 143 |
-
)
|
| 144 |
s.resources.budget_used += budget_cost
|
| 145 |
s.resources.time_used_days += time_cost
|
| 146 |
if action.action_type in {
|
|
@@ -176,6 +195,8 @@ class TransitionEngine:
|
|
| 176 |
ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
|
| 177 |
ActionType.MARKER_SELECTION: "markers_discovered",
|
| 178 |
ActionType.VALIDATE_MARKER: "markers_validated",
|
|
|
|
|
|
|
| 179 |
ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
|
| 180 |
}
|
| 181 |
flag = _MAP.get(at)
|
|
@@ -188,16 +209,43 @@ class TransitionEngine:
|
|
| 188 |
|
| 189 |
if at == ActionType.SEQUENCE_CELLS:
|
| 190 |
s.resources.sequencing_lanes_used += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
if at == ActionType.FILTER_DATA:
|
| 193 |
retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| 194 |
-
p.
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
| 198 |
if at == ActionType.CLUSTER_CELLS:
|
| 199 |
n_true = len(s.biology.cell_populations) or 5
|
| 200 |
p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
def _propagate_artifacts(
|
| 203 |
self,
|
|
@@ -208,6 +256,7 @@ class TransitionEngine:
|
|
| 208 |
if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| 209 |
top = output.data.get("top_genes", [])
|
| 210 |
s.discovered_de_genes = [g["gene"] for g in top[:20]]
|
|
|
|
| 211 |
|
| 212 |
if action.action_type == ActionType.CLUSTER_CELLS:
|
| 213 |
s.discovered_clusters = output.data.get("cluster_names", [])
|
|
|
|
| 15 |
ExperimentAction,
|
| 16 |
IntermediateOutput,
|
| 17 |
OutputType,
|
| 18 |
+
TOOL_REGISTRY,
|
| 19 |
)
|
| 20 |
|
| 21 |
from .latent_state import FullLatentState
|
|
|
|
| 23 |
from .output_generator import OutputGenerator
|
| 24 |
|
| 25 |
|
| 26 |
+
# Fallback costs per ActionType when the agent doesn't specify a known tool.
|
| 27 |
+
_BASE_ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
|
| 28 |
ActionType.COLLECT_SAMPLE: (5_000, 7.0),
|
| 29 |
ActionType.SELECT_COHORT: ( 500, 1.0),
|
| 30 |
ActionType.PREPARE_LIBRARY: (8_000, 3.0),
|
|
|
|
| 43 |
ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
|
| 44 |
ActionType.MARKER_SELECTION: ( 100, 0.5),
|
| 45 |
ActionType.VALIDATE_MARKER: (5_000, 14.0),
|
| 46 |
+
ActionType.DESIGN_FOLLOWUP: ( 100, 0.5),
|
| 47 |
+
ActionType.REQUEST_SUBAGENT_REVIEW: ( 50, 0.25),
|
| 48 |
ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
|
| 49 |
}
|
| 50 |
|
| 51 |
+
# Kept as public alias so existing imports (e.g. hackathon_environment) still work.
|
| 52 |
+
ACTION_COSTS = _BASE_ACTION_COSTS
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_action_cost(action: ExperimentAction) -> Tuple[float, float]:
|
| 56 |
+
"""Return (budget_cost, time_cost_days) for an action.
|
| 57 |
+
|
| 58 |
+
If the action specifies a ``method`` that exists in ``TOOL_REGISTRY``,
|
| 59 |
+
the tool's ``typical_cost_usd`` and ``typical_runtime_hours`` are used
|
| 60 |
+
(converted to days). Otherwise we fall back to the per-ActionType base
|
| 61 |
+
cost table.
|
| 62 |
+
"""
|
| 63 |
+
tool_spec = TOOL_REGISTRY.get(action.method or "")
|
| 64 |
+
if tool_spec is not None:
|
| 65 |
+
budget = tool_spec.typical_cost_usd
|
| 66 |
+
time_days = tool_spec.typical_runtime_hours / 24.0
|
| 67 |
+
return (budget, time_days)
|
| 68 |
+
return _BASE_ACTION_COSTS.get(action.action_type, (0.0, 0.0))
|
| 69 |
+
|
| 70 |
|
| 71 |
@dataclass
|
| 72 |
class TransitionResult:
|
|
|
|
| 159 |
def _apply_resource_cost(
|
| 160 |
self, s: FullLatentState, action: ExperimentAction
|
| 161 |
) -> None:
|
| 162 |
+
budget_cost, time_cost = compute_action_cost(action)
|
|
|
|
|
|
|
| 163 |
s.resources.budget_used += budget_cost
|
| 164 |
s.resources.time_used_days += time_cost
|
| 165 |
if action.action_type in {
|
|
|
|
| 195 |
ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
|
| 196 |
ActionType.MARKER_SELECTION: "markers_discovered",
|
| 197 |
ActionType.VALIDATE_MARKER: "markers_validated",
|
| 198 |
+
ActionType.DESIGN_FOLLOWUP: "followup_designed",
|
| 199 |
+
ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested",
|
| 200 |
ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
|
| 201 |
}
|
| 202 |
flag = _MAP.get(at)
|
|
|
|
| 209 |
|
| 210 |
if at == ActionType.SEQUENCE_CELLS:
|
| 211 |
s.resources.sequencing_lanes_used += 1
|
| 212 |
+
p.n_cells_sequenced = self.noise.sample_count(
|
| 213 |
+
s.biology.n_true_cells * s.technical.capture_efficiency
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if at in {ActionType.PERTURB_GENE, ActionType.PERTURB_COMPOUND}:
|
| 217 |
+
self._apply_perturbation_effects(s, action)
|
| 218 |
|
| 219 |
if at == ActionType.FILTER_DATA:
|
| 220 |
retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| 221 |
+
base = p.n_cells_sequenced or s.biology.n_true_cells
|
| 222 |
+
p.n_cells_after_filter = max(100, int(base * retain))
|
| 223 |
+
s.last_retain_frac = retain
|
| 224 |
|
| 225 |
if at == ActionType.CLUSTER_CELLS:
|
| 226 |
n_true = len(s.biology.cell_populations) or 5
|
| 227 |
p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
|
| 228 |
+
s.last_n_clusters = p.n_clusters_found
|
| 229 |
+
|
| 230 |
+
def _apply_perturbation_effects(
|
| 231 |
+
self, s: FullLatentState, action: ExperimentAction
|
| 232 |
+
) -> None:
|
| 233 |
+
"""Fold perturbation-specific gene effects into true_de_genes so
|
| 234 |
+
downstream DE analysis reflects the perturbed biology."""
|
| 235 |
+
target = action.parameters.get("target", "")
|
| 236 |
+
effects = s.biology.perturbation_effects.get(target, {})
|
| 237 |
+
if not effects:
|
| 238 |
+
return
|
| 239 |
+
# Efficiency drawn from the same distribution as the output handler
|
| 240 |
+
# so latent state and observable output are coherent.
|
| 241 |
+
if action.action_type == ActionType.PERTURB_GENE:
|
| 242 |
+
efficiency = self.noise.sample_qc_metric(0.80, 0.12, 0.0, 1.0)
|
| 243 |
+
else:
|
| 244 |
+
efficiency = self.noise.sample_qc_metric(0.70, 0.15, 0.0, 1.0)
|
| 245 |
+
s.last_perturbation_efficiency = efficiency
|
| 246 |
+
for gene_map in s.biology.true_de_genes.values():
|
| 247 |
+
for gene, delta in effects.items():
|
| 248 |
+
gene_map[gene] = gene_map.get(gene, 0.0) + delta * efficiency
|
| 249 |
|
| 250 |
def _propagate_artifacts(
|
| 251 |
self,
|
|
|
|
| 256 |
if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| 257 |
top = output.data.get("top_genes", [])
|
| 258 |
s.discovered_de_genes = [g["gene"] for g in top[:20]]
|
| 259 |
+
s.progress.n_de_genes_found = output.data.get("n_significant", 0)
|
| 260 |
|
| 261 |
if action.action_type == ActionType.CLUSTER_CELLS:
|
| 262 |
s.discovered_clusters = output.data.get("cluster_names", [])
|
server/tasks/bio_palette.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Curated biological building blocks for procedural scenario generation.
|
| 2 |
+
|
| 3 |
+
Provides tissue-specific cell types, disease profiles, pathway libraries,
|
| 4 |
+
regulatory network templates, and perturbation effect profiles. The
|
| 5 |
+
procedural generator composes these into complete ``Scenario`` objects
|
| 6 |
+
with fully populated ``LatentBiologicalState``.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ── Cell type templates ─────────────────────────────────────────────────────
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class CellTypeTemplate:
|
| 20 |
+
name: str
|
| 21 |
+
marker_genes: List[str]
|
| 22 |
+
proportion_range: Tuple[float, float] = (0.05, 0.30)
|
| 23 |
+
states: List[str] = field(default_factory=lambda: ["quiescent"])
|
| 24 |
+
disease_responsive: bool = False
|
| 25 |
+
response_range: Tuple[float, float] = (0.5, 1.5)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
TISSUE_CELL_TYPES: Dict[str, List[CellTypeTemplate]] = {
|
| 29 |
+
"heart": [
|
| 30 |
+
CellTypeTemplate("cardiomyocyte", ["TNNT2", "MYH7", "ACTC1"], (0.25, 0.40), ["contractile", "stressed"]),
|
| 31 |
+
CellTypeTemplate("cardiac_fibroblast", ["COL1A1", "DCN", "LUM"], (0.15, 0.30), ["quiescent", "activated"], True, (1.1, 1.8)),
|
| 32 |
+
CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.10, 0.20), ["quiescent"]),
|
| 33 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.15), ["quiescent", "activated", "inflammatory"], True, (1.2, 2.0)),
|
| 34 |
+
CellTypeTemplate("smooth_muscle", ["ACTA2", "MYH11", "TAGLN"], (0.08, 0.18), ["quiescent"]),
|
| 35 |
+
CellTypeTemplate("pericyte", ["PDGFRB", "RGS5", "NOTCH3"], (0.03, 0.10), ["quiescent"]),
|
| 36 |
+
],
|
| 37 |
+
"lung": [
|
| 38 |
+
CellTypeTemplate("AT2", ["SFTPC", "SFTPB", "ABCA3"], (0.15, 0.25), ["normal", "stressed"]),
|
| 39 |
+
CellTypeTemplate("AT1", ["AGER", "PDPN", "CAV1"], (0.10, 0.18), ["normal"]),
|
| 40 |
+
CellTypeTemplate("alveolar_macrophage", ["MARCO", "FABP4", "MCEMP1"], (0.10, 0.20), ["resident", "activated"]),
|
| 41 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "POSTN"], (0.12, 0.25), ["quiescent", "activated"], True, (1.2, 2.0)),
|
| 42 |
+
CellTypeTemplate("endothelial", ["PECAM1", "CLDN5", "VWF"], (0.08, 0.15), ["quiescent"]),
|
| 43 |
+
CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.08, 0.18), ["quiescent", "activated"]),
|
| 44 |
+
CellTypeTemplate("ciliated", ["FOXJ1", "DNAH5", "TPPP3"], (0.05, 0.12), ["normal"]),
|
| 45 |
+
],
|
| 46 |
+
"brain": [
|
| 47 |
+
CellTypeTemplate("excitatory_neuron", ["SLC17A7", "CAMK2A", "NRGN"], (0.25, 0.40), ["normal", "stressed"]),
|
| 48 |
+
CellTypeTemplate("inhibitory_neuron", ["GAD1", "GAD2", "SLC32A1"], (0.12, 0.22), ["normal"]),
|
| 49 |
+
CellTypeTemplate("astrocyte", ["GFAP", "AQP4", "SLC1A3"], (0.10, 0.20), ["quiescent", "activated"], True, (1.2, 1.8)),
|
| 50 |
+
CellTypeTemplate("microglia", ["CX3CR1", "P2RY12", "TMEM119"], (0.05, 0.12), ["homeostatic", "activated", "inflammatory"], True, (1.3, 2.5)),
|
| 51 |
+
CellTypeTemplate("oligodendrocyte", ["MBP", "PLP1", "MOG"], (0.10, 0.18), ["myelinating"]),
|
| 52 |
+
CellTypeTemplate("OPC", ["PDGFRA", "CSPG4", "OLIG2"], (0.03, 0.08), ["progenitor"]),
|
| 53 |
+
CellTypeTemplate("endothelial", ["CLDN5", "FLT1", "PECAM1"], (0.03, 0.08), ["quiescent"]),
|
| 54 |
+
],
|
| 55 |
+
"liver": [
|
| 56 |
+
CellTypeTemplate("hepatocyte", ["ALB", "APOB", "CYP3A4"], (0.55, 0.70), ["normal", "stressed"]),
|
| 57 |
+
CellTypeTemplate("cholangiocyte", ["KRT19", "KRT7", "EPCAM"], (0.05, 0.10), ["normal"]),
|
| 58 |
+
CellTypeTemplate("kupffer_cell", ["CD68", "CLEC4F", "MARCO"], (0.08, 0.15), ["quiescent", "activated", "inflammatory"], True, (1.2, 2.0)),
|
| 59 |
+
CellTypeTemplate("stellate_cell", ["ACTA2", "LRAT", "PDGFRB"], (0.05, 0.12), ["quiescent", "activated"], True, (1.3, 2.0)),
|
| 60 |
+
CellTypeTemplate("endothelial", ["PECAM1", "LYVE1", "STAB2"], (0.05, 0.10), ["quiescent"]),
|
| 61 |
+
CellTypeTemplate("NK_cell", ["NKG7", "GNLY", "KLRD1"], (0.03, 0.08), ["quiescent", "activated"]),
|
| 62 |
+
],
|
| 63 |
+
"bone_marrow": [
|
| 64 |
+
CellTypeTemplate("HSC", ["CD34", "KIT", "THY1"], (0.03, 0.08), ["stem"]),
|
| 65 |
+
CellTypeTemplate("CMP", ["CD34", "FLT3"], (0.08, 0.15), ["progenitor"]),
|
| 66 |
+
CellTypeTemplate("GMP", ["CSF3R", "CEBPA"], (0.08, 0.15), ["progenitor"]),
|
| 67 |
+
CellTypeTemplate("MEP", ["GATA1", "KLF1"], (0.06, 0.12), ["progenitor"]),
|
| 68 |
+
CellTypeTemplate("erythrocyte", ["HBA1", "HBB", "GYPA"], (0.15, 0.25), ["mature"]),
|
| 69 |
+
CellTypeTemplate("neutrophil", ["ELANE", "MPO", "CTSG"], (0.12, 0.22), ["mature"]),
|
| 70 |
+
CellTypeTemplate("monocyte", ["CD14", "CSF1R", "FCGR3A"], (0.10, 0.18), ["mature"]),
|
| 71 |
+
CellTypeTemplate("megakaryocyte", ["ITGA2B", "GP1BA", "PF4"], (0.05, 0.12), ["mature"]),
|
| 72 |
+
],
|
| 73 |
+
"kidney": [
|
| 74 |
+
CellTypeTemplate("proximal_tubule", ["SLC34A1", "LRP2", "CUBN"], (0.30, 0.45), ["normal", "stressed"]),
|
| 75 |
+
CellTypeTemplate("distal_tubule", ["SLC12A3", "CALB1"], (0.10, 0.18), ["normal"]),
|
| 76 |
+
CellTypeTemplate("collecting_duct", ["AQP2", "FXYD4"], (0.08, 0.15), ["normal"]),
|
| 77 |
+
CellTypeTemplate("podocyte", ["NPHS1", "NPHS2", "WT1"], (0.05, 0.10), ["normal", "stressed"]),
|
| 78 |
+
CellTypeTemplate("endothelial", ["PECAM1", "EMCN", "FLT1"], (0.05, 0.12), ["quiescent"]),
|
| 79 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.10), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
|
| 80 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "PDGFRA", "DCN"], (0.05, 0.12), ["quiescent", "activated"], True, (1.2, 1.8)),
|
| 81 |
+
],
|
| 82 |
+
"colon": [
|
| 83 |
+
CellTypeTemplate("colonocyte", ["CA2", "AQP8", "SLC26A3"], (0.25, 0.40), ["normal", "stressed"]),
|
| 84 |
+
CellTypeTemplate("goblet_cell", ["MUC2", "TFF3", "FCGBP"], (0.10, 0.18), ["secretory"]),
|
| 85 |
+
CellTypeTemplate("stem_cell", ["LGR5", "ASCL2", "OLFM4"], (0.05, 0.10), ["stem"]),
|
| 86 |
+
CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.10, 0.18), ["quiescent", "activated"]),
|
| 87 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.12), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
|
| 88 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "VIM"], (0.08, 0.15), ["quiescent", "activated"], True, (1.2, 1.8)),
|
| 89 |
+
CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.05, 0.10), ["quiescent"]),
|
| 90 |
+
],
|
| 91 |
+
"pancreas": [
|
| 92 |
+
CellTypeTemplate("beta_cell", ["INS", "MAFA", "NKX6-1"], (0.25, 0.40), ["normal", "stressed"], True, (0.4, 0.8)),
|
| 93 |
+
CellTypeTemplate("alpha_cell", ["GCG", "ARX", "IRX2"], (0.15, 0.25), ["normal"]),
|
| 94 |
+
CellTypeTemplate("delta_cell", ["SST", "HHEX"], (0.05, 0.10), ["normal"]),
|
| 95 |
+
CellTypeTemplate("ductal", ["KRT19", "SOX9", "CFTR"], (0.10, 0.18), ["normal"]),
|
| 96 |
+
CellTypeTemplate("acinar", ["PRSS1", "CPA1", "CELA3A"], (0.10, 0.20), ["normal"]),
|
| 97 |
+
CellTypeTemplate("stellate", ["ACTA2", "PDGFRA", "COL1A1"], (0.05, 0.10), ["quiescent", "activated"], True, (1.2, 1.8)),
|
| 98 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163"], (0.03, 0.08), ["quiescent", "inflammatory"]),
|
| 99 |
+
],
|
| 100 |
+
"skin": [
|
| 101 |
+
CellTypeTemplate("keratinocyte", ["KRT14", "KRT5", "KRT1"], (0.40, 0.55), ["basal", "differentiated"]),
|
| 102 |
+
CellTypeTemplate("melanocyte", ["MLANA", "PMEL", "TYR"], (0.05, 0.10), ["normal", "activated"]),
|
| 103 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "DCN"], (0.10, 0.20), ["quiescent", "activated"]),
|
| 104 |
+
CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.08, 0.15), ["quiescent", "activated"]),
|
| 105 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.10), ["quiescent", "inflammatory"]),
|
| 106 |
+
CellTypeTemplate("endothelial", ["PECAM1", "VWF"], (0.05, 0.10), ["quiescent"]),
|
| 107 |
+
],
|
| 108 |
+
"breast": [
|
| 109 |
+
CellTypeTemplate("luminal_epithelial", ["KRT8", "KRT18", "EPCAM"], (0.25, 0.40), ["normal", "stressed"]),
|
| 110 |
+
CellTypeTemplate("basal_epithelial", ["KRT14", "KRT5", "TP63"], (0.10, 0.20), ["normal"]),
|
| 111 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "FAP"], (0.10, 0.20), ["quiescent", "activated"], True, (1.2, 1.8)),
|
| 112 |
+
CellTypeTemplate("T_cell", ["CD3D", "CD3E", "CD8A"], (0.08, 0.15), ["quiescent", "activated", "exhausted"]),
|
| 113 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.12), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
|
| 114 |
+
CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.05, 0.10), ["quiescent"]),
|
| 115 |
+
],
|
| 116 |
+
"synovium": [
|
| 117 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "FAP", "THY1"], (0.20, 0.30), ["quiescent", "activated"], True, (1.2, 1.8)),
|
| 118 |
+
CellTypeTemplate("CD4_T_cell", ["CD3D", "CD4", "IL7R"], (0.12, 0.22), ["quiescent", "activated"]),
|
| 119 |
+
CellTypeTemplate("CD8_T_cell", ["CD3D", "CD8A", "GZMB"], (0.08, 0.15), ["quiescent", "activated"]),
|
| 120 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "MARCO"], (0.10, 0.18), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
|
| 121 |
+
CellTypeTemplate("B_cell", ["CD19", "MS4A1", "CD79A"], (0.05, 0.12), ["quiescent"]),
|
| 122 |
+
CellTypeTemplate("endothelial", ["PECAM1", "VWF"], (0.05, 0.10), ["quiescent"]),
|
| 123 |
+
CellTypeTemplate("mast_cell", ["KIT", "TPSAB1", "CPA3"], (0.03, 0.08), ["quiescent"]),
|
| 124 |
+
],
|
| 125 |
+
"aorta": [
|
| 126 |
+
CellTypeTemplate("smooth_muscle", ["ACTA2", "MYH11", "TAGLN"], (0.30, 0.45), ["contractile", "synthetic"], True, (0.6, 0.9)),
|
| 127 |
+
CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.15, 0.25), ["quiescent", "activated"]),
|
| 128 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "TREM2"], (0.08, 0.15), ["quiescent", "inflammatory"], True, (1.5, 2.5)),
|
| 129 |
+
CellTypeTemplate("fibroblast", ["COL1A1", "LUM", "DCN"], (0.08, 0.15), ["quiescent", "activated"]),
|
| 130 |
+
CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.05, 0.12), ["quiescent", "activated"]),
|
| 131 |
+
CellTypeTemplate("dendritic_cell", ["FCER1A", "CD1C", "CLEC10A"], (0.03, 0.08), ["quiescent"]),
|
| 132 |
+
],
|
| 133 |
+
"blood": [
|
| 134 |
+
CellTypeTemplate("CD4_T_cell", ["CD3D", "CD4", "IL7R"], (0.15, 0.25), ["quiescent", "activated"]),
|
| 135 |
+
CellTypeTemplate("CD8_T_cell", ["CD3D", "CD8A", "GZMB"], (0.10, 0.18), ["quiescent", "activated"]),
|
| 136 |
+
CellTypeTemplate("B_cell", ["CD19", "MS4A1", "CD79A"], (0.08, 0.15), ["quiescent"]),
|
| 137 |
+
CellTypeTemplate("NK_cell", ["NKG7", "GNLY", "KLRD1"], (0.05, 0.12), ["quiescent", "activated"]),
|
| 138 |
+
CellTypeTemplate("monocyte", ["CD14", "CSF1R", "FCGR3A"], (0.15, 0.25), ["classical", "non_classical"]),
|
| 139 |
+
CellTypeTemplate("neutrophil", ["ELANE", "MPO", "CTSG"], (0.10, 0.20), ["mature"]),
|
| 140 |
+
CellTypeTemplate("platelet", ["ITGA2B", "GP1BA", "PF4"], (0.03, 0.08), ["normal"]),
|
| 141 |
+
],
|
| 142 |
+
"spleen": [
|
| 143 |
+
CellTypeTemplate("B_cell", ["CD19", "MS4A1", "CD79A"], (0.20, 0.35), ["quiescent", "activated"]),
|
| 144 |
+
CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.15, 0.25), ["quiescent", "activated"]),
|
| 145 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.10, 0.18), ["quiescent", "inflammatory"]),
|
| 146 |
+
CellTypeTemplate("dendritic_cell", ["FCER1A", "CD1C", "CLEC10A"], (0.05, 0.10), ["quiescent"]),
|
| 147 |
+
CellTypeTemplate("NK_cell", ["NKG7", "GNLY", "KLRD1"], (0.05, 0.12), ["quiescent"]),
|
| 148 |
+
CellTypeTemplate("endothelial", ["PECAM1", "STAB2"], (0.05, 0.10), ["quiescent"]),
|
| 149 |
+
],
|
| 150 |
+
"thymus": [
|
| 151 |
+
CellTypeTemplate("double_negative_T", ["CD3D", "PTCRA"], (0.10, 0.18), ["progenitor"]),
|
| 152 |
+
CellTypeTemplate("double_positive_T", ["CD3D", "CD4", "CD8A"], (0.30, 0.45), ["progenitor"]),
|
| 153 |
+
CellTypeTemplate("CD4_SP", ["CD3D", "CD4", "IL7R"], (0.10, 0.18), ["mature"]),
|
| 154 |
+
CellTypeTemplate("CD8_SP", ["CD3D", "CD8A", "CD8B"], (0.08, 0.15), ["mature"]),
|
| 155 |
+
CellTypeTemplate("thymic_epithelial", ["FOXN1", "KRT5", "KRT8"], (0.05, 0.12), ["cortical", "medullary"]),
|
| 156 |
+
CellTypeTemplate("dendritic_cell", ["FCER1A", "CD1C"], (0.03, 0.08), ["quiescent"]),
|
| 157 |
+
CellTypeTemplate("macrophage", ["CD68", "CD163"], (0.03, 0.08), ["quiescent"]),
|
| 158 |
+
],
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ── Disease profiles ────────────────────────────────────────────────────────
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass
|
| 166 |
+
class DiseaseProfile:
|
| 167 |
+
name: str
|
| 168 |
+
display_name: str
|
| 169 |
+
tissue: str
|
| 170 |
+
condition_name: str
|
| 171 |
+
de_genes: Dict[str, Tuple[float, float]]
|
| 172 |
+
pathways: Dict[str, Tuple[float, float]]
|
| 173 |
+
markers: List[str]
|
| 174 |
+
mechanism_templates: List[str]
|
| 175 |
+
responding_cell_types: List[str] = field(default_factory=list)
|
| 176 |
+
hidden_failure_templates: List[str] = field(default_factory=list)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
DISEASE_PROFILES: Dict[str, DiseaseProfile] = {
|
| 180 |
+
"dilated_cardiomyopathy": DiseaseProfile(
|
| 181 |
+
name="dilated_cardiomyopathy",
|
| 182 |
+
display_name="dilated cardiomyopathy",
|
| 183 |
+
tissue="heart",
|
| 184 |
+
condition_name="dilated_cardiomyopathy",
|
| 185 |
+
de_genes={
|
| 186 |
+
"NPPA": (1.8, 3.5), "NPPB": (2.0, 4.0), "MYH7": (1.0, 2.5),
|
| 187 |
+
"COL1A1": (1.0, 2.2), "COL3A1": (0.8, 1.8), "POSTN": (1.5, 3.0),
|
| 188 |
+
"CCL2": (0.8, 1.8), "IL6": (0.5, 1.5), "TGFB1": (0.7, 1.6),
|
| 189 |
+
"ANKRD1": (1.5, 3.0), "XIRP2": (-2.0, -0.8), "MYL2": (-1.5, -0.5),
|
| 190 |
+
},
|
| 191 |
+
pathways={
|
| 192 |
+
"cardiac_muscle_contraction": (0.3, 0.6),
|
| 193 |
+
"extracellular_matrix_organisation": (0.7, 0.95),
|
| 194 |
+
"inflammatory_response": (0.5, 0.8),
|
| 195 |
+
"TGF_beta_signalling": (0.6, 0.85),
|
| 196 |
+
"apoptosis": (0.4, 0.65),
|
| 197 |
+
},
|
| 198 |
+
markers=["NPPA", "NPPB", "POSTN", "COL1A1"],
|
| 199 |
+
mechanism_templates=[
|
| 200 |
+
"TGF-beta-driven fibrosis",
|
| 201 |
+
"inflammatory macrophage infiltration",
|
| 202 |
+
],
|
| 203 |
+
responding_cell_types=["cardiac_fibroblast", "macrophage"],
|
| 204 |
+
),
|
| 205 |
+
"IPF": DiseaseProfile(
|
| 206 |
+
name="IPF",
|
| 207 |
+
display_name="idiopathic pulmonary fibrosis",
|
| 208 |
+
tissue="lung",
|
| 209 |
+
condition_name="IPF",
|
| 210 |
+
de_genes={
|
| 211 |
+
"SPP1": (2.0, 4.0), "MERTK": (0.8, 2.0), "MMP9": (1.0, 2.5),
|
| 212 |
+
"TREM2": (0.8, 2.0), "COL1A1": (1.5, 3.0), "COL3A1": (1.2, 2.5),
|
| 213 |
+
"POSTN": (1.5, 3.5), "SFTPC": (-2.0, -0.8), "AGER": (-2.5, -1.0),
|
| 214 |
+
},
|
| 215 |
+
pathways={
|
| 216 |
+
"extracellular_matrix_organisation": (0.75, 0.95),
|
| 217 |
+
"integrin_signalling": (0.6, 0.85),
|
| 218 |
+
"macrophage_activation": (0.65, 0.9),
|
| 219 |
+
"Wnt_signalling": (0.4, 0.7),
|
| 220 |
+
},
|
| 221 |
+
markers=["SPP1", "MERTK", "POSTN", "MMP9"],
|
| 222 |
+
mechanism_templates=[
|
| 223 |
+
"SPP1+ macrophage-driven fibroblast activation",
|
| 224 |
+
"integrin-mediated SPP1 signalling in fibrosis",
|
| 225 |
+
],
|
| 226 |
+
responding_cell_types=["fibroblast", "alveolar_macrophage"],
|
| 227 |
+
),
|
| 228 |
+
"Alzheimer": DiseaseProfile(
|
| 229 |
+
name="Alzheimer",
|
| 230 |
+
display_name="Alzheimer's disease",
|
| 231 |
+
tissue="brain",
|
| 232 |
+
condition_name="Alzheimer",
|
| 233 |
+
de_genes={
|
| 234 |
+
"TREM2": (1.0, 2.5), "APOE": (1.2, 2.8), "CLU": (0.8, 2.0),
|
| 235 |
+
"C1QA": (1.0, 2.2), "C1QB": (0.9, 2.0), "GFAP": (1.5, 3.0),
|
| 236 |
+
"AQP4": (0.6, 1.5), "SLC17A7": (-1.5, -0.5), "NRGN": (-2.0, -0.8),
|
| 237 |
+
"SNAP25": (-1.2, -0.4),
|
| 238 |
+
},
|
| 239 |
+
pathways={
|
| 240 |
+
"complement_cascade": (0.7, 0.9),
|
| 241 |
+
"neuroinflammation": (0.65, 0.85),
|
| 242 |
+
"amyloid_processing": (0.6, 0.8),
|
| 243 |
+
"synaptic_signalling": (0.3, 0.5),
|
| 244 |
+
"lipid_metabolism": (0.5, 0.7),
|
| 245 |
+
},
|
| 246 |
+
markers=["TREM2", "APOE", "GFAP", "C1QA"],
|
| 247 |
+
mechanism_templates=[
|
| 248 |
+
"TREM2-mediated microglial activation in amyloid clearance",
|
| 249 |
+
"complement-driven synaptic pruning",
|
| 250 |
+
"reactive astrogliosis amplifying neuroinflammation",
|
| 251 |
+
],
|
| 252 |
+
responding_cell_types=["microglia", "astrocyte"],
|
| 253 |
+
),
|
| 254 |
+
"colorectal_cancer": DiseaseProfile(
|
| 255 |
+
name="colorectal_cancer",
|
| 256 |
+
display_name="colorectal cancer",
|
| 257 |
+
tissue="colon",
|
| 258 |
+
condition_name="colorectal_cancer",
|
| 259 |
+
de_genes={
|
| 260 |
+
"MYC": (1.5, 3.0), "KRAS": (0.8, 1.8), "TP53": (-1.5, -0.5),
|
| 261 |
+
"APC": (-1.8, -0.8), "CDH1": (-1.2, -0.4), "VIM": (1.0, 2.5),
|
| 262 |
+
"MKI67": (1.5, 3.0), "CD44": (1.0, 2.2), "LGR5": (0.8, 2.0),
|
| 263 |
+
},
|
| 264 |
+
pathways={
|
| 265 |
+
"Wnt_signalling": (0.75, 0.95),
|
| 266 |
+
"cell_cycle": (0.7, 0.9),
|
| 267 |
+
"EMT": (0.6, 0.85),
|
| 268 |
+
"p53_signalling": (0.3, 0.5),
|
| 269 |
+
"MAPK_signalling": (0.55, 0.75),
|
| 270 |
+
},
|
| 271 |
+
markers=["MYC", "CD44", "VIM", "MKI67"],
|
| 272 |
+
mechanism_templates=[
|
| 273 |
+
"Wnt/beta-catenin-driven tumour stem cell expansion",
|
| 274 |
+
"epithelial-mesenchymal transition promoting invasion",
|
| 275 |
+
],
|
| 276 |
+
responding_cell_types=["stem_cell", "macrophage", "fibroblast"],
|
| 277 |
+
),
|
| 278 |
+
"type2_diabetes": DiseaseProfile(
|
| 279 |
+
name="type2_diabetes",
|
| 280 |
+
display_name="type 2 diabetes",
|
| 281 |
+
tissue="pancreas",
|
| 282 |
+
condition_name="type2_diabetes",
|
| 283 |
+
de_genes={
|
| 284 |
+
"INS": (-2.0, -0.8), "MAFA": (-1.5, -0.5), "PDX1": (-1.2, -0.4),
|
| 285 |
+
"UCN3": (-1.8, -0.6), "GCG": (0.8, 2.0), "ARX": (0.5, 1.5),
|
| 286 |
+
"IAPP": (0.6, 1.8), "TXNIP": (1.0, 2.5), "DDIT3": (0.8, 2.0),
|
| 287 |
+
},
|
| 288 |
+
pathways={
|
| 289 |
+
"insulin_signalling": (0.3, 0.5),
|
| 290 |
+
"ER_stress_response": (0.7, 0.9),
|
| 291 |
+
"oxidative_stress": (0.6, 0.8),
|
| 292 |
+
"glucagon_signalling": (0.6, 0.8),
|
| 293 |
+
"apoptosis": (0.5, 0.7),
|
| 294 |
+
},
|
| 295 |
+
markers=["INS", "TXNIP", "IAPP", "DDIT3"],
|
| 296 |
+
mechanism_templates=[
|
| 297 |
+
"ER stress-induced beta cell apoptosis",
|
| 298 |
+
"glucotoxicity-driven beta cell dedifferentiation",
|
| 299 |
+
],
|
| 300 |
+
responding_cell_types=["beta_cell", "stellate"],
|
| 301 |
+
),
|
| 302 |
+
"AML": DiseaseProfile(
|
| 303 |
+
name="AML",
|
| 304 |
+
display_name="acute myeloid leukemia",
|
| 305 |
+
tissue="bone_marrow",
|
| 306 |
+
condition_name="AML",
|
| 307 |
+
de_genes={
|
| 308 |
+
"FLT3": (1.5, 3.0), "NPM1": (0.8, 2.0), "IDH2": (0.6, 1.5),
|
| 309 |
+
"RUNX1": (-1.5, -0.5), "CEBPA": (-1.2, -0.4), "KIT": (1.0, 2.5),
|
| 310 |
+
"WT1": (1.2, 2.8), "MYC": (0.8, 2.0),
|
| 311 |
+
},
|
| 312 |
+
pathways={
|
| 313 |
+
"hematopoietic_cell_lineage": (0.3, 0.5),
|
| 314 |
+
"MAPK_signalling": (0.7, 0.9),
|
| 315 |
+
"PI3K_AKT_signalling": (0.65, 0.85),
|
| 316 |
+
"cell_cycle": (0.7, 0.9),
|
| 317 |
+
"apoptosis": (0.3, 0.5),
|
| 318 |
+
},
|
| 319 |
+
markers=["FLT3", "NPM1", "WT1", "KIT"],
|
| 320 |
+
mechanism_templates=[
|
| 321 |
+
"FLT3-ITD-driven proliferative advantage",
|
| 322 |
+
"myeloid differentiation block via RUNX1 loss",
|
| 323 |
+
],
|
| 324 |
+
responding_cell_types=["HSC", "GMP"],
|
| 325 |
+
),
|
| 326 |
+
"rheumatoid_arthritis": DiseaseProfile(
|
| 327 |
+
name="rheumatoid_arthritis",
|
| 328 |
+
display_name="rheumatoid arthritis",
|
| 329 |
+
tissue="synovium",
|
| 330 |
+
condition_name="rheumatoid_arthritis",
|
| 331 |
+
de_genes={
|
| 332 |
+
"IFNG": (1.0, 2.5), "TBX21": (0.8, 1.8), "IL17A": (1.0, 2.2),
|
| 333 |
+
"RORC": (0.6, 1.5), "TNF": (1.2, 2.5), "IL6": (1.0, 2.2),
|
| 334 |
+
"MMP3": (1.5, 3.0), "MMP1": (1.2, 2.5), "CXCL13": (1.0, 2.5),
|
| 335 |
+
},
|
| 336 |
+
pathways={
|
| 337 |
+
"JAK_STAT_signalling": (0.7, 0.9),
|
| 338 |
+
"TNF_signalling": (0.7, 0.9),
|
| 339 |
+
"Th17_differentiation": (0.6, 0.8),
|
| 340 |
+
"NF_kB_signalling": (0.65, 0.85),
|
| 341 |
+
"matrix_metalloproteinase_activity": (0.7, 0.9),
|
| 342 |
+
},
|
| 343 |
+
markers=["TNF", "IL6", "MMP3", "CXCL13"],
|
| 344 |
+
mechanism_templates=[
|
| 345 |
+
"TNF/NF-kB-driven synovial inflammation",
|
| 346 |
+
"Th17-mediated cartilage destruction via MMPs",
|
| 347 |
+
],
|
| 348 |
+
responding_cell_types=["fibroblast", "macrophage", "CD4_T_cell"],
|
| 349 |
+
),
|
| 350 |
+
"hepatocellular_carcinoma": DiseaseProfile(
|
| 351 |
+
name="hepatocellular_carcinoma",
|
| 352 |
+
display_name="hepatocellular carcinoma",
|
| 353 |
+
tissue="liver",
|
| 354 |
+
condition_name="HCC",
|
| 355 |
+
de_genes={
|
| 356 |
+
"GPC3": (2.0, 4.0), "AFP": (1.5, 3.5), "EPCAM": (1.0, 2.5),
|
| 357 |
+
"MYC": (1.0, 2.5), "VEGFA": (1.2, 2.8), "MKI67": (1.5, 3.0),
|
| 358 |
+
"ALB": (-2.0, -0.8), "CYP3A4": (-1.8, -0.6), "APOB": (-1.5, -0.5),
|
| 359 |
+
},
|
| 360 |
+
pathways={
|
| 361 |
+
"Wnt_signalling": (0.7, 0.9),
|
| 362 |
+
"cell_cycle": (0.75, 0.95),
|
| 363 |
+
"angiogenesis": (0.6, 0.8),
|
| 364 |
+
"PI3K_AKT_signalling": (0.65, 0.85),
|
| 365 |
+
"p53_signalling": (0.3, 0.5),
|
| 366 |
+
},
|
| 367 |
+
markers=["GPC3", "AFP", "VEGFA", "MKI67"],
|
| 368 |
+
mechanism_templates=[
|
| 369 |
+
"Wnt/beta-catenin-driven hepatocyte dedifferentiation",
|
| 370 |
+
"VEGF-mediated tumour angiogenesis",
|
| 371 |
+
],
|
| 372 |
+
responding_cell_types=["kupffer_cell", "stellate_cell"],
|
| 373 |
+
hidden_failure_templates=[
|
| 374 |
+
"Tumour heterogeneity may confound DE in mixed biopsies",
|
| 375 |
+
],
|
| 376 |
+
),
|
| 377 |
+
"atherosclerosis": DiseaseProfile(
|
| 378 |
+
name="atherosclerosis",
|
| 379 |
+
display_name="atherosclerosis",
|
| 380 |
+
tissue="aorta",
|
| 381 |
+
condition_name="atherosclerosis",
|
| 382 |
+
de_genes={
|
| 383 |
+
"TREM2": (1.5, 3.0), "CD9": (1.0, 2.2), "LGALS3": (1.2, 2.5),
|
| 384 |
+
"APOE": (0.8, 2.0), "MMP9": (1.0, 2.5), "IL1B": (0.8, 2.0),
|
| 385 |
+
"ACTA2": (-1.5, -0.5), "MYH11": (-2.0, -0.8), "CNN1": (-1.5, -0.5),
|
| 386 |
+
},
|
| 387 |
+
pathways={
|
| 388 |
+
"lipid_metabolism": (0.7, 0.9),
|
| 389 |
+
"inflammatory_response": (0.65, 0.85),
|
| 390 |
+
"foam_cell_formation": (0.7, 0.9),
|
| 391 |
+
"smooth_muscle_contraction": (0.3, 0.5),
|
| 392 |
+
"complement_cascade": (0.5, 0.7),
|
| 393 |
+
},
|
| 394 |
+
markers=["TREM2", "LGALS3", "MMP9", "CD9"],
|
| 395 |
+
mechanism_templates=[
|
| 396 |
+
"TREM2+ macrophage-driven foam cell formation",
|
| 397 |
+
"smooth muscle cell phenotypic switching in plaque",
|
| 398 |
+
],
|
| 399 |
+
responding_cell_types=["macrophage", "smooth_muscle"],
|
| 400 |
+
),
|
| 401 |
+
"breast_cancer": DiseaseProfile(
|
| 402 |
+
name="breast_cancer",
|
| 403 |
+
display_name="breast cancer",
|
| 404 |
+
tissue="breast",
|
| 405 |
+
condition_name="breast_cancer",
|
| 406 |
+
de_genes={
|
| 407 |
+
"ERBB2": (1.5, 3.5), "ESR1": (-1.5, 1.5), "MKI67": (1.5, 3.0),
|
| 408 |
+
"MYC": (1.0, 2.5), "CDH1": (-1.5, -0.3), "VIM": (0.8, 2.2),
|
| 409 |
+
"CD274": (0.8, 2.0), "FOXP3": (0.6, 1.5), "GZMB": (0.8, 2.0),
|
| 410 |
+
},
|
| 411 |
+
pathways={
|
| 412 |
+
"cell_cycle": (0.7, 0.9),
|
| 413 |
+
"PI3K_AKT_signalling": (0.65, 0.85),
|
| 414 |
+
"EMT": (0.55, 0.8),
|
| 415 |
+
"immune_checkpoint": (0.5, 0.75),
|
| 416 |
+
"estrogen_signalling": (0.3, 0.7),
|
| 417 |
+
},
|
| 418 |
+
markers=["ERBB2", "MKI67", "CD274", "VIM"],
|
| 419 |
+
mechanism_templates=[
|
| 420 |
+
"ERBB2-driven proliferative signalling",
|
| 421 |
+
"immune evasion via PD-L1 upregulation",
|
| 422 |
+
],
|
| 423 |
+
responding_cell_types=["macrophage", "T_cell", "fibroblast"],
|
| 424 |
+
),
|
| 425 |
+
"multiple_sclerosis": DiseaseProfile(
|
| 426 |
+
name="multiple_sclerosis",
|
| 427 |
+
display_name="multiple sclerosis",
|
| 428 |
+
tissue="brain",
|
| 429 |
+
condition_name="multiple_sclerosis",
|
| 430 |
+
de_genes={
|
| 431 |
+
"CD68": (1.0, 2.5), "CXCL10": (1.2, 2.8), "STAT1": (0.8, 2.0),
|
| 432 |
+
"IRF1": (0.8, 1.8), "MBP": (-2.0, -0.8), "PLP1": (-1.8, -0.6),
|
| 433 |
+
"MOG": (-1.5, -0.5), "GFAP": (1.0, 2.5), "C3": (0.8, 2.0),
|
| 434 |
+
},
|
| 435 |
+
pathways={
|
| 436 |
+
"interferon_signalling": (0.7, 0.9),
|
| 437 |
+
"neuroinflammation": (0.7, 0.9),
|
| 438 |
+
"complement_cascade": (0.6, 0.8),
|
| 439 |
+
"myelination": (0.2, 0.4),
|
| 440 |
+
"T_cell_activation": (0.6, 0.8),
|
| 441 |
+
},
|
| 442 |
+
markers=["CXCL10", "STAT1", "GFAP", "C3"],
|
| 443 |
+
mechanism_templates=[
|
| 444 |
+
"interferon-driven microglial activation in demyelination",
|
| 445 |
+
"complement-mediated oligodendrocyte damage",
|
| 446 |
+
],
|
| 447 |
+
responding_cell_types=["microglia", "astrocyte"],
|
| 448 |
+
),
|
| 449 |
+
"diabetic_nephropathy": DiseaseProfile(
|
| 450 |
+
name="diabetic_nephropathy",
|
| 451 |
+
display_name="diabetic nephropathy",
|
| 452 |
+
tissue="kidney",
|
| 453 |
+
condition_name="diabetic_nephropathy",
|
| 454 |
+
de_genes={
|
| 455 |
+
"HAVCR1": (1.5, 3.0), "LCN2": (1.2, 2.8), "COL4A1": (1.0, 2.5),
|
| 456 |
+
"VEGFA": (0.8, 2.0), "NPHS1": (-1.8, -0.6), "NPHS2": (-1.5, -0.5),
|
| 457 |
+
"WT1": (-1.2, -0.4), "TGFB1": (1.0, 2.2), "FN1": (1.2, 2.5),
|
| 458 |
+
},
|
| 459 |
+
pathways={
|
| 460 |
+
"TGF_beta_signalling": (0.7, 0.9),
|
| 461 |
+
"extracellular_matrix_organisation": (0.7, 0.9),
|
| 462 |
+
"oxidative_stress": (0.6, 0.8),
|
| 463 |
+
"VEGF_signalling": (0.5, 0.7),
|
| 464 |
+
"apoptosis": (0.5, 0.7),
|
| 465 |
+
},
|
| 466 |
+
markers=["HAVCR1", "LCN2", "TGFB1", "FN1"],
|
| 467 |
+
mechanism_templates=[
|
| 468 |
+
"TGF-beta-driven glomerular fibrosis",
|
| 469 |
+
"podocyte loss via oxidative stress",
|
| 470 |
+
],
|
| 471 |
+
responding_cell_types=["fibroblast", "macrophage"],
|
| 472 |
+
),
|
| 473 |
+
"melanoma": DiseaseProfile(
|
| 474 |
+
name="melanoma",
|
| 475 |
+
display_name="melanoma",
|
| 476 |
+
tissue="skin",
|
| 477 |
+
condition_name="melanoma",
|
| 478 |
+
de_genes={
|
| 479 |
+
"MLANA": (1.5, 3.0), "PMEL": (1.2, 2.5), "SOX10": (1.0, 2.2),
|
| 480 |
+
"MKI67": (1.5, 3.0), "CD274": (0.8, 2.0), "PDCD1": (0.8, 2.0),
|
| 481 |
+
"GZMB": (0.8, 2.0), "HAVCR2": (0.6, 1.5), "LAG3": (0.6, 1.5),
|
| 482 |
+
},
|
| 483 |
+
pathways={
|
| 484 |
+
"MAPK_signalling": (0.7, 0.9),
|
| 485 |
+
"immune_checkpoint": (0.6, 0.85),
|
| 486 |
+
"cell_cycle": (0.7, 0.9),
|
| 487 |
+
"melanogenesis": (0.5, 0.7),
|
| 488 |
+
"T_cell_exhaustion": (0.55, 0.8),
|
| 489 |
+
},
|
| 490 |
+
markers=["MLANA", "CD274", "GZMB", "MKI67"],
|
| 491 |
+
mechanism_templates=[
|
| 492 |
+
"MAPK-driven melanocyte proliferation",
|
| 493 |
+
"T cell exhaustion via immune checkpoint upregulation",
|
| 494 |
+
],
|
| 495 |
+
responding_cell_types=["T_cell", "macrophage"],
|
| 496 |
+
),
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# ── Pathway library ─────────────────────────────────────────────────────────
|
| 501 |
+
|
| 502 |
+
PATHWAY_LIBRARY: Dict[str, List[str]] = {
|
| 503 |
+
"TGF_beta_signalling": ["TGFB1", "TGFB2", "SMAD2", "SMAD3", "SMAD4", "ACVR1"],
|
| 504 |
+
"Wnt_signalling": ["WNT3A", "CTNNB1", "APC", "AXIN2", "LGR5", "TCF7L2"],
|
| 505 |
+
"MAPK_signalling": ["KRAS", "BRAF", "MAP2K1", "MAPK1", "MAPK3", "FOS", "JUN"],
|
| 506 |
+
"JAK_STAT_signalling": ["JAK1", "JAK2", "STAT1", "STAT3", "STAT5A", "SOCS1", "SOCS3"],
|
| 507 |
+
"PI3K_AKT_signalling": ["PIK3CA", "AKT1", "MTOR", "PTEN", "TSC2"],
|
| 508 |
+
"NF_kB_signalling": ["NFKB1", "RELA", "IKBKB", "TNF", "IL1B"],
|
| 509 |
+
"cell_cycle": ["CDK4", "CDK6", "CCND1", "CCNE1", "RB1", "E2F1", "MKI67"],
|
| 510 |
+
"apoptosis": ["BCL2", "BAX", "BAK1", "CASP3", "CASP9", "TP53", "BID"],
|
| 511 |
+
"inflammatory_response": ["TNF", "IL6", "IL1B", "CCL2", "CXCL8", "NFKB1"],
|
| 512 |
+
"extracellular_matrix_organisation": ["COL1A1", "COL3A1", "FN1", "POSTN", "MMP2", "MMP9", "TIMP1"],
|
| 513 |
+
"complement_cascade": ["C1QA", "C1QB", "C3", "C4A", "C5", "CFB"],
|
| 514 |
+
"neuroinflammation": ["TREM2", "CX3CR1", "P2RY12", "IL1B", "TNF", "C1QA"],
|
| 515 |
+
"synaptic_signalling": ["SLC17A7", "GRIA1", "GRIN1", "DLG4", "SNAP25", "SYP"],
|
| 516 |
+
"hematopoietic_cell_lineage": ["CD34", "KIT", "FLT3", "GATA1", "CEBPA", "SPI1"],
|
| 517 |
+
"insulin_signalling": ["INS", "INSR", "IRS1", "PIK3CA", "AKT1", "SLC2A4"],
|
| 518 |
+
"ER_stress_response": ["DDIT3", "ATF4", "XBP1", "HSPA5", "EIF2AK3"],
|
| 519 |
+
"oxidative_stress": ["SOD1", "SOD2", "CAT", "GPX1", "NFE2L2", "HMOX1"],
|
| 520 |
+
"angiogenesis": ["VEGFA", "VEGFB", "KDR", "FLT1", "ANGPT1", "ANGPT2"],
|
| 521 |
+
"EMT": ["CDH1", "CDH2", "VIM", "SNAI1", "SNAI2", "TWIST1", "ZEB1"],
|
| 522 |
+
"immune_checkpoint": ["CD274", "PDCD1", "CTLA4", "HAVCR2", "LAG3", "TIGIT"],
|
| 523 |
+
"T_cell_activation": ["CD3D", "CD28", "LCK", "ZAP70", "IL2", "IFNG"],
|
| 524 |
+
"T_cell_exhaustion": ["PDCD1", "HAVCR2", "LAG3", "TIGIT", "TOX", "ENTPD1"],
|
| 525 |
+
"TNF_signalling": ["TNF", "TNFRSF1A", "TRADD", "RIPK1", "NFKB1", "CASP8"],
|
| 526 |
+
"Th17_differentiation": ["IL17A", "IL17F", "RORC", "IL23R", "CCR6", "STAT3"],
|
| 527 |
+
"interferon_signalling": ["IFNG", "IFNB1", "STAT1", "IRF1", "IRF7", "MX1", "OAS1"],
|
| 528 |
+
"lipid_metabolism": ["APOE", "APOB", "LDLR", "HMGCR", "ABCA1", "PPARG"],
|
| 529 |
+
"myelination": ["MBP", "PLP1", "MOG", "MAG", "OLIG2", "SOX10"],
|
| 530 |
+
"foam_cell_formation": ["CD36", "MSR1", "ABCA1", "APOE", "LGALS3", "TREM2"],
|
| 531 |
+
"smooth_muscle_contraction": ["ACTA2", "MYH11", "TAGLN", "CNN1", "MYLK"],
|
| 532 |
+
"glucagon_signalling": ["GCG", "GCGR", "CREB1", "PCK1", "G6PC"],
|
| 533 |
+
"matrix_metalloproteinase_activity": ["MMP1", "MMP2", "MMP3", "MMP9", "TIMP1", "TIMP2"],
|
| 534 |
+
"estrogen_signalling": ["ESR1", "ESR2", "PGR", "GREB1", "TFF1"],
|
| 535 |
+
"melanogenesis": ["MITF", "TYR", "TYRP1", "DCT", "MLANA", "PMEL"],
|
| 536 |
+
"VEGF_signalling": ["VEGFA", "VEGFB", "KDR", "FLT1", "NRP1"],
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
# ── Regulatory network templates ────────────────────────────────────────────
|
| 541 |
+
|
| 542 |
+
REGULATORY_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
|
| 543 |
+
"erythroid": {
|
| 544 |
+
"GATA1": ["KLF1", "HBB", "HBA1", "GYPA", "ALAS2"],
|
| 545 |
+
"KLF1": ["HBB", "HBA1", "SLC4A1"],
|
| 546 |
+
},
|
| 547 |
+
"myeloid": {
|
| 548 |
+
"CEBPA": ["CSF3R", "ELANE", "MPO", "CTSG"],
|
| 549 |
+
"SPI1": ["CSF1R", "CD14", "FCGR3A", "CD68"],
|
| 550 |
+
},
|
| 551 |
+
"lymphoid": {
|
| 552 |
+
"TCF7": ["CD3D", "CD3E", "IL7R", "LEF1"],
|
| 553 |
+
"PAX5": ["CD19", "MS4A1", "CD79A"],
|
| 554 |
+
},
|
| 555 |
+
"fibrotic": {
|
| 556 |
+
"SMAD3": ["COL1A1", "COL3A1", "FN1", "POSTN"],
|
| 557 |
+
"TGFB1": ["ACTA2", "COL1A1", "CTGF"],
|
| 558 |
+
},
|
| 559 |
+
"inflammatory": {
|
| 560 |
+
"NFKB1": ["TNF", "IL6", "IL1B", "CCL2", "CXCL8"],
|
| 561 |
+
"STAT1": ["IRF1", "CXCL10", "MX1", "OAS1"],
|
| 562 |
+
},
|
| 563 |
+
"stem_cell": {
|
| 564 |
+
"RUNX1": ["CD34", "KIT", "FLT3"],
|
| 565 |
+
"MYC": ["CDK4", "CCND1", "E2F1"],
|
| 566 |
+
},
|
| 567 |
+
"neuronal": {
|
| 568 |
+
"NEUROD1": ["SLC17A7", "NRGN", "SNAP25"],
|
| 569 |
+
"DLX1": ["GAD1", "GAD2", "SLC32A1"],
|
| 570 |
+
},
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
# ── Perturbation templates ──────────────────────────────────────────────────
|
| 575 |
+
|
| 576 |
+
@dataclass
|
| 577 |
+
class PerturbationTemplate:
|
| 578 |
+
name: str
|
| 579 |
+
target_pathway: str
|
| 580 |
+
gene_effects: Dict[str, float]
|
| 581 |
+
description: str
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
PERTURBATION_TEMPLATES: Dict[str, PerturbationTemplate] = {
|
| 585 |
+
"JAK_inhibitor": PerturbationTemplate(
|
| 586 |
+
name="JAK_inhibitor",
|
| 587 |
+
target_pathway="JAK_STAT_signalling",
|
| 588 |
+
gene_effects={"STAT1": -0.8, "STAT3": -0.7, "IFNG": -1.5, "IL17A": -1.3, "SOCS1": 1.2},
|
| 589 |
+
description="JAK inhibitor treatment",
|
| 590 |
+
),
|
| 591 |
+
"anti_TNF": PerturbationTemplate(
|
| 592 |
+
name="anti_TNF",
|
| 593 |
+
target_pathway="TNF_signalling",
|
| 594 |
+
gene_effects={"TNF": -1.5, "IL6": -1.0, "IL1B": -0.8, "MMP3": -1.2, "SOCS3": 0.8},
|
| 595 |
+
description="anti-TNF biologic therapy",
|
| 596 |
+
),
|
| 597 |
+
"PD1_blockade": PerturbationTemplate(
|
| 598 |
+
name="PD1_blockade",
|
| 599 |
+
target_pathway="immune_checkpoint",
|
| 600 |
+
gene_effects={"PDCD1": -1.0, "GZMB": 1.5, "IFNG": 1.2, "PRF1": 1.0, "TNF": 0.8},
|
| 601 |
+
description="anti-PD-1 immune checkpoint blockade",
|
| 602 |
+
),
|
| 603 |
+
"BRAF_inhibitor": PerturbationTemplate(
|
| 604 |
+
name="BRAF_inhibitor",
|
| 605 |
+
target_pathway="MAPK_signalling",
|
| 606 |
+
gene_effects={"BRAF": -0.5, "MAPK1": -1.0, "MKI67": -1.5, "CCND1": -1.2, "FOS": -0.8},
|
| 607 |
+
description="BRAF inhibitor treatment",
|
| 608 |
+
),
|
| 609 |
+
"TGFb_inhibitor": PerturbationTemplate(
|
| 610 |
+
name="TGFb_inhibitor",
|
| 611 |
+
target_pathway="TGF_beta_signalling",
|
| 612 |
+
gene_effects={"TGFB1": -0.8, "COL1A1": -1.2, "COL3A1": -1.0, "POSTN": -1.5, "ACTA2": -0.8},
|
| 613 |
+
description="TGF-beta pathway inhibitor",
|
| 614 |
+
),
|
| 615 |
+
"mTOR_inhibitor": PerturbationTemplate(
|
| 616 |
+
name="mTOR_inhibitor",
|
| 617 |
+
target_pathway="PI3K_AKT_signalling",
|
| 618 |
+
gene_effects={"MTOR": -0.8, "AKT1": -0.6, "MKI67": -1.2, "CCND1": -1.0, "HIF1A": -0.7},
|
| 619 |
+
description="mTOR inhibitor treatment",
|
| 620 |
+
),
|
| 621 |
+
"CRISPR_TP53_KO": PerturbationTemplate(
|
| 622 |
+
name="CRISPR_TP53_KO",
|
| 623 |
+
target_pathway="p53_signalling",
|
| 624 |
+
gene_effects={"TP53": -2.0, "BAX": -1.0, "CDKN1A": -1.5, "MDM2": -0.8, "MKI67": 1.0},
|
| 625 |
+
description="CRISPR knockout of TP53",
|
| 626 |
+
),
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
# ── Trajectory templates ────────────────────────────────────────────────────
|
| 631 |
+
|
| 632 |
+
@dataclass
|
| 633 |
+
class TrajectoryTemplate:
|
| 634 |
+
"""Template for a developmental trajectory through cell populations."""
|
| 635 |
+
root_population: str
|
| 636 |
+
branches: List[List[str]]
|
| 637 |
+
n_lineages: int
|
| 638 |
+
tissue: str
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
TRAJECTORY_TEMPLATES: List[TrajectoryTemplate] = [
|
| 642 |
+
TrajectoryTemplate(
|
| 643 |
+
root_population="HSC",
|
| 644 |
+
branches=[
|
| 645 |
+
["HSC", "CMP", "GMP", "neutrophil"],
|
| 646 |
+
["HSC", "CMP", "GMP", "monocyte"],
|
| 647 |
+
["HSC", "MEP", "erythrocyte"],
|
| 648 |
+
["HSC", "MEP", "megakaryocyte"],
|
| 649 |
+
],
|
| 650 |
+
n_lineages=3,
|
| 651 |
+
tissue="bone_marrow",
|
| 652 |
+
),
|
| 653 |
+
TrajectoryTemplate(
|
| 654 |
+
root_population="double_negative_T",
|
| 655 |
+
branches=[
|
| 656 |
+
["double_negative_T", "double_positive_T", "CD4_SP"],
|
| 657 |
+
["double_negative_T", "double_positive_T", "CD8_SP"],
|
| 658 |
+
],
|
| 659 |
+
n_lineages=2,
|
| 660 |
+
tissue="thymus",
|
| 661 |
+
),
|
| 662 |
+
TrajectoryTemplate(
|
| 663 |
+
root_population="stem_cell",
|
| 664 |
+
branches=[
|
| 665 |
+
["stem_cell", "colonocyte"],
|
| 666 |
+
["stem_cell", "goblet_cell"],
|
| 667 |
+
],
|
| 668 |
+
n_lineages=2,
|
| 669 |
+
tissue="colon",
|
| 670 |
+
),
|
| 671 |
+
TrajectoryTemplate(
|
| 672 |
+
root_population="OPC",
|
| 673 |
+
branches=[
|
| 674 |
+
["OPC", "oligodendrocyte"],
|
| 675 |
+
],
|
| 676 |
+
n_lineages=1,
|
| 677 |
+
tissue="brain",
|
| 678 |
+
),
|
| 679 |
+
]
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# ── Hidden failure condition templates ──────────────────────────────────────
|
| 683 |
+
|
| 684 |
+
HIDDEN_FAILURE_TEMPLATES: List[str] = [
|
| 685 |
+
"High ambient RNA may confound DE in low-abundance transcripts",
|
| 686 |
+
"Strong batch effects between conditions may inflate false positives",
|
| 687 |
+
"Low cell viability in disease samples reduces statistical power",
|
| 688 |
+
"Doublet contamination in dense populations obscures rare cell types",
|
| 689 |
+
"Sample degradation during processing introduces 3' bias artefacts",
|
| 690 |
+
"Dissociation-induced gene expression changes confound stress signatures",
|
| 691 |
+
"Unbalanced sample sizes between conditions reduce DE sensitivity",
|
| 692 |
+
]
|
server/tasks/generator.py
CHANGED
|
@@ -12,7 +12,7 @@ from typing import List, Optional, Tuple
|
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
|
| 15 |
-
from models import TaskSpec
|
| 16 |
|
| 17 |
from server.simulator.latent_state import (
|
| 18 |
CellPopulation,
|
|
@@ -24,6 +24,7 @@ from server.simulator.latent_state import (
|
|
| 24 |
TechnicalState,
|
| 25 |
)
|
| 26 |
from .scenarios import SCENARIO_LIBRARY, Scenario
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class TaskGenerator:
|
|
@@ -34,7 +35,10 @@ class TaskGenerator:
|
|
| 34 |
scenarios: Optional[List[Scenario]] = None,
|
| 35 |
domain_randomise: bool = True,
|
| 36 |
):
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
self.domain_randomise = domain_randomise
|
| 39 |
|
| 40 |
def generate(
|
|
@@ -58,6 +62,14 @@ class TaskGenerator:
|
|
| 58 |
if self.domain_randomise:
|
| 59 |
self._randomise(rng, task, biology, technical)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
latent = FullLatentState(
|
| 62 |
biology=biology,
|
| 63 |
technical=technical,
|
|
@@ -67,6 +79,7 @@ class TaskGenerator:
|
|
| 67 |
time_limit_days=task.time_limit_days,
|
| 68 |
),
|
| 69 |
hidden_failure_conditions=list(scenario.hidden_failure_conditions),
|
|
|
|
| 70 |
rng_seed=seed or 0,
|
| 71 |
)
|
| 72 |
return task, latent
|
|
|
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
|
| 15 |
+
from models import TaskSpec, tools_for_modality, assays_for_modality
|
| 16 |
|
| 17 |
from server.simulator.latent_state import (
|
| 18 |
CellPopulation,
|
|
|
|
| 24 |
TechnicalState,
|
| 25 |
)
|
| 26 |
from .scenarios import SCENARIO_LIBRARY, Scenario
|
| 27 |
+
from .procedural_generator import generate_procedural_scenarios
|
| 28 |
|
| 29 |
|
| 30 |
class TaskGenerator:
|
|
|
|
| 35 |
scenarios: Optional[List[Scenario]] = None,
|
| 36 |
domain_randomise: bool = True,
|
| 37 |
):
|
| 38 |
+
if scenarios is not None:
|
| 39 |
+
self.scenarios = scenarios
|
| 40 |
+
else:
|
| 41 |
+
self.scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(n=20, seed=42)
|
| 42 |
self.domain_randomise = domain_randomise
|
| 43 |
|
| 44 |
def generate(
|
|
|
|
| 62 |
if self.domain_randomise:
|
| 63 |
self._randomise(rng, task, biology, technical)
|
| 64 |
|
| 65 |
+
# Filter available tools/assays to those compatible with the modality.
|
| 66 |
+
compatible_tools = [t.name for t in tools_for_modality(task.modality)]
|
| 67 |
+
compatible_assays = [a.name for a in assays_for_modality(task.modality)]
|
| 68 |
+
if compatible_tools:
|
| 69 |
+
task.available_tools = compatible_tools
|
| 70 |
+
if compatible_assays:
|
| 71 |
+
task.available_assays = compatible_assays
|
| 72 |
+
|
| 73 |
latent = FullLatentState(
|
| 74 |
biology=biology,
|
| 75 |
technical=technical,
|
|
|
|
| 79 |
time_limit_days=task.time_limit_days,
|
| 80 |
),
|
| 81 |
hidden_failure_conditions=list(scenario.hidden_failure_conditions),
|
| 82 |
+
task_modality=task.modality,
|
| 83 |
rng_seed=seed or 0,
|
| 84 |
)
|
| 85 |
return task, latent
|
server/tasks/procedural_generator.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Procedural scenario generator.
|
| 2 |
+
|
| 3 |
+
Composes biologically coherent ``Scenario`` objects from the curated
|
| 4 |
+
palette in ``bio_palette``, producing fully populated
|
| 5 |
+
``LatentBiologicalState`` instances that drive every simulator tool
|
| 6 |
+
(clustering, DE, pathway enrichment, trajectory, regulatory networks,
|
| 7 |
+
marker validation) with realistic intermediate outputs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from models import TaskSpec
|
| 18 |
+
|
| 19 |
+
from server.simulator.latent_state import (
|
| 20 |
+
CellPopulation,
|
| 21 |
+
LatentBiologicalState,
|
| 22 |
+
TechnicalState,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from .bio_palette import (
|
| 26 |
+
DISEASE_PROFILES,
|
| 27 |
+
HIDDEN_FAILURE_TEMPLATES,
|
| 28 |
+
PATHWAY_LIBRARY,
|
| 29 |
+
PERTURBATION_TEMPLATES,
|
| 30 |
+
REGULATORY_TEMPLATES,
|
| 31 |
+
TISSUE_CELL_TYPES,
|
| 32 |
+
TRAJECTORY_TEMPLATES,
|
| 33 |
+
CellTypeTemplate,
|
| 34 |
+
DiseaseProfile,
|
| 35 |
+
)
|
| 36 |
+
from .scenarios import Scenario
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
SCENARIO_TYPES = ("de", "trajectory", "perturbation", "biomarker")
|
| 41 |
+
|
| 42 |
+
_DIFFICULTY_PARAMS = {
|
| 43 |
+
"easy": {
|
| 44 |
+
"n_pops": (4, 5),
|
| 45 |
+
"de_scale": (1.2, 1.6),
|
| 46 |
+
"noise_dropout": (0.05, 0.10),
|
| 47 |
+
"noise_doublet": (0.03, 0.06),
|
| 48 |
+
"noise_ambient": (0.02, 0.05),
|
| 49 |
+
"noise_batch_strength": (0.05, 0.12),
|
| 50 |
+
"n_batches": (1, 2),
|
| 51 |
+
"budget_range": (70_000, 100_000),
|
| 52 |
+
"time_range": (100, 150),
|
| 53 |
+
"sample_quality": (0.85, 0.95),
|
| 54 |
+
"include_trajectory": False,
|
| 55 |
+
"include_perturbation": False,
|
| 56 |
+
"include_network": False,
|
| 57 |
+
"include_failure_conditions": False,
|
| 58 |
+
},
|
| 59 |
+
"medium": {
|
| 60 |
+
"n_pops": (5, 7),
|
| 61 |
+
"de_scale": (0.9, 1.3),
|
| 62 |
+
"noise_dropout": (0.08, 0.14),
|
| 63 |
+
"noise_doublet": (0.04, 0.08),
|
| 64 |
+
"noise_ambient": (0.03, 0.07),
|
| 65 |
+
"noise_batch_strength": (0.08, 0.18),
|
| 66 |
+
"n_batches": (1, 3),
|
| 67 |
+
"budget_range": (80_000, 120_000),
|
| 68 |
+
"time_range": (120, 180),
|
| 69 |
+
"sample_quality": (0.78, 0.92),
|
| 70 |
+
"include_trajectory": True,
|
| 71 |
+
"include_perturbation": False,
|
| 72 |
+
"include_network": True,
|
| 73 |
+
"include_failure_conditions": False,
|
| 74 |
+
},
|
| 75 |
+
"hard": {
|
| 76 |
+
"n_pops": (6, 8),
|
| 77 |
+
"de_scale": (0.6, 1.0),
|
| 78 |
+
"noise_dropout": (0.10, 0.20),
|
| 79 |
+
"noise_doublet": (0.06, 0.12),
|
| 80 |
+
"noise_ambient": (0.05, 0.10),
|
| 81 |
+
"noise_batch_strength": (0.12, 0.25),
|
| 82 |
+
"n_batches": (2, 4),
|
| 83 |
+
"budget_range": (90_000, 140_000),
|
| 84 |
+
"time_range": (140, 200),
|
| 85 |
+
"sample_quality": (0.65, 0.85),
|
| 86 |
+
"include_trajectory": True,
|
| 87 |
+
"include_perturbation": True,
|
| 88 |
+
"include_network": True,
|
| 89 |
+
"include_failure_conditions": True,
|
| 90 |
+
},
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def generate_scenario(
|
| 95 |
+
seed: int,
|
| 96 |
+
difficulty: str = "medium",
|
| 97 |
+
scenario_type: Optional[str] = None,
|
| 98 |
+
) -> Scenario:
|
| 99 |
+
"""Generate a single procedural scenario with complete latent state.
|
| 100 |
+
|
| 101 |
+
Parameters
|
| 102 |
+
----------
|
| 103 |
+
seed
|
| 104 |
+
RNG seed for reproducibility.
|
| 105 |
+
difficulty
|
| 106 |
+
One of ``"easy"``, ``"medium"``, ``"hard"``.
|
| 107 |
+
scenario_type
|
| 108 |
+
One of ``"de"``, ``"trajectory"``, ``"perturbation"``,
|
| 109 |
+
``"biomarker"``, or ``None`` for random selection.
|
| 110 |
+
"""
|
| 111 |
+
rng = np.random.default_rng(seed)
|
| 112 |
+
params = _DIFFICULTY_PARAMS[difficulty]
|
| 113 |
+
|
| 114 |
+
if scenario_type is None:
|
| 115 |
+
scenario_type = rng.choice(SCENARIO_TYPES)
|
| 116 |
+
|
| 117 |
+
disease_key = rng.choice(list(DISEASE_PROFILES.keys()))
|
| 118 |
+
disease = DISEASE_PROFILES[disease_key]
|
| 119 |
+
tissue = disease.tissue
|
| 120 |
+
|
| 121 |
+
cell_templates = TISSUE_CELL_TYPES.get(tissue, [])
|
| 122 |
+
if not cell_templates:
|
| 123 |
+
tissue = rng.choice(list(TISSUE_CELL_TYPES.keys()))
|
| 124 |
+
cell_templates = TISSUE_CELL_TYPES[tissue]
|
| 125 |
+
|
| 126 |
+
populations = _sample_populations(rng, cell_templates, disease, params)
|
| 127 |
+
de_genes = _build_de_genes(rng, disease, params)
|
| 128 |
+
pathways = _build_pathways(rng, disease)
|
| 129 |
+
markers = _derive_markers(rng, de_genes, disease)
|
| 130 |
+
mechanisms = list(disease.mechanism_templates)
|
| 131 |
+
n_cells = int(rng.integers(8_000, 22_000))
|
| 132 |
+
|
| 133 |
+
trajectory = None
|
| 134 |
+
if scenario_type == "trajectory" or (
|
| 135 |
+
params["include_trajectory"] and rng.random() < 0.4
|
| 136 |
+
):
|
| 137 |
+
trajectory = _build_trajectory(rng, tissue, populations)
|
| 138 |
+
|
| 139 |
+
reg_network: Dict[str, List[str]] = {}
|
| 140 |
+
if scenario_type == "trajectory" or (
|
| 141 |
+
params["include_network"] and rng.random() < 0.5
|
| 142 |
+
):
|
| 143 |
+
reg_network = _build_regulatory_network(rng, tissue, populations)
|
| 144 |
+
|
| 145 |
+
perturbation_effects: Dict[str, Dict[str, float]] = {}
|
| 146 |
+
if scenario_type == "perturbation" or (
|
| 147 |
+
params["include_perturbation"] and rng.random() < 0.5
|
| 148 |
+
):
|
| 149 |
+
perturbation_effects = _build_perturbation(rng, disease)
|
| 150 |
+
|
| 151 |
+
technical = _build_technical(rng, params)
|
| 152 |
+
|
| 153 |
+
hidden_failures: List[str] = []
|
| 154 |
+
if params["include_failure_conditions"] and rng.random() < 0.6:
|
| 155 |
+
n_failures = int(rng.integers(1, 3))
|
| 156 |
+
indices = rng.choice(
|
| 157 |
+
len(HIDDEN_FAILURE_TEMPLATES), size=min(n_failures, len(HIDDEN_FAILURE_TEMPLATES)), replace=False,
|
| 158 |
+
)
|
| 159 |
+
hidden_failures = [HIDDEN_FAILURE_TEMPLATES[i] for i in indices]
|
| 160 |
+
|
| 161 |
+
task = _build_task(rng, disease, tissue, scenario_type, params, perturbation_effects)
|
| 162 |
+
|
| 163 |
+
biology = LatentBiologicalState(
|
| 164 |
+
cell_populations=populations,
|
| 165 |
+
true_de_genes=de_genes,
|
| 166 |
+
true_pathways=pathways,
|
| 167 |
+
true_trajectory=trajectory,
|
| 168 |
+
true_regulatory_network=reg_network,
|
| 169 |
+
perturbation_effects=perturbation_effects,
|
| 170 |
+
true_markers=markers,
|
| 171 |
+
causal_mechanisms=mechanisms,
|
| 172 |
+
n_true_cells=n_cells,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
name = f"proc_{disease.name}_{scenario_type}_{seed}"
|
| 176 |
+
|
| 177 |
+
tags = [scenario_type, "scRNA-seq", tissue, disease.name, difficulty]
|
| 178 |
+
|
| 179 |
+
return Scenario(
|
| 180 |
+
name=name,
|
| 181 |
+
task=task,
|
| 182 |
+
biology=biology,
|
| 183 |
+
technical=technical,
|
| 184 |
+
hidden_failure_conditions=hidden_failures,
|
| 185 |
+
difficulty=difficulty,
|
| 186 |
+
tags=tags,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def generate_procedural_scenarios(
|
| 191 |
+
n: int = 20,
|
| 192 |
+
seed: int = 42,
|
| 193 |
+
) -> List[Scenario]:
|
| 194 |
+
"""Pre-generate a pool of procedural scenarios across difficulties."""
|
| 195 |
+
rng = np.random.default_rng(seed)
|
| 196 |
+
scenarios: List[Scenario] = []
|
| 197 |
+
difficulties = ["easy", "medium", "hard"]
|
| 198 |
+
|
| 199 |
+
for i in range(n):
|
| 200 |
+
diff = difficulties[i % len(difficulties)]
|
| 201 |
+
child_seed = int(rng.integers(0, 2**31))
|
| 202 |
+
scenario = generate_scenario(
|
| 203 |
+
seed=child_seed,
|
| 204 |
+
difficulty=diff,
|
| 205 |
+
scenario_type=None,
|
| 206 |
+
)
|
| 207 |
+
scenarios.append(scenario)
|
| 208 |
+
|
| 209 |
+
logger.info("Generated %d procedural scenarios.", len(scenarios))
|
| 210 |
+
return scenarios
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ── Internal builders ───────────────────────────────────────────────────────
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _sample_populations(
|
| 217 |
+
rng: np.random.Generator,
|
| 218 |
+
templates: List[CellTypeTemplate],
|
| 219 |
+
disease: DiseaseProfile,
|
| 220 |
+
params: dict,
|
| 221 |
+
) -> List[CellPopulation]:
|
| 222 |
+
lo, hi = params["n_pops"]
|
| 223 |
+
n_pops = int(rng.integers(lo, hi + 1))
|
| 224 |
+
n_pops = min(n_pops, len(templates))
|
| 225 |
+
|
| 226 |
+
indices = rng.choice(len(templates), size=n_pops, replace=False)
|
| 227 |
+
selected = [templates[i] for i in sorted(indices)]
|
| 228 |
+
|
| 229 |
+
responding_names = set(disease.responding_cell_types)
|
| 230 |
+
|
| 231 |
+
populations: List[CellPopulation] = []
|
| 232 |
+
for tmpl in selected:
|
| 233 |
+
prop = float(rng.uniform(*tmpl.proportion_range))
|
| 234 |
+
state = rng.choice(tmpl.states)
|
| 235 |
+
|
| 236 |
+
condition_response: Dict[str, float] = {}
|
| 237 |
+
if tmpl.disease_responsive and tmpl.name in responding_names:
|
| 238 |
+
condition_response[disease.condition_name] = float(
|
| 239 |
+
rng.uniform(*tmpl.response_range)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
populations.append(CellPopulation(
|
| 243 |
+
name=tmpl.name,
|
| 244 |
+
proportion=prop,
|
| 245 |
+
marker_genes=list(tmpl.marker_genes),
|
| 246 |
+
state=state,
|
| 247 |
+
condition_response=condition_response,
|
| 248 |
+
))
|
| 249 |
+
|
| 250 |
+
total = sum(p.proportion for p in populations)
|
| 251 |
+
if total > 0:
|
| 252 |
+
for p in populations:
|
| 253 |
+
p.proportion = round(p.proportion / total, 4)
|
| 254 |
+
|
| 255 |
+
return populations
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _build_de_genes(
|
| 259 |
+
rng: np.random.Generator,
|
| 260 |
+
disease: DiseaseProfile,
|
| 261 |
+
params: dict,
|
| 262 |
+
) -> Dict[str, Dict[str, float]]:
|
| 263 |
+
comparison = f"{disease.condition_name}_vs_healthy"
|
| 264 |
+
scale_lo, scale_hi = params["de_scale"]
|
| 265 |
+
|
| 266 |
+
effects: Dict[str, float] = {}
|
| 267 |
+
for gene, (lo, hi) in disease.de_genes.items():
|
| 268 |
+
base = float(rng.uniform(lo, hi))
|
| 269 |
+
scale = float(rng.uniform(scale_lo, scale_hi))
|
| 270 |
+
if base > 0:
|
| 271 |
+
effects[gene] = round(base * scale, 3)
|
| 272 |
+
else:
|
| 273 |
+
effects[gene] = round(base * scale, 3)
|
| 274 |
+
|
| 275 |
+
return {comparison: effects}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _build_pathways(
|
| 279 |
+
rng: np.random.Generator,
|
| 280 |
+
disease: DiseaseProfile,
|
| 281 |
+
) -> Dict[str, float]:
|
| 282 |
+
pathways: Dict[str, float] = {}
|
| 283 |
+
for pw, (lo, hi) in disease.pathways.items():
|
| 284 |
+
pathways[pw] = round(float(rng.uniform(lo, hi)), 3)
|
| 285 |
+
return pathways
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _derive_markers(
|
| 289 |
+
rng: np.random.Generator,
|
| 290 |
+
de_genes: Dict[str, Dict[str, float]],
|
| 291 |
+
disease: DiseaseProfile,
|
| 292 |
+
) -> List[str]:
|
| 293 |
+
markers = list(disease.markers)
|
| 294 |
+
|
| 295 |
+
all_effects: Dict[str, float] = {}
|
| 296 |
+
for effects in de_genes.values():
|
| 297 |
+
all_effects.update(effects)
|
| 298 |
+
|
| 299 |
+
for gene in markers:
|
| 300 |
+
if gene not in all_effects:
|
| 301 |
+
all_effects[gene] = float(rng.uniform(1.0, 2.5))
|
| 302 |
+
for comp_effects in de_genes.values():
|
| 303 |
+
comp_effects[gene] = all_effects[gene]
|
| 304 |
+
|
| 305 |
+
n_markers = min(len(markers), int(rng.integers(3, 7)))
|
| 306 |
+
return markers[:n_markers]
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _build_trajectory(
|
| 310 |
+
rng: np.random.Generator,
|
| 311 |
+
tissue: str,
|
| 312 |
+
populations: List[CellPopulation],
|
| 313 |
+
) -> Optional[Dict[str, Any]]:
|
| 314 |
+
pop_names = {p.name for p in populations}
|
| 315 |
+
|
| 316 |
+
for tmpl in TRAJECTORY_TEMPLATES:
|
| 317 |
+
if tmpl.tissue == tissue:
|
| 318 |
+
valid_branches = [
|
| 319 |
+
branch for branch in tmpl.branches
|
| 320 |
+
if all(node in pop_names for node in branch)
|
| 321 |
+
]
|
| 322 |
+
if valid_branches:
|
| 323 |
+
return {
|
| 324 |
+
"root": tmpl.root_population,
|
| 325 |
+
"n_lineages": len(valid_branches),
|
| 326 |
+
"branching": len(valid_branches) > 1,
|
| 327 |
+
"branches": valid_branches,
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
if len(populations) >= 3:
|
| 331 |
+
root = populations[0].name
|
| 332 |
+
branches = [[root, p.name] for p in populations[1:]]
|
| 333 |
+
selected = branches[:int(rng.integers(2, min(4, len(branches)) + 1))]
|
| 334 |
+
return {
|
| 335 |
+
"root": root,
|
| 336 |
+
"n_lineages": len(selected),
|
| 337 |
+
"branching": len(selected) > 1,
|
| 338 |
+
"branches": selected,
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _build_regulatory_network(
|
| 345 |
+
rng: np.random.Generator,
|
| 346 |
+
tissue: str,
|
| 347 |
+
populations: List[CellPopulation],
|
| 348 |
+
) -> Dict[str, List[str]]:
|
| 349 |
+
all_genes = set()
|
| 350 |
+
for p in populations:
|
| 351 |
+
all_genes.update(p.marker_genes)
|
| 352 |
+
|
| 353 |
+
network: Dict[str, List[str]] = {}
|
| 354 |
+
|
| 355 |
+
tissue_to_programs = {
|
| 356 |
+
"bone_marrow": ["erythroid", "myeloid", "stem_cell"],
|
| 357 |
+
"thymus": ["lymphoid"],
|
| 358 |
+
"blood": ["lymphoid", "myeloid"],
|
| 359 |
+
"spleen": ["lymphoid"],
|
| 360 |
+
"brain": ["neuronal", "inflammatory"],
|
| 361 |
+
"heart": ["fibrotic", "inflammatory"],
|
| 362 |
+
"lung": ["fibrotic", "inflammatory"],
|
| 363 |
+
"liver": ["fibrotic", "inflammatory"],
|
| 364 |
+
"kidney": ["fibrotic", "inflammatory"],
|
| 365 |
+
"colon": ["inflammatory", "stem_cell"],
|
| 366 |
+
"pancreas": ["inflammatory"],
|
| 367 |
+
"skin": ["inflammatory"],
|
| 368 |
+
"breast": ["inflammatory"],
|
| 369 |
+
"synovium": ["inflammatory", "lymphoid"],
|
| 370 |
+
"aorta": ["inflammatory"],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
programs = tissue_to_programs.get(tissue, ["inflammatory"])
|
| 374 |
+
for prog_name in programs:
|
| 375 |
+
prog = REGULATORY_TEMPLATES.get(prog_name, {})
|
| 376 |
+
for tf, targets in prog.items():
|
| 377 |
+
network[tf] = list(targets)
|
| 378 |
+
|
| 379 |
+
if not network:
|
| 380 |
+
for p in populations[:2]:
|
| 381 |
+
if len(p.marker_genes) >= 2:
|
| 382 |
+
tf = p.marker_genes[0]
|
| 383 |
+
network[tf] = p.marker_genes[1:]
|
| 384 |
+
|
| 385 |
+
return network
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _build_perturbation(
|
| 389 |
+
rng: np.random.Generator,
|
| 390 |
+
disease: DiseaseProfile,
|
| 391 |
+
) -> Dict[str, Dict[str, float]]:
|
| 392 |
+
disease_pathways = set(disease.pathways.keys())
|
| 393 |
+
|
| 394 |
+
matching = [
|
| 395 |
+
(name, tmpl) for name, tmpl in PERTURBATION_TEMPLATES.items()
|
| 396 |
+
if tmpl.target_pathway in disease_pathways
|
| 397 |
+
]
|
| 398 |
+
|
| 399 |
+
if matching:
|
| 400 |
+
name, tmpl = matching[int(rng.integers(0, len(matching)))]
|
| 401 |
+
else:
|
| 402 |
+
name = rng.choice(list(PERTURBATION_TEMPLATES.keys()))
|
| 403 |
+
tmpl = PERTURBATION_TEMPLATES[name]
|
| 404 |
+
|
| 405 |
+
scaled: Dict[str, float] = {}
|
| 406 |
+
for gene, effect in tmpl.gene_effects.items():
|
| 407 |
+
scale = float(rng.uniform(0.7, 1.3))
|
| 408 |
+
scaled[gene] = round(effect * scale, 3)
|
| 409 |
+
|
| 410 |
+
return {name: scaled}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _build_technical(
|
| 414 |
+
rng: np.random.Generator,
|
| 415 |
+
params: dict,
|
| 416 |
+
) -> TechnicalState:
|
| 417 |
+
n_batches = int(rng.integers(*params["n_batches"]))
|
| 418 |
+
batch_effects: Dict[str, float] = {}
|
| 419 |
+
for i in range(max(1, n_batches)):
|
| 420 |
+
strength = float(rng.uniform(*params["noise_batch_strength"]))
|
| 421 |
+
batch_effects[f"batch_{i}"] = round(strength, 3)
|
| 422 |
+
|
| 423 |
+
return TechnicalState(
|
| 424 |
+
batch_effects=batch_effects,
|
| 425 |
+
dropout_rate=round(float(rng.uniform(*params["noise_dropout"])), 3),
|
| 426 |
+
doublet_rate=round(float(rng.uniform(*params["noise_doublet"])), 3),
|
| 427 |
+
ambient_rna_fraction=round(float(rng.uniform(*params["noise_ambient"])), 3),
|
| 428 |
+
sample_quality=round(float(rng.uniform(*params["sample_quality"])), 3),
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _build_task(
|
| 433 |
+
rng: np.random.Generator,
|
| 434 |
+
disease: DiseaseProfile,
|
| 435 |
+
tissue: str,
|
| 436 |
+
scenario_type: str,
|
| 437 |
+
params: dict,
|
| 438 |
+
perturbation_effects: Dict[str, Dict[str, float]],
|
| 439 |
+
) -> TaskSpec:
|
| 440 |
+
budget = float(rng.integers(*params["budget_range"]))
|
| 441 |
+
time_days = float(rng.integers(*params["time_range"]))
|
| 442 |
+
|
| 443 |
+
if scenario_type == "de":
|
| 444 |
+
problem = (
|
| 445 |
+
f"Identify differentially expressed genes between "
|
| 446 |
+
f"{disease.display_name} and healthy {tissue} tissue "
|
| 447 |
+
f"using single-cell RNA sequencing."
|
| 448 |
+
)
|
| 449 |
+
criteria = [
|
| 450 |
+
f"Identify DE genes between {disease.condition_name} and healthy",
|
| 451 |
+
"Validate at least one candidate marker",
|
| 452 |
+
]
|
| 453 |
+
elif scenario_type == "trajectory":
|
| 454 |
+
problem = (
|
| 455 |
+
f"Infer the developmental trajectory of cell populations "
|
| 456 |
+
f"in {tissue} tissue in the context of {disease.display_name}."
|
| 457 |
+
)
|
| 458 |
+
criteria = [
|
| 459 |
+
"Reconstruct branching lineage structure",
|
| 460 |
+
"Identify key transcription factors driving fate decisions",
|
| 461 |
+
]
|
| 462 |
+
elif scenario_type == "perturbation":
|
| 463 |
+
pert_name = next(iter(perturbation_effects), "treatment")
|
| 464 |
+
pert_tmpl = PERTURBATION_TEMPLATES.get(pert_name)
|
| 465 |
+
pert_desc = pert_tmpl.description if pert_tmpl else pert_name
|
| 466 |
+
problem = (
|
| 467 |
+
f"Determine the effect of {pert_desc} on cell states "
|
| 468 |
+
f"in {tissue} tissue affected by {disease.display_name}."
|
| 469 |
+
)
|
| 470 |
+
criteria = [
|
| 471 |
+
"Quantify shift in cell activation states",
|
| 472 |
+
f"Identify pathways modulated by {pert_name}",
|
| 473 |
+
"Propose validation strategy",
|
| 474 |
+
]
|
| 475 |
+
else:
|
| 476 |
+
top_marker = disease.markers[0] if disease.markers else "candidate"
|
| 477 |
+
problem = (
|
| 478 |
+
f"Validate candidate biomarker {top_marker} for "
|
| 479 |
+
f"{disease.display_name} in {tissue} tissue using "
|
| 480 |
+
f"single-cell RNA sequencing."
|
| 481 |
+
)
|
| 482 |
+
criteria = [
|
| 483 |
+
f"Validate {top_marker} as a disease marker",
|
| 484 |
+
"Confirm expression specificity across cell types",
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
conditions = ["healthy", disease.condition_name]
|
| 488 |
+
if scenario_type == "perturbation" and perturbation_effects:
|
| 489 |
+
pert_name = next(iter(perturbation_effects))
|
| 490 |
+
conditions = [f"untreated_{disease.condition_name}", f"{pert_name}_treated"]
|
| 491 |
+
|
| 492 |
+
return TaskSpec(
|
| 493 |
+
problem_statement=problem,
|
| 494 |
+
modality="scRNA-seq",
|
| 495 |
+
organism="human",
|
| 496 |
+
tissue=tissue,
|
| 497 |
+
conditions=conditions,
|
| 498 |
+
budget_limit=budget,
|
| 499 |
+
time_limit_days=time_days,
|
| 500 |
+
success_criteria=criteria,
|
| 501 |
+
)
|
server/tasks/scenarios.py
CHANGED
|
@@ -353,8 +353,8 @@ SCENARIO_LIBRARY: List[Scenario] = [
|
|
| 353 |
budget_limit=90_000.0,
|
| 354 |
time_limit_days=150.0,
|
| 355 |
prior_observations=[
|
| 356 |
-
"
|
| 357 |
-
"
|
| 358 |
],
|
| 359 |
success_criteria=[
|
| 360 |
"Validate SPP1 as a marker for pro-fibrotic macrophages",
|
|
@@ -452,3 +452,5 @@ SCENARIO_LIBRARY: List[Scenario] = [
|
|
| 452 |
),
|
| 453 |
),
|
| 454 |
]
|
|
|
|
|
|
|
|
|
| 353 |
budget_limit=90_000.0,
|
| 354 |
time_limit_days=150.0,
|
| 355 |
prior_observations=[
|
| 356 |
+
"A macrophage subpopulation shows elevated expression in IPF tissue relative to controls",
|
| 357 |
+
"Pro-fibrotic macrophage enrichment has been observed in fibrotic regions by spatial profiling",
|
| 358 |
],
|
| 359 |
success_criteria=[
|
| 360 |
"Validate SPP1 as a marker for pro-fibrotic macrophages",
|
|
|
|
| 452 |
),
|
| 453 |
),
|
| 454 |
]
|
| 455 |
+
|
| 456 |
+
|
tests/test_environment.py
CHANGED
|
@@ -56,6 +56,17 @@ class TestEnvironmentLifecycle:
|
|
| 56 |
assert obs.latest_output is not None
|
| 57 |
assert obs.latest_output.success is False
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def test_conclusion_ends_episode(self):
|
| 60 |
env = BioExperimentEnvironment()
|
| 61 |
env.reset()
|
|
|
|
| 56 |
assert obs.latest_output is not None
|
| 57 |
assert obs.latest_output.success is False
|
| 58 |
|
| 59 |
+
def test_premature_followup_design_is_flagged(self):
|
| 60 |
+
env = BioExperimentEnvironment()
|
| 61 |
+
env.reset()
|
| 62 |
+
obs = env.step(ExperimentAction(
|
| 63 |
+
action_type=ActionType.DESIGN_FOLLOWUP,
|
| 64 |
+
parameters={"assay": "qPCR"},
|
| 65 |
+
))
|
| 66 |
+
assert obs.latest_output is not None
|
| 67 |
+
assert obs.latest_output.success is True
|
| 68 |
+
assert any("follow-up design" in msg.lower() for msg in obs.rule_violations)
|
| 69 |
+
|
| 70 |
def test_conclusion_ends_episode(self):
|
| 71 |
env = BioExperimentEnvironment()
|
| 72 |
env.reset()
|
tests/test_rewards.py
CHANGED
|
@@ -61,6 +61,29 @@ class TestStepReward:
|
|
| 61 |
)
|
| 62 |
assert rb.total < 0
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
class TestTerminalReward:
|
| 66 |
def test_correct_conclusion_rewarded(self):
|
|
@@ -103,3 +126,42 @@ class TestTerminalReward:
|
|
| 103 |
]
|
| 104 |
rb = rc.terminal_reward(state, claims, [])
|
| 105 |
assert rb.components.get("overconfidence_penalty", 0) < 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
assert rb.total < 0
|
| 63 |
|
| 64 |
+
def test_premature_meta_action_gets_penalized(self):
|
| 65 |
+
rc = RewardComputer()
|
| 66 |
+
prev, nxt = _states(
|
| 67 |
+
prev_flags={"data_normalized": True},
|
| 68 |
+
next_flags={"followup_designed": True},
|
| 69 |
+
budget_used=2_000,
|
| 70 |
+
)
|
| 71 |
+
output = IntermediateOutput(
|
| 72 |
+
output_type=OutputType.FOLLOWUP_DESIGN,
|
| 73 |
+
step_index=2,
|
| 74 |
+
quality_score=1.0,
|
| 75 |
+
uncertainty=0.0,
|
| 76 |
+
)
|
| 77 |
+
rb = rc.step_reward(
|
| 78 |
+
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 79 |
+
prev,
|
| 80 |
+
nxt,
|
| 81 |
+
output,
|
| 82 |
+
[],
|
| 83 |
+
[],
|
| 84 |
+
)
|
| 85 |
+
assert rb.components.get("premature_meta_action_penalty", 0.0) < 0.0
|
| 86 |
+
|
| 87 |
|
| 88 |
class TestTerminalReward:
|
| 89 |
def test_correct_conclusion_rewarded(self):
|
|
|
|
| 126 |
]
|
| 127 |
rb = rc.terminal_reward(state, claims, [])
|
| 128 |
assert rb.components.get("overconfidence_penalty", 0) < 0
|
| 129 |
+
|
| 130 |
+
def test_discovery_error_penalizes_wrong_markers_and_mechanisms(self):
|
| 131 |
+
rc = RewardComputer()
|
| 132 |
+
state = FullLatentState(
|
| 133 |
+
biology=LatentBiologicalState(
|
| 134 |
+
true_markers=["NPPA", "NPPB"],
|
| 135 |
+
causal_mechanisms=["TGF-beta-driven fibrosis"],
|
| 136 |
+
),
|
| 137 |
+
progress=ExperimentProgress(
|
| 138 |
+
samples_collected=True,
|
| 139 |
+
cells_sequenced=True,
|
| 140 |
+
qc_performed=True,
|
| 141 |
+
data_filtered=True,
|
| 142 |
+
data_normalized=True,
|
| 143 |
+
de_performed=True,
|
| 144 |
+
markers_discovered=True,
|
| 145 |
+
conclusion_reached=True,
|
| 146 |
+
),
|
| 147 |
+
resources=ResourceState(budget_total=100_000, budget_used=40_000),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
aligned = rc.terminal_reward(
|
| 151 |
+
state,
|
| 152 |
+
[],
|
| 153 |
+
[],
|
| 154 |
+
discovered_markers=["NPPA", "NPPB"],
|
| 155 |
+
candidate_mechanisms=["TGF-beta-driven fibrosis"],
|
| 156 |
+
)
|
| 157 |
+
misaligned = rc.terminal_reward(
|
| 158 |
+
state,
|
| 159 |
+
[],
|
| 160 |
+
[],
|
| 161 |
+
discovered_markers=["WRONG1", "WRONG2"],
|
| 162 |
+
candidate_mechanisms=["unrelated inflammatory process"],
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
|
| 166 |
+
assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
|
| 167 |
+
assert aligned.terminal > misaligned.terminal
|
tests/test_rules.py
CHANGED
|
@@ -66,6 +66,36 @@ class TestRedundancy:
|
|
| 66 |
assert not hard
|
| 67 |
assert any("redundant" in m.lower() for m in soft)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
class TestResourceConstraints:
|
| 71 |
def test_exhausted_budget_blocked(self):
|
|
|
|
| 66 |
assert not hard
|
| 67 |
assert any("redundant" in m.lower() for m in soft)
|
| 68 |
|
| 69 |
+
def test_repeated_followup_design_is_soft(self):
|
| 70 |
+
engine = RuleEngine()
|
| 71 |
+
violations = engine.check(
|
| 72 |
+
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 73 |
+
_state(followup_designed=True, de_performed=True),
|
| 74 |
+
)
|
| 75 |
+
hard = engine.hard_violations(violations)
|
| 76 |
+
soft = engine.soft_violations(violations)
|
| 77 |
+
assert not hard
|
| 78 |
+
assert any("redundant" in m.lower() for m in soft)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TestMetaActionTiming:
|
| 82 |
+
def test_followup_design_without_analysis_is_soft(self):
|
| 83 |
+
engine = RuleEngine()
|
| 84 |
+
violations = engine.check(
|
| 85 |
+
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 86 |
+
_state(),
|
| 87 |
+
)
|
| 88 |
+
soft = engine.soft_violations(violations)
|
| 89 |
+
assert any("follow-up design" in m.lower() for m in soft)
|
| 90 |
+
|
| 91 |
+
def test_subagent_review_without_analysis_is_soft(self):
|
| 92 |
+
engine = RuleEngine()
|
| 93 |
+
violations = engine.check(
|
| 94 |
+
ExperimentAction(action_type=ActionType.REQUEST_SUBAGENT_REVIEW),
|
| 95 |
+
_state(),
|
| 96 |
+
)
|
| 97 |
+
soft = engine.soft_violations(violations)
|
| 98 |
+
assert any("subagent review" in m.lower() for m in soft)
|
| 99 |
|
| 100 |
class TestResourceConstraints:
|
| 101 |
def test_exhausted_budget_blocked(self):
|
tests/test_run_agent.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for run_agent parser and fallback helpers."""
|
| 2 |
+
|
| 3 |
+
from models import ActionType, ExperimentAction
|
| 4 |
+
from run_agent import fallback_action, parse_action
|
| 5 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_parse_action_accepts_reasoning_variant():
|
| 9 |
+
action = parse_action(
|
| 10 |
+
'{"action_type":"run_qc","parameters":{},"Reasoning":"check quality","confidence":0.8}'
|
| 11 |
+
)
|
| 12 |
+
assert action is not None
|
| 13 |
+
assert action.action_type == ActionType.RUN_QC
|
| 14 |
+
assert action.justification == "check quality"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_parse_action_accepts_justifyement_typo():
|
| 18 |
+
action = parse_action(
|
| 19 |
+
'{"action_type":"collect_sample","parameters":{},"justifyement":"typo key","confidence":0.7}'
|
| 20 |
+
)
|
| 21 |
+
assert action is not None
|
| 22 |
+
assert action.action_type == ActionType.COLLECT_SAMPLE
|
| 23 |
+
assert action.justification == "typo key"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_fallback_uses_observation_progress_not_step_index():
|
| 27 |
+
env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
|
| 28 |
+
obs = env.reset(seed=0)
|
| 29 |
+
for action_type in (
|
| 30 |
+
ActionType.COLLECT_SAMPLE,
|
| 31 |
+
ActionType.PREPARE_LIBRARY,
|
| 32 |
+
ActionType.SEQUENCE_CELLS,
|
| 33 |
+
):
|
| 34 |
+
obs = env.step(ExperimentAction(action_type=action_type))
|
| 35 |
+
action = fallback_action(obs)
|
| 36 |
+
assert action.action_type == ActionType.RUN_QC
|
tests/test_training_script.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for GRPO training helpers."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from models import ActionType
|
| 6 |
+
from training_script import (
|
| 7 |
+
INVALID_ACTION_PENALTY,
|
| 8 |
+
OpenEnvReward,
|
| 9 |
+
available_numeric_log_keys,
|
| 10 |
+
build_prompt_examples,
|
| 11 |
+
completion_to_text,
|
| 12 |
+
parse_action_completion,
|
| 13 |
+
save_training_plots,
|
| 14 |
+
select_metric_key,
|
| 15 |
+
select_reward_key,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_completion_to_text_from_chat_messages():
|
| 20 |
+
completion = [
|
| 21 |
+
{"role": "assistant", "content": '{"action_type":"collect_sample"}'}
|
| 22 |
+
]
|
| 23 |
+
assert completion_to_text(completion) == '{"action_type":"collect_sample"}'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_parse_action_completion_roundtrip():
|
| 27 |
+
action = parse_action_completion(
|
| 28 |
+
'{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",'
|
| 29 |
+
'"parameters":{"min_genes":200},"confidence":0.8}'
|
| 30 |
+
)
|
| 31 |
+
assert action is not None
|
| 32 |
+
assert action.action_type == ActionType.RUN_QC
|
| 33 |
+
assert action.method == "scanpy.pp.calculate_qc_metrics"
|
| 34 |
+
assert action.parameters["min_genes"] == 200
|
| 35 |
+
assert action.confidence == 0.8
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_parse_action_completion_accepts_reasoning_alias():
|
| 39 |
+
action = parse_action_completion(
|
| 40 |
+
'{"action_type":"run_qc","reasoning":"Measure quality before filtering."}'
|
| 41 |
+
)
|
| 42 |
+
assert action is not None
|
| 43 |
+
assert action.justification == "Measure quality before filtering."
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_build_prompt_examples_contains_reference_action():
|
| 47 |
+
examples = build_prompt_examples(
|
| 48 |
+
dataset_episodes=1,
|
| 49 |
+
rollout_steps=2,
|
| 50 |
+
collection_policy="heuristic",
|
| 51 |
+
scenario_names=["cardiac_disease_de"],
|
| 52 |
+
seed=0,
|
| 53 |
+
domain_randomise=False,
|
| 54 |
+
)
|
| 55 |
+
assert len(examples) == 2
|
| 56 |
+
assert examples[0]["scenario_name"] == "cardiac_disease_de"
|
| 57 |
+
assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_openenv_reward_penalizes_invalid_completion():
|
| 61 |
+
reward_fn = OpenEnvReward(
|
| 62 |
+
reward_backend="local",
|
| 63 |
+
base_url="http://localhost:8000",
|
| 64 |
+
)
|
| 65 |
+
rewards = reward_fn(
|
| 66 |
+
completions=[[{"role": "assistant", "content": "not valid json"}]],
|
| 67 |
+
scenario_name=["cardiac_disease_de"],
|
| 68 |
+
history_actions=["[]"],
|
| 69 |
+
)
|
| 70 |
+
assert rewards == [INVALID_ACTION_PENALTY]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_openenv_reward_scores_valid_completion_locally():
|
| 74 |
+
examples = build_prompt_examples(
|
| 75 |
+
dataset_episodes=1,
|
| 76 |
+
rollout_steps=1,
|
| 77 |
+
collection_policy="heuristic",
|
| 78 |
+
scenario_names=["cardiac_disease_de"],
|
| 79 |
+
seed=0,
|
| 80 |
+
domain_randomise=False,
|
| 81 |
+
)
|
| 82 |
+
reward_fn = OpenEnvReward(
|
| 83 |
+
reward_backend="local",
|
| 84 |
+
base_url="http://localhost:8000",
|
| 85 |
+
)
|
| 86 |
+
sample = examples[0]
|
| 87 |
+
rewards = reward_fn(
|
| 88 |
+
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
|
| 89 |
+
scenario_name=[sample["scenario_name"]],
|
| 90 |
+
history_actions=[sample["history_actions"]],
|
| 91 |
+
)
|
| 92 |
+
assert len(rewards) == 1
|
| 93 |
+
assert rewards[0] > 0.0
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_log_key_selection_prefers_reward_and_metric_keys():
|
| 97 |
+
log_history = [
|
| 98 |
+
{"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05},
|
| 99 |
+
{"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04},
|
| 100 |
+
]
|
| 101 |
+
assert available_numeric_log_keys(log_history) == [
|
| 102 |
+
"loss",
|
| 103 |
+
"objective/kl",
|
| 104 |
+
"rewards/open_env_reward",
|
| 105 |
+
]
|
| 106 |
+
reward_key = select_reward_key(log_history)
|
| 107 |
+
assert reward_key == "rewards/open_env_reward"
|
| 108 |
+
assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def test_save_training_plots_writes_expected_files(tmp_path):
|
| 112 |
+
log_history = [
|
| 113 |
+
{"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8},
|
| 114 |
+
{"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5},
|
| 115 |
+
]
|
| 116 |
+
plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm")
|
| 117 |
+
|
| 118 |
+
assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"}
|
| 119 |
+
for plot_path in plot_paths.values():
|
| 120 |
+
assert Path(plot_path).exists()
|
| 121 |
+
|
| 122 |
+
manifest_path = tmp_path / "training_plot_manifest.json"
|
| 123 |
+
assert manifest_path.exists()
|
training/__init__.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
from .evaluation import EvaluationSuite
|
| 2 |
-
from .gym_wrapper import BioExperimentGymEnv
|
| 3 |
from .trajectory import Trajectory, TrajectoryDataset
|
| 4 |
|
| 5 |
__all__ = [
|
| 6 |
-
"BioExperimentGymEnv",
|
| 7 |
"EvaluationSuite",
|
| 8 |
"PaperBenchmarkResult",
|
| 9 |
"Trajectory",
|
|
|
|
| 1 |
from .evaluation import EvaluationSuite
|
|
|
|
| 2 |
from .trajectory import Trajectory, TrajectoryDataset
|
| 3 |
|
| 4 |
__all__ = [
|
|
|
|
| 5 |
"EvaluationSuite",
|
| 6 |
"PaperBenchmarkResult",
|
| 7 |
"Trajectory",
|
training/evaluation.py
CHANGED
|
@@ -118,7 +118,7 @@ class EvaluationSuite:
|
|
| 118 |
for t in ds.trajectories:
|
| 119 |
violations = sum(
|
| 120 |
1 for s in t.steps
|
| 121 |
-
if
|
| 122 |
and s.observation.get("rule_violations") is not None
|
| 123 |
)
|
| 124 |
if violations == 0:
|
|
@@ -146,7 +146,8 @@ class EvaluationSuite:
|
|
| 146 |
at = s.action.get("action_type")
|
| 147 |
if at:
|
| 148 |
all_types.add(at)
|
| 149 |
-
|
|
|
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
def _mean_conclusion_confidence(ds: TrajectoryDataset) -> float:
|
|
|
|
| 118 |
for t in ds.trajectories:
|
| 119 |
violations = sum(
|
| 120 |
1 for s in t.steps
|
| 121 |
+
if s.observation.get("rule_violations", []) != []
|
| 122 |
and s.observation.get("rule_violations") is not None
|
| 123 |
)
|
| 124 |
if violations == 0:
|
|
|
|
| 146 |
at = s.action.get("action_type")
|
| 147 |
if at:
|
| 148 |
all_types.add(at)
|
| 149 |
+
from models import ActionType
|
| 150 |
+
return len(all_types) / max(len(ActionType), 1)
|
| 151 |
|
| 152 |
@staticmethod
|
| 153 |
def _mean_conclusion_confidence(ds: TrajectoryDataset) -> float:
|
training/literature_benchmark.py
CHANGED
|
@@ -148,7 +148,6 @@ def run_paper_benchmark(
|
|
| 148 |
tool_call_spec=_tool_context(
|
| 149 |
obs.task,
|
| 150 |
libraries=["biopython"],
|
| 151 |
-
include_expected_findings=True,
|
| 152 |
),
|
| 153 |
)
|
| 154 |
)
|
|
@@ -353,19 +352,7 @@ def infer_conclusion_claims(obs: ExperimentObservation) -> List[ConclusionClaim]
|
|
| 353 |
evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
|
| 354 |
))
|
| 355 |
|
| 356 |
-
|
| 357 |
-
return claims
|
| 358 |
-
|
| 359 |
-
# Fallback: preserve the strongest expected findings verbatim if the
|
| 360 |
-
# heuristic extractors do not recover enough signal from the episode.
|
| 361 |
-
return [
|
| 362 |
-
ConclusionClaim(
|
| 363 |
-
claim=finding.finding,
|
| 364 |
-
confidence=0.65,
|
| 365 |
-
claim_type=finding.category,
|
| 366 |
-
)
|
| 367 |
-
for finding in obs.task.expected_findings[:3]
|
| 368 |
-
]
|
| 369 |
|
| 370 |
|
| 371 |
def compare_expected_findings(
|
|
@@ -444,11 +431,11 @@ def _default_comparison_name(task: TaskSpec) -> str:
|
|
| 444 |
|
| 445 |
|
| 446 |
def _preferred_marker(task: TaskSpec) -> str:
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
return "
|
| 452 |
|
| 453 |
|
| 454 |
def _latest_output_data(
|
|
|
|
| 148 |
tool_call_spec=_tool_context(
|
| 149 |
obs.task,
|
| 150 |
libraries=["biopython"],
|
|
|
|
| 151 |
),
|
| 152 |
)
|
| 153 |
)
|
|
|
|
| 352 |
evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
|
| 353 |
))
|
| 354 |
|
| 355 |
+
return claims
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
def compare_expected_findings(
|
|
|
|
| 431 |
|
| 432 |
|
| 433 |
def _preferred_marker(task: TaskSpec) -> str:
|
| 434 |
+
"""Derive a candidate marker from the problem statement, not expected findings."""
|
| 435 |
+
tokens = [t for t in TOKEN_RE.findall(task.problem_statement) if t.isupper() and len(t) >= 3]
|
| 436 |
+
if tokens:
|
| 437 |
+
return tokens[0]
|
| 438 |
+
return "unknown"
|
| 439 |
|
| 440 |
|
| 441 |
def _latest_output_data(
|
training/rollout_collection.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Collect trajectories with direct OpenEnv environment access."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import random
|
| 7 |
+
import uuid
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Sequence
|
| 10 |
+
|
| 11 |
+
from models import ActionType, ExperimentAction
|
| 12 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 13 |
+
from training.evaluation import EvaluationSuite
|
| 14 |
+
from training.trajectory import Trajectory, TrajectoryDataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
HEURISTIC_SEQUENCE = [
|
| 18 |
+
ActionType.COLLECT_SAMPLE,
|
| 19 |
+
ActionType.PREPARE_LIBRARY,
|
| 20 |
+
ActionType.SEQUENCE_CELLS,
|
| 21 |
+
ActionType.RUN_QC,
|
| 22 |
+
ActionType.FILTER_DATA,
|
| 23 |
+
ActionType.NORMALIZE_DATA,
|
| 24 |
+
ActionType.CLUSTER_CELLS,
|
| 25 |
+
ActionType.TRAJECTORY_ANALYSIS,
|
| 26 |
+
ActionType.MARKER_SELECTION,
|
| 27 |
+
ActionType.SYNTHESIZE_CONCLUSION,
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_args() -> argparse.Namespace:
|
| 32 |
+
parser = argparse.ArgumentParser(
|
| 33 |
+
description="Run rollout episodes and persist trajectories."
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--episodes", type=int, default=10, help="Number of episodes.")
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--policy",
|
| 38 |
+
choices=["random", "heuristic"],
|
| 39 |
+
default="heuristic",
|
| 40 |
+
help="Policy to use for rollouts.",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--max-steps",
|
| 44 |
+
type=int,
|
| 45 |
+
default=None,
|
| 46 |
+
help="Optional hard cutoff per episode (defaults to env limit).",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--output-dir",
|
| 50 |
+
default="training/rollouts",
|
| 51 |
+
help="Directory for JSON trajectory outputs.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--seed", type=int, default=None, help="RNG seed.")
|
| 54 |
+
return parser.parse_args()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType:
|
| 58 |
+
seen = set(history)
|
| 59 |
+
for action in HEURISTIC_SEQUENCE:
|
| 60 |
+
if action not in seen:
|
| 61 |
+
return action
|
| 62 |
+
if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen:
|
| 63 |
+
return ActionType.VALIDATE_MARKER
|
| 64 |
+
if ActionType.SYNTHESIZE_CONCLUSION in seen:
|
| 65 |
+
return ActionType.SYNTHESIZE_CONCLUSION
|
| 66 |
+
return ActionType.SYNTHESIZE_CONCLUSION
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType:
|
| 70 |
+
if policy == "random":
|
| 71 |
+
return random.choice(list(ActionType))
|
| 72 |
+
return heuristic_next_action(history, step_index)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def default_comparison_name(conditions: Sequence[str]) -> str:
|
| 76 |
+
normalized = {condition.lower() for condition in conditions}
|
| 77 |
+
if {"healthy", "ipf"} <= normalized:
|
| 78 |
+
return "IPF_vs_healthy"
|
| 79 |
+
if any("treated" in condition for condition in normalized) and any(
|
| 80 |
+
"untreated" in condition for condition in normalized
|
| 81 |
+
):
|
| 82 |
+
return "treated_vs_untreated"
|
| 83 |
+
if any("healthy" in condition for condition in normalized):
|
| 84 |
+
return "disease_vs_healthy"
|
| 85 |
+
return "disease_vs_healthy"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def build_experiment_action(
|
| 89 |
+
action_type: ActionType,
|
| 90 |
+
discovered_markers: Sequence[str],
|
| 91 |
+
conditions: Sequence[str],
|
| 92 |
+
) -> ExperimentAction:
|
| 93 |
+
method = None
|
| 94 |
+
parameters: Dict[str, object] = {}
|
| 95 |
+
|
| 96 |
+
if action_type == ActionType.COLLECT_SAMPLE:
|
| 97 |
+
parameters = {"n_samples": 6}
|
| 98 |
+
elif action_type == ActionType.PREPARE_LIBRARY:
|
| 99 |
+
method = "10x_chromium"
|
| 100 |
+
elif action_type == ActionType.RUN_QC:
|
| 101 |
+
method = "scanpy.pp.calculate_qc_metrics"
|
| 102 |
+
elif action_type == ActionType.FILTER_DATA:
|
| 103 |
+
method = "scanpy.pp.filter_cells"
|
| 104 |
+
elif action_type == ActionType.NORMALIZE_DATA:
|
| 105 |
+
method = "scanpy.pp.normalize_total"
|
| 106 |
+
elif action_type == ActionType.CLUSTER_CELLS:
|
| 107 |
+
method = "scanpy.tl.leiden"
|
| 108 |
+
elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| 109 |
+
method = "scanpy.tl.rank_genes_groups"
|
| 110 |
+
parameters = {"comparison": default_comparison_name(conditions)}
|
| 111 |
+
elif action_type == ActionType.TRAJECTORY_ANALYSIS:
|
| 112 |
+
method = "scanpy.tl.dpt"
|
| 113 |
+
elif action_type == ActionType.MARKER_SELECTION:
|
| 114 |
+
method = "scanpy.tl.rank_genes_groups"
|
| 115 |
+
elif action_type == ActionType.VALIDATE_MARKER:
|
| 116 |
+
method = "qPCR"
|
| 117 |
+
parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
|
| 118 |
+
elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 119 |
+
parameters = {"claims": []}
|
| 120 |
+
|
| 121 |
+
return ExperimentAction(
|
| 122 |
+
action_type=action_type,
|
| 123 |
+
method=method,
|
| 124 |
+
parameters=parameters,
|
| 125 |
+
confidence=0.75,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def run_episode(
|
| 130 |
+
env: BioExperimentEnvironment,
|
| 131 |
+
episode_id: str,
|
| 132 |
+
policy: str,
|
| 133 |
+
max_steps: int | None = None,
|
| 134 |
+
) -> Trajectory:
|
| 135 |
+
structured_obs = env.reset()
|
| 136 |
+
traj = Trajectory(
|
| 137 |
+
episode_id=episode_id,
|
| 138 |
+
task=structured_obs.task.model_dump(),
|
| 139 |
+
metadata={
|
| 140 |
+
"task_problem": structured_obs.task.problem_statement,
|
| 141 |
+
"policy": policy,
|
| 142 |
+
},
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
done = structured_obs.done
|
| 146 |
+
step_num = 0
|
| 147 |
+
while not done:
|
| 148 |
+
if max_steps is not None and step_num >= max_steps:
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
history = [rec.action_type for rec in structured_obs.pipeline_history]
|
| 152 |
+
action_type = pick_action(policy, step_num, history)
|
| 153 |
+
experiment_action = build_experiment_action(
|
| 154 |
+
action_type=action_type,
|
| 155 |
+
discovered_markers=structured_obs.discovered_markers,
|
| 156 |
+
conditions=structured_obs.task.conditions,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
structured_obs = env.step(experiment_action)
|
| 160 |
+
reward = structured_obs.reward
|
| 161 |
+
done = structured_obs.done
|
| 162 |
+
step_num += 1
|
| 163 |
+
|
| 164 |
+
traj.add_step(
|
| 165 |
+
action=experiment_action,
|
| 166 |
+
observation=structured_obs,
|
| 167 |
+
reward=reward,
|
| 168 |
+
done=done,
|
| 169 |
+
reward_breakdown=structured_obs.step_reward_breakdown,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
print(
|
| 173 |
+
f" step={structured_obs.step_index:02d} "
|
| 174 |
+
f"action={action_type.value:>28} "
|
| 175 |
+
f"reward={reward:+.3f}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return traj
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def main() -> None:
|
| 182 |
+
args = parse_args()
|
| 183 |
+
if args.seed is not None:
|
| 184 |
+
random.seed(args.seed)
|
| 185 |
+
|
| 186 |
+
out_dir = Path(args.output_dir)
|
| 187 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
env = BioExperimentEnvironment()
|
| 190 |
+
trajectories: List[Trajectory] = []
|
| 191 |
+
|
| 192 |
+
print(
|
| 193 |
+
f"Starting rollout collection: episodes={args.episodes}, policy={args.policy}"
|
| 194 |
+
)
|
| 195 |
+
for ep in range(args.episodes):
|
| 196 |
+
print(f"Episode {ep + 1}/{args.episodes}")
|
| 197 |
+
traj = run_episode(
|
| 198 |
+
env=env,
|
| 199 |
+
episode_id=str(uuid.uuid4()),
|
| 200 |
+
policy=args.policy,
|
| 201 |
+
max_steps=args.max_steps,
|
| 202 |
+
)
|
| 203 |
+
traj.save(out_dir / f"{traj.episode_id}.json")
|
| 204 |
+
trajectories.append(traj)
|
| 205 |
+
|
| 206 |
+
dataset = TrajectoryDataset(trajectories)
|
| 207 |
+
stats = EvaluationSuite.online_metrics(trajectories)
|
| 208 |
+
|
| 209 |
+
print("\nRun complete.")
|
| 210 |
+
print(f"Saved trajectories to: {out_dir}")
|
| 211 |
+
print("Online metrics:")
|
| 212 |
+
for metric in stats:
|
| 213 |
+
print(f" - {metric.name}: {metric.value:.4f}")
|
| 214 |
+
|
| 215 |
+
print(f"Summary: {dataset.summary()}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
main()
|
training_script.py
ADDED
|
@@ -0,0 +1,1250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train a planner with TRL GRPO and OpenEnv rewards."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
from numbers import Real
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 12 |
+
|
| 13 |
+
from client import BioExperimentEnv
|
| 14 |
+
from models import (
|
| 15 |
+
ActionType,
|
| 16 |
+
ExperimentAction,
|
| 17 |
+
ExperimentObservation,
|
| 18 |
+
build_agent_observation_context,
|
| 19 |
+
build_agent_system_prompt,
|
| 20 |
+
)
|
| 21 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 22 |
+
from server.tasks.scenarios import SCENARIO_LIBRARY
|
| 23 |
+
|
| 24 |
+
DEFAULT_MODEL_ID = "Qwen/Qwen3.5-0.8B"
|
| 25 |
+
DEFAULT_OUTPUT_DIR = "training/grpo-output"
|
| 26 |
+
DEFAULT_BASE_URL = "http://localhost:8000"
|
| 27 |
+
INVALID_ACTION_PENALTY = -2.0
|
| 28 |
+
ENVIRONMENT_ERROR_PENALTY = -4.0
|
| 29 |
+
|
| 30 |
+
SYSTEM_PROMPT = build_agent_system_prompt()
|
| 31 |
+
|
| 32 |
+
HEURISTIC_SEQUENCE = [
|
| 33 |
+
ActionType.COLLECT_SAMPLE,
|
| 34 |
+
ActionType.PREPARE_LIBRARY,
|
| 35 |
+
ActionType.SEQUENCE_CELLS,
|
| 36 |
+
ActionType.RUN_QC,
|
| 37 |
+
ActionType.FILTER_DATA,
|
| 38 |
+
ActionType.NORMALIZE_DATA,
|
| 39 |
+
ActionType.CLUSTER_CELLS,
|
| 40 |
+
ActionType.DIFFERENTIAL_EXPRESSION,
|
| 41 |
+
ActionType.PATHWAY_ENRICHMENT,
|
| 42 |
+
ActionType.MARKER_SELECTION,
|
| 43 |
+
ActionType.SYNTHESIZE_CONCLUSION,
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
VALID_ACTION_TYPES = {action.value for action in ActionType}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compact_preview(value: Any, max_chars: int = 160) -> str:
|
| 50 |
+
try:
|
| 51 |
+
text = json.dumps(value, ensure_ascii=True, sort_keys=True)
|
| 52 |
+
except TypeError:
|
| 53 |
+
text = str(value)
|
| 54 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 55 |
+
if len(text) <= max_chars:
|
| 56 |
+
return text
|
| 57 |
+
return text[: max_chars - 3] + "..."
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _edit_distance(a: str, b: str) -> int:
|
| 61 |
+
if len(a) < len(b):
|
| 62 |
+
return _edit_distance(b, a)
|
| 63 |
+
if not b:
|
| 64 |
+
return len(a)
|
| 65 |
+
prev = list(range(len(b) + 1))
|
| 66 |
+
for i, ca in enumerate(a):
|
| 67 |
+
curr = [i + 1]
|
| 68 |
+
for j, cb in enumerate(b):
|
| 69 |
+
curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
|
| 70 |
+
prev = curr
|
| 71 |
+
return prev[-1]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
|
| 75 |
+
for name in names:
|
| 76 |
+
if name in payload:
|
| 77 |
+
return payload[name]
|
| 78 |
+
|
| 79 |
+
lowered = {
|
| 80 |
+
str(key).lower(): value
|
| 81 |
+
for key, value in payload.items()
|
| 82 |
+
}
|
| 83 |
+
for name in names:
|
| 84 |
+
if name.lower() in lowered:
|
| 85 |
+
return lowered[name.lower()]
|
| 86 |
+
|
| 87 |
+
for key, value in lowered.items():
|
| 88 |
+
for name in names:
|
| 89 |
+
threshold = max(2, len(name) // 3)
|
| 90 |
+
if _edit_distance(key, name.lower()) <= threshold:
|
| 91 |
+
return value
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 96 |
+
parser = argparse.ArgumentParser(
|
| 97 |
+
description="Train a GRPO policy against the OpenEnv bio experiment environment."
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
|
| 100 |
+
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
|
| 101 |
+
parser.add_argument("--dataset-episodes", type=int, default=8)
|
| 102 |
+
parser.add_argument("--rollout-steps", type=int, default=6)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--collection-policy",
|
| 105 |
+
choices=["random", "heuristic"],
|
| 106 |
+
default="heuristic",
|
| 107 |
+
help="Policy used to build prompt states for GRPO training.",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--reward-backend",
|
| 111 |
+
choices=["local", "remote"],
|
| 112 |
+
default="local",
|
| 113 |
+
help="Use local in-process scoring or a live OpenEnv server.",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--base-url",
|
| 117 |
+
default=DEFAULT_BASE_URL,
|
| 118 |
+
help="Base URL for the OpenEnv server when reward-backend=remote.",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--scenario-name",
|
| 122 |
+
action="append",
|
| 123 |
+
default=None,
|
| 124 |
+
help="Repeatable scenario selector. Defaults to all curated scenarios.",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--domain-randomise",
|
| 128 |
+
action="store_true",
|
| 129 |
+
help="Enable domain randomisation while building prompts and local rewards.",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument("--num-generations", type=int, default=2)
|
| 132 |
+
parser.add_argument("--max-completion-length", type=int, default=220)
|
| 133 |
+
parser.add_argument("--max-prompt-length", type=int, default=768)
|
| 134 |
+
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| 135 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
| 136 |
+
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 137 |
+
parser.add_argument("--num-train-epochs", type=float, default=1.0)
|
| 138 |
+
parser.add_argument("--logging-steps", type=int, default=1)
|
| 139 |
+
parser.add_argument("--save-steps", type=int, default=50)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--plot-metric-key",
|
| 142 |
+
default=None,
|
| 143 |
+
help="Optional extra metric key from trainer log history to plot.",
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--load-model-only",
|
| 148 |
+
action="store_true",
|
| 149 |
+
help="Download and load the selected model and tokenizer, then exit.",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--trust-remote-code",
|
| 153 |
+
action="store_true",
|
| 154 |
+
help="Pass trust_remote_code=True to model/tokenizer loading.",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--dry-run",
|
| 158 |
+
action="store_true",
|
| 159 |
+
help="Build the prompt dataset and smoke-test the reward function without training.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--push-to-hub",
|
| 163 |
+
type=str,
|
| 164 |
+
default=None,
|
| 165 |
+
help="HuggingFace Hub repo id to push the trained model to (e.g. 'myuser/my-model').",
|
| 166 |
+
)
|
| 167 |
+
return parser
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| 171 |
+
return build_argument_parser().parse_args(argv)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| 175 |
+
"""Build an argparse-style namespace for notebooks and scripts."""
|
| 176 |
+
parser = build_argument_parser()
|
| 177 |
+
defaults = vars(parser.parse_args([]))
|
| 178 |
+
unknown = sorted(set(overrides) - set(defaults))
|
| 179 |
+
if unknown:
|
| 180 |
+
raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| 181 |
+
defaults.update(overrides)
|
| 182 |
+
return argparse.Namespace(**defaults)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def format_observation(obs: ExperimentObservation) -> str:
|
| 186 |
+
parts = [
|
| 187 |
+
f"TASK: {obs.task.problem_statement}",
|
| 188 |
+
f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
|
| 189 |
+
f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
|
| 190 |
+
(
|
| 191 |
+
"Step: "
|
| 192 |
+
f"{obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} "
|
| 193 |
+
f"| Time: {obs.resource_usage.time_remaining_days:.0f}d"
|
| 194 |
+
),
|
| 195 |
+
]
|
| 196 |
+
context = build_agent_observation_context(obs, max_tools=5, max_assays=2)
|
| 197 |
+
if context:
|
| 198 |
+
parts.append(context)
|
| 199 |
+
if obs.pipeline_history:
|
| 200 |
+
parts.append("History:")
|
| 201 |
+
for step in obs.pipeline_history[-5:]:
|
| 202 |
+
tag = "OK" if step.success else "FAIL"
|
| 203 |
+
line = f" [{tag}] {step.action_type.value}: {step.output_summary[:100]}"
|
| 204 |
+
if step.parameters:
|
| 205 |
+
line += f" | params={compact_preview(step.parameters, 120)}"
|
| 206 |
+
parts.append(line)
|
| 207 |
+
if obs.latest_output and obs.latest_output.data:
|
| 208 |
+
parts.append(
|
| 209 |
+
f"Latest output data: {compact_preview(obs.latest_output.data, 200)}"
|
| 210 |
+
)
|
| 211 |
+
if obs.rule_violations:
|
| 212 |
+
parts.append(f"Violations: {obs.rule_violations}")
|
| 213 |
+
if obs.discovered_markers:
|
| 214 |
+
parts.append(f"Markers: {obs.discovered_markers[:5]}")
|
| 215 |
+
if obs.candidate_mechanisms:
|
| 216 |
+
parts.append(f"Mechanisms: {obs.candidate_mechanisms[:5]}")
|
| 217 |
+
return "\n".join(parts)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def build_training_prompt(obs: ExperimentObservation) -> str:
|
| 221 |
+
return f"{SYSTEM_PROMPT}\n\n{format_observation(obs)}"
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType:
|
| 225 |
+
seen = set(history)
|
| 226 |
+
for action in HEURISTIC_SEQUENCE:
|
| 227 |
+
if action not in seen:
|
| 228 |
+
return action
|
| 229 |
+
if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen:
|
| 230 |
+
return ActionType.VALIDATE_MARKER
|
| 231 |
+
return ActionType.SYNTHESIZE_CONCLUSION
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType:
|
| 235 |
+
if policy == "random":
|
| 236 |
+
return random.choice(list(ActionType))
|
| 237 |
+
return heuristic_next_action(history, step_index)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def default_comparison_name(conditions: Sequence[str]) -> str:
|
| 241 |
+
normalized = {condition.lower() for condition in conditions}
|
| 242 |
+
if {"healthy", "ipf"} <= normalized:
|
| 243 |
+
return "IPF_vs_healthy"
|
| 244 |
+
if any("treated" in condition for condition in normalized) and any(
|
| 245 |
+
"untreated" in condition for condition in normalized
|
| 246 |
+
):
|
| 247 |
+
return "treated_vs_untreated"
|
| 248 |
+
return "disease_vs_healthy"
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def build_experiment_action(
|
| 252 |
+
action_type: ActionType,
|
| 253 |
+
discovered_markers: Sequence[str],
|
| 254 |
+
conditions: Sequence[str],
|
| 255 |
+
) -> ExperimentAction:
|
| 256 |
+
method = None
|
| 257 |
+
parameters: Dict[str, object] = {}
|
| 258 |
+
justification = f"Advance the experiment with {action_type.value}."
|
| 259 |
+
|
| 260 |
+
if action_type == ActionType.COLLECT_SAMPLE:
|
| 261 |
+
parameters = {"n_samples": 6}
|
| 262 |
+
justification = "Collect enough samples to start the experiment."
|
| 263 |
+
elif action_type == ActionType.PREPARE_LIBRARY:
|
| 264 |
+
method = "10x_chromium"
|
| 265 |
+
justification = "Prepare a single-cell library for sequencing."
|
| 266 |
+
elif action_type == ActionType.SEQUENCE_CELLS:
|
| 267 |
+
method = "NovaSeq"
|
| 268 |
+
justification = "Generate reads for downstream single-cell analysis."
|
| 269 |
+
elif action_type == ActionType.RUN_QC:
|
| 270 |
+
method = "scanpy.pp.calculate_qc_metrics"
|
| 271 |
+
justification = "Measure technical quality before filtering."
|
| 272 |
+
elif action_type == ActionType.FILTER_DATA:
|
| 273 |
+
method = "scanpy.pp.filter_cells"
|
| 274 |
+
justification = "Remove low-quality cells and technical artifacts."
|
| 275 |
+
elif action_type == ActionType.NORMALIZE_DATA:
|
| 276 |
+
method = "scanpy.pp.normalize_total"
|
| 277 |
+
justification = "Normalize counts for comparable expression profiles."
|
| 278 |
+
elif action_type == ActionType.CLUSTER_CELLS:
|
| 279 |
+
method = "scanpy.tl.leiden"
|
| 280 |
+
justification = "Resolve cell states before interpretation."
|
| 281 |
+
elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| 282 |
+
method = "scanpy.tl.rank_genes_groups"
|
| 283 |
+
parameters = {"comparison": default_comparison_name(conditions)}
|
| 284 |
+
justification = "Identify genes associated with the phenotype of interest."
|
| 285 |
+
elif action_type == ActionType.TRAJECTORY_ANALYSIS:
|
| 286 |
+
method = "scanpy.tl.dpt"
|
| 287 |
+
justification = "Recover pseudotime and lineage structure."
|
| 288 |
+
elif action_type == ActionType.PATHWAY_ENRICHMENT:
|
| 289 |
+
method = "gseapy.prerank"
|
| 290 |
+
justification = "Translate gene-level changes into pathway programs."
|
| 291 |
+
elif action_type == ActionType.MARKER_SELECTION:
|
| 292 |
+
method = "scanpy.tl.rank_genes_groups"
|
| 293 |
+
justification = "Nominate marker genes for validation."
|
| 294 |
+
elif action_type == ActionType.VALIDATE_MARKER:
|
| 295 |
+
method = "qPCR"
|
| 296 |
+
parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
|
| 297 |
+
justification = "Validate the strongest discovered marker."
|
| 298 |
+
elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 299 |
+
top = list(discovered_markers[:5]) if discovered_markers else []
|
| 300 |
+
parameters = {
|
| 301 |
+
"claims": [{
|
| 302 |
+
"top_markers": top,
|
| 303 |
+
"causal_mechanisms": [],
|
| 304 |
+
"predicted_pathways": {},
|
| 305 |
+
"confidence": 0.6,
|
| 306 |
+
"claim_type": "correlational",
|
| 307 |
+
"claim": "",
|
| 308 |
+
}],
|
| 309 |
+
}
|
| 310 |
+
justification = "Summarize the current evidence into a conclusion."
|
| 311 |
+
|
| 312 |
+
return ExperimentAction(
|
| 313 |
+
action_type=action_type,
|
| 314 |
+
method=method,
|
| 315 |
+
parameters=parameters,
|
| 316 |
+
justification=justification,
|
| 317 |
+
confidence=0.75,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def selected_scenarios(requested: Optional[Sequence[str]]) -> List[str]:
|
| 322 |
+
from server.tasks.procedural_generator import generate_procedural_scenarios
|
| 323 |
+
all_scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(n=20, seed=42)
|
| 324 |
+
available = [scenario.name for scenario in all_scenarios]
|
| 325 |
+
if not requested:
|
| 326 |
+
return available
|
| 327 |
+
unknown = sorted(set(requested) - set(available))
|
| 328 |
+
if unknown:
|
| 329 |
+
raise ValueError(f"Unknown scenarios requested: {', '.join(unknown)}")
|
| 330 |
+
return list(requested)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def action_completion_json(action: ExperimentAction) -> str:
|
| 334 |
+
payload = {
|
| 335 |
+
"action_type": action.action_type.value,
|
| 336 |
+
"method": action.method,
|
| 337 |
+
"parameters": action.parameters,
|
| 338 |
+
"justification": action.justification,
|
| 339 |
+
"confidence": action.confidence,
|
| 340 |
+
}
|
| 341 |
+
return json.dumps(payload)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def build_prompt_examples(
|
| 345 |
+
*,
|
| 346 |
+
dataset_episodes: int,
|
| 347 |
+
rollout_steps: int,
|
| 348 |
+
collection_policy: str,
|
| 349 |
+
scenario_names: Sequence[str],
|
| 350 |
+
seed: int,
|
| 351 |
+
domain_randomise: bool,
|
| 352 |
+
) -> List[Dict[str, str]]:
|
| 353 |
+
rng = random.Random(seed)
|
| 354 |
+
examples: List[Dict[str, str]] = []
|
| 355 |
+
scenario_cycle = list(scenario_names)
|
| 356 |
+
rng.shuffle(scenario_cycle)
|
| 357 |
+
|
| 358 |
+
for episode_idx in range(dataset_episodes):
|
| 359 |
+
scenario_name = scenario_cycle[episode_idx % len(scenario_cycle)]
|
| 360 |
+
env = BioExperimentEnvironment(
|
| 361 |
+
scenario_name=scenario_name,
|
| 362 |
+
domain_randomise=domain_randomise,
|
| 363 |
+
)
|
| 364 |
+
obs = env.reset()
|
| 365 |
+
history_actions: List[ExperimentAction] = []
|
| 366 |
+
|
| 367 |
+
for step_idx in range(rollout_steps):
|
| 368 |
+
if obs.done:
|
| 369 |
+
break
|
| 370 |
+
|
| 371 |
+
next_action = build_experiment_action(
|
| 372 |
+
action_type=pick_action(
|
| 373 |
+
collection_policy,
|
| 374 |
+
step_idx,
|
| 375 |
+
[action.action_type for action in history_actions],
|
| 376 |
+
),
|
| 377 |
+
discovered_markers=obs.discovered_markers,
|
| 378 |
+
conditions=obs.task.conditions,
|
| 379 |
+
)
|
| 380 |
+
examples.append({
|
| 381 |
+
"prompt": build_training_prompt(obs),
|
| 382 |
+
"scenario_name": scenario_name,
|
| 383 |
+
"history_actions": json.dumps(
|
| 384 |
+
[action.model_dump() for action in history_actions]
|
| 385 |
+
),
|
| 386 |
+
"rng_seed": str(env._latent.rng_seed),
|
| 387 |
+
"reference_action": action_completion_json(next_action),
|
| 388 |
+
"problem_statement": obs.task.problem_statement,
|
| 389 |
+
})
|
| 390 |
+
|
| 391 |
+
history_actions.append(next_action)
|
| 392 |
+
obs = env.step(next_action)
|
| 393 |
+
|
| 394 |
+
return examples
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def completion_to_text(completion: Any) -> str:
|
| 398 |
+
if isinstance(completion, str):
|
| 399 |
+
return completion.strip()
|
| 400 |
+
if isinstance(completion, dict):
|
| 401 |
+
return content_to_text(completion.get("content", ""))
|
| 402 |
+
if isinstance(completion, list):
|
| 403 |
+
for item in reversed(completion):
|
| 404 |
+
if isinstance(item, dict) and "content" in item:
|
| 405 |
+
text = content_to_text(item["content"])
|
| 406 |
+
if text:
|
| 407 |
+
return text
|
| 408 |
+
if isinstance(item, str) and item.strip():
|
| 409 |
+
return item.strip()
|
| 410 |
+
return str(completion).strip()
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def content_to_text(content: Any) -> str:
|
| 414 |
+
if isinstance(content, str):
|
| 415 |
+
return content.strip()
|
| 416 |
+
if isinstance(content, list):
|
| 417 |
+
parts: List[str] = []
|
| 418 |
+
for part in content:
|
| 419 |
+
if isinstance(part, str):
|
| 420 |
+
parts.append(part)
|
| 421 |
+
elif isinstance(part, dict):
|
| 422 |
+
if isinstance(part.get("text"), str):
|
| 423 |
+
parts.append(part["text"])
|
| 424 |
+
elif isinstance(part.get("content"), str):
|
| 425 |
+
parts.append(part["content"])
|
| 426 |
+
return "".join(parts).strip()
|
| 427 |
+
return str(content).strip()
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _repair_truncated_json(text: str) -> Optional[str]:
|
| 431 |
+
"""Try to repair JSON truncated mid-value (common with small LLMs)."""
|
| 432 |
+
s = text.strip()
|
| 433 |
+
if not s.startswith("{"):
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
s = re.sub(r',\s*"[^"\n]*$', '', s)
|
| 437 |
+
s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
|
| 438 |
+
|
| 439 |
+
in_string = False
|
| 440 |
+
escape = False
|
| 441 |
+
for ch in s:
|
| 442 |
+
if escape:
|
| 443 |
+
escape = False
|
| 444 |
+
continue
|
| 445 |
+
if ch == "\\":
|
| 446 |
+
escape = True
|
| 447 |
+
continue
|
| 448 |
+
if ch == '"':
|
| 449 |
+
in_string = not in_string
|
| 450 |
+
|
| 451 |
+
if in_string:
|
| 452 |
+
s += '"'
|
| 453 |
+
|
| 454 |
+
open_braces = s.count("{") - s.count("}")
|
| 455 |
+
open_brackets = s.count("[") - s.count("]")
|
| 456 |
+
s += "]" * max(0, open_brackets)
|
| 457 |
+
s += "}" * max(0, open_braces)
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
obj = json.loads(s)
|
| 461 |
+
if isinstance(obj, dict):
|
| 462 |
+
return s
|
| 463 |
+
except json.JSONDecodeError:
|
| 464 |
+
pass
|
| 465 |
+
|
| 466 |
+
s = re.sub(r',\s*([}\]])', r'\1', s)
|
| 467 |
+
try:
|
| 468 |
+
obj = json.loads(s)
|
| 469 |
+
if isinstance(obj, dict):
|
| 470 |
+
return s
|
| 471 |
+
except json.JSONDecodeError:
|
| 472 |
+
pass
|
| 473 |
+
return None
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def _normalize_jsonish_text(text: str) -> str:
|
| 477 |
+
"""Normalize common near-JSON artifacts emitted by small local models."""
|
| 478 |
+
text = _strip_js_comments(text)
|
| 479 |
+
text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
|
| 480 |
+
text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
|
| 481 |
+
text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
|
| 482 |
+
text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
|
| 483 |
+
return text
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def _strip_js_comments(text: str) -> str:
|
| 487 |
+
"""Remove // and /* */ comments that small LLMs inject into JSON."""
|
| 488 |
+
text = re.sub(r'//[^\n]*', '', text)
|
| 489 |
+
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
| 490 |
+
return text
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
|
| 494 |
+
stripped = _normalize_jsonish_text(text).strip()
|
| 495 |
+
fence_prefix = "```"
|
| 496 |
+
if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
|
| 497 |
+
lines = stripped.splitlines()
|
| 498 |
+
if len(lines) >= 3:
|
| 499 |
+
stripped = "\n".join(lines[1:-1]).strip()
|
| 500 |
+
|
| 501 |
+
candidates: List[str] = [stripped]
|
| 502 |
+
start = stripped.find("{")
|
| 503 |
+
while start != -1:
|
| 504 |
+
depth = 0
|
| 505 |
+
for idx in range(start, len(stripped)):
|
| 506 |
+
char = stripped[idx]
|
| 507 |
+
if char == "{":
|
| 508 |
+
depth += 1
|
| 509 |
+
elif char == "}":
|
| 510 |
+
depth -= 1
|
| 511 |
+
if depth == 0:
|
| 512 |
+
candidates.append(stripped[start:idx + 1])
|
| 513 |
+
break
|
| 514 |
+
start = stripped.find("{", start + 1)
|
| 515 |
+
|
| 516 |
+
first_brace = stripped.find("{")
|
| 517 |
+
if first_brace != -1:
|
| 518 |
+
repaired = _repair_truncated_json(stripped[first_brace:])
|
| 519 |
+
if repaired is not None:
|
| 520 |
+
candidates.append(repaired)
|
| 521 |
+
|
| 522 |
+
candidates.sort(key=len, reverse=True)
|
| 523 |
+
|
| 524 |
+
for candidate in candidates:
|
| 525 |
+
try:
|
| 526 |
+
parsed = json.loads(candidate)
|
| 527 |
+
except json.JSONDecodeError:
|
| 528 |
+
continue
|
| 529 |
+
if isinstance(parsed, dict):
|
| 530 |
+
return parsed
|
| 531 |
+
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def normalize_optional_string(value: Any) -> Optional[str]:
|
| 536 |
+
if value is None or isinstance(value, bool):
|
| 537 |
+
return None
|
| 538 |
+
if isinstance(value, str):
|
| 539 |
+
value = value.strip()
|
| 540 |
+
return value or None
|
| 541 |
+
if isinstance(value, (int, float)):
|
| 542 |
+
return str(value)
|
| 543 |
+
return compact_preview(value, 80)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def parse_action_completion(text: str) -> Optional[ExperimentAction]:
|
| 547 |
+
payload = extract_json_object(text)
|
| 548 |
+
if payload is not None:
|
| 549 |
+
action_type = get_payload_value(payload, "action_type")
|
| 550 |
+
if action_type not in VALID_ACTION_TYPES:
|
| 551 |
+
return None
|
| 552 |
+
|
| 553 |
+
parameters = get_payload_value(payload, "parameters", "params") or {}
|
| 554 |
+
if not isinstance(parameters, dict):
|
| 555 |
+
parameters = {}
|
| 556 |
+
|
| 557 |
+
confidence = get_payload_value(payload, "confidence")
|
| 558 |
+
if confidence is None:
|
| 559 |
+
confidence = 0.5
|
| 560 |
+
try:
|
| 561 |
+
confidence = float(confidence)
|
| 562 |
+
except (TypeError, ValueError):
|
| 563 |
+
confidence = 0.5
|
| 564 |
+
|
| 565 |
+
justification = get_payload_value(
|
| 566 |
+
payload, "justification", "reasoning", "rationale", "reason"
|
| 567 |
+
)
|
| 568 |
+
if justification is not None and not isinstance(justification, str):
|
| 569 |
+
justification = compact_preview(justification, 200)
|
| 570 |
+
|
| 571 |
+
return ExperimentAction(
|
| 572 |
+
action_type=ActionType(action_type),
|
| 573 |
+
method=normalize_optional_string(get_payload_value(payload, "method")),
|
| 574 |
+
parameters=parameters,
|
| 575 |
+
justification=justification,
|
| 576 |
+
confidence=min(1.0, max(0.0, confidence)),
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
action_match = re.search(
|
| 580 |
+
r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
|
| 581 |
+
text,
|
| 582 |
+
re.IGNORECASE,
|
| 583 |
+
)
|
| 584 |
+
if not action_match:
|
| 585 |
+
return None
|
| 586 |
+
|
| 587 |
+
action_type = action_match.group(1).strip()
|
| 588 |
+
if action_type not in VALID_ACTION_TYPES:
|
| 589 |
+
return None
|
| 590 |
+
|
| 591 |
+
method_match = re.search(
|
| 592 |
+
r'["\']method["\']\s*:\s*("((?:[^"\\]|\\.)*)"|null|none|true|false|-?\d+(?:\.\d+)?)',
|
| 593 |
+
text,
|
| 594 |
+
re.IGNORECASE,
|
| 595 |
+
)
|
| 596 |
+
confidence_match = re.search(
|
| 597 |
+
r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)',
|
| 598 |
+
text,
|
| 599 |
+
re.IGNORECASE,
|
| 600 |
+
)
|
| 601 |
+
justification_match = re.search(
|
| 602 |
+
r'["\'](?:justif\w*|reasoning|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)',
|
| 603 |
+
text,
|
| 604 |
+
re.DOTALL | re.IGNORECASE,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
confidence = 0.5
|
| 608 |
+
if confidence_match:
|
| 609 |
+
try:
|
| 610 |
+
confidence = float(confidence_match.group(1))
|
| 611 |
+
except ValueError:
|
| 612 |
+
confidence = 0.5
|
| 613 |
+
|
| 614 |
+
justification = None
|
| 615 |
+
if justification_match:
|
| 616 |
+
try:
|
| 617 |
+
justification = json.loads(f'"{justification_match.group(1)}"')
|
| 618 |
+
except json.JSONDecodeError:
|
| 619 |
+
justification = justification_match.group(1)
|
| 620 |
+
|
| 621 |
+
method = None
|
| 622 |
+
if method_match:
|
| 623 |
+
raw_method = method_match.group(1)
|
| 624 |
+
if raw_method.startswith('"') and raw_method.endswith('"'):
|
| 625 |
+
try:
|
| 626 |
+
method = json.loads(raw_method)
|
| 627 |
+
except json.JSONDecodeError:
|
| 628 |
+
method = raw_method.strip('"')
|
| 629 |
+
elif raw_method.lower() not in {"null", "none", "true", "false"}:
|
| 630 |
+
method = raw_method
|
| 631 |
+
|
| 632 |
+
return ExperimentAction(
|
| 633 |
+
action_type=ActionType(action_type),
|
| 634 |
+
method=normalize_optional_string(method),
|
| 635 |
+
parameters={},
|
| 636 |
+
justification=justification,
|
| 637 |
+
confidence=min(1.0, max(0.0, confidence)),
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def decode_history_actions(history_actions: Optional[str]) -> List[ExperimentAction]:
|
| 642 |
+
if not history_actions:
|
| 643 |
+
return []
|
| 644 |
+
raw_actions = json.loads(history_actions)
|
| 645 |
+
return [
|
| 646 |
+
ExperimentAction(**action_payload)
|
| 647 |
+
for action_payload in raw_actions
|
| 648 |
+
if isinstance(action_payload, dict)
|
| 649 |
+
]
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def normalise_column(values: Any, length: int) -> List[Any]:
|
| 653 |
+
if values is None:
|
| 654 |
+
return [None] * length
|
| 655 |
+
if isinstance(values, list):
|
| 656 |
+
if len(values) == length:
|
| 657 |
+
return values
|
| 658 |
+
if len(values) == 1:
|
| 659 |
+
return values * length
|
| 660 |
+
return values[:length] + [None] * max(0, length - len(values))
|
| 661 |
+
return [values] * length
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class OpenEnvReward:
|
| 665 |
+
"""Reward function compatible with TRL GRPOTrainer."""
|
| 666 |
+
|
| 667 |
+
def __init__(
|
| 668 |
+
self,
|
| 669 |
+
*,
|
| 670 |
+
reward_backend: str,
|
| 671 |
+
base_url: str,
|
| 672 |
+
invalid_action_penalty: float = INVALID_ACTION_PENALTY,
|
| 673 |
+
environment_error_penalty: float = ENVIRONMENT_ERROR_PENALTY,
|
| 674 |
+
domain_randomise: bool = False,
|
| 675 |
+
) -> None:
|
| 676 |
+
self.__name__ = "openenv_reward"
|
| 677 |
+
self.reward_backend = reward_backend
|
| 678 |
+
self.base_url = base_url
|
| 679 |
+
self.invalid_action_penalty = invalid_action_penalty
|
| 680 |
+
self.environment_error_penalty = environment_error_penalty
|
| 681 |
+
self.domain_randomise = domain_randomise
|
| 682 |
+
|
| 683 |
+
def __call__(
|
| 684 |
+
self,
|
| 685 |
+
completions: List[Any],
|
| 686 |
+
scenario_name: Optional[List[str]] = None,
|
| 687 |
+
history_actions: Optional[List[str]] = None,
|
| 688 |
+
rng_seed: Optional[List[str]] = None,
|
| 689 |
+
**_: Any,
|
| 690 |
+
) -> List[float]:
|
| 691 |
+
scenario_names = normalise_column(scenario_name, len(completions))
|
| 692 |
+
history_columns = normalise_column(history_actions, len(completions))
|
| 693 |
+
seed_columns = normalise_column(rng_seed, len(completions))
|
| 694 |
+
rewards: List[float] = []
|
| 695 |
+
|
| 696 |
+
for completion, current_scenario, current_history, current_seed in zip(
|
| 697 |
+
completions,
|
| 698 |
+
scenario_names,
|
| 699 |
+
history_columns,
|
| 700 |
+
seed_columns,
|
| 701 |
+
):
|
| 702 |
+
action = parse_action_completion(completion_to_text(completion))
|
| 703 |
+
if action is None:
|
| 704 |
+
rewards.append(self.invalid_action_penalty)
|
| 705 |
+
continue
|
| 706 |
+
|
| 707 |
+
try:
|
| 708 |
+
if self.reward_backend == "remote":
|
| 709 |
+
reward = self._score_remote(action, current_scenario, current_history)
|
| 710 |
+
else:
|
| 711 |
+
reward = self._score_local(action, current_scenario, current_history, current_seed)
|
| 712 |
+
except Exception:
|
| 713 |
+
reward = self.environment_error_penalty
|
| 714 |
+
|
| 715 |
+
rewards.append(float(reward))
|
| 716 |
+
|
| 717 |
+
return rewards
|
| 718 |
+
|
| 719 |
+
def _score_local(
|
| 720 |
+
self,
|
| 721 |
+
action: ExperimentAction,
|
| 722 |
+
scenario_name: Optional[str],
|
| 723 |
+
history_actions: Optional[str],
|
| 724 |
+
rng_seed: Optional[str] = None,
|
| 725 |
+
) -> float:
|
| 726 |
+
env = BioExperimentEnvironment(
|
| 727 |
+
scenario_name=scenario_name,
|
| 728 |
+
domain_randomise=self.domain_randomise,
|
| 729 |
+
)
|
| 730 |
+
seed = int(rng_seed) if rng_seed else None
|
| 731 |
+
obs = env.reset(seed=seed)
|
| 732 |
+
for previous_action in decode_history_actions(history_actions):
|
| 733 |
+
obs = env.step(previous_action)
|
| 734 |
+
if obs.done:
|
| 735 |
+
return float(obs.reward)
|
| 736 |
+
obs = env.step(action)
|
| 737 |
+
return float(obs.reward)
|
| 738 |
+
|
| 739 |
+
def _score_remote(
|
| 740 |
+
self,
|
| 741 |
+
action: ExperimentAction,
|
| 742 |
+
scenario_name: Optional[str],
|
| 743 |
+
history_actions: Optional[str],
|
| 744 |
+
) -> float:
|
| 745 |
+
with BioExperimentEnv(base_url=self.base_url) as env:
|
| 746 |
+
# NOTE: scenario_name is accepted for API parity with _score_local
|
| 747 |
+
# but the OpenEnv HTTP protocol does not yet support passing it
|
| 748 |
+
# through reset(). The server will use its configured default.
|
| 749 |
+
result = env.reset()
|
| 750 |
+
for previous_action in decode_history_actions(history_actions):
|
| 751 |
+
result = env.step(previous_action)
|
| 752 |
+
if result.done:
|
| 753 |
+
return float(result.reward or 0.0)
|
| 754 |
+
result = env.step(action)
|
| 755 |
+
if result.reward is not None:
|
| 756 |
+
return float(result.reward)
|
| 757 |
+
return float(result.observation.reward)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def is_numeric_log_value(value: Any) -> bool:
|
| 761 |
+
return isinstance(value, Real) and not isinstance(value, bool)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def available_numeric_log_keys(log_history: Sequence[Dict[str, Any]]) -> List[str]:
|
| 765 |
+
keys = {
|
| 766 |
+
key
|
| 767 |
+
for entry in log_history
|
| 768 |
+
if isinstance(entry, dict)
|
| 769 |
+
for key, value in entry.items()
|
| 770 |
+
if key != "step" and is_numeric_log_value(value)
|
| 771 |
+
}
|
| 772 |
+
return sorted(keys)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def extract_log_series(
|
| 776 |
+
log_history: Sequence[Dict[str, Any]],
|
| 777 |
+
key: Optional[str],
|
| 778 |
+
) -> List[Tuple[float, float]]:
|
| 779 |
+
if not key:
|
| 780 |
+
return []
|
| 781 |
+
|
| 782 |
+
series: List[Tuple[float, float]] = []
|
| 783 |
+
synthetic_step = 0
|
| 784 |
+
for entry in log_history:
|
| 785 |
+
if not isinstance(entry, dict) or key not in entry:
|
| 786 |
+
continue
|
| 787 |
+
value = entry.get(key)
|
| 788 |
+
if not is_numeric_log_value(value):
|
| 789 |
+
continue
|
| 790 |
+
|
| 791 |
+
raw_step = entry.get("step")
|
| 792 |
+
if is_numeric_log_value(raw_step):
|
| 793 |
+
step = float(raw_step)
|
| 794 |
+
else:
|
| 795 |
+
synthetic_step += 1
|
| 796 |
+
step = float(synthetic_step)
|
| 797 |
+
|
| 798 |
+
series.append((step, float(value)))
|
| 799 |
+
|
| 800 |
+
return series
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def select_reward_key(log_history: Sequence[Dict[str, Any]]) -> Optional[str]:
|
| 804 |
+
numeric_keys = available_numeric_log_keys(log_history)
|
| 805 |
+
reward_keys = [key for key in numeric_keys if "reward" in key.lower()]
|
| 806 |
+
if not reward_keys:
|
| 807 |
+
return None
|
| 808 |
+
|
| 809 |
+
preferred = [
|
| 810 |
+
"reward",
|
| 811 |
+
"mean_reward",
|
| 812 |
+
"reward_mean",
|
| 813 |
+
"rewards/open_env_reward",
|
| 814 |
+
]
|
| 815 |
+
lowered = {key.lower(): key for key in reward_keys}
|
| 816 |
+
for key in preferred:
|
| 817 |
+
if key in lowered:
|
| 818 |
+
return lowered[key]
|
| 819 |
+
|
| 820 |
+
reward_keys.sort(key=lambda key: ("/" in key, len(key), key))
|
| 821 |
+
return reward_keys[0]
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def select_metric_key(
|
| 825 |
+
log_history: Sequence[Dict[str, Any]],
|
| 826 |
+
*,
|
| 827 |
+
reward_key: Optional[str],
|
| 828 |
+
requested_key: Optional[str] = None,
|
| 829 |
+
) -> Optional[str]:
|
| 830 |
+
numeric_keys = available_numeric_log_keys(log_history)
|
| 831 |
+
if requested_key:
|
| 832 |
+
if requested_key not in numeric_keys:
|
| 833 |
+
available = ", ".join(numeric_keys) or "none"
|
| 834 |
+
raise ValueError(
|
| 835 |
+
f"Requested plot metric '{requested_key}' was not logged. "
|
| 836 |
+
f"Available numeric keys: {available}"
|
| 837 |
+
)
|
| 838 |
+
return requested_key
|
| 839 |
+
|
| 840 |
+
excluded = {
|
| 841 |
+
"epoch",
|
| 842 |
+
"loss",
|
| 843 |
+
"learning_rate",
|
| 844 |
+
"step",
|
| 845 |
+
"total_flos",
|
| 846 |
+
"train_loss",
|
| 847 |
+
"train_runtime",
|
| 848 |
+
"train_samples_per_second",
|
| 849 |
+
"train_steps_per_second",
|
| 850 |
+
}
|
| 851 |
+
if reward_key:
|
| 852 |
+
excluded.add(reward_key)
|
| 853 |
+
|
| 854 |
+
preferred = [
|
| 855 |
+
"kl",
|
| 856 |
+
"objective/kl",
|
| 857 |
+
"completion_length",
|
| 858 |
+
"mean_completion_length",
|
| 859 |
+
"grad_norm",
|
| 860 |
+
"entropy",
|
| 861 |
+
"accuracy",
|
| 862 |
+
"learning_rate",
|
| 863 |
+
"epoch",
|
| 864 |
+
]
|
| 865 |
+
numeric_set = set(numeric_keys)
|
| 866 |
+
for key in preferred:
|
| 867 |
+
if key in numeric_set and key not in excluded:
|
| 868 |
+
return key
|
| 869 |
+
|
| 870 |
+
candidates = [
|
| 871 |
+
key for key in numeric_keys
|
| 872 |
+
if key not in excluded and "reward" not in key.lower()
|
| 873 |
+
]
|
| 874 |
+
if candidates:
|
| 875 |
+
return candidates[0]
|
| 876 |
+
|
| 877 |
+
for fallback in ("learning_rate", "epoch"):
|
| 878 |
+
if fallback in numeric_set:
|
| 879 |
+
return fallback
|
| 880 |
+
|
| 881 |
+
return None
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def save_plot(
|
| 885 |
+
path: Path,
|
| 886 |
+
*,
|
| 887 |
+
series: Sequence[Tuple[float, float]],
|
| 888 |
+
title: str,
|
| 889 |
+
ylabel: str,
|
| 890 |
+
) -> None:
|
| 891 |
+
import matplotlib
|
| 892 |
+
|
| 893 |
+
matplotlib.use("Agg")
|
| 894 |
+
import matplotlib.pyplot as plt
|
| 895 |
+
|
| 896 |
+
fig, ax = plt.subplots(figsize=(8, 4.5))
|
| 897 |
+
if series:
|
| 898 |
+
x_values, y_values = zip(*series)
|
| 899 |
+
ax.plot(x_values, y_values, marker="o", linewidth=1.8)
|
| 900 |
+
else:
|
| 901 |
+
ax.text(
|
| 902 |
+
0.5,
|
| 903 |
+
0.5,
|
| 904 |
+
"No logged data available",
|
| 905 |
+
ha="center",
|
| 906 |
+
va="center",
|
| 907 |
+
transform=ax.transAxes,
|
| 908 |
+
)
|
| 909 |
+
ax.set_title(title)
|
| 910 |
+
ax.set_xlabel("Step")
|
| 911 |
+
ax.set_ylabel(ylabel)
|
| 912 |
+
ax.grid(True, alpha=0.3)
|
| 913 |
+
fig.tight_layout()
|
| 914 |
+
fig.savefig(path, dpi=150)
|
| 915 |
+
plt.close(fig)
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def save_training_plots(
|
| 919 |
+
log_history: Sequence[Dict[str, Any]],
|
| 920 |
+
output_dir: str | Path,
|
| 921 |
+
metric_key: Optional[str] = None,
|
| 922 |
+
) -> Dict[str, str]:
|
| 923 |
+
import matplotlib
|
| 924 |
+
|
| 925 |
+
matplotlib.use("Agg")
|
| 926 |
+
import matplotlib.pyplot as plt
|
| 927 |
+
|
| 928 |
+
output_path = Path(output_dir)
|
| 929 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 930 |
+
|
| 931 |
+
reward_key = select_reward_key(log_history)
|
| 932 |
+
selected_metric_key = select_metric_key(
|
| 933 |
+
log_history,
|
| 934 |
+
reward_key=reward_key,
|
| 935 |
+
requested_key=metric_key,
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
loss_series = extract_log_series(log_history, "loss")
|
| 939 |
+
reward_series = extract_log_series(log_history, reward_key)
|
| 940 |
+
metric_series = extract_log_series(log_history, selected_metric_key)
|
| 941 |
+
|
| 942 |
+
loss_path = output_path / "training_loss.png"
|
| 943 |
+
reward_path = output_path / "training_reward.png"
|
| 944 |
+
metric_path = output_path / "training_metric.png"
|
| 945 |
+
dashboard_path = output_path / "training_dashboard.png"
|
| 946 |
+
manifest_path = output_path / "training_plot_manifest.json"
|
| 947 |
+
|
| 948 |
+
save_plot(loss_path, series=loss_series, title="Training Loss", ylabel="Loss")
|
| 949 |
+
save_plot(
|
| 950 |
+
reward_path,
|
| 951 |
+
series=reward_series,
|
| 952 |
+
title=f"Training Reward ({reward_key or 'not logged'})",
|
| 953 |
+
ylabel="Reward",
|
| 954 |
+
)
|
| 955 |
+
save_plot(
|
| 956 |
+
metric_path,
|
| 957 |
+
series=metric_series,
|
| 958 |
+
title=f"Training Metric ({selected_metric_key or 'not logged'})",
|
| 959 |
+
ylabel=selected_metric_key or "Metric",
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
fig, axes = plt.subplots(3, 1, figsize=(10, 12))
|
| 963 |
+
plot_specs = [
|
| 964 |
+
(axes[0], loss_series, "Training Loss", "Loss"),
|
| 965 |
+
(axes[1], reward_series, f"Training Reward ({reward_key or 'not logged'})", "Reward"),
|
| 966 |
+
(
|
| 967 |
+
axes[2],
|
| 968 |
+
metric_series,
|
| 969 |
+
f"Training Metric ({selected_metric_key or 'not logged'})",
|
| 970 |
+
selected_metric_key or "Metric",
|
| 971 |
+
),
|
| 972 |
+
]
|
| 973 |
+
for axis, series, title, ylabel in plot_specs:
|
| 974 |
+
if series:
|
| 975 |
+
x_values, y_values = zip(*series)
|
| 976 |
+
axis.plot(x_values, y_values, marker="o", linewidth=1.8)
|
| 977 |
+
else:
|
| 978 |
+
axis.text(
|
| 979 |
+
0.5,
|
| 980 |
+
0.5,
|
| 981 |
+
"No logged data available",
|
| 982 |
+
ha="center",
|
| 983 |
+
va="center",
|
| 984 |
+
transform=axis.transAxes,
|
| 985 |
+
)
|
| 986 |
+
axis.set_title(title)
|
| 987 |
+
axis.set_xlabel("Step")
|
| 988 |
+
axis.set_ylabel(ylabel)
|
| 989 |
+
axis.grid(True, alpha=0.3)
|
| 990 |
+
fig.tight_layout()
|
| 991 |
+
fig.savefig(dashboard_path, dpi=150)
|
| 992 |
+
plt.close(fig)
|
| 993 |
+
|
| 994 |
+
manifest = {
|
| 995 |
+
"available_numeric_keys": available_numeric_log_keys(log_history),
|
| 996 |
+
"reward_key": reward_key,
|
| 997 |
+
"metric_key": selected_metric_key,
|
| 998 |
+
"plots": {
|
| 999 |
+
"loss": str(loss_path),
|
| 1000 |
+
"reward": str(reward_path),
|
| 1001 |
+
"metric": str(metric_path),
|
| 1002 |
+
"dashboard": str(dashboard_path),
|
| 1003 |
+
},
|
| 1004 |
+
}
|
| 1005 |
+
manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
|
| 1006 |
+
return manifest["plots"]
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def run_dry_run_preview(
|
| 1010 |
+
examples: Sequence[Dict[str, str]],
|
| 1011 |
+
reward_fn: OpenEnvReward,
|
| 1012 |
+
output_dir: str,
|
| 1013 |
+
) -> None:
|
| 1014 |
+
if not examples:
|
| 1015 |
+
raise ValueError("No training prompts were generated for the dry run.")
|
| 1016 |
+
|
| 1017 |
+
sample = examples[0]
|
| 1018 |
+
sample_reward = reward_fn(
|
| 1019 |
+
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
|
| 1020 |
+
scenario_name=[sample["scenario_name"]],
|
| 1021 |
+
history_actions=[sample["history_actions"]],
|
| 1022 |
+
)[0]
|
| 1023 |
+
|
| 1024 |
+
print(f"Built {len(examples)} prompt states.")
|
| 1025 |
+
print(f"Output directory: {Path(output_dir)}")
|
| 1026 |
+
print(f"Sample scenario: {sample['scenario_name']}")
|
| 1027 |
+
print(f"Sample reward for reference action: {sample_reward:+.3f}")
|
| 1028 |
+
print("\nSample prompt:\n")
|
| 1029 |
+
print(sample["prompt"])
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
def resolve_torch_runtime() -> Dict[str, Any]:
|
| 1033 |
+
import torch
|
| 1034 |
+
|
| 1035 |
+
use_cuda = torch.cuda.is_available()
|
| 1036 |
+
bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
|
| 1037 |
+
dtype = torch.bfloat16 if bf16 else (
|
| 1038 |
+
torch.float16 if use_cuda else torch.float32
|
| 1039 |
+
)
|
| 1040 |
+
return {
|
| 1041 |
+
"use_cuda": use_cuda,
|
| 1042 |
+
"device": "cuda:0" if use_cuda else "cpu",
|
| 1043 |
+
"dtype": dtype,
|
| 1044 |
+
"bf16": bf16,
|
| 1045 |
+
"fp16": use_cuda and not bf16,
|
| 1046 |
+
"device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
|
| 1047 |
+
}
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
def load_model_artifacts(
|
| 1051 |
+
model_id: str,
|
| 1052 |
+
*,
|
| 1053 |
+
trust_remote_code: bool,
|
| 1054 |
+
):
|
| 1055 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1056 |
+
|
| 1057 |
+
runtime = resolve_torch_runtime()
|
| 1058 |
+
print(f"Loading tokenizer for {model_id} ...")
|
| 1059 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 1060 |
+
model_id,
|
| 1061 |
+
trust_remote_code=trust_remote_code,
|
| 1062 |
+
)
|
| 1063 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1064 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1065 |
+
|
| 1066 |
+
print(f"Loading model for {model_id} ...")
|
| 1067 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1068 |
+
model_id,
|
| 1069 |
+
trust_remote_code=trust_remote_code,
|
| 1070 |
+
torch_dtype=runtime["dtype"],
|
| 1071 |
+
)
|
| 1072 |
+
if runtime["use_cuda"]:
|
| 1073 |
+
model = model.to(runtime["device"])
|
| 1074 |
+
else:
|
| 1075 |
+
model = model.to("cpu")
|
| 1076 |
+
return tokenizer, model
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
def generate_action_with_model(
|
| 1080 |
+
model: Any,
|
| 1081 |
+
tokenizer: Any,
|
| 1082 |
+
prompt_or_observation: str | ExperimentObservation,
|
| 1083 |
+
*,
|
| 1084 |
+
max_new_tokens: int = 220,
|
| 1085 |
+
temperature: float = 0.2,
|
| 1086 |
+
top_p: float = 0.9,
|
| 1087 |
+
do_sample: bool = True,
|
| 1088 |
+
) -> Dict[str, Any]:
|
| 1089 |
+
import torch
|
| 1090 |
+
|
| 1091 |
+
if isinstance(prompt_or_observation, ExperimentObservation):
|
| 1092 |
+
prompt = build_training_prompt(prompt_or_observation)
|
| 1093 |
+
else:
|
| 1094 |
+
prompt = str(prompt_or_observation)
|
| 1095 |
+
|
| 1096 |
+
model_device = getattr(model, "device", None)
|
| 1097 |
+
if model_device is None:
|
| 1098 |
+
model_device = resolve_torch_runtime()["device"]
|
| 1099 |
+
|
| 1100 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 1101 |
+
inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| 1102 |
+
prompt_tokens = inputs["input_ids"].shape[1]
|
| 1103 |
+
|
| 1104 |
+
generation_kwargs = {
|
| 1105 |
+
"max_new_tokens": max_new_tokens,
|
| 1106 |
+
"do_sample": do_sample,
|
| 1107 |
+
"temperature": temperature,
|
| 1108 |
+
"top_p": top_p,
|
| 1109 |
+
"pad_token_id": tokenizer.pad_token_id,
|
| 1110 |
+
}
|
| 1111 |
+
with torch.no_grad():
|
| 1112 |
+
output_ids = model.generate(**inputs, **generation_kwargs)
|
| 1113 |
+
|
| 1114 |
+
new_tokens = output_ids[0][prompt_tokens:]
|
| 1115 |
+
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 1116 |
+
action = parse_action_completion(response_text)
|
| 1117 |
+
return {
|
| 1118 |
+
"prompt": prompt,
|
| 1119 |
+
"response_text": response_text,
|
| 1120 |
+
"action": action,
|
| 1121 |
+
}
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| 1125 |
+
random.seed(args.seed)
|
| 1126 |
+
runtime = resolve_torch_runtime()
|
| 1127 |
+
|
| 1128 |
+
if args.load_model_only:
|
| 1129 |
+
tokenizer, model = load_model_artifacts(
|
| 1130 |
+
args.model_id,
|
| 1131 |
+
trust_remote_code=args.trust_remote_code,
|
| 1132 |
+
)
|
| 1133 |
+
device = getattr(model, "device", "unknown")
|
| 1134 |
+
print(f"Model ready: {args.model_id}")
|
| 1135 |
+
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
| 1136 |
+
print(f"Model device: {device}")
|
| 1137 |
+
print(f"Runtime device name: {runtime['device_name']}")
|
| 1138 |
+
return {
|
| 1139 |
+
"args": args,
|
| 1140 |
+
"runtime": runtime,
|
| 1141 |
+
"tokenizer": tokenizer,
|
| 1142 |
+
"model": model,
|
| 1143 |
+
}
|
| 1144 |
+
|
| 1145 |
+
scenario_names = selected_scenarios(args.scenario_name)
|
| 1146 |
+
examples = build_prompt_examples(
|
| 1147 |
+
dataset_episodes=args.dataset_episodes,
|
| 1148 |
+
rollout_steps=args.rollout_steps,
|
| 1149 |
+
collection_policy=args.collection_policy,
|
| 1150 |
+
scenario_names=scenario_names,
|
| 1151 |
+
seed=args.seed,
|
| 1152 |
+
domain_randomise=args.domain_randomise,
|
| 1153 |
+
)
|
| 1154 |
+
reward_fn = OpenEnvReward(
|
| 1155 |
+
reward_backend=args.reward_backend,
|
| 1156 |
+
base_url=args.base_url,
|
| 1157 |
+
domain_randomise=args.domain_randomise,
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
if args.dry_run:
|
| 1161 |
+
run_dry_run_preview(examples, reward_fn, args.output_dir)
|
| 1162 |
+
return {
|
| 1163 |
+
"args": args,
|
| 1164 |
+
"runtime": runtime,
|
| 1165 |
+
"scenario_names": scenario_names,
|
| 1166 |
+
"examples": examples,
|
| 1167 |
+
"reward_fn": reward_fn,
|
| 1168 |
+
}
|
| 1169 |
+
|
| 1170 |
+
from datasets import Dataset
|
| 1171 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 1172 |
+
|
| 1173 |
+
train_dataset = Dataset.from_list(examples)
|
| 1174 |
+
tokenizer, model = load_model_artifacts(
|
| 1175 |
+
args.model_id,
|
| 1176 |
+
trust_remote_code=args.trust_remote_code,
|
| 1177 |
+
)
|
| 1178 |
+
config = GRPOConfig(
|
| 1179 |
+
output_dir=args.output_dir,
|
| 1180 |
+
learning_rate=args.learning_rate,
|
| 1181 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 1182 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 1183 |
+
num_generations=args.num_generations,
|
| 1184 |
+
max_completion_length=args.max_completion_length,
|
| 1185 |
+
num_train_epochs=args.num_train_epochs,
|
| 1186 |
+
logging_steps=args.logging_steps,
|
| 1187 |
+
save_steps=args.save_steps,
|
| 1188 |
+
bf16=runtime["bf16"],
|
| 1189 |
+
fp16=runtime["fp16"],
|
| 1190 |
+
report_to="none",
|
| 1191 |
+
remove_unused_columns=False,
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
print(
|
| 1195 |
+
f"Training runtime: device={runtime['device']} "
|
| 1196 |
+
f"name={runtime['device_name']} "
|
| 1197 |
+
f"dtype={runtime['dtype']}"
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
trainer = GRPOTrainer(
|
| 1201 |
+
model=model,
|
| 1202 |
+
reward_funcs=reward_fn,
|
| 1203 |
+
args=config,
|
| 1204 |
+
train_dataset=train_dataset,
|
| 1205 |
+
processing_class=tokenizer,
|
| 1206 |
+
)
|
| 1207 |
+
trainer.train()
|
| 1208 |
+
trainer.save_model(args.output_dir)
|
| 1209 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 1210 |
+
if args.push_to_hub:
|
| 1211 |
+
from huggingface_hub import HfApi
|
| 1212 |
+
api = HfApi()
|
| 1213 |
+
api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
|
| 1214 |
+
print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
|
| 1215 |
+
api.upload_folder(
|
| 1216 |
+
folder_path=args.output_dir,
|
| 1217 |
+
repo_id=args.push_to_hub,
|
| 1218 |
+
repo_type="model",
|
| 1219 |
+
create_pr=False,
|
| 1220 |
+
)
|
| 1221 |
+
print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
|
| 1222 |
+
plot_paths = save_training_plots(
|
| 1223 |
+
trainer.state.log_history,
|
| 1224 |
+
args.output_dir,
|
| 1225 |
+
metric_key=args.plot_metric_key,
|
| 1226 |
+
)
|
| 1227 |
+
print("Saved training plots:")
|
| 1228 |
+
for plot_name, plot_path in plot_paths.items():
|
| 1229 |
+
print(f" - {plot_name}: {plot_path}")
|
| 1230 |
+
|
| 1231 |
+
return {
|
| 1232 |
+
"args": args,
|
| 1233 |
+
"runtime": runtime,
|
| 1234 |
+
"scenario_names": scenario_names,
|
| 1235 |
+
"examples": examples,
|
| 1236 |
+
"reward_fn": reward_fn,
|
| 1237 |
+
"train_dataset": train_dataset,
|
| 1238 |
+
"tokenizer": tokenizer,
|
| 1239 |
+
"model": model,
|
| 1240 |
+
"trainer": trainer,
|
| 1241 |
+
"plot_paths": plot_paths,
|
| 1242 |
+
}
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def main() -> None:
|
| 1246 |
+
run_training(parse_args())
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
if __name__ == "__main__":
|
| 1250 |
+
main()
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|