Ev3Dev commited on
Commit
5c3cfae
·
verified ·
1 Parent(s): 4db0438

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -21,7 +21,7 @@ The environment is designed as a partially observable Markov decision process (P
21
  - visible task metadata, resource usage, step history, and intermediate outputs
22
  - dense step-wise reward plus terminal reward for conclusion quality
23
 
24
- ## What "how it works" means here
25
 
26
  At a high level, each episode looks like this:
27
 
@@ -68,11 +68,15 @@ This separation is what makes the environment a POMDP rather than a fully observ
68
 
69
  Defines the contracts that all other modules use:
70
 
71
- - `ExperimentAction`: one structured step chosen by the agent
72
- - `ExperimentObservation`: what the agent can see after each step
73
- - `TaskSpec`: the problem statement, budget, time limit, assays, tools, and expected findings
74
- - `IntermediateOutput`: the simulated artifact returned by a step
75
- - `ConclusionClaim`: structured claims used for final synthesis
 
 
 
 
76
 
77
  The action vocabulary is intentionally broad enough to mix wet-lab, computational, and meta-planning actions.
78
 
@@ -80,27 +84,30 @@ The action vocabulary is intentionally broad enough to mix wet-lab, computationa
80
 
81
  This is where episodes come from.
82
 
83
- - `scenarios.py` defines a small library of curated biological scenarios
84
- - `generator.py` turns a scenario into a `(TaskSpec, FullLatentState)` pair
85
- - optional domain randomization perturbs budget, time, noise, batch effects, cell proportions, and effect sizes
86
 
87
- Right now the scenario library includes:
88
 
89
- - `cardiac_disease_de`: disease vs healthy differential expression in heart tissue
90
- - `hematopoiesis_trajectory`: developmental trajectory inference in bone marrow
91
- - `perturbation_immune`: treatment response under JAK inhibition
92
- - `biomarker_validation_lung`: follow-up validation of `SPP1` in IPF
 
 
 
 
93
 
94
  ### `server/simulator/`
95
 
96
  This is the simulator itself.
97
 
98
- - `latent_state.py` defines hidden biological, technical, progress, and resource state
99
- - `noise.py` centralizes stochasticity so episodes are reproducible from a seed
100
- - `output_generator.py` turns an action plus hidden state into a realistic `IntermediateOutput`
101
- - `transition.py` applies action costs, updates progress flags, propagates artifacts, and decides whether the episode is done
102
 
103
- The output generator does not simply echo the action. It conditions outputs on the hidden state, then injects realistic noise such as dropout, false positives, false negatives, and imperfect clustering.
104
 
105
  ### `server/rules/engine.py`
106
 
@@ -109,32 +116,41 @@ The rule engine enforces scientific and procedural constraints before each actio
109
  - hard violations block the action entirely
110
  - soft violations allow the action, but reduce output quality and add reward penalties
111
 
112
- Examples:
113
 
114
- - sequencing before library prep is a hard violation
115
- - running QC twice is a soft redundancy violation
116
- - making causal claims without enough evidence is a soft validity violation
 
117
 
118
  ### `server/rewards/reward.py`
119
 
120
  Rewards are decomposed rather than being a single opaque number.
121
 
122
- Per-step reward includes:
123
 
124
- - validity
125
- - ordering
126
- - information gain
127
- - efficiency
128
- - novelty
129
- - penalties
130
- - potential-based shaping
 
 
 
 
 
 
131
 
132
  Terminal reward adds:
133
 
134
- - pipeline completeness
135
- - calibration of conclusions against hidden truth
136
- - remaining budget and time efficiency
137
- - overconfidence penalties for strong but incorrect claims
 
 
138
 
139
  This makes the environment easier to debug, benchmark, and train against.
140
 
@@ -145,7 +161,7 @@ This is the orchestration layer that ties everything together.
145
  On `reset()` it:
146
 
147
  - seeds the noise model
148
- - generates a task and latent state
149
  - clears history, outputs, discoveries, conclusions, and cumulative reward
150
 
151
  On `step()` it:
@@ -171,12 +187,34 @@ Here is the concrete order of operations for `env.step(action)`:
171
  5. Otherwise deduct budget and time based on `ACTION_COSTS`.
172
  6. Update latent progress flags like `samples_collected`, `qc_performed`, or `de_performed`.
173
  7. Generate a structured simulated output for the chosen action.
174
- 8. If there were soft violations, degrade output quality and attach warnings.
175
  9. Propagate artifacts back into latent state, such as discovered DE genes or cluster names.
176
  10. Compute decomposed reward from state transition plus output quality.
177
  11. If the episode is ending, compute terminal reward from completeness and conclusion calibration.
178
  12. Return an observation that exposes the visible summary but not the hidden truth.
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  ## Typical successful pipeline
181
 
182
  Most scenarios reward a sensible experiment order similar to:
@@ -193,12 +231,47 @@ Most scenarios reward a sensible experiment order similar to:
193
  `regulatory_network_inference`, `marker_selection`, `validate_marker`
194
  9. `synthesize_conclusion`
195
 
196
- The exact best sequence depends on the scenario. For example:
197
 
198
  - trajectory scenarios benefit from `trajectory_analysis` and regulatory inference
199
  - biomarker scenarios benefit from DE, marker selection, and validation
200
  - perturbation scenarios benefit from pathway-level interpretation
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  ## Interfaces you can use
203
 
204
  ### 1. In-process environment
@@ -224,6 +297,10 @@ print(obs.latest_output.summary if obs.latest_output else "No output yet")
224
  print(obs.reward)
225
  ```
226
 
 
 
 
 
227
  ### 2. OpenEnv client/server mode
228
 
229
  Use the FastAPI app when you want to serve the environment over HTTP and WebSocket:
@@ -233,6 +310,16 @@ uv sync --extra dev
233
  uv run uvicorn server.app:app --reload
234
  ```
235
 
 
 
 
 
 
 
 
 
 
 
236
  Then connect with the client:
237
 
238
  ```python
@@ -247,40 +334,133 @@ with BioExperimentEnv(base_url="http://localhost:8000") as env:
247
 
248
  The environment class supports concurrent sessions, but the bundled server is currently configured with `max_concurrent_envs=1` in `server/app.py`.
249
 
250
- ### 3. Gymnasium wrapper
251
 
252
- Use `training/gym_wrapper.py` when you want a classic RL interface:
253
 
254
- ```python
255
- from training.gym_wrapper import BioExperimentGymEnv
256
-
257
- env = BioExperimentGymEnv()
258
- obs, info = env.reset()
259
- obs, reward, terminated, truncated, info = env.step({
260
- "action_type": 0,
261
- "confidence": 0.7,
262
- })
263
  ```
264
 
265
- This wrapper vectorizes the structured observation into arrays and reduces the action interface to:
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- - a discrete action type index
268
- - a scalar confidence value
 
 
 
 
 
 
 
269
 
270
- ### 4. Benchmark and scripted agents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  - `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
 
 
273
  - `run_agent.py` runs a local language model planner against the environment
274
  - `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
275
  - `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
276
 
277
- ## Episode termination
278
 
279
- An episode ends when one of the following happens:
280
 
281
- - the agent chooses `synthesize_conclusion`
282
- - resources are exhausted
283
- - the environment reaches `MAX_STEPS` which is currently `30`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  ## Why this is useful
286
 
@@ -299,31 +479,51 @@ That makes it suitable for:
299
  - literature-grounded evaluation
300
  - comparing structured policies against LLM-driven planners
301
 
302
- ## Minimal project map
303
 
304
  ```text
305
  .
306
- ├── client.py # OpenEnv client
307
  ├── models.py # Shared action / observation / task schemas
 
 
 
308
  ├── server/
309
- │ ├── app.py # FastAPI/OpenEnv server
 
310
  │ ├── hackathon_environment.py # Main environment orchestration
311
- │ ├── rewards/ # Reward model
312
- │ ├── rules/ # Constraint checking
313
- ── simulator/ # Latent state, noise, outputs, transitions
314
- ── tasks/ # Scenario library and task generation
 
 
 
 
 
 
 
 
 
 
315
  ├── training/
316
- │ ├── evaluation.py # Metrics
317
- │ ├── gym_wrapper.py # Gymnasium wrapper
318
  │ ├── literature_benchmark.py # Paper-backed benchmark flow
319
- ── trajectory.py # Trajectory serialization
320
- └── tests/ # Unit and integration tests
 
 
 
 
 
 
 
 
 
321
  ```
322
 
323
  ## Quick sanity check
324
 
325
- The current implementation was sanity-checked with:
326
-
327
  ```bash
328
  uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
329
  ```
@@ -335,3 +535,9 @@ Those tests verify:
335
  - conclusion termination
336
  - literature-backed scenario selection
337
  - benchmark matching for curated expected findings
 
 
 
 
 
 
 
21
  - visible task metadata, resource usage, step history, and intermediate outputs
22
  - dense step-wise reward plus terminal reward for conclusion quality
23
 
24
+ ## How it works
25
 
26
  At a high level, each episode looks like this:
27
 
 
68
 
69
  Defines the contracts that all other modules use:
70
 
71
+ - `ActionType`: 21 discrete experiment steps, grouped into three frozensets — `WET_LAB_ACTIONS` (8), `COMPUTATIONAL_ACTIONS` (10), and `META_ACTIONS` (3)
72
+ - `SubagentType`: 9 sub-agent delegate roles (e.g. `wet_lab_planner`, `computational_analyst`, `causal_reasoning_agent`)
73
+ - `ExperimentAction`: one structured step chosen by the agent; fields include `action_type`, `method`, `parameters`, `justification`, `confidence` (clamped to `[0, 1]`), `invoked_subagent`, `tool_call_spec`, `input_targets`
74
+ - `ExperimentObservation`: what the agent can see after each step; includes `task`, `pipeline_history`, `resource_usage`, `latest_output`, `all_outputs`, `discovered_markers`, `candidate_mechanisms`, `conclusions`, `rule_violations`, `step_reward_breakdown`
75
+ - `TaskSpec`: the problem statement, organism, tissue, conditions, budget, time limit, assays, tools, paper references, and expected findings
76
+ - `IntermediateOutput`: the simulated artifact returned by a step; carries `output_type`, `success`, `quality_score`, `summary`, `data`, `uncertainty`, `warnings`, `artifacts_available`
77
+ - `ConclusionClaim`: structured claims used for final synthesis; carries `claim`, `evidence_steps`, `confidence`, `claim_type`, `supporting_data`
78
+ - `PipelineStepRecord`: compact observable record of one past step stored in history
79
+ - `ResourceUsage`: budget and time tracking visible to the agent
80
 
81
  The action vocabulary is intentionally broad enough to mix wet-lab, computational, and meta-planning actions.
82
 
 
84
 
85
  This is where episodes come from.
86
 
87
+ - `scenarios.py` defines a curated library of four biological scenarios as `Scenario` dataclass objects, each bundling a `TaskSpec`, a `LatentBiologicalState`, a `TechnicalState`, hidden failure conditions, and tags
88
+ - `generator.py` turns a scenario into a `(TaskSpec, FullLatentState)` pair via `TaskGenerator.generate()`; optional domain randomisation perturbs budget (±30%), time (±20%), technical noise, batch effects, cell proportions, and effect sizes
 
89
 
90
+ The four scenarios are:
91
 
92
+ | Name | Difficulty | Tissue | Problem | Budget | Time |
93
+ |---|---|---|---|---|---|
94
+ | `cardiac_disease_de` | easy | heart | Differential expression between healthy and dilated cardiomyopathy cardiomyocytes | $80 K | 120 days |
95
+ | `hematopoiesis_trajectory` | medium | bone marrow | Infer HSC → mature lineage trajectory with three branches | $100 K | 150 days |
96
+ | `perturbation_immune` | hard | synovial fluid | JAK inhibitor effect on T-cell states in rheumatoid arthritis | $120 K | 180 days |
97
+ | `biomarker_validation_lung` | medium | lung | Validate SPP1 as biomarker for pro-fibrotic macrophages in IPF | $90 K | 150 days |
98
+
99
+ Each scenario carries paper references with DOIs, true DE genes with log2FC values, true pathway activities, true regulatory networks, and ground-truth causal mechanisms used for terminal reward calibration.
100
 
101
  ### `server/simulator/`
102
 
103
  This is the simulator itself.
104
 
105
+ - `latent_state.py` defines `FullLatentState`, the root aggregate of all hidden state. Key sub-structures are `LatentBiologicalState` (true DE genes, pathways, gene programs, trajectory, regulatory network, markers, causal mechanisms), `TechnicalState` (dropout, doublets, ambient RNA, sample quality), `ExperimentProgress` (18 boolean milestone flags plus counts), and `ResourceState` (internal budget and time tracking with exhaustion properties)
106
+ - `noise.py` centralises stochasticity in `NoiseModel`. All randomness flows through a single seeded `numpy.Generator`. Methods include `add_expression_noise`, `sample_effect_sizes`, `sample_p_values`, `generate_false_positives`, `generate_false_negatives`, `quality_degradation`, `sample_qc_metric`, `sample_cluster_count`, `shuffle_ranking`, and `coin_flip`
107
+ - `output_generator.py` turns an action plus hidden state into a realistic `IntermediateOutput`. Every action type has a dedicated handler conditioned on the latent state; noise is then injected — dropout in expression data, false positives and false negatives in DE and marker results, over/under-clustering, and pathway contamination
108
+ - `transition.py` applies action costs from `ACTION_COSTS`, updates progress flags, calls the output generator, degrades quality on soft violations, propagates discovered DE genes and cluster names back into latent state, and decides whether the episode is done
109
 
110
+ The output generator does not simply echo the action. It conditions outputs on the hidden state, then injects realistic noise.
111
 
112
  ### `server/rules/engine.py`
113
 
 
116
  - hard violations block the action entirely
117
  - soft violations allow the action, but reduce output quality and add reward penalties
118
 
119
+ The four rule families are:
120
 
121
+ 1. **Prerequisites (HARD)** each computational step requires the appropriate upstream milestone flag. For example: `normalize_data` requires `data_filtered`, `differential_expression` requires `data_normalized`, `validate_marker` requires `markers_discovered`
122
+ 2. **Resource constraints (HARD/SOFT)** — budget or time exhausted is a hard block; action cost exceeding remaining budget (when budget > 0) is a soft warning
123
+ 3. **Redundancy (SOFT)** repeating an already-completed step such as `run_qc` or `normalize_data`
124
+ 4. **Causal validity (SOFT)** — synthesizing conclusions without prior DE or clustering; making causal claims without validation evidence; pathway enrichment before DE
125
 
126
  ### `server/rewards/reward.py`
127
 
128
  Rewards are decomposed rather than being a single opaque number.
129
 
130
+ Per-step reward formula:
131
 
132
+ ```
133
+ R_t = r_validity + r_ordering + r_info_gain + r_efficiency + r_novelty + r_penalty + γ[φ(s_{t+1}) − φ(s_t)]
134
+ ```
135
+
136
+ | Component | Weight | Description |
137
+ |---|---|---|
138
+ | `validity` | 0.3 | `1.0` if output succeeded, `−1.0` if hard violation |
139
+ | `ordering` | 0.2 | `1.0` if natural next step, `0.3` otherwise |
140
+ | `info_gain` | 0.4 | `quality_score × (1 − uncertainty)` |
141
+ | `efficiency` | 0.3 | `max(0, 1 − 5 × budget_fraction_used)` |
142
+ | `novelty` | +0.1 | Bonus when no soft violations |
143
+ | `penalty` | −0.15/violation | Per soft violation |
144
+ | `shaping` | γ = 0.99 | Potential-based over 12 progress milestones |
145
 
146
  Terminal reward adds:
147
 
148
+ | Component | Weight | Description |
149
+ |---|---|---|
150
+ | Pipeline completeness | 3.0 | Fraction of 7 core milestones completed |
151
+ | Calibration | 4.0 | How well conclusions match hidden markers and mechanisms |
152
+ | Budget + time efficiency | 1.0 | Average fraction of budget and time remaining |
153
+ | Overconfidence penalty | −0.5/claim | For high-confidence claims (`> 0.8`) that are wrong |
154
 
155
  This makes the environment easier to debug, benchmark, and train against.
156
 
 
161
  On `reset()` it:
162
 
163
  - seeds the noise model
164
+ - generates a task and latent state via `TaskGenerator`
165
  - clears history, outputs, discoveries, conclusions, and cumulative reward
166
 
167
  On `step()` it:
 
187
  5. Otherwise deduct budget and time based on `ACTION_COSTS`.
188
  6. Update latent progress flags like `samples_collected`, `qc_performed`, or `de_performed`.
189
  7. Generate a structured simulated output for the chosen action.
190
+ 8. If there were soft violations, degrade output quality (×0.5) and attach warnings.
191
  9. Propagate artifacts back into latent state, such as discovered DE genes or cluster names.
192
  10. Compute decomposed reward from state transition plus output quality.
193
  11. If the episode is ending, compute terminal reward from completeness and conclusion calibration.
194
  12. Return an observation that exposes the visible summary but not the hidden truth.
195
 
196
+ ## Action costs
197
+
198
+ Each action deducts from the episode's budget and time. Computational steps also accrue compute hours.
199
+
200
+ | Action | Budget | Time (days) |
201
+ |---|---|---|
202
+ | `sequence_cells` | $15,000 | 5 |
203
+ | `prepare_library` | $8,000 | 3 |
204
+ | `collect_sample` | $5,000 | 7 |
205
+ | `validate_marker` | $5,000 | 14 |
206
+ | `culture_cells` | $3,000 | 14 |
207
+ | `perturb_gene` | $2,000 | 3 |
208
+ | `perturb_compound` | $1,000 | 2 |
209
+ | `select_cohort` | $500 | 1 |
210
+ | `run_qc` | $100 | 0.5 |
211
+ | `integrate_batches` | $300 | 1 |
212
+ | `regulatory_network_inference` | $200 | 1 |
213
+ | `cluster_cells` | $150 | 0.5 |
214
+ | `differential_expression`, `trajectory_analysis`, `pathway_enrichment` | $100–200 | 0.5–1 |
215
+ | `filter_data`, `normalize_data`, `marker_selection` | $50–100 | 0.25–0.5 |
216
+ | `synthesize_conclusion`, `design_followup_experiment`, `request_subagent_review` | $0 | 0.25–0.5 |
217
+
218
  ## Typical successful pipeline
219
 
220
  Most scenarios reward a sensible experiment order similar to:
 
231
  `regulatory_network_inference`, `marker_selection`, `validate_marker`
232
  9. `synthesize_conclusion`
233
 
234
+ The exact best sequence depends on the scenario:
235
 
236
  - trajectory scenarios benefit from `trajectory_analysis` and regulatory inference
237
  - biomarker scenarios benefit from DE, marker selection, and validation
238
  - perturbation scenarios benefit from pathway-level interpretation
239
 
240
+ ## Episode termination
241
+
242
+ An episode ends when one of the following happens:
243
+
244
+ - the agent chooses `synthesize_conclusion`
245
+ - resources are exhausted
246
+ - the environment reaches `MAX_STEPS` which is currently `30`
247
+
248
+ ## Installation
249
+
250
+ Dependencies are managed with `uv`. The package requires Python ≥ 3.10.
251
+
252
+ ```bash
253
+ # Core environment only
254
+ uv sync
255
+
256
+ # With dev/test tools
257
+ uv sync --extra dev
258
+
259
+ # With training dependencies (TRL, transformers, torch)
260
+ uv sync --extra train
261
+
262
+ # With bioinformatics extras (scanpy, biopython, gseapy)
263
+ uv sync --extra bio
264
+ ```
265
+
266
+ Key dependency groups from `pyproject.toml`:
267
+
268
+ | Group | Key packages |
269
+ |---|---|
270
+ | core | `openenv-core[core]>=0.2.0`, `numpy`, `scipy`, `pydantic>=2.0` |
271
+ | train | `trl>=0.29`, `transformers>=5.3`, `accelerate`, `datasets`, `torch`, `matplotlib` |
272
+ | bio | `scanpy`, `biopython`, `gseapy` |
273
+ | dev | `pytest`, `pytest-cov` |
274
+
275
  ## Interfaces you can use
276
 
277
  ### 1. In-process environment
 
297
  print(obs.reward)
298
  ```
299
 
300
+ The constructor accepts:
301
+ - `scenario_name: Optional[str]` — pin to a specific scenario; `None` picks randomly each episode
302
+ - `domain_randomise: bool = True` — perturbs scenario parameters for generalization
303
+
304
  ### 2. OpenEnv client/server mode
305
 
306
  Use the FastAPI app when you want to serve the environment over HTTP and WebSocket:
 
310
  uv run uvicorn server.app:app --reload
311
  ```
312
 
313
+ The server exposes five endpoints:
314
+
315
+ | Method | Path | Description |
316
+ |---|---|---|
317
+ | `POST` | `/reset` | Start a new episode |
318
+ | `POST` | `/step` | Execute one action |
319
+ | `GET` | `/state` | Current environment state |
320
+ | `GET` | `/schema` | Action/observation JSON schemas |
321
+ | `WS` | `/ws` | WebSocket for persistent sessions |
322
+
323
  Then connect with the client:
324
 
325
  ```python
 
334
 
335
  The environment class supports concurrent sessions, but the bundled server is currently configured with `max_concurrent_envs=1` in `server/app.py`.
336
 
337
+ ### 3. Running a local agent
338
 
339
+ `run_agent.py` runs a single interactive episode using a local Hugging Face model:
340
 
341
+ ```bash
342
+ uv run python run_agent.py
 
 
 
 
 
 
 
343
  ```
344
 
345
+ Configuration is via environment variables:
346
+
347
+ | Variable | Default | Description |
348
+ |---|---|---|
349
+ | `RUN_AGENT_USE_PIPELINE` | `0` | Use HF `pipeline()` path instead of direct generate |
350
+ | `RUN_AGENT_MAX_EPISODE_STEPS` | `12` | Maximum number of planning steps |
351
+
352
+ The local model defaults to `Qwen/Qwen3.5-0.8B` with sampling parameters `temperature=0.7`, `top_p=0.8`, `top_k=20`, `repetition_penalty=1.3`. The episode runs up to `MAX_EPISODE_STEPS = 12` steps. When action parsing fails, the script falls back to an observation-aware action that respects prerequisites.
353
+
354
+ PowerShell note: older PowerShell versions do not support `&&`. Run commands from the target directory directly, or use `;` as the command separator.
355
+
356
+ Windows runtime warnings:
357
+ - If you see HuggingFace symlink-cache warnings, functionality is unaffected; optionally set `HF_HUB_DISABLE_SYMLINKS_WARNING=1`.
358
+ - If you see flash attention / causal-conv fallback warnings, execution continues with a slower PyTorch path.
359
 
360
+ ### 4. GRPO training
361
+
362
+ `training_script.py` follows the TRL GRPO pattern and uses OpenEnv rewards to score generated action JSON against this environment.
363
+
364
+ ```bash
365
+ uv sync --extra train
366
+ uv run python training_script.py --dry-run
367
+ uv run python training_script.py --model-id Qwen/Qwen3.5-0.8B
368
+ ```
369
 
370
+ Key arguments:
371
+
372
+ | Argument | Default | Description |
373
+ |---|---|---|
374
+ | `--model-id` | `Qwen/Qwen2.5-7B-Instruct` | Base model to fine-tune |
375
+ | `--output-dir` | `training/grpo-output` | Save directory |
376
+ | `--dataset-episodes` | `8` | Rollout episodes for prompt dataset |
377
+ | `--rollout-steps` | `6` | Steps per episode during collection |
378
+ | `--collection-policy` | `heuristic` | `random` or `heuristic` |
379
+ | `--reward-backend` | `local` | `local` (in-process) or `remote` (live server) |
380
+ | `--base-url` | `http://localhost:8000` | Server URL for remote backend |
381
+ | `--scenario-name` | all | Repeatable; restricts which scenarios are used |
382
+ | `--domain-randomise` | off | Enable domain randomisation |
383
+ | `--num-generations` | `4` | GRPO generations per prompt |
384
+ | `--max-completion-length` | `220` | Max tokens for model completions |
385
+ | `--max-prompt-length` | `768` | Max tokens for prompts |
386
+ | `--learning-rate` | `5e-6` | AdamW learning rate |
387
+ | `--dry-run` | off | Build data and test reward without training |
388
+
389
+ By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
390
+
391
+ After training, the script saves plots to the output directory:
392
+
393
+ - `training_loss.png`
394
+ - `training_reward.png`
395
+ - `training_metric.png`
396
+ - `training_dashboard.png`
397
+ - `training_plot_manifest.json`
398
+
399
+ Use `--plot-metric-key <logged_key>` to force a specific extra metric on the third chart; otherwise the script auto-selects a useful logged metric such as KL or gradient norm.
400
+
401
+ ### 5. Rollout collection
402
+
403
+ `training/rollout_collection.py` collects direct environment rollouts into trajectory files:
404
+
405
+ ```bash
406
+ uv run python -m training.rollout_collection
407
+ ```
408
+
409
+ This runs N episodes with a `random` or `heuristic` policy, saves JSON trajectories, and prints evaluation metrics.
410
+
411
+ ### 6. Benchmark and scripted agents
412
 
413
  - `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
414
+ - `training/rollout_collection.py` collects direct environment rollouts into trajectory files
415
+ - `training_script.py` trains a GRPO policy with OpenEnv reward calls
416
  - `run_agent.py` runs a local language model planner against the environment
417
  - `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
418
  - `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
419
 
420
+ ## Training utilities
421
 
422
+ ### `training/trajectory.py`
423
 
424
+ Provides `TrajectoryStep`, `Trajectory`, and `TrajectoryDataset` for episode serialization.
425
+
426
+ - `TrajectoryStep` stores `action`, `observation`, `reward`, `done`, `reward_breakdown`, and an optional `latent_snapshot`
427
+ - `Trajectory` accumulates steps with `add_step()`, computes `total_reward`, and exposes `save(path)` / `load(path)`
428
+ - `TrajectoryDataset` wraps a list of trajectories with `filter_successful()`, `save_dir()`, `load_dir()`, and `summary()` (n, success_rate, mean_reward, mean_length, max/min reward)
429
+
430
+ ### `training/evaluation.py`
431
+
432
+ `EvaluationSuite` is a stateless class with four families of `@staticmethod` methods:
433
+
434
+ | Family | Method | Metrics |
435
+ |---|---|---|
436
+ | Online RL | `online_metrics(trajectories)` | `mean_return`, `median_return`, `std_return`, `mean_episode_length`, `success_rate` |
437
+ | Offline benchmark | `benchmark_metrics(dataset)` | `pipeline_validity_rate`, `ordering_score`, `action_diversity`, `mean_conclusion_confidence` |
438
+ | Expert review | `expert_review_metrics(...)` | Placeholder; averages provided scores |
439
+ | Simulator fidelity | `simulator_fidelity_metrics(sim, real)` | `reward_distribution_gap` |
440
+
441
+ ### `training/literature_benchmark.py`
442
+
443
+ `run_paper_benchmark(problem_statement, scenario_name, domain_randomise)` runs a paper-aligned action pipeline and scores against `expected_findings` using keyword matching. Returns a `PaperBenchmarkResult` with `match_ratio`.
444
+
445
+ ## Docker deployment
446
+
447
+ The server ships with a `server/Dockerfile`. It uses a multi-stage build based on `openenv-base`, installs dependencies via `uv`, and starts `uvicorn server.app:app` on port 8000.
448
+
449
+ ```bash
450
+ docker build -f server/Dockerfile -t bio-experiment-env .
451
+ docker run -p 8000:8000 bio-experiment-env
452
+ ```
453
+
454
+ The `openenv.yaml` file configures the deployment for the OpenEnv platform:
455
+
456
+ ```yaml
457
+ spec_version: 1
458
+ name: hackathon
459
+ type: space
460
+ runtime: fastapi
461
+ app: server.app:app
462
+ port: 8000
463
+ ```
464
 
465
  ## Why this is useful
466
 
 
479
  - literature-grounded evaluation
480
  - comparing structured policies against LLM-driven planners
481
 
482
+ ## Project map
483
 
484
  ```text
485
  .
486
+ ├── client.py # OpenEnv HTTP/WebSocket client
487
  ├── models.py # Shared action / observation / task schemas
488
+ ├── openenv.yaml # OpenEnv platform deployment config
489
+ ├── pyproject.toml # Package metadata and dependency groups
490
+ ├── run_agent.py # Single-episode interactive agent runner
491
  ├── server/
492
+ │ ├── app.py # FastAPI/OpenEnv server entry point
493
+ │ ├── Dockerfile # Multi-stage Docker build
494
  │ ├── hackathon_environment.py # Main environment orchestration
495
+ │ ├── requirements.txt # Minimal server dependencies
496
+ │ ├── rewards/
497
+ │ └── reward.py # Decomposed reward model
498
+ ── rules/
499
+ │ │ └── engine.py # Biological constraint checking
500
+ │ ├── simulator/
501
+ │ │ ├── latent_state.py # Hidden biological, technical, progress, resource state
502
+ │ │ ├── noise.py # Seeded stochastic noise model
503
+ │ │ ├── output_generator.py # Per-action simulated output generation
504
+ │ │ └── transition.py # State transition engine and ACTION_COSTS table
505
+ │ ├── subagents/ # Placeholder for future sub-agent integration
506
+ │ └── tasks/
507
+ │ ├── generator.py # TaskGenerator with domain randomisation
508
+ │ └── scenarios.py # SCENARIO_LIBRARY with 4 curated scenarios
509
  ├── training/
510
+ │ ├── evaluation.py # EvaluationSuite metrics
 
511
  │ ├── literature_benchmark.py # Paper-backed benchmark flow
512
+ ── rollout_collection.py # Direct rollout collection helper
513
+ └── trajectory.py # Trajectory serialization and dataset utilities
514
+ ├── training_script.py # TRL GRPO training entry point
515
+ └── tests/
516
+ ├── test_environment.py
517
+ ├── test_literature_benchmark.py
518
+ ├── test_models.py
519
+ ├── test_rewards.py
520
+ ├── test_rules.py
521
+ ├── test_simulator.py
522
+ └── test_training_script.py
523
  ```
524
 
525
  ## Quick sanity check
526
 
 
 
527
  ```bash
528
  uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
529
  ```
 
535
  - conclusion termination
536
  - literature-backed scenario selection
537
  - benchmark matching for curated expected findings
538
+
539
+ Run the full suite with coverage:
540
+
541
+ ```bash
542
+ uv run pytest tests/ --cov -q
543
+ ```
__init__.py CHANGED
@@ -1,48 +1,75 @@
1
  try: # pragma: no cover - package import path
2
  from .client import BioExperimentEnv
3
  from .models import (
 
4
  ActionType,
 
5
  ConclusionClaim,
6
  ExpectedFinding,
7
  ExperimentAction,
8
  ExperimentObservation,
9
  IntermediateOutput,
 
 
10
  OutputType,
11
  PaperReference,
12
  PipelineStepRecord,
13
  ResourceUsage,
14
  SubagentType,
 
15
  TaskSpec,
 
 
 
 
16
  )
17
  except ImportError: # pragma: no cover - direct module import path
18
  from client import BioExperimentEnv
19
  from models import (
 
20
  ActionType,
 
21
  ConclusionClaim,
22
  ExpectedFinding,
23
  ExperimentAction,
24
  ExperimentObservation,
25
  IntermediateOutput,
 
 
26
  OutputType,
27
  PaperReference,
28
  PipelineStepRecord,
29
  ResourceUsage,
30
  SubagentType,
 
31
  TaskSpec,
 
 
 
 
32
  )
33
 
34
  __all__ = [
 
35
  "ActionType",
 
36
  "BioExperimentEnv",
37
  "ConclusionClaim",
38
  "ExpectedFinding",
39
  "ExperimentAction",
40
  "ExperimentObservation",
41
  "IntermediateOutput",
 
 
42
  "OutputType",
43
  "PaperReference",
44
  "PipelineStepRecord",
45
  "ResourceUsage",
46
  "SubagentType",
 
47
  "TaskSpec",
 
 
 
 
48
  ]
 
1
  try: # pragma: no cover - package import path
2
  from .client import BioExperimentEnv
3
  from .models import (
4
+ ASSAY_REGISTRY,
5
  ActionType,
6
+ AssaySpec,
7
  ConclusionClaim,
8
  ExpectedFinding,
9
  ExperimentAction,
10
  ExperimentObservation,
11
  IntermediateOutput,
12
+ MODALITY_REGISTRY,
13
+ ModalitySpec,
14
  OutputType,
15
  PaperReference,
16
  PipelineStepRecord,
17
  ResourceUsage,
18
  SubagentType,
19
+ TOOL_REGISTRY,
20
  TaskSpec,
21
+ ToolSpec,
22
+ assays_for_modality,
23
+ tools_by_category,
24
+ tools_for_modality,
25
  )
26
  except ImportError: # pragma: no cover - direct module import path
27
  from client import BioExperimentEnv
28
  from models import (
29
+ ASSAY_REGISTRY,
30
  ActionType,
31
+ AssaySpec,
32
  ConclusionClaim,
33
  ExpectedFinding,
34
  ExperimentAction,
35
  ExperimentObservation,
36
  IntermediateOutput,
37
+ MODALITY_REGISTRY,
38
+ ModalitySpec,
39
  OutputType,
40
  PaperReference,
41
  PipelineStepRecord,
42
  ResourceUsage,
43
  SubagentType,
44
+ TOOL_REGISTRY,
45
  TaskSpec,
46
+ ToolSpec,
47
+ assays_for_modality,
48
+ tools_by_category,
49
+ tools_for_modality,
50
  )
51
 
52
  __all__ = [
53
+ "ASSAY_REGISTRY",
54
  "ActionType",
55
+ "AssaySpec",
56
  "BioExperimentEnv",
57
  "ConclusionClaim",
58
  "ExpectedFinding",
59
  "ExperimentAction",
60
  "ExperimentObservation",
61
  "IntermediateOutput",
62
+ "MODALITY_REGISTRY",
63
+ "ModalitySpec",
64
  "OutputType",
65
  "PaperReference",
66
  "PipelineStepRecord",
67
  "ResourceUsage",
68
  "SubagentType",
69
+ "TOOL_REGISTRY",
70
  "TaskSpec",
71
+ "ToolSpec",
72
+ "assays_for_modality",
73
+ "tools_by_category",
74
+ "tools_for_modality",
75
  ]
_dashboard_state.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": 1772961529.009918,
3
+ "step": 4,
4
+ "episode_done": false,
5
+ "cumulative_reward": 1.8781281693998408,
6
+ "gen_time_s": 15.02,
7
+ "model_response_raw": "{\n\"action_type\": \"prepare_library\",\n\"method\": \"\",\n\"parameters\": {\n\"cutoff_value\": \"\"\n},\n\"Justification\": \"The error 'Cannot sequence without library preparation' indicates that while sample collection was completed successfully via dry-polymerization technology ('collectible'), mature cDNA libraries have yet to be constructed from those frozen tissues. Library prep requires reverse transcription followed by amplification/cleanup using polymerases/ligation enzymes. Since all other required conditions met but materials remain unprocessed, preparing high-throughput Illumina libraries first allows",
8
+ "model_thinking": "",
9
+ "thinking_enabled": false,
10
+ "task": {
11
+ "problem_statement": "Determine the effect of anti-TNF biologic therapy on cell states in synovium tissue affected by rheumatoid arthritis.",
12
+ "organism": "human",
13
+ "tissue": "synovium",
14
+ "modality": "scRNA-seq",
15
+ "conditions": [
16
+ "untreated_rheumatoid_arthritis",
17
+ "anti_TNF_treated"
18
+ ],
19
+ "budget_limit": 80340.57829173621,
20
+ "time_limit_days": 140.67464277202689
21
+ },
22
+ "resources": {
23
+ "budget_used": 13500.0,
24
+ "budget_remaining": 66840.58,
25
+ "time_used_days": 11.0,
26
+ "time_remaining_days": 129.7,
27
+ "samples_consumed": 0,
28
+ "compute_hours_used": 0.0
29
+ },
30
+ "pipeline_history": [
31
+ {
32
+ "step_index": 1,
33
+ "action_type": "collect_sample",
34
+ "method": null,
35
+ "output_summary": "Collected 6 samples (quality=0.54)",
36
+ "success": true,
37
+ "quality_score": 0.541,
38
+ "resource_cost": 5000.0,
39
+ "time_cost_days": 7.0
40
+ },
41
+ {
42
+ "step_index": 2,
43
+ "action_type": "select_cohort",
44
+ "method": null,
45
+ "output_summary": "Selected cohort of 4 samples with criteria {}",
46
+ "success": true,
47
+ "quality_score": 1.0,
48
+ "resource_cost": 500.0,
49
+ "time_cost_days": 1.0
50
+ },
51
+ {
52
+ "step_index": 3,
53
+ "action_type": "sequence_cells",
54
+ "method": null,
55
+ "output_summary": "Action blocked: Cannot sequence without library preparation",
56
+ "success": false,
57
+ "quality_score": 1.0,
58
+ "resource_cost": 15000.0,
59
+ "time_cost_days": 5.0
60
+ },
61
+ {
62
+ "step_index": 4,
63
+ "action_type": "prepare_library",
64
+ "method": null,
65
+ "output_summary": "Library prepared (complexity=0.72)",
66
+ "success": true,
67
+ "quality_score": 0.718,
68
+ "resource_cost": 8000.0,
69
+ "time_cost_days": 3.0
70
+ }
71
+ ],
72
+ "current_action": {
73
+ "action_type": "prepare_library",
74
+ "method": null,
75
+ "parameters": {
76
+ "cutoff_value": ""
77
+ },
78
+ "justification": "The error 'Cannot sequence without library preparation' indicates that while sample collection was completed successfully via dry-polymerization technology ('collectible'), mature cDNA libraries have yet to be constructed from those frozen tissues. Library prep requires reverse transcription followed by amplification/cleanup using polymerases/ligation enzymes. Since all other required conditions met but materials remain unprocessed, preparing high-throughput Illumina libraries first allows conversion of input DNA/RNA directly into sequencable formats.",
79
+ "confidence": 0.95
80
+ },
81
+ "latest_output": {
82
+ "summary": "Library prepared (complexity=0.72)",
83
+ "success": true,
84
+ "quality_score": 0.718,
85
+ "uncertainty": 0.0,
86
+ "warnings": [],
87
+ "data_preview": "{\"library_complexity\": 0.7177962559040082, \"method\": \"10x_chromium\"}"
88
+ },
89
+ "discovered_markers": [],
90
+ "candidate_mechanisms": [],
91
+ "rule_violations": [],
92
+ "uncertainty_summary": {
93
+ "avg_uncertainty": 0.0,
94
+ "avg_quality": 0.815
95
+ },
96
+ "reward_breakdown": {},
97
+ "latent": {
98
+ "cell_populations": [
99
+ {
100
+ "name": "fibroblast",
101
+ "proportion": 0.445,
102
+ "marker_genes": [
103
+ "COL1A1",
104
+ "FAP",
105
+ "THY1"
106
+ ],
107
+ "state": "activated"
108
+ },
109
+ {
110
+ "name": "CD4_T_cell",
111
+ "proportion": 0.179,
112
+ "marker_genes": [
113
+ "CD3D",
114
+ "CD4",
115
+ "IL7R"
116
+ ],
117
+ "state": "quiescent"
118
+ },
119
+ {
120
+ "name": "CD8_T_cell",
121
+ "proportion": 0.139,
122
+ "marker_genes": [
123
+ "CD3D",
124
+ "CD8A",
125
+ "GZMB"
126
+ ],
127
+ "state": "activated"
128
+ },
129
+ {
130
+ "name": "B_cell",
131
+ "proportion": 0.142,
132
+ "marker_genes": [
133
+ "CD19",
134
+ "MS4A1",
135
+ "CD79A"
136
+ ],
137
+ "state": "quiescent"
138
+ },
139
+ {
140
+ "name": "endothelial",
141
+ "proportion": 0.096,
142
+ "marker_genes": [
143
+ "PECAM1",
144
+ "VWF"
145
+ ],
146
+ "state": "quiescent"
147
+ }
148
+ ],
149
+ "true_markers": [
150
+ "TNF",
151
+ "IL6",
152
+ "MMP3",
153
+ "CXCL13"
154
+ ],
155
+ "causal_mechanisms": [
156
+ "TNF/NF-kB-driven synovial inflammation",
157
+ "Th17-mediated cartilage destruction via MMPs"
158
+ ],
159
+ "true_pathways": {
160
+ "JAK_STAT_signalling": 0.785,
161
+ "TNF_signalling": 0.723,
162
+ "Th17_differentiation": 0.633,
163
+ "NF_kB_signalling": 0.826,
164
+ "matrix_metalloproteinase_activity": 0.847
165
+ },
166
+ "true_de_genes_count": 9,
167
+ "true_regulatory_network_size": 16,
168
+ "confounders": {},
169
+ "n_true_cells": 15873,
170
+ "technical": {
171
+ "ambient_rna_fraction": 0.037873267501661645,
172
+ "doublet_rate": 0.03797665930677535,
173
+ "dropout_rate": 0.14738025069803395,
174
+ "sample_quality": 0.9068064354870293,
175
+ "library_complexity": 0.8,
176
+ "capture_efficiency": 0.6
177
+ },
178
+ "progress": {
179
+ "samples_collected": true,
180
+ "cohort_selected": true,
181
+ "cells_cultured": false,
182
+ "library_prepared": true,
183
+ "perturbation_applied": false,
184
+ "cells_sequenced": false,
185
+ "qc_performed": false,
186
+ "data_filtered": false,
187
+ "data_normalized": false,
188
+ "batches_integrated": false,
189
+ "cells_clustered": false,
190
+ "de_performed": false,
191
+ "trajectories_inferred": false,
192
+ "pathways_analyzed": false,
193
+ "networks_inferred": false,
194
+ "markers_discovered": false,
195
+ "markers_validated": false,
196
+ "followup_designed": false,
197
+ "subagent_review_requested": false,
198
+ "conclusion_reached": false,
199
+ "n_cells_sequenced": null,
200
+ "n_cells_after_filter": null,
201
+ "n_clusters_found": null,
202
+ "n_de_genes_found": null,
203
+ "n_markers_found": null
204
+ },
205
+ "hidden_failure_conditions": []
206
+ }
207
+ }
dashboard.html ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>Bio-Experiment Agent Dashboard</title>
7
+ <link rel="preconnect" href="https://fonts.googleapis.com" />
8
+ <link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=DM+Sans:wght@400;500;700&display=swap" rel="stylesheet" />
9
+ <style>
10
+ :root {
11
+ --bg: #0c0e14;
12
+ --surface: #151822;
13
+ --surface2: #1c2030;
14
+ --border: #2a2f42;
15
+ --text: #e2e4ea;
16
+ --text-dim: #8b90a5;
17
+ --accent: #5ce0d8;
18
+ --accent2: #7c6cf0;
19
+ --green: #4ade80;
20
+ --red: #f87171;
21
+ --amber: #fbbf24;
22
+ --blue: #60a5fa;
23
+ --pink: #f472b6;
24
+ }
25
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
26
+ body { background: var(--bg); color: var(--text); font-family: 'DM Sans', system-ui, sans-serif; line-height: 1.5; min-height: 100vh; }
27
+ .mono { font-family: 'JetBrains Mono', monospace; }
28
+
29
+ .header { display: flex; align-items: center; justify-content: space-between; padding: 14px 28px; border-bottom: 1px solid var(--border); background: var(--surface); }
30
+ .header h1 { font-size: 18px; font-weight: 700; letter-spacing: -.3px; }
31
+ .header h1 span { color: var(--accent); }
32
+ .header-right { display: flex; align-items: center; gap: 10px; }
33
+ .status-pill { font-size: 12px; padding: 4px 14px; border-radius: 20px; font-weight: 600; text-transform: uppercase; letter-spacing: .5px; }
34
+ .status-pill.live { background: rgba(76,222,128,.15); color: var(--green); }
35
+ .status-pill.done { background: rgba(248,113,113,.15); color: var(--red); }
36
+ .status-pill.waiting { background: rgba(139,144,165,.15); color: var(--text-dim); }
37
+
38
+ .btn { padding: 6px 16px; border-radius: 8px; border: 1px solid var(--border); background: var(--surface2); color: var(--text); font-size: 12px; font-weight: 600; cursor: pointer; transition: all .15s; }
39
+ .btn:hover { border-color: var(--accent); color: var(--accent); }
40
+ .btn.primary { background: rgba(92,224,216,.12); border-color: var(--accent); color: var(--accent); }
41
+ .btn.primary:hover { background: rgba(92,224,216,.25); }
42
+ .btn.danger { border-color: var(--red); color: var(--red); }
43
+ .btn.danger:hover { background: rgba(248,113,113,.12); }
44
+
45
+ .grid { display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 16px; padding: 20px 28px; max-width: 1600px; }
46
+ @media (max-width: 1100px) { .grid { grid-template-columns: 1fr 1fr; } }
47
+ @media (max-width: 700px) { .grid { grid-template-columns: 1fr; } }
48
+
49
+ .card { background: var(--surface); border: 1px solid var(--border); border-radius: 12px; padding: 18px 20px; overflow: hidden; }
50
+ .card h2 { font-size: 11px; font-weight: 600; text-transform: uppercase; letter-spacing: 1px; color: var(--text-dim); margin-bottom: 12px; }
51
+ .card.span2 { grid-column: span 2; }
52
+ .card.span3 { grid-column: span 3; }
53
+ @media (max-width: 700px) { .card.span2, .card.span3 { grid-column: span 1; } }
54
+
55
+ .gauge-row { display: flex; gap: 14px; flex-wrap: wrap; }
56
+ .gauge { flex: 1; min-width: 130px; background: var(--surface2); border-radius: 10px; padding: 14px; }
57
+ .gauge-label { font-size: 11px; color: var(--text-dim); margin-bottom: 6px; text-transform: uppercase; letter-spacing: .5px; }
58
+ .gauge-value { font-size: 22px; font-weight: 700; }
59
+ .gauge-bar { height: 5px; border-radius: 3px; background: var(--border); margin-top: 8px; overflow: hidden; }
60
+ .gauge-bar-fill { height: 100%; border-radius: 3px; transition: width .6s ease; }
61
+
62
+ .timeline { position: relative; padding-left: 20px; }
63
+ .timeline::before { content: ''; position: absolute; left: 6px; top: 0; bottom: 0; width: 2px; background: var(--border); }
64
+ .timeline-item { position: relative; margin-bottom: 14px; padding-left: 18px; }
65
+ .timeline-item::before { content: ''; position: absolute; left: -18px; top: 6px; width: 10px; height: 10px; border-radius: 50%; border: 2px solid var(--accent); background: var(--bg); }
66
+ .timeline-item.fail::before { border-color: var(--red); }
67
+ .tl-action { font-weight: 600; font-size: 14px; }
68
+ .tl-meta { font-size: 12px; color: var(--text-dim); margin-top: 2px; }
69
+
70
+ .mini-table { width: 100%; font-size: 13px; border-collapse: collapse; }
71
+ .mini-table td { padding: 5px 8px; border-bottom: 1px solid var(--border); vertical-align: top; }
72
+ .mini-table td:first-child { color: var(--text-dim); white-space: nowrap; width: 40%; }
73
+
74
+ .tag-list { display: flex; flex-wrap: wrap; gap: 6px; }
75
+ .tag { font-size: 12px; padding: 3px 10px; border-radius: 6px; background: var(--surface2); border: 1px solid var(--border); font-family: 'JetBrains Mono', monospace; }
76
+ .tag.green { border-color: rgba(76,222,128,.3); color: var(--green); }
77
+ .tag.pink { border-color: rgba(244,114,182,.3); color: var(--pink); }
78
+ .tag.amber { border-color: rgba(251,191,36,.3); color: var(--amber); }
79
+ .tag.red { border-color: rgba(248,113,113,.3); color: var(--red); }
80
+ .tag.match { background: rgba(76,222,128,.15); }
81
+ .tag.miss { background: rgba(248,113,113,.08); }
82
+
83
+ .code-block { background: var(--surface2); border: 1px solid var(--border); border-radius: 8px; padding: 12px 14px; font-family: 'JetBrains Mono', monospace; font-size: 12px; white-space: pre-wrap; word-break: break-all; max-height: 220px; overflow-y: auto; color: var(--text-dim); line-height: 1.6; }
84
+
85
+ .progress-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 6px; }
86
+ .progress-item { display: flex; align-items: center; gap: 6px; font-size: 12px; }
87
+ .dot { width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0; background: var(--border); }
88
+ .dot.done { background: var(--green); }
89
+
90
+ .pop-bar-container { margin-bottom: 10px; }
91
+ .pop-bar-label { font-size: 12px; margin-bottom: 3px; display: flex; justify-content: space-between; }
92
+ .pop-bar { height: 14px; border-radius: 4px; background: var(--surface2); overflow: hidden; }
93
+ .pop-bar-fill { height: 100%; border-radius: 4px; }
94
+
95
+ #reward-chart { width: 100%; height: 120px; }
96
+ ::-webkit-scrollbar { width: 6px; }
97
+ ::-webkit-scrollbar-track { background: transparent; }
98
+ ::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
99
+
100
+ .conclusion-card { background: var(--surface2); border: 1px solid var(--border); border-radius: 10px; padding: 14px 16px; margin-bottom: 12px; }
101
+ .conclusion-card .cc-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px; }
102
+ .cc-type { font-size: 11px; padding: 2px 10px; border-radius: 4px; font-weight: 600; text-transform: uppercase; letter-spacing: .5px; }
103
+ .cc-type.causal { background: rgba(244,114,182,.15); color: var(--pink); }
104
+ .cc-type.correlative { background: rgba(96,165,250,.15); color: var(--blue); }
105
+ .cc-type.descriptive { background: rgba(139,144,165,.15); color: var(--text-dim); }
106
+ .cc-conf { font-family: 'JetBrains Mono', monospace; font-size: 13px; font-weight: 600; }
107
+ .cc-claim { font-size: 14px; margin-bottom: 8px; line-height: 1.5; }
108
+ .cc-section-label { font-size: 10px; color: var(--text-dim); text-transform: uppercase; letter-spacing: .5px; margin-bottom: 3px; margin-top: 8px; }
109
+
110
+ /* ── control panel ────────────────────────────── */
111
+ .control-panel { background: var(--surface); border: 1px solid var(--border); border-radius: 12px; margin: 20px 28px 0; padding: 18px 20px; }
112
+ .control-panel summary { cursor: pointer; font-size: 13px; font-weight: 600; color: var(--accent); }
113
+ .control-panel[open] summary { margin-bottom: 14px; }
114
+ .form-row { display: flex; gap: 12px; margin-bottom: 10px; flex-wrap: wrap; align-items: end; }
115
+ .form-field { display: flex; flex-direction: column; gap: 4px; }
116
+ .form-field label { font-size: 11px; color: var(--text-dim); text-transform: uppercase; letter-spacing: .5px; }
117
+ .form-field input, .form-field textarea, .form-field select {
118
+ background: var(--surface2); border: 1px solid var(--border); border-radius: 6px;
119
+ color: var(--text); padding: 7px 10px; font-size: 13px; font-family: inherit; outline: none;
120
+ }
121
+ .form-field input:focus, .form-field textarea:focus, .form-field select:focus { border-color: var(--accent); }
122
+ .form-field textarea { min-height: 60px; resize: vertical; }
123
+
124
+ /* ── final report ─────────────────────────────── */
125
+ .report-overlay { display: none; position: fixed; inset: 0; z-index: 100; background: rgba(12,14,20,.85); backdrop-filter: blur(6px); overflow-y: auto; padding: 40px 20px; }
126
+ .report-overlay.visible { display: flex; justify-content: center; align-items: flex-start; }
127
+ .report-card { background: var(--surface); border: 1px solid var(--border); border-radius: 16px; padding: 32px 36px; max-width: 900px; width: 100%; }
128
+ .report-card h2 { font-size: 22px; font-weight: 700; margin-bottom: 4px; color: var(--text); text-transform: none; letter-spacing: normal; }
129
+ .report-card .subtitle { font-size: 13px; color: var(--text-dim); margin-bottom: 20px; }
130
+ .report-section { margin-bottom: 20px; }
131
+ .report-section h3 { font-size: 12px; color: var(--accent); text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px; }
132
+ .comparison-row { display: flex; gap: 20px; margin-bottom: 16px; }
133
+ .comparison-col { flex: 1; }
134
+ .comparison-col h4 { font-size: 11px; color: var(--text-dim); text-transform: uppercase; margin-bottom: 6px; }
135
+
136
+ .pulse { animation: pulse 1.5s ease-in-out infinite; }
137
+ @keyframes pulse { 0%,100% { opacity: 1; } 50% { opacity: .5; } }
138
+ </style>
139
+ </head>
140
+ <body>
141
+
142
+ <div class="header">
143
+ <h1><span>BioExp</span> Agent Dashboard</h1>
144
+ <div class="header-right">
145
+ <span id="thinking-badge" class="mono" style="font-size:11px;color:var(--accent2);display:none">REASONING ON</span>
146
+ <span id="step-label" class="mono" style="font-size:13px;color:var(--text-dim)">Step 0</span>
147
+ <span id="status-pill" class="status-pill waiting">Waiting</span>
148
+ <button class="btn primary" onclick="doRestart()">Restart</button>
149
+ <button class="btn" onclick="showReport()">Report</button>
150
+ </div>
151
+ </div>
152
+
153
+ <!-- Control Panel (collapsible) -->
154
+ <details class="control-panel" id="control-panel">
155
+ <summary>New Task / Custom Ground Truth</summary>
156
+ <div class="form-row">
157
+ <div class="form-field" style="flex:2">
158
+ <label>Scenario (leave blank for random)</label>
159
+ <select id="f-scenario"><option value="">— random —</option></select>
160
+ </div>
161
+ <div class="form-field" style="flex:1">
162
+ <label>True Markers (comma-separated)</label>
163
+ <input id="f-markers" placeholder="e.g. MYH7, TNNT2, ACTA1" />
164
+ </div>
165
+ <div class="form-field" style="flex:1">
166
+ <label>Causal Mechanisms (comma-separated)</label>
167
+ <input id="f-mechanisms" placeholder="e.g. sarcomere dysfunction" />
168
+ </div>
169
+ </div>
170
+ <div class="form-row">
171
+ <div class="form-field" style="flex:2">
172
+ <label>True Pathways (name:score, comma-sep)</label>
173
+ <input id="f-pathways" placeholder="e.g. Wnt_signaling:0.8, MAPK:0.6" />
174
+ </div>
175
+ <div class="form-field">
176
+ <button class="btn primary" onclick="doCustomRun()">Run with Ground Truth</button>
177
+ </div>
178
+ </div>
179
+ </details>
180
+
181
+ <div class="grid">
182
+ <div class="card span2" id="card-task">
183
+ <h2>Task</h2>
184
+ <div id="task-statement" style="font-size:15px;font-weight:500;margin-bottom:8px;">—</div>
185
+ <div id="task-meta" style="font-size:13px;color:var(--text-dim)"></div>
186
+ </div>
187
+
188
+ <div class="card">
189
+ <h2>Reward</h2>
190
+ <div id="reward-value" class="mono" style="font-size:32px;font-weight:700;margin-bottom:6px;">0.000</div>
191
+ <canvas id="reward-chart"></canvas>
192
+ </div>
193
+
194
+ <div class="card span3"><h2>Resources</h2><div class="gauge-row" id="gauges"></div></div>
195
+
196
+ <div class="card span2" style="max-height:460px;overflow-y:auto">
197
+ <h2>Pipeline History <span style="color:var(--accent);font-size:10px">OBSERVABLE</span></h2>
198
+ <div class="timeline" id="timeline"></div>
199
+ </div>
200
+
201
+ <div class="card">
202
+ <h2>Current Action</h2>
203
+ <table class="mini-table" id="action-table"><tbody></tbody></table>
204
+ <h2 style="margin-top:14px" id="thinking-header" style="display:none">Model Reasoning</h2>
205
+ <div class="code-block" id="model-thinking" style="display:none;border-color:rgba(124,108,240,.2);max-height:140px;margin-bottom:10px">—</div>
206
+ <h2 style="margin-top:10px">Model Raw Output</h2>
207
+ <div class="code-block" id="model-response">—</div>
208
+ </div>
209
+
210
+ <div class="card">
211
+ <h2>Discovered Markers <span style="color:var(--accent);font-size:10px">OBSERVABLE</span></h2>
212
+ <div class="tag-list" id="markers-list"><span class="tag" style="color:var(--text-dim)">none yet</span></div>
213
+ <h2 style="margin-top:14px">Candidate Mechanisms</h2>
214
+ <div class="tag-list" id="mechanisms-list"><span class="tag" style="color:var(--text-dim)">none yet</span></div>
215
+ </div>
216
+
217
+ <div class="card">
218
+ <h2>Rule Violations</h2>
219
+ <div id="violations" style="font-size:13px;color:var(--text-dim)">None</div>
220
+ <h2 style="margin-top:14px">Uncertainty Summary</h2>
221
+ <table class="mini-table" id="uncertainty-table"><tbody></tbody></table>
222
+ <h2 style="margin-top:14px">Reward Breakdown</h2>
223
+ <table class="mini-table" id="reward-breakdown-table"><tbody></tbody></table>
224
+ </div>
225
+
226
+ <div class="card">
227
+ <h2>Latest Output</h2>
228
+ <table class="mini-table" id="output-table"><tbody></tbody></table>
229
+ <div class="code-block" id="output-data" style="margin-top:10px;max-height:140px">—</div>
230
+ </div>
231
+
232
+ <div class="card span3" id="card-conclusions" style="display:none;border-color:rgba(76,222,128,.25)">
233
+ <h2 style="color:var(--green)">Synthesized Conclusions</h2>
234
+ <div id="conclusions-list"></div>
235
+ </div>
236
+
237
+ <!-- Ground Truth Comparison (shown when episode done + has conclusions) -->
238
+ <div class="card span3" id="card-gt-comparison" style="display:none;border-color:rgba(251,191,36,.25)">
239
+ <h2 style="color:var(--amber)">Ground Truth Comparison</h2>
240
+ <div class="comparison-row">
241
+ <div class="comparison-col">
242
+ <h4>Agent's Markers</h4>
243
+ <div class="tag-list" id="gt-agent-markers"></div>
244
+ </div>
245
+ <div class="comparison-col">
246
+ <h4>True Markers</h4>
247
+ <div class="tag-list" id="gt-true-markers"></div>
248
+ </div>
249
+ </div>
250
+ <div class="comparison-row">
251
+ <div class="comparison-col">
252
+ <h4>Agent's Mechanisms</h4>
253
+ <div class="tag-list" id="gt-agent-mechs"></div>
254
+ </div>
255
+ <div class="comparison-col">
256
+ <h4>True Mechanisms</h4>
257
+ <div class="tag-list" id="gt-true-mechs"></div>
258
+ </div>
259
+ </div>
260
+ <div id="gt-score" style="margin-top:8px;font-size:14px;font-weight:600"></div>
261
+ </div>
262
+
263
+ <div class="card" style="border-color:rgba(124,108,240,.25)">
264
+ <h2 style="color:var(--accent2)">Cell Populations <span style="font-size:10px">HIDDEN</span></h2>
265
+ <div id="populations"></div>
266
+ </div>
267
+ <div class="card" style="border-color:rgba(124,108,240,.25)">
268
+ <h2 style="color:var(--accent2)">Ground Truth <span style="font-size:10px">HIDDEN</span></h2>
269
+ <div style="margin-bottom:8px"><span style="font-size:11px;color:var(--text-dim);text-transform:uppercase">True Markers</span><div class="tag-list" id="true-markers" style="margin-top:4px"></div></div>
270
+ <div style="margin-bottom:8px"><span style="font-size:11px;color:var(--text-dim);text-transform:uppercase">Causal Mechanisms</span><div class="tag-list" id="true-mechanisms" style="margin-top:4px"></div></div>
271
+ <div><span style="font-size:11px;color:var(--text-dim);text-transform:uppercase">Top Pathways</span><table class="mini-table" id="pathways-table" style="margin-top:4px"><tbody></tbody></table></div>
272
+ </div>
273
+ <div class="card" style="border-color:rgba(124,108,240,.25)">
274
+ <h2 style="color:var(--accent2)">Technical State <span style="font-size:10px">HIDDEN</span></h2>
275
+ <table class="mini-table" id="technical-table"><tbody></tbody></table>
276
+ <h2 style="margin-top:14px;color:var(--accent2)">Failure Conditions <span style="font-size:10px">HIDDEN</span></h2>
277
+ <div class="tag-list" id="failure-conditions"></div>
278
+ </div>
279
+ <div class="card span3" style="border-color:rgba(124,108,240,.25)">
280
+ <h2 style="color:var(--accent2)">Experiment Progress <span style="font-size:10px">HIDDEN</span></h2>
281
+ <div class="progress-grid" id="progress-grid"></div>
282
+ </div>
283
+ </div>
284
+
285
+ <!-- Final Report Overlay -->
286
+ <div class="report-overlay" id="report-overlay" onclick="if(event.target===this)hideReport()">
287
+ <div class="report-card" id="report-content"></div>
288
+ </div>
289
+
290
+ <script>
291
+ const POLL_MS = 1200;
292
+ const POP_COLORS = ['#5ce0d8','#7c6cf0','#f472b6','#60a5fa','#fbbf24','#4ade80','#f87171','#c084fc','#fb923c','#38bdf8'];
293
+ let rewardHistory = [];
294
+ let lastTimestamp = 0;
295
+ let latestState = null;
296
+
297
+ function $(id) { return document.getElementById(id); }
298
+ function setHTML(id, html) { $(id).innerHTML = html; }
299
+ function tagsHTML(arr, cls) {
300
+ if (!arr || !arr.length) return '<span class="tag" style="color:var(--text-dim)">—</span>';
301
+ return arr.map(t => `<span class="tag ${cls||''}">${esc(t)}</span>`).join('');
302
+ }
303
+ function esc(s) { if (s == null) return '—'; const d = document.createElement('div'); d.textContent = String(s); return d.innerHTML; }
304
+ function pct(used, total) { if (!total) return 0; return Math.min(100, Math.max(0, (used / total) * 100)); }
305
+ function gaugeColor(p) { return p < 50 ? 'var(--green)' : p < 80 ? 'var(--amber)' : 'var(--red)'; }
306
+ function fmt(n) { if (n == null) return '0'; return Number(n).toLocaleString('en-US', { maximumFractionDigits: 0 }); }
307
+ function gauge(label, value, pctVal, inv) {
308
+ let bar = '';
309
+ if (pctVal != null) { const c = inv ? gaugeColor(100-pctVal) : gaugeColor(pctVal); bar = `<div class="gauge-bar"><div class="gauge-bar-fill" style="width:${pctVal.toFixed(1)}%;background:${c}"></div></div>`; }
310
+ return `<div class="gauge"><div class="gauge-label">${label}</div><div class="gauge-value mono">${value}</div>${bar}</div>`;
311
+ }
312
+ function miniRows(obj) { return Object.entries(obj).map(([k,v]) => `<tr><td>${esc(k)}</td><td>${esc(v)}</td></tr>`).join(''); }
313
+
314
+ function drawRewardChart(canvas, data) {
315
+ const ctx = canvas.getContext('2d'); const W = canvas.width = canvas.offsetWidth * 2; const H = canvas.height = canvas.offsetHeight * 2;
316
+ ctx.clearRect(0, 0, W, H); if (data.length < 2) return;
317
+ const vals = data.map(d => d.v); const minV = Math.min(0, ...vals); const maxV = Math.max(0.1, ...vals); const range = maxV - minV || 1; const pad = 8;
318
+ ctx.strokeStyle = 'rgba(92,224,216,.4)'; ctx.lineWidth = 2; ctx.beginPath();
319
+ const yZ = H - pad - ((0 - minV) / range) * (H - 2*pad); ctx.moveTo(pad, yZ); ctx.lineTo(W-pad, yZ); ctx.stroke();
320
+ ctx.strokeStyle = '#5ce0d8'; ctx.lineWidth = 3; ctx.beginPath();
321
+ data.forEach((d,i) => { const x = pad+(i/(data.length-1))*(W-2*pad); const y = H-pad-((d.v-minV)/range)*(H-2*pad); i===0?ctx.moveTo(x,y):ctx.lineTo(x,y); }); ctx.stroke();
322
+ data.forEach((d,i) => { const x = pad+(i/(data.length-1))*(W-2*pad); const y = H-pad-((d.v-minV)/range)*(H-2*pad); ctx.fillStyle = d.v>=0?'#4ade80':'#f87171'; ctx.beginPath(); ctx.arc(x,y,5,0,Math.PI*2); ctx.fill(); });
323
+ }
324
+
325
+ function comparedTags(agentArr, trueArr, cls) {
326
+ if (!agentArr || !agentArr.length) return '<span class="tag" style="color:var(--text-dim)">—</span>';
327
+ const trueSet = new Set((trueArr||[]).map(t => t.toUpperCase()));
328
+ return agentArr.map(t => {
329
+ const hit = trueSet.has(t.toUpperCase());
330
+ return `<span class="tag ${cls} ${hit?'match':'miss'}">${esc(t)} ${hit?'✓':'✗'}</span>`;
331
+ }).join('');
332
+ }
333
+
334
+ // ── API actions ──
335
+ async function doRestart() {
336
+ rewardHistory = []; lastTimestamp = 0;
337
+ await fetch('/api/restart', { method: 'POST' });
338
+ }
339
+
340
+ async function doCustomRun() {
341
+ const scenario = $('f-scenario').value || undefined;
342
+ const markers = $('f-markers').value.split(',').map(s=>s.trim()).filter(Boolean);
343
+ const mechs = $('f-mechanisms').value.split(',').map(s=>s.trim()).filter(Boolean);
344
+ const pwRaw = $('f-pathways').value.split(',').map(s=>s.trim()).filter(Boolean);
345
+ const pathways = {};
346
+ pwRaw.forEach(p => { const [k,v] = p.split(':'); if (k && v) pathways[k.trim()] = parseFloat(v); });
347
+ const gt = {};
348
+ if (markers.length) gt.true_markers = markers;
349
+ if (mechs.length) gt.causal_mechanisms = mechs;
350
+ if (Object.keys(pathways).length) gt.true_pathways = pathways;
351
+ rewardHistory = []; lastTimestamp = 0;
352
+ await fetch('/api/run', { method: 'POST', headers: {'Content-Type':'application/json'}, body: JSON.stringify({ scenario_name: scenario, ground_truth: Object.keys(gt).length ? gt : undefined }) });
353
+ }
354
+
355
+ function showReport() {
356
+ const s = latestState; if (!s) return;
357
+ const rc = $('report-content');
358
+ const t = s.task || {};
359
+ const lat = s.latent || {};
360
+ const conc = s.conclusions || [];
361
+ const trueM = lat.true_markers || [];
362
+ const trueMech = lat.causal_mechanisms || [];
363
+ const agentM = s.discovered_markers || [];
364
+ const markerHits = agentM.filter(m => trueM.some(t => t.toUpperCase() === m.toUpperCase()));
365
+ const r = s.resources || {};
366
+
367
+ let html = `<h2>Experiment Report</h2>
368
+ <div class="subtitle">${esc(t.problem_statement)}</div>
369
+ <div class="report-section"><h3>Summary</h3>
370
+ <table class="mini-table"><tbody>
371
+ <tr><td>Status</td><td>${s.episode_done ? 'Completed' : 'In Progress'}</td></tr>
372
+ <tr><td>Steps</td><td>${s.step}</td></tr>
373
+ <tr><td>Cumulative Reward</td><td style="color:${(s.cumulative_reward||0)>=0?'var(--green)':'var(--red)'}">${((s.cumulative_reward||0)>=0?'+':'')}${(s.cumulative_reward||0).toFixed(3)}</td></tr>
374
+ <tr><td>Budget Used</td><td>$${fmt(r.budget_used)} / $${fmt((r.budget_used||0)+(r.budget_remaining||0))}</td></tr>
375
+ <tr><td>Time Used</td><td>${(r.time_used_days||0).toFixed(0)}d / ${((r.time_used_days||0)+(r.time_remaining_days||0)).toFixed(0)}d</td></tr>
376
+ <tr><td>Markers Found</td><td>${agentM.length} (${markerHits.length} match ground truth)</td></tr>
377
+ </tbody></table>
378
+ </div>`;
379
+
380
+ if (conc.length) {
381
+ html += `<div class="report-section"><h3>Conclusions</h3>`;
382
+ conc.forEach(c => {
383
+ html += `<div class="conclusion-card"><div class="cc-header"><span class="cc-type ${(c.claim_type||'').toLowerCase()}">${esc(c.claim_type)}</span><span class="cc-conf" style="color:${c.confidence>=.7?'var(--green)':c.confidence>=.4?'var(--amber)':'var(--red)'}">${((c.confidence||0)*100).toFixed(0)}%</span></div>`;
384
+ if (c.claim) html += `<div class="cc-claim">${esc(c.claim)}</div>`;
385
+ if (c.top_markers?.length) html += `<div class="cc-section-label">Top Markers</div><div class="tag-list">${c.top_markers.map(m=>`<span class="tag green">${esc(m)}</span>`).join('')}</div>`;
386
+ if (c.causal_mechanisms?.length) html += `<div class="cc-section-label">Causal Mechanisms</div><div class="tag-list">${c.causal_mechanisms.map(m=>`<span class="tag pink">${esc(m)}</span>`).join('')}</div>`;
387
+ if (c.predicted_pathways && Object.keys(c.predicted_pathways).length) html += `<div class="cc-section-label">Predicted Pathways</div><table class="mini-table"><tbody>${Object.entries(c.predicted_pathways).map(([k,v])=>`<tr><td>${esc(k)}</td><td>${Number(v).toFixed(3)}</td></tr>`).join('')}</tbody></table>`;
388
+ html += `</div>`;
389
+ });
390
+ html += `</div>`;
391
+ }
392
+
393
+ html += `<div class="report-section"><h3>Ground Truth Comparison</h3>
394
+ <div class="comparison-row"><div class="comparison-col"><h4>Agent's Markers</h4><div class="tag-list">${comparedTags(agentM, trueM, 'green')}</div></div>
395
+ <div class="comparison-col"><h4>True Markers</h4><div class="tag-list">${tagsHTML(trueM,'green')}</div></div></div>
396
+ <div class="comparison-row"><div class="comparison-col"><h4>Agent's Mechanisms</h4><div class="tag-list">${comparedTags(s.candidate_mechanisms, trueMech, 'pink')}</div></div>
397
+ <div class="comparison-col"><h4>True Mechanisms</h4><div class="tag-list">${tagsHTML(trueMech,'pink')}</div></div></div>
398
+ </div>`;
399
+
400
+ const hist = s.pipeline_history || [];
401
+ if (hist.length) {
402
+ html += `<div class="report-section"><h3>Pipeline Steps</h3><table class="mini-table"><tbody>`;
403
+ hist.forEach(h => { html += `<tr><td>${h.success?'✓':'✗'} ${esc(h.action_type)}</td><td>${esc(h.output_summary)} · q=${h.quality_score}</td></tr>`; });
404
+ html += `</tbody></table></div>`;
405
+ }
406
+
407
+ html += `<div style="margin-top:20px;text-align:right"><button class="btn" onclick="hideReport()">Close</button> <button class="btn primary" onclick="doRestart();hideReport()">New Run</button></div>`;
408
+ rc.innerHTML = html;
409
+ $('report-overlay').classList.add('visible');
410
+ }
411
+ function hideReport() { $('report-overlay').classList.remove('visible'); }
412
+
413
+ function renderState(s) {
414
+ latestState = s;
415
+ if (s.error) { $('status-pill').className='status-pill waiting'; $('status-pill').textContent='Waiting'; $('task-statement').textContent=s.error; return; }
416
+
417
+ const pill = $('status-pill');
418
+ if (s.episode_done) { pill.className='status-pill done'; pill.textContent='Done'; } else { pill.className='status-pill live'; pill.textContent='Live'; }
419
+ $('step-label').textContent = `Step ${s.step}`;
420
+
421
+ if (s.thinking_enabled) { $('thinking-badge').style.display = ''; } else { $('thinking-badge').style.display = 'none'; }
422
+
423
+ const t = s.task || {};
424
+ $('task-statement').textContent = t.problem_statement || '—';
425
+ $('task-meta').innerHTML = [t.organism, t.tissue, t.modality, t.conditions ? t.conditions.join(' vs ') : null].filter(Boolean).map(v => `<span class="tag">${esc(v)}</span>`).join(' ');
426
+
427
+ const cum = s.cumulative_reward || 0;
428
+ $('reward-value').textContent = (cum >= 0 ? '+' : '') + cum.toFixed(3);
429
+ $('reward-value').style.color = cum >= 0 ? 'var(--green)' : 'var(--red)';
430
+ if (s.timestamp !== lastTimestamp && s.step > 0) { rewardHistory.push({ step: s.step, v: cum }); lastTimestamp = s.timestamp; }
431
+ drawRewardChart($('reward-chart'), rewardHistory);
432
+
433
+ const r = s.resources || {};
434
+ const bT = (r.budget_used||0)+(r.budget_remaining||0), tT = (r.time_used_days||0)+(r.time_remaining_days||0);
435
+ const bP = pct(r.budget_used, bT), tP = pct(r.time_used_days, tT);
436
+ $('gauges').innerHTML = [gauge('Budget Used',`$${fmt(r.budget_used)}`,bP), gauge('Budget Left',`$${fmt(r.budget_remaining)}`,100-bP,true), gauge('Time Used',`${(r.time_used_days||0).toFixed(0)}d`,tP), gauge('Time Left',`${(r.time_remaining_days||0).toFixed(0)}d`,100-tP,true), gauge('Samples',String(r.samples_consumed||0),null), gauge('Compute',`${(r.compute_hours_used||0).toFixed(1)}h`,null)].join('');
437
+
438
+ const hist = s.pipeline_history || [];
439
+ $('timeline').innerHTML = hist.length ? hist.map(h => `<div class="timeline-item ${!h.success?'fail':''}"><div class="tl-action">${esc(h.action_type)}${h.method?` <span style="color:var(--text-dim);font-weight:400;font-size:12px">${esc(h.method)}</span>`:''}</div><div class="tl-meta">${h.success?'✓':'✗'} ${esc(h.output_summary)} · q=${h.quality_score} · $${fmt(h.resource_cost)} · ${h.time_cost_days}d</div></div>`).join('') : '<div style="color:var(--text-dim);font-size:13px">No steps yet</div>';
440
+
441
+ const a = s.current_action;
442
+ if (a) { $('action-table').querySelector('tbody').innerHTML = miniRows({'Type':a.action_type,'Method':a.method||'—','Confidence':a.confidence?.toFixed(2),'Justification':a.justification||'—','Fallback?':s.used_fallback?'YES':'no'}); }
443
+
444
+ if (s.model_thinking) { $('model-thinking').style.display=''; $('model-thinking').textContent = s.model_thinking; } else { $('model-thinking').style.display='none'; }
445
+ $('model-response').textContent = s.model_response_raw || '—';
446
+
447
+ setHTML('markers-list', tagsHTML(s.discovered_markers, 'green'));
448
+ setHTML('mechanisms-list', tagsHTML(s.candidate_mechanisms, 'pink'));
449
+
450
+ const v = s.rule_violations || [];
451
+ $('violations').innerHTML = v.length ? v.map(x=>`<div class="tag red" style="margin-bottom:4px">${esc(x)}</div>`).join('') : '<span style="color:var(--text-dim)">None</span>';
452
+ $('uncertainty-table').querySelector('tbody').innerHTML = miniRows(s.uncertainty_summary || {});
453
+ const rb = s.reward_breakdown || {};
454
+ $('reward-breakdown-table').querySelector('tbody').innerHTML = miniRows(Object.fromEntries(Object.entries(rb).map(([k,v])=>[k,(v>=0?'+':'')+v.toFixed(4)])));
455
+
456
+ const lo = s.latest_output;
457
+ if (lo) { $('output-table').querySelector('tbody').innerHTML = miniRows({'Summary':lo.summary,'Success':lo.success?'✓':'✗','Quality':lo.quality_score,'Uncertainty':lo.uncertainty,'Warnings':(lo.warnings||[]).join('; ')||'—'}); $('output-data').textContent = lo.data_preview||'—'; }
458
+
459
+ const conc = s.conclusions || [];
460
+ if (conc.length) {
461
+ $('card-conclusions').style.display = '';
462
+ $('conclusions-list').innerHTML = conc.map(c => {
463
+ const confColor = c.confidence>=.7?'var(--green)':c.confidence>=.4?'var(--amber)':'var(--red)';
464
+ let h = `<div class="conclusion-card"><div class="cc-header"><span class="cc-type ${(c.claim_type||'').toLowerCase()}">${esc(c.claim_type||'unknown')}</span><span class="cc-conf" style="color:${confColor}">${((c.confidence||0)*100).toFixed(0)}%</span></div>`;
465
+ if (c.claim) h += `<div class="cc-claim">${esc(c.claim)}</div>`;
466
+ if (c.top_markers?.length) h += `<div class="cc-section-label">Top Markers</div><div class="tag-list">${c.top_markers.map(m=>`<span class="tag green">${esc(m)}</span>`).join('')}</div>`;
467
+ if (c.causal_mechanisms?.length) h += `<div class="cc-section-label">Causal Mechanisms</div><div class="tag-list">${c.causal_mechanisms.map(m=>`<span class="tag pink">${esc(m)}</span>`).join('')}</div>`;
468
+ if (c.predicted_pathways && Object.keys(c.predicted_pathways).length) h += `<div class="cc-section-label">Predicted Pathways</div><table class="mini-table"><tbody>${Object.entries(c.predicted_pathways).map(([k,v])=>`<tr><td>${esc(k)}</td><td>${Number(v).toFixed(3)}</td></tr>`).join('')}</tbody></table>`;
469
+ return h + '</div>';
470
+ }).join('');
471
+ } else { $('card-conclusions').style.display = 'none'; }
472
+
473
+ // Ground truth comparison (visible when done or has conclusions)
474
+ const lat = s.latent;
475
+ if ((s.episode_done || conc.length) && lat) {
476
+ $('card-gt-comparison').style.display = '';
477
+ setHTML('gt-agent-markers', comparedTags(s.discovered_markers, lat.true_markers, 'green'));
478
+ setHTML('gt-true-markers', tagsHTML(lat.true_markers, 'green'));
479
+ setHTML('gt-agent-mechs', comparedTags(s.candidate_mechanisms, lat.causal_mechanisms, 'pink'));
480
+ setHTML('gt-true-mechs', tagsHTML(lat.causal_mechanisms, 'pink'));
481
+ const hits = (s.discovered_markers||[]).filter(m => (lat.true_markers||[]).some(t => t.toUpperCase()===m.toUpperCase()));
482
+ $('gt-score').innerHTML = `Marker accuracy: <span style="color:var(--accent)">${hits.length}</span> / ${(lat.true_markers||[]).length} true markers recovered`;
483
+ } else { $('card-gt-comparison').style.display = 'none'; }
484
+
485
+ if (!lat) return;
486
+ const pops = lat.cell_populations || [];
487
+ $('populations').innerHTML = pops.map((p,i) => { const c = POP_COLORS[i%POP_COLORS.length]; const w = (p.proportion*100).toFixed(1); return `<div class="pop-bar-container"><div class="pop-bar-label"><span>${esc(p.name)} <span style="color:var(--text-dim);font-size:11px">${p.state}</span></span><span class="mono" style="font-size:12px">${w}%</span></div><div class="pop-bar"><div class="pop-bar-fill" style="width:${w}%;background:${c}"></div></div><div class="tag-list" style="margin-top:3px">${p.marker_genes.map(g=>`<span class="tag" style="font-size:11px">${esc(g)}</span>`).join('')}</div></div>`; }).join('') || '<span style="color:var(--text-dim)">—</span>';
488
+
489
+ setHTML('true-markers', tagsHTML(lat.true_markers, 'green'));
490
+ setHTML('true-mechanisms', tagsHTML(lat.causal_mechanisms, 'pink'));
491
+ const pw = lat.true_pathways || {};
492
+ $('pathways-table').querySelector('tbody').innerHTML = miniRows(Object.fromEntries(Object.entries(pw).slice(0,10).map(([k,v])=>[k,v.toFixed(3)])));
493
+ $('technical-table').querySelector('tbody').innerHTML = miniRows(lat.technical || {});
494
+ setHTML('failure-conditions', tagsHTML(lat.hidden_failure_conditions, 'red'));
495
+ const prog = lat.progress || {};
496
+ const bK = Object.entries(prog).filter(([,v])=>typeof v==='boolean'), nK = Object.entries(prog).filter(([,v])=>typeof v!=='boolean');
497
+ $('progress-grid').innerHTML = bK.map(([k,v])=>`<div class="progress-item"><div class="dot ${v?'done':''}"></div>${k.replace(/_/g,' ')}</div>`).join('') + nK.map(([k,v])=>`<div class="progress-item" style="color:var(--accent)"><span class="mono" style="font-size:11px;margin-right:4px">${v??'—'}</span>${k.replace(/_/g,' ')}</div>`).join('');
498
+
499
+ if (s.episode_done && !reportShownForTimestamp && s.timestamp) { reportShownForTimestamp = s.timestamp; setTimeout(showReport, 800); }
500
+ }
501
+
502
+ let reportShownForTimestamp = null;
503
+
504
+ async function loadScenarios() {
505
+ try {
506
+ const res = await fetch('/api/scenarios');
507
+ const data = await res.json();
508
+ const sel = $('f-scenario');
509
+ (data.scenarios || []).forEach(n => { const o = document.createElement('option'); o.value = n; o.textContent = n; sel.appendChild(o); });
510
+ } catch(e) {}
511
+ }
512
+
513
+ async function poll() {
514
+ try { const res = await fetch('/api/state',{cache:'no-store'}); const data = await res.json(); renderState(data); } catch(e) {}
515
+ setTimeout(poll, POLL_MS);
516
+ }
517
+
518
+ loadScenarios();
519
+ poll();
520
+ </script>
521
+ </body>
522
+ </html>
dashboard.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Lightweight dashboard server for the bio-experiment agent.
2
+
3
+ No external dependencies — uses only the Python standard library.
4
+
5
+ Usage:
6
+ python dashboard.py # serves on http://localhost:8050
7
+ python dashboard.py --port 9000
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
15
+ from pathlib import Path
16
+
17
+ ROOT = Path(__file__).parent
18
+ STATE_FILE = ROOT / "_dashboard_state.json"
19
+ CMD_FILE = ROOT / "_dashboard_cmd.json"
20
+ DASHBOARD_HTML = ROOT / "dashboard.html"
21
+
22
+
23
+ class DashboardHandler(SimpleHTTPRequestHandler):
24
+ def do_GET(self):
25
+ if self.path == "/" or self.path == "/index.html":
26
+ self._serve_file(DASHBOARD_HTML, "text/html")
27
+ elif self.path == "/api/state":
28
+ self._serve_state()
29
+ elif self.path == "/api/scenarios":
30
+ self._serve_scenarios()
31
+ else:
32
+ self.send_error(404)
33
+
34
+ def do_POST(self):
35
+ if self.path == "/api/restart":
36
+ self._handle_command({"action": "restart"})
37
+ elif self.path == "/api/run":
38
+ body = self._read_body()
39
+ if body is None:
40
+ return
41
+ body["action"] = "restart"
42
+ self._handle_command(body)
43
+ else:
44
+ self.send_error(404)
45
+
46
+ def do_OPTIONS(self):
47
+ self.send_response(204)
48
+ self.send_header("Access-Control-Allow-Origin", "*")
49
+ self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
50
+ self.send_header("Access-Control-Allow-Headers", "Content-Type")
51
+ self.end_headers()
52
+
53
+ def _read_body(self):
54
+ length = int(self.headers.get("Content-Length", 0))
55
+ if length == 0:
56
+ return {}
57
+ raw = self.rfile.read(length)
58
+ try:
59
+ return json.loads(raw)
60
+ except json.JSONDecodeError:
61
+ self._json_response(400, {"error": "Invalid JSON"})
62
+ return None
63
+
64
+ def _handle_command(self, cmd: dict):
65
+ CMD_FILE.write_text(json.dumps(cmd), encoding="utf-8")
66
+ self._json_response(200, {"ok": True, "command": cmd.get("action")})
67
+
68
+ def _serve_state(self):
69
+ self.send_response(200)
70
+ self.send_header("Content-Type", "application/json")
71
+ self.send_header("Access-Control-Allow-Origin", "*")
72
+ self.send_header("Cache-Control", "no-cache")
73
+ self.end_headers()
74
+ try:
75
+ data = STATE_FILE.read_bytes()
76
+ except FileNotFoundError:
77
+ data = b'{"error": "No state file yet. Run run_agent.py to start an episode."}'
78
+ self.wfile.write(data)
79
+
80
+ def _serve_scenarios(self):
81
+ try:
82
+ from server.tasks.scenarios import SCENARIO_LIBRARY
83
+ names = [s.name for s in SCENARIO_LIBRARY]
84
+ except Exception:
85
+ names = []
86
+ self._json_response(200, {"scenarios": names})
87
+
88
+ def _serve_file(self, path: Path, content_type: str):
89
+ try:
90
+ body = path.read_bytes()
91
+ except FileNotFoundError:
92
+ self.send_error(404, f"{path.name} not found")
93
+ return
94
+ self.send_response(200)
95
+ self.send_header("Content-Type", content_type)
96
+ self.send_header("Content-Length", str(len(body)))
97
+ self.end_headers()
98
+ self.wfile.write(body)
99
+
100
+ def _json_response(self, code: int, obj: dict):
101
+ body = json.dumps(obj).encode()
102
+ self.send_response(code)
103
+ self.send_header("Content-Type", "application/json")
104
+ self.send_header("Access-Control-Allow-Origin", "*")
105
+ self.send_header("Content-Length", str(len(body)))
106
+ self.end_headers()
107
+ self.wfile.write(body)
108
+
109
+ def log_message(self, format, *args):
110
+ pass
111
+
112
+
113
+ def main():
114
+ parser = argparse.ArgumentParser(description="Bio-experiment dashboard server")
115
+ parser.add_argument("--port", type=int, default=8050)
116
+ args = parser.parse_args()
117
+
118
+ server = HTTPServer(("0.0.0.0", args.port), DashboardHandler)
119
+ print(f"Dashboard running at http://localhost:{args.port}")
120
+ print("Waiting for agent state from run_agent.py ...")
121
+ try:
122
+ server.serve_forever()
123
+ except KeyboardInterrupt:
124
+ print("\nShutting down.")
125
+ server.server_close()
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
demo.html ADDED
@@ -0,0 +1,1639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>BioEnv</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap');
9
+
10
+ :root {
11
+ --bg: #07090d;
12
+ --bg-surface: #0c0f16;
13
+ --bg-raised: #111827;
14
+ --bg-hover: #1a2235;
15
+ --border: #1e293b;
16
+ --border-active: #334155;
17
+ --text: #e2e8f0;
18
+ --text-dim: #94a3b8;
19
+ --text-muted: #475569;
20
+ --accent: #38bdf8;
21
+ --accent-dim: rgba(56,189,248,0.12);
22
+ --green: #34d399;
23
+ --green-dim: rgba(52,211,153,0.10);
24
+ --amber: #fbbf24;
25
+ --amber-dim: rgba(251,191,36,0.10);
26
+ --red: #f87171;
27
+ --red-dim: rgba(248,113,113,0.10);
28
+ --cyan: #22d3ee;
29
+ --cyan-dim: rgba(34,211,238,0.10);
30
+ --pink: #f472b6;
31
+ --purple: #a78bfa;
32
+ }
33
+
34
+ * { margin: 0; padding: 0; box-sizing: border-box; }
35
+ html, body { height: 100%; overflow: hidden; }
36
+
37
+ body {
38
+ font-family: 'Inter', -apple-system, sans-serif;
39
+ background: var(--bg);
40
+ color: var(--text);
41
+ display: flex;
42
+ flex-direction: column;
43
+ }
44
+
45
+ /* ---- Top Bar ---- */
46
+ .topbar {
47
+ height: 48px;
48
+ min-height: 48px;
49
+ background: var(--bg-surface);
50
+ border-bottom: 1px solid var(--border);
51
+ display: flex;
52
+ align-items: center;
53
+ padding: 0 20px;
54
+ gap: 16px;
55
+ z-index: 10;
56
+ }
57
+ .topbar-logo {
58
+ font-size: 15px;
59
+ font-weight: 800;
60
+ letter-spacing: -0.5px;
61
+ background: linear-gradient(135deg, #38bdf8, #22d3ee);
62
+ -webkit-background-clip: text;
63
+ -webkit-text-fill-color: transparent;
64
+ }
65
+ .topbar-sep { width: 1px; height: 20px; background: var(--border); }
66
+ .topbar-env {
67
+ font-size: 12px;
68
+ color: var(--text-dim);
69
+ font-family: 'JetBrains Mono', monospace;
70
+ }
71
+ .topbar-status {
72
+ display: flex;
73
+ align-items: center;
74
+ gap: 6px;
75
+ margin-left: auto;
76
+ font-size: 12px;
77
+ color: var(--text-dim);
78
+ }
79
+ .status-dot {
80
+ width: 7px; height: 7px;
81
+ border-radius: 50%;
82
+ background: var(--text-muted);
83
+ }
84
+ .status-dot.live {
85
+ background: var(--green);
86
+ box-shadow: 0 0 8px var(--green);
87
+ animation: pulse 2s infinite;
88
+ }
89
+ @keyframes pulse {
90
+ 0%, 100% { opacity: 1; }
91
+ 50% { opacity: 0.5; }
92
+ }
93
+ .topbar-btn {
94
+ font-size: 12px;
95
+ font-weight: 600;
96
+ padding: 6px 14px;
97
+ border-radius: 6px;
98
+ border: none;
99
+ cursor: pointer;
100
+ transition: all 0.15s;
101
+ font-family: inherit;
102
+ }
103
+ .btn-primary { background: var(--accent); color: #07090d; font-weight: 700; }
104
+ .btn-primary:hover { background: #7dd3fc; }
105
+ .btn-primary:disabled { opacity: 0.4; cursor: not-allowed; }
106
+ .btn-ghost {
107
+ background: transparent;
108
+ color: var(--text-dim);
109
+ border: 1px solid var(--border);
110
+ }
111
+ .btn-ghost:hover { background: var(--bg-hover); color: var(--text); }
112
+
113
+ /* ---- Main Layout ---- */
114
+ .main {
115
+ flex: 1;
116
+ display: grid;
117
+ grid-template-columns: 260px 1fr 340px;
118
+ overflow: hidden;
119
+ }
120
+
121
+ /* ---- Left Sidebar ---- */
122
+ .sidebar {
123
+ background: var(--bg-surface);
124
+ border-right: 1px solid var(--border);
125
+ display: flex;
126
+ flex-direction: column;
127
+ overflow-y: auto;
128
+ }
129
+ .sidebar-section {
130
+ padding: 16px;
131
+ border-bottom: 1px solid var(--border);
132
+ }
133
+ .sidebar-heading {
134
+ font-size: 10px;
135
+ font-weight: 600;
136
+ text-transform: uppercase;
137
+ letter-spacing: 1.5px;
138
+ color: var(--text-muted);
139
+ margin-bottom: 10px;
140
+ }
141
+ .scenario-list { display: flex; flex-direction: column; gap: 4px; }
142
+ .scenario-opt {
143
+ display: flex;
144
+ align-items: center;
145
+ gap: 10px;
146
+ padding: 8px 10px;
147
+ border-radius: 6px;
148
+ cursor: pointer;
149
+ transition: all 0.15s;
150
+ border: 1px solid transparent;
151
+ }
152
+ .scenario-opt:hover { background: var(--bg-hover); }
153
+ .scenario-opt.active {
154
+ background: var(--accent-dim);
155
+ border-color: rgba(56,189,248,0.2);
156
+ }
157
+ .scenario-opt .sc-dot { width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0; }
158
+ .scenario-opt .sc-name {
159
+ font-size: 12px; font-weight: 500; flex: 1;
160
+ white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
161
+ }
162
+ .scenario-opt .sc-diff {
163
+ font-size: 10px; font-weight: 600;
164
+ text-transform: uppercase; letter-spacing: 0.5px;
165
+ }
166
+ .gauge { margin-bottom: 14px; }
167
+ .gauge:last-child { margin-bottom: 0; }
168
+ .gauge-header {
169
+ display: flex; justify-content: space-between;
170
+ align-items: baseline; margin-bottom: 6px;
171
+ }
172
+ .gauge-label { font-size: 12px; color: var(--text-dim); font-weight: 500; }
173
+ .gauge-value {
174
+ font-size: 12px; font-weight: 600;
175
+ font-family: 'JetBrains Mono', monospace;
176
+ }
177
+ .gauge-track {
178
+ height: 4px; background: var(--bg-hover);
179
+ border-radius: 4px; overflow: hidden;
180
+ }
181
+ .gauge-fill {
182
+ height: 100%; border-radius: 4px;
183
+ transition: width 0.8s cubic-bezier(0.4,0,0.2,1);
184
+ }
185
+ .pipeline-steps { display: flex; flex-direction: column; gap: 2px; }
186
+ .pipe-step {
187
+ display: flex; align-items: center; gap: 8px;
188
+ padding: 5px 8px; border-radius: 4px;
189
+ font-size: 11px; font-family: 'JetBrains Mono', monospace;
190
+ color: var(--text-muted);
191
+ opacity: 0; transform: translateX(-8px);
192
+ transition: all 0.3s ease;
193
+ }
194
+ .pipe-step.visible { opacity: 1; transform: translateX(0); }
195
+ .pipe-step.active { color: var(--text); background: var(--accent-dim); }
196
+ .pipe-step.done { color: var(--text-dim); }
197
+ .pipe-step .step-icon {
198
+ width: 16px; height: 16px; border-radius: 50%;
199
+ border: 1.5px solid var(--text-muted);
200
+ display: flex; align-items: center; justify-content: center;
201
+ font-size: 8px; flex-shrink: 0; transition: all 0.3s;
202
+ }
203
+ .pipe-step.done .step-icon {
204
+ background: var(--green-dim); border-color: var(--green); color: var(--green);
205
+ }
206
+ .pipe-step.active .step-icon {
207
+ border-color: var(--accent); background: var(--accent-dim);
208
+ color: var(--accent); animation: pulse 1.5s infinite;
209
+ }
210
+
211
+ /* ---- Center: Lab + Terminal ---- */
212
+ .center {
213
+ display: flex;
214
+ flex-direction: column;
215
+ overflow: hidden;
216
+ background: var(--bg);
217
+ }
218
+
219
+ /* Lab canvas */
220
+ .lab-panel {
221
+ height: 300px;
222
+ min-height: 300px;
223
+ background: var(--bg-surface);
224
+ border-bottom: 1px solid var(--border);
225
+ position: relative;
226
+ overflow: hidden;
227
+ }
228
+ .lab-panel canvas {
229
+ display: block;
230
+ width: 100%;
231
+ height: 100%;
232
+ }
233
+ .lab-label {
234
+ position: absolute;
235
+ top: 8px;
236
+ left: 12px;
237
+ font-size: 10px;
238
+ font-weight: 600;
239
+ text-transform: uppercase;
240
+ letter-spacing: 1.5px;
241
+ color: var(--text-muted);
242
+ z-index: 2;
243
+ pointer-events: none;
244
+ }
245
+ .lab-action-label {
246
+ position: absolute;
247
+ bottom: 10px;
248
+ left: 50%;
249
+ transform: translateX(-50%);
250
+ font-size: 11px;
251
+ font-family: 'JetBrains Mono', monospace;
252
+ color: var(--text-dim);
253
+ background: rgba(12,15,22,0.85);
254
+ padding: 4px 14px;
255
+ border-radius: 100px;
256
+ border: 1px solid var(--border);
257
+ z-index: 2;
258
+ pointer-events: none;
259
+ opacity: 0;
260
+ transition: opacity 0.3s;
261
+ }
262
+ .lab-action-label.visible { opacity: 1; }
263
+
264
+ .center-header {
265
+ height: 36px;
266
+ min-height: 36px;
267
+ display: flex;
268
+ align-items: center;
269
+ padding: 0 16px;
270
+ background: var(--bg-surface);
271
+ border-bottom: 1px solid var(--border);
272
+ gap: 8px;
273
+ }
274
+ .tab {
275
+ font-size: 11px; font-weight: 500;
276
+ padding: 4px 12px; border-radius: 4px;
277
+ color: var(--text-dim); cursor: pointer;
278
+ transition: all 0.15s;
279
+ }
280
+ .tab.active { color: var(--text); background: var(--bg-hover); }
281
+ .tab:hover { color: var(--text); }
282
+
283
+ .terminal {
284
+ flex: 1;
285
+ overflow-y: auto;
286
+ padding: 16px 20px;
287
+ font-family: 'JetBrains Mono', monospace;
288
+ font-size: 12.5px;
289
+ line-height: 1.9;
290
+ scrollbar-width: thin;
291
+ scrollbar-color: var(--border) transparent;
292
+ }
293
+ .terminal::-webkit-scrollbar { width: 6px; }
294
+ .terminal::-webkit-scrollbar-track { background: transparent; }
295
+ .terminal::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
296
+
297
+ .t-line {
298
+ white-space: pre-wrap;
299
+ opacity: 0;
300
+ animation: lineIn 0.25s ease forwards;
301
+ }
302
+ @keyframes lineIn {
303
+ from { opacity: 0; transform: translateY(4px); }
304
+ to { opacity: 1; transform: translateY(0); }
305
+ }
306
+ .t-prompt { color: var(--green); }
307
+ .t-cmd { color: var(--text); }
308
+ .t-dim { color: var(--text-muted); }
309
+ .t-label { color: var(--accent); }
310
+ .t-str { color: var(--amber); }
311
+ .t-kw { color: var(--pink); }
312
+ .t-fn { color: var(--cyan); }
313
+ .t-num { color: var(--purple); }
314
+ .t-ok { color: var(--green); }
315
+ .t-warn { color: var(--amber); }
316
+ .t-err { color: var(--red); }
317
+ .t-sub { color: var(--text-dim); }
318
+
319
+ /* ---- Right Panel ---- */
320
+ .right {
321
+ background: var(--bg-surface);
322
+ border-left: 1px solid var(--border);
323
+ display: flex;
324
+ flex-direction: column;
325
+ overflow-y: auto;
326
+ scrollbar-width: thin;
327
+ scrollbar-color: var(--border) transparent;
328
+ }
329
+ .panel-section {
330
+ padding: 16px;
331
+ border-bottom: 1px solid var(--border);
332
+ }
333
+ .panel-heading {
334
+ font-size: 10px; font-weight: 600;
335
+ text-transform: uppercase; letter-spacing: 1.5px;
336
+ color: var(--text-muted); margin-bottom: 12px;
337
+ display: flex; align-items: center; justify-content: space-between;
338
+ }
339
+ .reward-row {
340
+ display: flex; align-items: center; gap: 10px; margin-bottom: 8px;
341
+ }
342
+ .reward-row:last-child { margin-bottom: 0; }
343
+ .rw-label {
344
+ font-size: 11px; font-weight: 500; width: 80px;
345
+ color: var(--text-dim); text-align: right;
346
+ }
347
+ .rw-track {
348
+ flex: 1; height: 18px;
349
+ background: rgba(255,255,255,0.03);
350
+ border-radius: 4px; overflow: hidden; position: relative;
351
+ }
352
+ .rw-fill {
353
+ height: 100%; border-radius: 4px; width: 0%;
354
+ transition: width 0.6s cubic-bezier(0.4,0,0.2,1);
355
+ display: flex; align-items: center; justify-content: flex-end;
356
+ padding-right: 6px; font-size: 10px; font-weight: 600;
357
+ font-family: 'JetBrains Mono', monospace;
358
+ color: rgba(255,255,255,0.85); min-width: fit-content;
359
+ }
360
+ .rw-fill.validity { background: linear-gradient(90deg, rgba(52,211,153,0.5), rgba(52,211,153,0.85)); }
361
+ .rw-fill.ordering { background: linear-gradient(90deg, rgba(34,211,238,0.5), rgba(34,211,238,0.85)); }
362
+ .rw-fill.info_gain { background: linear-gradient(90deg, rgba(56,189,248,0.5), rgba(56,189,248,0.85)); }
363
+ .rw-fill.efficiency { background: linear-gradient(90deg, rgba(251,191,36,0.5), rgba(251,191,36,0.85)); }
364
+ .rw-fill.novelty { background: linear-gradient(90deg, rgba(167,139,250,0.5), rgba(167,139,250,0.85)); }
365
+ .rw-fill.penalty { background: linear-gradient(90deg, rgba(248,113,113,0.5), rgba(248,113,113,0.85)); }
366
+ .cumulative-row {
367
+ display: flex; align-items: baseline; justify-content: space-between;
368
+ margin-top: 12px; padding-top: 12px; border-top: 1px solid var(--border);
369
+ }
370
+ .cum-label { font-size: 11px; color: var(--text-dim); }
371
+ .cum-value {
372
+ font-size: 20px; font-weight: 700;
373
+ font-family: 'JetBrains Mono', monospace; color: var(--green);
374
+ }
375
+ .discovery-list { display: flex; flex-direction: column; gap: 6px; }
376
+ .discovery {
377
+ display: flex; align-items: flex-start; gap: 8px;
378
+ padding: 8px 10px; background: var(--bg-raised);
379
+ border-radius: 6px; border: 1px solid var(--border);
380
+ opacity: 0; transform: scale(0.95); transition: all 0.3s ease;
381
+ }
382
+ .discovery.visible { opacity: 1; transform: scale(1); }
383
+ .disc-icon {
384
+ width: 20px; height: 20px; border-radius: 4px;
385
+ display: flex; align-items: center; justify-content: center;
386
+ font-size: 10px; flex-shrink: 0; margin-top: 1px;
387
+ }
388
+ .disc-body { flex: 1; }
389
+ .disc-title { font-size: 11px; font-weight: 600; }
390
+ .disc-detail {
391
+ font-size: 10px; color: var(--text-dim); margin-top: 2px;
392
+ font-family: 'JetBrains Mono', monospace;
393
+ }
394
+ .empty-state {
395
+ font-size: 11px; color: var(--text-muted);
396
+ font-style: italic; padding: 8px 0;
397
+ }
398
+ .step-reward-mini {
399
+ display: flex; align-items: center; justify-content: space-between;
400
+ padding: 6px 10px; background: var(--bg-raised);
401
+ border-radius: 6px; margin-bottom: 4px;
402
+ font-size: 11px; font-family: 'JetBrains Mono', monospace;
403
+ opacity: 0; transition: all 0.3s;
404
+ }
405
+ .step-reward-mini.visible { opacity: 1; }
406
+ .step-reward-mini .srm-name { color: var(--text-dim); }
407
+ .step-reward-mini .srm-val { font-weight: 600; }
408
+ .step-reward-mini .srm-val.pos { color: var(--green); }
409
+ .step-reward-mini .srm-val.neg { color: var(--red); }
410
+ </style>
411
+ </head>
412
+ <body>
413
+
414
+ <!-- Top Bar -->
415
+ <div class="topbar">
416
+ <div class="topbar-logo">BioEnv</div>
417
+ <div class="topbar-sep"></div>
418
+ <div class="topbar-env">biomarker_validation_lung</div>
419
+ <div class="topbar-status">
420
+ <div class="status-dot" id="statusDot"></div>
421
+ <span id="statusText">Ready</span>
422
+ </div>
423
+ <button class="topbar-btn btn-ghost" id="resetBtn" onclick="resetDemo()">Reset</button>
424
+ <button class="topbar-btn btn-primary" id="runBtn" onclick="startDemo()">Run Episode</button>
425
+ </div>
426
+
427
+ <div class="main">
428
+ <!-- Left Sidebar -->
429
+ <div class="sidebar">
430
+ <div class="sidebar-section">
431
+ <div class="sidebar-heading">Scenario</div>
432
+ <div class="scenario-list">
433
+ <div class="scenario-opt" onclick="selectScenario(this)">
434
+ <div class="sc-dot" style="background: var(--green);"></div>
435
+ <span class="sc-name">Cardiac Disease DE</span>
436
+ <span class="sc-diff" style="color: var(--green);">Easy</span>
437
+ </div>
438
+ <div class="scenario-opt" onclick="selectScenario(this)">
439
+ <div class="sc-dot" style="background: var(--amber);"></div>
440
+ <span class="sc-name">Hematopoiesis Trajectory</span>
441
+ <span class="sc-diff" style="color: var(--amber);">Med</span>
442
+ </div>
443
+ <div class="scenario-opt" onclick="selectScenario(this)">
444
+ <div class="sc-dot" style="background: var(--amber);"></div>
445
+ <span class="sc-name">Perturbation Immune</span>
446
+ <span class="sc-diff" style="color: var(--amber);">Med</span>
447
+ </div>
448
+ <div class="scenario-opt active" onclick="selectScenario(this)">
449
+ <div class="sc-dot" style="background: var(--red);"></div>
450
+ <span class="sc-name">Biomarker Validation (Lung)</span>
451
+ <span class="sc-diff" style="color: var(--red);">Hard</span>
452
+ </div>
453
+ </div>
454
+ </div>
455
+ <div class="sidebar-section">
456
+ <div class="sidebar-heading">Environment State</div>
457
+ <div class="gauge">
458
+ <div class="gauge-header">
459
+ <span class="gauge-label">Budget</span>
460
+ <span class="gauge-value" id="budgetVal">$100,000</span>
461
+ </div>
462
+ <div class="gauge-track"><div class="gauge-fill" id="budgetFill" style="width:100%;background:var(--green);"></div></div>
463
+ </div>
464
+ <div class="gauge">
465
+ <div class="gauge-header">
466
+ <span class="gauge-label">Time</span>
467
+ <span class="gauge-value" id="timeVal">180 / 180 days</span>
468
+ </div>
469
+ <div class="gauge-track"><div class="gauge-fill" id="timeFill" style="width:100%;background:var(--cyan);"></div></div>
470
+ </div>
471
+ <div class="gauge">
472
+ <div class="gauge-header">
473
+ <span class="gauge-label">Steps</span>
474
+ <span class="gauge-value" id="stepVal">0 / 30</span>
475
+ </div>
476
+ <div class="gauge-track"><div class="gauge-fill" id="stepFill" style="width:0%;background:var(--accent);"></div></div>
477
+ </div>
478
+ </div>
479
+ <div class="sidebar-section" style="flex:1;overflow-y:auto;">
480
+ <div class="sidebar-heading">Pipeline</div>
481
+ <div class="pipeline-steps" id="pipelineSteps"></div>
482
+ </div>
483
+ </div>
484
+
485
+ <!-- Center: Lab + Terminal -->
486
+ <div class="center">
487
+ <div class="lab-panel">
488
+ <div class="lab-label">Virtual Lab</div>
489
+ <div class="lab-action-label" id="labActionLabel"></div>
490
+ <canvas id="labCanvas"></canvas>
491
+ </div>
492
+ <div class="center-header">
493
+ <div class="tab active">Agent Log</div>
494
+ <div class="tab">Raw JSON</div>
495
+ </div>
496
+ <div class="terminal" id="terminal"></div>
497
+ </div>
498
+
499
+ <!-- Right Panel -->
500
+ <div class="right">
501
+ <div class="panel-section">
502
+ <div class="panel-heading">
503
+ Step Reward
504
+ <span id="stepRewardLabel" style="font-family:'JetBrains Mono',monospace;font-size:11px;color:var(--text-dim);">--</span>
505
+ </div>
506
+ <div id="rewardBars">
507
+ <div class="reward-row"><span class="rw-label">Validity</span><div class="rw-track"><div class="rw-fill validity" id="rw-validity"></div></div></div>
508
+ <div class="reward-row"><span class="rw-label">Ordering</span><div class="rw-track"><div class="rw-fill ordering" id="rw-ordering"></div></div></div>
509
+ <div class="reward-row"><span class="rw-label">Info Gain</span><div class="rw-track"><div class="rw-fill info_gain" id="rw-info_gain"></div></div></div>
510
+ <div class="reward-row"><span class="rw-label">Efficiency</span><div class="rw-track"><div class="rw-fill efficiency" id="rw-efficiency"></div></div></div>
511
+ <div class="reward-row"><span class="rw-label">Novelty</span><div class="rw-track"><div class="rw-fill novelty" id="rw-novelty"></div></div></div>
512
+ <div class="reward-row"><span class="rw-label">Penalty</span><div class="rw-track"><div class="rw-fill penalty" id="rw-penalty"></div></div></div>
513
+ </div>
514
+ <div class="cumulative-row">
515
+ <span class="cum-label">Cumulative Reward</span>
516
+ <span class="cum-value" id="cumReward">0.00</span>
517
+ </div>
518
+ </div>
519
+ <div class="panel-section">
520
+ <div class="panel-heading">Reward History</div>
521
+ <div id="rewardHistory"><div class="empty-state">No steps yet</div></div>
522
+ </div>
523
+ <div class="panel-section">
524
+ <div class="panel-heading">Discoveries</div>
525
+ <div class="discovery-list" id="discoveries"><div class="empty-state">No discoveries yet</div></div>
526
+ </div>
527
+ <div class="panel-section">
528
+ <div class="panel-heading">Violations</div>
529
+ <div id="violations"><div class="empty-state">No violations</div></div>
530
+ </div>
531
+ </div>
532
+ </div>
533
+
534
+ <script>
535
+ // =====================================================
536
+ // VIRTUAL LAB - Canvas rendering
537
+ // =====================================================
538
+ const labCanvas = document.getElementById('labCanvas');
539
+ const ctx = labCanvas.getContext('2d');
540
+ let labW, labH, dpr;
541
+
542
+ function resizeLab() {
543
+ const rect = labCanvas.parentElement.getBoundingClientRect();
544
+ dpr = window.devicePixelRatio || 1;
545
+ labW = rect.width;
546
+ labH = rect.height;
547
+ labCanvas.width = labW * dpr;
548
+ labCanvas.height = labH * dpr;
549
+ ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
550
+ }
551
+ resizeLab();
552
+ window.addEventListener('resize', () => { resizeLab(); });
553
+
554
+ // Lab stations (positions as fractions of canvas, converted in draw)
555
+ const STATIONS = {
556
+ idle: { fx: 0.06, fy: 0.55, label: 'ENTRANCE', icon: 'door', color: '#475569' },
557
+ sample: { fx: 0.20, fy: 0.35, label: 'SAMPLE BENCH', icon: 'bench', color: '#34d399' },
558
+ cohort: { fx: 0.20, fy: 0.75, label: 'COHORT SELECT', icon: 'people', color: '#34d399' },
559
+ prep: { fx: 0.38, fy: 0.35, label: 'LIBRARY PREP', icon: 'flask', color: '#2dd4bf' },
560
+ sequencer: { fx: 0.38, fy: 0.75, label: 'SEQUENCER', icon: 'machine', color: '#22d3ee' },
561
+ computer: { fx: 0.62, fy: 0.50, label: 'COMPUTE', icon: 'screen', color: '#38bdf8' },
562
+ whiteboard: { fx: 0.84, fy: 0.45, label: 'SYNTHESIS', icon: 'board', color: '#a78bfa' },
563
+ };
564
+
565
+ // Map actions to stations
566
+ const ACTION_STATION = {
567
+ collect_sample: 'sample',
568
+ select_cohort: 'cohort',
569
+ prepare_library: 'prep',
570
+ sequence_cells: 'sequencer',
571
+ run_qc: 'computer',
572
+ normalize_data: 'computer',
573
+ cluster_cells: 'computer',
574
+ differential_expression: 'computer',
575
+ pathway_enrichment: 'computer',
576
+ marker_selection: 'computer',
577
+ validate_marker: 'computer',
578
+ synthesize_conclusion: 'whiteboard',
579
+ };
580
+
581
+ // Agent state
582
+ let agent = { x: 0, y: 0, targetX: 0, targetY: 0, station: 'idle', working: false };
583
+ let agentTrail = [];
584
+ let workingTick = 0;
585
+ let terminalLines = []; // fake terminal on computer screen
586
+ let activeStationKey = null;
587
+ let particlesLab = [];
588
+
589
+ function stationPos(key) {
590
+ const s = STATIONS[key];
591
+ return { x: s.fx * labW, y: s.fy * labH };
592
+ }
593
+
594
+ function initAgent() {
595
+ const p = stationPos('idle');
596
+ agent.x = p.x; agent.y = p.y;
597
+ agent.targetX = p.x; agent.targetY = p.y;
598
+ agent.station = 'idle';
599
+ agent.working = false;
600
+ agent.facing = 1;
601
+ agentTrail = [];
602
+ terminalLines = [];
603
+ activeStationKey = null;
604
+ particlesLab = [];
605
+ }
606
+ initAgent();
607
+
608
+ function moveAgentTo(stationKey) {
609
+ const p = stationPos(stationKey);
610
+ agent.targetX = p.x;
611
+ agent.targetY = p.y;
612
+ agent.station = stationKey;
613
+ agent.working = false;
614
+ activeStationKey = stationKey;
615
+ }
616
+
617
+ function setAgentWorking(actionName) {
618
+ agent.working = true;
619
+ workingTick = 0;
620
+ // If at computer, set up terminal lines
621
+ if (agent.station === 'computer') {
622
+ terminalLines = [];
623
+ typeComputerLines(actionName);
624
+ }
625
+ }
626
+
627
+ const COMP_COMMANDS = {
628
+ run_qc: ['$ scanpy.pp.filter_cells()', ' filtering 11847 cells...', ' 10234 passed QC', ' doublet rate: 3.2%'],
629
+ normalize_data: ['$ scran.normalize(adata)', ' computing size factors...', ' log1p transform', ' HVGs: 3000 selected'],
630
+ cluster_cells: ['$ sc.tl.leiden(adata, 0.8)', ' building kNN graph...', ' optimizing modularity', ' 14 clusters found'],
631
+ differential_expression: ['$ DESeq2.run(IPF, Ctrl)', ' fitting GLM...', ' 1847 DE genes', ' SPP1 log2FC=3.42 ***'],
632
+ pathway_enrichment: ['$ gseapy.enrich(de_genes)', ' KEGG + Reactome...', ' ECM-receptor p=4.2e-12', ' TGF-beta p=1.8e-09'],
633
+ marker_selection: ['$ rank_markers(candidates)', ' SPP1 AUROC: 0.94', ' MMP7 AUROC: 0.87', ' COL1A1 AUROC: 0.81'],
634
+ validate_marker: ['$ cross_validate("SPP1")', ' fold 1: 0.93', ' fold 2: 0.89', ' mean AUROC: 0.91 OK'],
635
+ };
636
+
637
+ async function typeComputerLines(actionName) {
638
+ const lines = COMP_COMMANDS[actionName] || ['$ processing...', ' computing...', ' done'];
639
+ for (let i = 0; i < lines.length; i++) {
640
+ await wait(250);
641
+ terminalLines.push(lines[i]);
642
+ if (terminalLines.length > 5) terminalLines.shift();
643
+ }
644
+ }
645
+
646
+ // Particles burst
647
+ function spawnParticles(x, y, color, count = 8) {
648
+ for (let i = 0; i < count; i++) {
649
+ const angle = (Math.PI * 2 / count) * i + Math.random() * 0.5;
650
+ particlesLab.push({
651
+ x, y,
652
+ vx: Math.cos(angle) * (1.5 + Math.random() * 2),
653
+ vy: Math.sin(angle) * (1.5 + Math.random() * 2),
654
+ life: 1,
655
+ color,
656
+ size: 2 + Math.random() * 2,
657
+ });
658
+ }
659
+ }
660
+
661
+ // ---- Draw loop ----
662
+ let frameCount = 0;
663
+ const FLOOR_COLOR = '#0f1520';
664
+ const WALL_COLOR = '#1a2332';
665
+ const FLOOR_TILE_A = '#0d1219';
666
+ const FLOOR_TILE_B = '#10161f';
667
+
668
+ function drawLab() {
669
+ frameCount++;
670
+ ctx.clearRect(0, 0, labW, labH);
671
+
672
+ // Floor - checkerboard tiles
673
+ const tileSize = 24;
674
+ for (let ty = 0; ty < labH; ty += tileSize) {
675
+ for (let tx = 0; tx < labW; tx += tileSize) {
676
+ const checker = ((Math.floor(tx / tileSize) + Math.floor(ty / tileSize)) % 2 === 0);
677
+ ctx.fillStyle = checker ? FLOOR_TILE_A : FLOOR_TILE_B;
678
+ ctx.fillRect(tx, ty, tileSize, tileSize);
679
+ }
680
+ }
681
+
682
+ // Walls - top and bottom border
683
+ ctx.fillStyle = WALL_COLOR;
684
+ ctx.fillRect(0, 0, labW, 18);
685
+ ctx.fillRect(0, labH - 8, labW, 8);
686
+ ctx.strokeStyle = '#253040';
687
+ ctx.lineWidth = 1;
688
+ ctx.beginPath(); ctx.moveTo(0, 18); ctx.lineTo(labW, 18); ctx.stroke();
689
+
690
+ // Draw equipment at each station (behind the person)
691
+ for (const [key, s] of Object.entries(STATIONS)) {
692
+ const pos = stationPos(key);
693
+ const isActive = key === activeStationKey;
694
+ drawEquipment(key, pos.x, pos.y, s.color, isActive);
695
+ }
696
+
697
+ // Draw walking path (subtle floor markings)
698
+ ctx.strokeStyle = 'rgba(56,189,248,0.06)';
699
+ ctx.lineWidth = 16;
700
+ ctx.lineCap = 'round';
701
+ ctx.lineJoin = 'round';
702
+ const pathOrder = ['idle','sample','prep','computer','whiteboard'];
703
+ ctx.beginPath();
704
+ const p0 = stationPos(pathOrder[0]);
705
+ ctx.moveTo(p0.x, p0.y + 10);
706
+ for (let i = 1; i < pathOrder.length; i++) {
707
+ const p = stationPos(pathOrder[i]);
708
+ ctx.lineTo(p.x, p.y + 10);
709
+ }
710
+ ctx.stroke();
711
+ // Lower path
712
+ ctx.beginPath();
713
+ const pl0 = stationPos('idle');
714
+ ctx.moveTo(pl0.x, pl0.y + 10);
715
+ const pl1 = stationPos('cohort');
716
+ ctx.lineTo(pl1.x, pl1.y + 10);
717
+ const pl2 = stationPos('sequencer');
718
+ ctx.lineTo(pl2.x, pl2.y + 10);
719
+ const pl3 = stationPos('computer');
720
+ ctx.lineTo(pl3.x, pl3.y + 10);
721
+ ctx.stroke();
722
+ ctx.lineCap = 'butt';
723
+
724
+ // Floating terminal popup at computer
725
+ if (agent.station === 'computer' && agent.working && terminalLines.length > 0) {
726
+ const cp = stationPos('computer');
727
+ const sx = cp.x + 55, sy = cp.y - 65;
728
+ const sw = 170, sh = 95;
729
+
730
+ // Shadow
731
+ ctx.fillStyle = 'rgba(0,0,0,0.4)';
732
+ roundRect(ctx, sx + 3, sy + 3, sw, sh, 6);
733
+ ctx.fill();
734
+
735
+ ctx.fillStyle = 'rgba(7,9,13,0.97)';
736
+ ctx.strokeStyle = 'rgba(56,189,248,0.3)';
737
+ ctx.lineWidth = 1;
738
+ roundRect(ctx, sx, sy, sw, sh, 6);
739
+ ctx.fill(); ctx.stroke();
740
+
741
+ // Title bar
742
+ ctx.fillStyle = 'rgba(30,41,59,0.5)';
743
+ ctx.fillRect(sx + 1, sy + 1, sw - 2, 14);
744
+ ctx.fillStyle = '#475569';
745
+ ctx.font = '500 7px Inter, sans-serif';
746
+ ctx.textAlign = 'left';
747
+ ctx.fillText('terminal', sx + 6, sy + 10);
748
+ // dots
749
+ ctx.fillStyle = '#f87171'; ctx.beginPath(); ctx.arc(sx + sw - 28, sy + 7, 3, 0, Math.PI*2); ctx.fill();
750
+ ctx.fillStyle = '#fbbf24'; ctx.beginPath(); ctx.arc(sx + sw - 18, sy + 7, 3, 0, Math.PI*2); ctx.fill();
751
+ ctx.fillStyle = '#34d399'; ctx.beginPath(); ctx.arc(sx + sw - 8, sy + 7, 3, 0, Math.PI*2); ctx.fill();
752
+
753
+ ctx.font = '500 9px JetBrains Mono, monospace';
754
+ const startY = sy + 28;
755
+ for (let i = 0; i < terminalLines.length; i++) {
756
+ const line = terminalLines[i];
757
+ ctx.fillStyle = line.startsWith('$') ? '#34d399' : line.includes('***') || line.includes('OK') ? '#34d399' : '#94a3b8';
758
+ ctx.fillText(terminalLines[i].substring(0, 24), sx + 8, startY + i * 14);
759
+ }
760
+ if (frameCount % 60 < 30) {
761
+ ctx.fillStyle = '#34d399';
762
+ ctx.fillRect(sx + 8, startY + terminalLines.length * 14 - 8, 6, 11);
763
+ }
764
+ }
765
+
766
+ // Whiteboard popup
767
+ if (agent.station === 'whiteboard' && agent.working) {
768
+ const wp = stationPos('whiteboard');
769
+ const bx = wp.x - 60, by = wp.y - 75;
770
+ const bw = 120, bh = 72;
771
+ ctx.fillStyle = 'rgba(0,0,0,0.3)';
772
+ roundRect(ctx, bx + 3, by + 3, bw, bh, 6);
773
+ ctx.fill();
774
+ ctx.fillStyle = 'rgba(17,24,39,0.95)';
775
+ ctx.strokeStyle = 'rgba(167,139,250,0.3)';
776
+ ctx.lineWidth = 1;
777
+ roundRect(ctx, bx, by, bw, bh, 6);
778
+ ctx.fill(); ctx.stroke();
779
+ ctx.font = '600 8px JetBrains Mono, monospace';
780
+ ctx.textAlign = 'left';
781
+ ctx.fillStyle = '#a78bfa';
782
+ ctx.fillText('CONCLUSION', bx + 8, by + 14);
783
+ ctx.font = '400 7.5px JetBrains Mono, monospace';
784
+ const synthLines = ['SPP1 validated', 'AUROC = 0.91', 'Confidence: 0.85', 'Match: 4/5'];
785
+ for (let i = 0; i < synthLines.length; i++) {
786
+ ctx.fillStyle = i === 0 ? '#34d399' : '#94a3b8';
787
+ ctx.fillText(synthLines[i], bx + 8, by + 28 + i * 12);
788
+ }
789
+ }
790
+
791
+ // Activity text above active station
792
+ if (agent.working && activeStationKey && activeStationKey !== 'idle') {
793
+ const sp = stationPos(activeStationKey);
794
+ const actTexts = {
795
+ sample: 'collecting tissue...', cohort: 'selecting cohort...',
796
+ prep: 'preparing library...', sequencer: 'sequencing...',
797
+ computer: 'computing...', whiteboard: 'synthesizing...',
798
+ };
799
+ ctx.fillStyle = STATIONS[activeStationKey].color;
800
+ ctx.font = '500 9px JetBrains Mono, monospace';
801
+ ctx.textAlign = 'center';
802
+ ctx.globalAlpha = 0.5 + 0.3 * Math.sin(frameCount * 0.06);
803
+ const yOff = ['sample','prep'].includes(activeStationKey) ? -55 : -50;
804
+ ctx.fillText(actTexts[activeStationKey] || 'working...', sp.x, sp.y + yOff);
805
+ ctx.globalAlpha = 1;
806
+ }
807
+
808
+ // Move agent smoothly
809
+ const dx = agent.targetX - agent.x;
810
+ const dy = agent.targetY - agent.y;
811
+ const dist = Math.sqrt(dx * dx + dy * dy);
812
+ const isWalking = dist > 2;
813
+ if (isWalking) {
814
+ const speed = 0.05;
815
+ agent.x += dx * speed;
816
+ agent.y += dy * speed;
817
+ agent.facing = dx > 0 ? 1 : dx < -0.5 ? -1 : agent.facing;
818
+ }
819
+
820
+ // Draw person
821
+ drawPerson(agent.x, agent.y, isWalking, agent.working, agent.facing || 1);
822
+
823
+ // Particles
824
+ for (let i = particlesLab.length - 1; i >= 0; i--) {
825
+ const p = particlesLab[i];
826
+ p.x += p.vx; p.y += p.vy;
827
+ p.vx *= 0.95; p.vy *= 0.95;
828
+ p.life -= 0.02;
829
+ if (p.life <= 0) { particlesLab.splice(i, 1); continue; }
830
+ ctx.globalAlpha = p.life * 0.6;
831
+ ctx.fillStyle = p.color;
832
+ ctx.beginPath();
833
+ ctx.arc(p.x, p.y, p.size * p.life, 0, Math.PI * 2);
834
+ ctx.fill();
835
+ }
836
+ ctx.globalAlpha = 1;
837
+
838
+ // Station labels
839
+ for (const [key, s] of Object.entries(STATIONS)) {
840
+ if (key === 'idle') continue;
841
+ const pos = stationPos(key);
842
+ const isActive = key === activeStationKey;
843
+ ctx.fillStyle = isActive ? s.color : '#334155';
844
+ ctx.font = `600 ${isActive ? 9 : 8}px Inter, sans-serif`;
845
+ ctx.textAlign = 'center';
846
+ const ly = key === 'cohort' || key === 'sequencer' ? pos.y + 45 : pos.y + 42;
847
+ ctx.fillText(s.label, pos.x, ly);
848
+ }
849
+
850
+ requestAnimationFrame(drawLab);
851
+ }
852
+
853
+ // ---- Draw person (lab coat researcher) ----
854
+ function drawPerson(x, y, walking, working, facing) {
855
+ const f = facing;
856
+ const t = frameCount;
857
+ // Walking cycle
858
+ const walkCycle = walking ? Math.sin(t * 0.15) : 0;
859
+ const bobY = walking ? Math.abs(Math.sin(t * 0.15)) * 2 : 0;
860
+ // Working arm animation
861
+ const workArm = working ? Math.sin(t * 0.08) * 0.3 : 0;
862
+
863
+ const py = y - bobY; // feet position base
864
+
865
+ ctx.save();
866
+ ctx.translate(x, py);
867
+
868
+ // Shadow
869
+ ctx.fillStyle = 'rgba(0,0,0,0.25)';
870
+ ctx.beginPath();
871
+ ctx.ellipse(0, 12, 10, 4, 0, 0, Math.PI * 2);
872
+ ctx.fill();
873
+
874
+ // Legs
875
+ const legSpread = walking ? walkCycle * 5 : 0;
876
+ ctx.strokeStyle = '#1e3a5f';
877
+ ctx.lineWidth = 3;
878
+ ctx.lineCap = 'round';
879
+ // Left leg
880
+ ctx.beginPath();
881
+ ctx.moveTo(-3, 4);
882
+ ctx.lineTo(-3 + legSpread, 12);
883
+ ctx.stroke();
884
+ // Right leg
885
+ ctx.beginPath();
886
+ ctx.moveTo(3, 4);
887
+ ctx.lineTo(3 - legSpread, 12);
888
+ ctx.stroke();
889
+ // Shoes
890
+ ctx.fillStyle = '#1e293b';
891
+ ctx.beginPath(); ctx.arc(-3 + legSpread, 12, 2.5, 0, Math.PI * 2); ctx.fill();
892
+ ctx.beginPath(); ctx.arc(3 - legSpread, 12, 2.5, 0, Math.PI * 2); ctx.fill();
893
+
894
+ // Body / lab coat
895
+ ctx.fillStyle = '#e2e8f0'; // white lab coat
896
+ ctx.beginPath();
897
+ ctx.moveTo(-7, -4);
898
+ ctx.lineTo(-6, 6);
899
+ ctx.lineTo(6, 6);
900
+ ctx.lineTo(7, -4);
901
+ ctx.quadraticCurveTo(7, -10, 0, -10);
902
+ ctx.quadraticCurveTo(-7, -10, -7, -4);
903
+ ctx.fill();
904
+ // Coat outline
905
+ ctx.strokeStyle = '#94a3b8';
906
+ ctx.lineWidth = 0.5;
907
+ ctx.stroke();
908
+ // Coat split at bottom
909
+ ctx.beginPath();
910
+ ctx.moveTo(0, 1);
911
+ ctx.lineTo(0, 6);
912
+ ctx.strokeStyle = '#cbd5e1';
913
+ ctx.lineWidth = 0.5;
914
+ ctx.stroke();
915
+ // Pocket
916
+ ctx.strokeStyle = '#94a3b8';
917
+ ctx.lineWidth = 0.5;
918
+ ctx.strokeRect(f > 0 ? 1 : -5, -1, 4, 3);
919
+
920
+ // Arms
921
+ ctx.strokeStyle = '#e2e8f0';
922
+ ctx.lineWidth = 3.5;
923
+ ctx.lineCap = 'round';
924
+ // Back arm
925
+ const backArmSwing = walking ? -walkCycle * 4 : 0;
926
+ ctx.beginPath();
927
+ ctx.moveTo(-f * 6, -6);
928
+ ctx.lineTo(-f * 6 + backArmSwing, 2);
929
+ ctx.stroke();
930
+ // Front arm (active arm)
931
+ if (working) {
932
+ // Arm reaching forward/up for work
933
+ ctx.beginPath();
934
+ ctx.moveTo(f * 6, -6);
935
+ ctx.lineTo(f * 10 + workArm * 5, -8 + workArm * 3);
936
+ ctx.stroke();
937
+ // Hand/tool
938
+ ctx.fillStyle = '#fde68a';
939
+ ctx.beginPath();
940
+ ctx.arc(f * 10 + workArm * 5, -8 + workArm * 3, 2, 0, Math.PI * 2);
941
+ ctx.fill();
942
+ } else {
943
+ const frontArmSwing = walking ? walkCycle * 4 : 0;
944
+ ctx.beginPath();
945
+ ctx.moveTo(f * 6, -6);
946
+ ctx.lineTo(f * 6 + frontArmSwing, 2);
947
+ ctx.stroke();
948
+ }
949
+ // Skin for hands
950
+ ctx.fillStyle = '#fde68a';
951
+ ctx.beginPath(); ctx.arc(-f * 6 + backArmSwing, 2, 1.8, 0, Math.PI * 2); ctx.fill();
952
+ if (!working) {
953
+ const fs = walking ? walkCycle * 4 : 0;
954
+ ctx.beginPath(); ctx.arc(f * 6 + fs, 2, 1.8, 0, Math.PI * 2); ctx.fill();
955
+ }
956
+
957
+ // Head
958
+ ctx.fillStyle = '#fde68a'; // skin
959
+ ctx.beginPath();
960
+ ctx.arc(0, -15, 7, 0, Math.PI * 2);
961
+ ctx.fill();
962
+ // Hair
963
+ ctx.fillStyle = '#1e293b';
964
+ ctx.beginPath();
965
+ ctx.arc(0, -17, 7, Math.PI, 0);
966
+ ctx.fill();
967
+ // Face details
968
+ ctx.fillStyle = '#1e293b';
969
+ // Eyes
970
+ ctx.beginPath();
971
+ ctx.arc(f * 2.5, -15.5, 1, 0, Math.PI * 2);
972
+ ctx.fill();
973
+ ctx.beginPath();
974
+ ctx.arc(f * -1.5, -15.5, 1, 0, Math.PI * 2);
975
+ ctx.fill();
976
+ // Glasses
977
+ ctx.strokeStyle = '#475569';
978
+ ctx.lineWidth = 0.7;
979
+ ctx.beginPath();
980
+ ctx.arc(f * 2.5, -15.5, 2.5, 0, Math.PI * 2);
981
+ ctx.stroke();
982
+ ctx.beginPath();
983
+ ctx.arc(f * -1.5, -15.5, 2.5, 0, Math.PI * 2);
984
+ ctx.stroke();
985
+ ctx.beginPath();
986
+ ctx.moveTo(f * 0.5, -15.5);
987
+ ctx.lineTo(f * -0.5, -15.5);
988
+ ctx.stroke();
989
+ // Mouth
990
+ if (working) {
991
+ ctx.fillStyle = '#1e293b';
992
+ ctx.beginPath();
993
+ ctx.arc(f * 0.5, -12.5, 1, 0, Math.PI);
994
+ ctx.fill();
995
+ }
996
+
997
+ // ID Badge
998
+ ctx.fillStyle = '#38bdf8';
999
+ ctx.fillRect(f > 0 ? -6 : 2, -3, 4, 5);
1000
+ ctx.fillStyle = '#fff';
1001
+ ctx.font = 'bold 3px Inter, sans-serif';
1002
+ ctx.textAlign = 'center';
1003
+ ctx.fillText('AI', f > 0 ? -4 : 4, 0.5);
1004
+
1005
+ ctx.restore();
1006
+ }
1007
+
1008
+ // ---- Draw lab equipment ----
1009
+ function drawEquipment(stationKey, cx, cy, color, active) {
1010
+ ctx.save();
1011
+
1012
+ switch (stationKey) {
1013
+ case 'idle':
1014
+ // Door frame
1015
+ ctx.strokeStyle = '#334155';
1016
+ ctx.lineWidth = 2;
1017
+ ctx.strokeRect(cx - 12, cy - 30, 24, 40);
1018
+ ctx.fillStyle = '#1a2332';
1019
+ ctx.fillRect(cx - 10, cy - 28, 20, 36);
1020
+ ctx.fillStyle = '#475569';
1021
+ ctx.beginPath(); ctx.arc(cx + 6, cy - 10, 2, 0, Math.PI * 2); ctx.fill();
1022
+ break;
1023
+
1024
+ case 'sample':
1025
+ // Lab bench with sample tubes
1026
+ // Bench surface
1027
+ ctx.fillStyle = '#1a2332';
1028
+ ctx.fillRect(cx - 30, cy - 8, 60, 6);
1029
+ // Bench legs
1030
+ ctx.fillStyle = '#253040';
1031
+ ctx.fillRect(cx - 28, cy - 2, 4, 20);
1032
+ ctx.fillRect(cx + 24, cy - 2, 4, 20);
1033
+ // Tube rack
1034
+ ctx.fillStyle = '#253040';
1035
+ ctx.fillRect(cx - 18, cy - 18, 36, 10);
1036
+ // Test tubes
1037
+ const tubeColors = ['#34d399', '#22d3ee', '#fbbf24', '#f472b6', '#34d399', '#22d3ee'];
1038
+ for (let i = 0; i < 6; i++) {
1039
+ const tx = cx - 14 + i * 6;
1040
+ ctx.fillStyle = active ? tubeColors[i] : '#334155';
1041
+ ctx.globalAlpha = active ? 0.7 : 0.4;
1042
+ ctx.fillRect(tx, cy - 28, 4, 12);
1043
+ // Tube caps
1044
+ ctx.globalAlpha = 1;
1045
+ ctx.fillStyle = active ? tubeColors[i] : '#475569';
1046
+ ctx.fillRect(tx - 0.5, cy - 29, 5, 2);
1047
+ }
1048
+ ctx.globalAlpha = 1;
1049
+ // Pipette if active
1050
+ if (active) {
1051
+ const pipY = cy - 32 + Math.sin(frameCount * 0.08) * 4;
1052
+ ctx.strokeStyle = '#94a3b8';
1053
+ ctx.lineWidth = 2;
1054
+ ctx.beginPath();
1055
+ ctx.moveTo(cx + 5, pipY);
1056
+ ctx.lineTo(cx + 5, pipY - 14);
1057
+ ctx.stroke();
1058
+ ctx.fillStyle = '#64748b';
1059
+ ctx.fillRect(cx + 3, pipY - 18, 5, 6);
1060
+ // Droplet
1061
+ if (frameCount % 60 < 20) {
1062
+ ctx.fillStyle = '#34d399';
1063
+ ctx.globalAlpha = 0.6;
1064
+ ctx.beginPath();
1065
+ ctx.arc(cx + 5, pipY + 3, 1.5, 0, Math.PI * 2);
1066
+ ctx.fill();
1067
+ ctx.globalAlpha = 1;
1068
+ }
1069
+ }
1070
+ break;
1071
+
1072
+ case 'cohort':
1073
+ // Filing cabinet / patient records
1074
+ ctx.fillStyle = '#1a2332';
1075
+ ctx.fillRect(cx - 20, cy - 22, 40, 40);
1076
+ ctx.strokeStyle = '#253040';
1077
+ ctx.lineWidth = 1;
1078
+ for (let i = 0; i < 3; i++) {
1079
+ const dy = cy - 18 + i * 13;
1080
+ ctx.strokeRect(cx - 18, dy, 36, 11);
1081
+ ctx.fillStyle = active ? '#475569' : '#253040';
1082
+ ctx.fillRect(cx - 4, dy + 4, 8, 3);
1083
+ }
1084
+ // Clipboard
1085
+ ctx.fillStyle = '#253040';
1086
+ ctx.fillRect(cx + 24, cy - 16, 14, 20);
1087
+ ctx.strokeStyle = '#475569';
1088
+ ctx.lineWidth = 0.5;
1089
+ for (let i = 0; i < 4; i++) {
1090
+ ctx.beginPath();
1091
+ ctx.moveTo(cx + 27, cy - 12 + i * 4);
1092
+ ctx.lineTo(cx + 35, cy - 12 + i * 4);
1093
+ ctx.stroke();
1094
+ }
1095
+ if (active) {
1096
+ ctx.fillStyle = color;
1097
+ ctx.globalAlpha = 0.5;
1098
+ ctx.beginPath(); ctx.arc(cx + 31, cy - 14, 2, 0, Math.PI * 2); ctx.fill();
1099
+ ctx.globalAlpha = 1;
1100
+ }
1101
+ break;
1102
+
1103
+ case 'prep':
1104
+ // Library prep station - PCR machine + bench
1105
+ // Bench
1106
+ ctx.fillStyle = '#1a2332';
1107
+ ctx.fillRect(cx - 28, cy - 6, 56, 6);
1108
+ ctx.fillStyle = '#253040';
1109
+ ctx.fillRect(cx - 26, cy, 4, 18);
1110
+ ctx.fillRect(cx + 22, cy, 4, 18);
1111
+ // PCR/thermocycler machine
1112
+ ctx.fillStyle = active ? '#192535' : '#172030';
1113
+ ctx.strokeStyle = active ? color : '#253040';
1114
+ ctx.lineWidth = 1;
1115
+ roundRect(ctx, cx - 18, cy - 26, 36, 20, 3);
1116
+ ctx.fill(); ctx.stroke();
1117
+ // Display on machine
1118
+ ctx.fillStyle = active ? 'rgba(45,212,191,0.15)' : 'rgba(30,41,59,0.3)';
1119
+ ctx.fillRect(cx - 14, cy - 22, 16, 8);
1120
+ if (active) {
1121
+ ctx.fillStyle = color;
1122
+ ctx.font = '500 6px JetBrains Mono, monospace';
1123
+ ctx.textAlign = 'left';
1124
+ ctx.fillText('72.0°C', cx - 12, cy - 16);
1125
+ // LED
1126
+ ctx.fillStyle = color;
1127
+ ctx.beginPath(); ctx.arc(cx + 12, cy - 18, 2, 0, Math.PI * 2); ctx.fill();
1128
+ }
1129
+ // Microplate
1130
+ ctx.fillStyle = '#1e293b';
1131
+ ctx.fillRect(cx - 20, cy - 3, 18, 12);
1132
+ ctx.strokeStyle = '#334155';
1133
+ ctx.lineWidth = 0.3;
1134
+ for (let r = 0; r < 3; r++) {
1135
+ for (let c = 0; c < 4; c++) {
1136
+ ctx.beginPath();
1137
+ ctx.arc(cx - 17 + c * 4.5, cy + 1 + r * 3.5, 1.2, 0, Math.PI * 2);
1138
+ ctx.stroke();
1139
+ }
1140
+ }
1141
+ break;
1142
+
1143
+ case 'sequencer':
1144
+ // Big sequencing machine (NovaSeq-like)
1145
+ // Machine body
1146
+ ctx.fillStyle = '#172030';
1147
+ ctx.strokeStyle = active ? color : '#253040';
1148
+ ctx.lineWidth = active ? 1.5 : 1;
1149
+ roundRect(ctx, cx - 24, cy - 28, 48, 44, 4);
1150
+ ctx.fill(); ctx.stroke();
1151
+ // Front panel / screen
1152
+ ctx.fillStyle = active ? 'rgba(34,211,238,0.1)' : 'rgba(30,41,59,0.3)';
1153
+ roundRect(ctx, cx - 18, cy - 22, 36, 18, 2);
1154
+ ctx.fill();
1155
+ if (active) {
1156
+ // Progress bar on screen
1157
+ ctx.fillStyle = 'rgba(34,211,238,0.2)';
1158
+ ctx.fillRect(cx - 14, cy - 12, 28, 4);
1159
+ const progress = (frameCount % 120) / 120;
1160
+ ctx.fillStyle = color;
1161
+ ctx.fillRect(cx - 14, cy - 12, 28 * progress, 4);
1162
+ ctx.fillStyle = color;
1163
+ ctx.font = '500 6px JetBrains Mono, monospace';
1164
+ ctx.textAlign = 'center';
1165
+ ctx.fillText('SEQUENCING', cx, cy - 16);
1166
+ }
1167
+ // Slot
1168
+ ctx.fillStyle = '#0f1520';
1169
+ ctx.fillRect(cx - 10, cy, 20, 4);
1170
+ // Status LEDs
1171
+ ctx.fillStyle = active ? '#34d399' : '#334155';
1172
+ ctx.beginPath(); ctx.arc(cx - 14, cy + 10, 2, 0, Math.PI * 2); ctx.fill();
1173
+ if (active && frameCount % 30 < 15) {
1174
+ ctx.fillStyle = '#fbbf24';
1175
+ } else {
1176
+ ctx.fillStyle = '#334155';
1177
+ }
1178
+ ctx.beginPath(); ctx.arc(cx - 8, cy + 10, 2, 0, Math.PI * 2); ctx.fill();
1179
+ break;
1180
+
1181
+ case 'computer':
1182
+ // Computer desk with dual monitors
1183
+ // Desk
1184
+ ctx.fillStyle = '#1a2332';
1185
+ ctx.fillRect(cx - 36, cy + 2, 72, 5);
1186
+ ctx.fillStyle = '#253040';
1187
+ ctx.fillRect(cx - 32, cy + 7, 4, 16);
1188
+ ctx.fillRect(cx + 28, cy + 7, 4, 16);
1189
+ // Chair
1190
+ ctx.fillStyle = '#1e293b';
1191
+ ctx.beginPath();
1192
+ ctx.arc(cx, cy + 28, 8, 0, Math.PI * 2);
1193
+ ctx.fill();
1194
+ ctx.fillStyle = '#253040';
1195
+ ctx.fillRect(cx - 1, cy + 20, 2, 8);
1196
+ // Monitor 1 (main)
1197
+ ctx.fillStyle = active ? '#0c1219' : '#131c28';
1198
+ ctx.strokeStyle = active ? 'rgba(56,189,248,0.4)' : '#253040';
1199
+ ctx.lineWidth = 1;
1200
+ roundRect(ctx, cx - 30, cy - 28, 32, 24, 2);
1201
+ ctx.fill(); ctx.stroke();
1202
+ // Monitor stand
1203
+ ctx.fillStyle = '#334155';
1204
+ ctx.fillRect(cx - 16, cy - 4, 4, 6);
1205
+ ctx.fillRect(cx - 20, cy + 1, 12, 2);
1206
+ // Monitor 2
1207
+ ctx.fillStyle = active ? '#0c1219' : '#131c28';
1208
+ ctx.strokeStyle = active ? 'rgba(56,189,248,0.3)' : '#253040';
1209
+ roundRect(ctx, cx + 2, cy - 24, 26, 20, 2);
1210
+ ctx.fill(); ctx.stroke();
1211
+ ctx.fillStyle = '#334155';
1212
+ ctx.fillRect(cx + 13, cy - 4, 4, 6);
1213
+ ctx.fillRect(cx + 9, cy + 1, 12, 2);
1214
+ // Screen content
1215
+ if (active) {
1216
+ ctx.fillStyle = 'rgba(56,189,248,0.08)';
1217
+ ctx.fillRect(cx - 28, cy - 26, 28, 20);
1218
+ // Code lines
1219
+ for (let i = 0; i < 5; i++) {
1220
+ ctx.fillStyle = `rgba(56,189,248,${0.15 + i * 0.06})`;
1221
+ const w = 8 + Math.sin(i * 2.3 + frameCount * 0.02) * 6;
1222
+ ctx.fillRect(cx - 26, cy - 24 + i * 4, w, 2);
1223
+ }
1224
+ // Second screen - graph
1225
+ ctx.fillStyle = 'rgba(56,189,248,0.06)';
1226
+ ctx.fillRect(cx + 4, cy - 22, 22, 16);
1227
+ ctx.strokeStyle = 'rgba(34,211,238,0.3)';
1228
+ ctx.lineWidth = 1;
1229
+ ctx.beginPath();
1230
+ ctx.moveTo(cx + 6, cy - 8);
1231
+ for (let i = 0; i < 8; i++) {
1232
+ ctx.lineTo(cx + 6 + i * 2.5, cy - 10 - Math.sin(i * 0.8 + frameCount * 0.03) * 5);
1233
+ }
1234
+ ctx.stroke();
1235
+ }
1236
+ // Keyboard
1237
+ ctx.fillStyle = '#1e293b';
1238
+ ctx.fillRect(cx - 14, cy + 4, 28, 6);
1239
+ // Typing effect
1240
+ if (active && agent.working) {
1241
+ const keyX = cx - 12 + (frameCount % 20) * 1.2;
1242
+ ctx.fillStyle = 'rgba(56,189,248,0.4)';
1243
+ ctx.fillRect(keyX, cy + 5, 3, 4);
1244
+ }
1245
+ break;
1246
+
1247
+ case 'whiteboard':
1248
+ // Whiteboard on wall + standing desk
1249
+ // Board on wall
1250
+ ctx.fillStyle = '#1e293b';
1251
+ ctx.strokeStyle = '#334155';
1252
+ ctx.lineWidth = 1;
1253
+ ctx.fillRect(cx - 28, cy - 34, 56, 32);
1254
+ ctx.strokeRect(cx - 28, cy - 34, 56, 32);
1255
+ // Board content
1256
+ if (active) {
1257
+ ctx.fillStyle = 'rgba(167,139,250,0.1)';
1258
+ ctx.fillRect(cx - 26, cy - 32, 52, 28);
1259
+ // Diagram elements
1260
+ ctx.strokeStyle = 'rgba(167,139,250,0.4)';
1261
+ ctx.lineWidth = 0.8;
1262
+ // Boxes
1263
+ ctx.strokeRect(cx - 20, cy - 28, 14, 8);
1264
+ ctx.strokeRect(cx + 6, cy - 28, 14, 8);
1265
+ ctx.strokeRect(cx - 8, cy - 16, 16, 8);
1266
+ // Arrows
1267
+ ctx.beginPath();
1268
+ ctx.moveTo(cx - 6, cy - 24); ctx.lineTo(cx + 6, cy - 24); ctx.stroke();
1269
+ ctx.beginPath();
1270
+ ctx.moveTo(cx, cy - 20); ctx.lineTo(cx, cy - 16); ctx.stroke();
1271
+ // Checkmark
1272
+ ctx.strokeStyle = '#34d399';
1273
+ ctx.lineWidth = 1.5;
1274
+ ctx.beginPath();
1275
+ ctx.moveTo(cx - 4, cy - 12);
1276
+ ctx.lineTo(cx - 1, cy - 9);
1277
+ ctx.lineTo(cx + 5, cy - 15);
1278
+ ctx.stroke();
1279
+ } else {
1280
+ // Faint lines
1281
+ ctx.strokeStyle = '#253040';
1282
+ ctx.lineWidth = 0.5;
1283
+ for (let i = 0; i < 4; i++) {
1284
+ ctx.beginPath();
1285
+ ctx.moveTo(cx - 22, cy - 28 + i * 7);
1286
+ ctx.lineTo(cx + 22, cy - 28 + i * 7);
1287
+ ctx.stroke();
1288
+ }
1289
+ }
1290
+ // Standing desk
1291
+ ctx.fillStyle = '#1a2332';
1292
+ ctx.fillRect(cx - 16, cy + 2, 32, 4);
1293
+ ctx.fillStyle = '#253040';
1294
+ ctx.fillRect(cx - 2, cy + 6, 4, 14);
1295
+ break;
1296
+ }
1297
+
1298
+ ctx.restore();
1299
+ }
1300
+
1301
+ function roundRect(ctx, x, y, w, h, r) {
1302
+ ctx.beginPath();
1303
+ ctx.moveTo(x + r, y);
1304
+ ctx.lineTo(x + w - r, y);
1305
+ ctx.quadraticCurveTo(x + w, y, x + w, y + r);
1306
+ ctx.lineTo(x + w, y + h - r);
1307
+ ctx.quadraticCurveTo(x + w, y + h, x + w - r, y + h);
1308
+ ctx.lineTo(x + r, y + h);
1309
+ ctx.quadraticCurveTo(x, y + h, x, y + h - r);
1310
+ ctx.lineTo(x, y + r);
1311
+ ctx.quadraticCurveTo(x, y, x + r, y);
1312
+ ctx.closePath();
1313
+ }
1314
+
1315
+ drawLab();
1316
+
1317
+ // =====================================================
1318
+ // EPISODE DATA + APP LOGIC
1319
+ // =====================================================
1320
+ const EPISODE = [
1321
+ {
1322
+ action: 'collect_sample', params: 'n_samples=8, tissue="lung"', category: 'wet',
1323
+ budget: 92400, budgetPct: 92.4, time: 165, timePct: 91.7,
1324
+ output: ['Collected 8 lung tissue samples (4 IPF, 4 control)','Tissue quality: excellent | Storage: -80C'],
1325
+ reward: { validity: 0.90, ordering: 1.00, info_gain: 0.10, efficiency: 0.72, novelty: 1.00, penalty: 0.0 },
1326
+ total: 0.45,
1327
+ },
1328
+ {
1329
+ action: 'select_cohort', params: 'criteria="age_matched, sex_balanced"', category: 'wet',
1330
+ budget: 91800, budgetPct: 91.8, time: 162, timePct: 90.0,
1331
+ output: ['Cohort selected: 4 IPF patients (2M/2F, age 58-67)','Controls matched: 4 healthy donors (2M/2F, age 55-65)'],
1332
+ reward: { validity: 0.85, ordering: 0.90, info_gain: 0.15, efficiency: 0.80, novelty: 0.90, penalty: 0.0 },
1333
+ total: 0.38,
1334
+ },
1335
+ {
1336
+ action: 'prepare_library', params: 'protocol="10x_chromium_v3"', category: 'wet',
1337
+ budget: 84200, budgetPct: 84.2, time: 155, timePct: 86.1,
1338
+ output: ['Library prep complete using 10x Chromium v3','Estimated cell capture: ~12,000 cells','cDNA yield: 42ng (good)'],
1339
+ reward: { validity: 0.95, ordering: 1.00, info_gain: 0.20, efficiency: 0.70, novelty: 0.95, penalty: 0.0 },
1340
+ total: 0.52,
1341
+ },
1342
+ {
1343
+ action: 'sequence_cells', params: 'depth="standard", platform="NovaSeq"', category: 'wet',
1344
+ budget: 68500, budgetPct: 68.5, time: 142, timePct: 78.9,
1345
+ output: ['11,847 cells sequenced | 22,438 genes detected','Median reads/cell: 45,200 | Median genes/cell: 3,842','Sequencing saturation: 78.3%'],
1346
+ reward: { validity: 0.95, ordering: 1.00, info_gain: 0.55, efficiency: 0.60, novelty: 0.90, penalty: 0.0 },
1347
+ total: 0.68,
1348
+ },
1349
+ {
1350
+ action: 'run_qc', params: 'tool="scanpy", min_genes=200', category: 'comp',
1351
+ budget: 68100, budgetPct: 68.1, time: 141, timePct: 78.3,
1352
+ output: ['QC complete: 10,234 / 11,847 cells passed (86.4%)','Removed: 382 doublets (3.2%), 1,231 low-quality cells','Mitochondrial threshold: 20% (flagged 847 cells)'],
1353
+ reward: { validity: 0.95, ordering: 1.00, info_gain: 0.35, efficiency: 0.85, novelty: 0.80, penalty: 0.0 },
1354
+ total: 0.55,
1355
+ },
1356
+ {
1357
+ action: 'normalize_data', params: 'method="scran", log_transform=true', category: 'comp',
1358
+ budget: 67900, budgetPct: 67.9, time: 140, timePct: 77.8,
1359
+ output: ['Size-factor normalization (scran) applied','Log1p transform complete | HVG selection: 3,000 genes'],
1360
+ reward: { validity: 0.90, ordering: 1.00, info_gain: 0.25, efficiency: 0.90, novelty: 0.70, penalty: 0.0 },
1361
+ total: 0.42,
1362
+ },
1363
+ {
1364
+ action: 'cluster_cells', params: 'algorithm="leiden", resolution=0.8', category: 'comp',
1365
+ budget: 67500, budgetPct: 67.5, time: 139, timePct: 77.2,
1366
+ output: ['Leiden clustering: 14 clusters identified','AT1 (8.2%), AT2 (12.1%), Fibroblast (15.7%), Macrophage (18.3%)','Endothelial (9.4%), Basal (6.1%), Ciliated (5.8%), NK/T (7.2%)','Smooth Muscle (4.1%), Mast (2.9%), B cell (3.4%), pDC (2.0%)','Mesothelial (2.6%), Aberrant Basaloid (2.2%)'],
1367
+ reward: { validity: 0.95, ordering: 1.00, info_gain: 0.65, efficiency: 0.85, novelty: 0.85, penalty: 0.0 },
1368
+ total: 0.72,
1369
+ discovery: { title: '14 cell populations identified', detail: 'Including Aberrant Basaloid cells (IPF-associated)', color: 'var(--cyan)', bg: 'var(--cyan-dim)' },
1370
+ },
1371
+ {
1372
+ action: 'differential_expression', params: 'method="DESeq2", contrast="IPF_vs_Ctrl"', category: 'comp',
1373
+ budget: 67000, budgetPct: 67.0, time: 137, timePct: 76.1,
1374
+ output: ['1,847 DE genes (|log2FC| > 1, padj < 0.05)','Top upregulated in IPF:',' SPP1 log2FC=3.42 padj=1.2e-18',' MMP7 log2FC=2.89 padj=3.4e-15',' COL1A1 log2FC=2.67 padj=8.7e-14',' TGFB1 log2FC=1.95 padj=2.1e-09','Top downregulated: AGER (-3.1), SFTPC (-2.8), HOPX (-2.3)'],
1375
+ reward: { validity: 0.95, ordering: 1.00, info_gain: 0.78, efficiency: 0.80, novelty: 0.88, penalty: 0.0 },
1376
+ total: 0.82,
1377
+ discovery: { title: 'SPP1 strongly upregulated in IPF', detail: 'log2FC=3.42, padj=1.2e-18', color: 'var(--pink)', bg: 'rgba(244,114,182,0.10)' },
1378
+ },
1379
+ {
1380
+ action: 'pathway_enrichment', params: 'tool="gseapy", gene_sets="KEGG,Reactome"', category: 'comp',
1381
+ budget: 66600, budgetPct: 66.6, time: 136, timePct: 75.6,
1382
+ output: ['Top enriched pathways (IPF vs Control):',' ECM-receptor interaction padj=4.2e-12',' TGF-beta signaling padj=1.8e-09',' PI3K-Akt signaling padj=3.1e-07',' Focal adhesion padj=8.9e-07','SPP1 participates in 3/4 top pathways'],
1383
+ reward: { validity: 0.90, ordering: 1.00, info_gain: 0.60, efficiency: 0.85, novelty: 0.75, penalty: 0.0 },
1384
+ total: 0.58,
1385
+ discovery: { title: 'SPP1 in ECM/TGF-beta/PI3K pathways', detail: 'Core fibrosis signaling axis confirmed', color: 'var(--purple)', bg: 'rgba(167,139,250,0.10)' },
1386
+ },
1387
+ {
1388
+ action: 'marker_selection', params: 'candidates=["SPP1","MMP7","COL1A1"]', category: 'comp',
1389
+ budget: 66200, budgetPct: 66.2, time: 135, timePct: 75.0,
1390
+ output: ['Marker ranking by discriminative power:',' 1. SPP1 - AUROC: 0.94, specificity: 0.89',' 2. MMP7 - AUROC: 0.87, specificity: 0.82',' 3. COL1A1 - AUROC: 0.81, specificity: 0.76','SPP1 selected as primary biomarker candidate'],
1391
+ reward: { validity: 0.90, ordering: 1.00, info_gain: 0.50, efficiency: 0.88, novelty: 0.70, penalty: 0.0 },
1392
+ total: 0.55,
1393
+ },
1394
+ {
1395
+ action: 'validate_marker', params: 'gene="SPP1", method="cross_validation"', category: 'comp',
1396
+ budget: 65200, budgetPct: 65.2, time: 130, timePct: 72.2,
1397
+ output: ['SPP1 Biomarker Validation Report:',' 5-fold CV AUROC: 0.91 (+/- 0.03)',' Sensitivity: 0.88',' Specificity: 0.87',' Positive LR: 6.77',' Expression in Aberrant Basaloid: 94.2% of cells',' Status: VALIDATED as IPF biomarker'],
1398
+ reward: { validity: 0.95, ordering: 1.00, info_gain: 0.72, efficiency: 0.82, novelty: 0.85, penalty: 0.0 },
1399
+ total: 0.76,
1400
+ discovery: { title: 'SPP1 validated as IPF biomarker', detail: 'AUROC=0.91, specificity=0.87', color: 'var(--green)', bg: 'var(--green-dim)' },
1401
+ },
1402
+ {
1403
+ action: 'synthesize_conclusion', params: 'confidence=0.85', category: 'meta',
1404
+ budget: 65000, budgetPct: 65.0, time: 129, timePct: 71.7,
1405
+ output: ['CONCLUSION (confidence: 0.85):','','SPP1 is a validated biomarker for IPF with strong','discriminative power (AUROC=0.91). It is upregulated','3.42-fold in IPF lungs, concentrated in Aberrant Basaloid','cells (94.2%), and participates in ECM-receptor, TGF-beta,','and PI3K-Akt signaling pathways.','','Literature match: 4/5 expected findings confirmed','Calibration: Well-calibrated (no overconfidence penalty)'],
1406
+ reward: { validity: 1.00, ordering: 1.00, info_gain: 0.40, efficiency: 0.90, novelty: 0.50, penalty: 0.0 },
1407
+ total: 0.91, terminal: true,
1408
+ },
1409
+ ];
1410
+
1411
+ // State
1412
+ let running = false;
1413
+ let cumReward = 0;
1414
+
1415
+ // DOM refs
1416
+ const terminalEl = document.getElementById('terminal');
1417
+ const statusDot = document.getElementById('statusDot');
1418
+ const statusText = document.getElementById('statusText');
1419
+ const runBtn = document.getElementById('runBtn');
1420
+ const labActionLabel = document.getElementById('labActionLabel');
1421
+
1422
+ // Helpers
1423
+ function addLine(html) {
1424
+ const div = document.createElement('div');
1425
+ div.className = 't-line';
1426
+ div.innerHTML = html || '&nbsp;';
1427
+ terminalEl.appendChild(div);
1428
+ terminalEl.scrollTop = terminalEl.scrollHeight;
1429
+ }
1430
+
1431
+ function setGauge(id, value, pct, color) {
1432
+ document.getElementById(id + 'Val').textContent = value;
1433
+ const fill = document.getElementById(id + 'Fill');
1434
+ fill.style.width = pct + '%';
1435
+ if (color) fill.style.background = color;
1436
+ }
1437
+
1438
+ function setRewardBars(r) {
1439
+ for (const key of ['validity','ordering','info_gain','efficiency','novelty','penalty']) {
1440
+ const el = document.getElementById('rw-' + key);
1441
+ el.style.width = (r[key] * 100) + '%';
1442
+ el.textContent = r[key] > 0.01 ? r[key].toFixed(2) : '';
1443
+ }
1444
+ }
1445
+
1446
+ function clearRewardBars() {
1447
+ for (const key of ['validity','ordering','info_gain','efficiency','novelty','penalty']) {
1448
+ const el = document.getElementById('rw-' + key);
1449
+ el.style.width = '0%';
1450
+ el.textContent = '';
1451
+ }
1452
+ }
1453
+
1454
+ function addPipeStep(step, index) {
1455
+ const el = document.createElement('div');
1456
+ el.className = 'pipe-step';
1457
+ el.id = 'pipe-' + index;
1458
+ const catColor = step.category === 'wet' ? 'var(--green)' : step.category === 'comp' ? 'var(--accent)' : 'var(--pink)';
1459
+ el.innerHTML = `<div class="step-icon" style="color:${catColor};border-color:${catColor};">${index + 1}</div><span>${step.action}</span>`;
1460
+ document.getElementById('pipelineSteps').appendChild(el);
1461
+ requestAnimationFrame(() => el.classList.add('visible'));
1462
+ return el;
1463
+ }
1464
+
1465
+ function addDiscovery(d) {
1466
+ const c = document.getElementById('discoveries');
1467
+ if (c.querySelector('.empty-state')) c.innerHTML = '';
1468
+ const el = document.createElement('div');
1469
+ el.className = 'discovery';
1470
+ el.innerHTML = `<div class="disc-icon" style="background:${d.bg};color:${d.color};">&#9670;</div><div class="disc-body"><div class="disc-title">${d.title}</div><div class="disc-detail">${d.detail}</div></div>`;
1471
+ c.appendChild(el);
1472
+ requestAnimationFrame(() => el.classList.add('visible'));
1473
+ }
1474
+
1475
+ function addRewardHistory(step, index) {
1476
+ const c = document.getElementById('rewardHistory');
1477
+ if (c.querySelector('.empty-state')) c.innerHTML = '';
1478
+ const el = document.createElement('div');
1479
+ el.className = 'step-reward-mini';
1480
+ el.innerHTML = `<span class="srm-name">${index + 1}. ${step.action}</span><span class="srm-val ${step.total >= 0 ? 'pos' : 'neg'}">${step.total >= 0 ? '+' : ''}${step.total.toFixed(2)}</span>`;
1481
+ c.appendChild(el);
1482
+ requestAnimationFrame(() => el.classList.add('visible'));
1483
+ }
1484
+
1485
+ function selectScenario(el) {
1486
+ if (running) return;
1487
+ document.querySelectorAll('.scenario-opt').forEach(e => e.classList.remove('active'));
1488
+ el.classList.add('active');
1489
+ }
1490
+
1491
+ function wait(ms) { return new Promise(r => setTimeout(r, ms)); }
1492
+
1493
+ // ---- Run ----
1494
+ async function startDemo() {
1495
+ if (running) return;
1496
+ running = true;
1497
+ runBtn.disabled = true;
1498
+ runBtn.textContent = 'Running...';
1499
+ statusDot.classList.add('live');
1500
+ statusText.textContent = 'Running';
1501
+ terminalEl.innerHTML = '';
1502
+ cumReward = 0;
1503
+ document.getElementById('pipelineSteps').innerHTML = '';
1504
+ document.getElementById('discoveries').innerHTML = '<div class="empty-state">No discoveries yet</div>';
1505
+ document.getElementById('rewardHistory').innerHTML = '<div class="empty-state">No steps yet</div>';
1506
+ document.getElementById('violations').innerHTML = '<div class="empty-state">No violations</div>';
1507
+ clearRewardBars();
1508
+ document.getElementById('cumReward').textContent = '0.00';
1509
+ document.getElementById('stepRewardLabel').textContent = '--';
1510
+ initAgent();
1511
+
1512
+ addLine('<span class="t-label">[BioEnv]</span> <span class="t-dim">Initializing environment...</span>');
1513
+ await wait(500);
1514
+ addLine('<span class="t-label">[BioEnv]</span> Scenario: <span class="t-str">biomarker_validation_lung</span> (Hard)');
1515
+ await wait(200);
1516
+ addLine('<span class="t-label">[BioEnv]</span> Organism: <span class="t-str">Homo sapiens</span> | Tissue: <span class="t-str">Lung</span>');
1517
+ await wait(200);
1518
+ addLine('<span class="t-label">[BioEnv]</span> Budget: <span class="t-num">$100,000</span> | Time: <span class="t-num">180 days</span> | Max steps: <span class="t-num">30</span>');
1519
+ await wait(200);
1520
+ addLine('<span class="t-label">[BioEnv]</span> Task: Validate <span class="t-kw">SPP1</span> as biomarker for idiopathic pulmonary fibrosis');
1521
+ await wait(400);
1522
+ addLine('');
1523
+
1524
+ for (let i = 0; i < EPISODE.length; i++) {
1525
+ await runStep(i);
1526
+ await wait(500);
1527
+ }
1528
+
1529
+ // Done
1530
+ moveAgentTo('idle');
1531
+ labActionLabel.classList.remove('visible');
1532
+ addLine('');
1533
+ addLine('<span class="t-label">[BioEnv]</span> <span class="t-ok">Episode complete!</span>');
1534
+ addLine('<span class="t-label">[BioEnv]</span> Total reward: <span class="t-ok">+' + cumReward.toFixed(2) + '</span> | Steps: <span class="t-num">' + EPISODE.length + '</span> | Budget remaining: <span class="t-num">$65,000</span>');
1535
+ addLine('<span class="t-label">[BioEnv]</span> Literature match: <span class="t-ok">4/5 expected findings confirmed</span>');
1536
+ addLine('<span class="t-label">[BioEnv]</span> Calibration: <span class="t-ok">Well-calibrated</span> (no overconfidence penalty)');
1537
+
1538
+ statusDot.classList.remove('live');
1539
+ statusText.textContent = 'Complete';
1540
+ runBtn.textContent = 'Run Episode';
1541
+ runBtn.disabled = false;
1542
+ running = false;
1543
+ }
1544
+
1545
+ async function runStep(i) {
1546
+ const step = EPISODE[i];
1547
+ const station = ACTION_STATION[step.action] || 'computer';
1548
+
1549
+ // Move agent in lab
1550
+ moveAgentTo(station);
1551
+ labActionLabel.textContent = step.action + '()';
1552
+ labActionLabel.classList.add('visible');
1553
+ await wait(800); // wait for agent to travel
1554
+
1555
+ // Start working animation
1556
+ setAgentWorking(step.action);
1557
+ spawnParticles(agent.targetX, agent.targetY, STATIONS[station].color);
1558
+
1559
+ // Pipeline sidebar
1560
+ const pipeEl = addPipeStep(step, i);
1561
+ if (i > 0) {
1562
+ const prev = document.getElementById('pipe-' + (i - 1));
1563
+ prev.classList.remove('active');
1564
+ prev.classList.add('done');
1565
+ prev.querySelector('.step-icon').innerHTML = '&#10003;';
1566
+ }
1567
+ pipeEl.classList.add('active');
1568
+
1569
+ // Gauges
1570
+ setGauge('budget', '$' + step.budget.toLocaleString(), step.budgetPct,
1571
+ step.budgetPct > 50 ? 'var(--green)' : step.budgetPct > 25 ? 'var(--amber)' : 'var(--red)');
1572
+ setGauge('time', step.time + ' / 180 days', step.timePct, 'var(--cyan)');
1573
+ setGauge('step', (i + 1) + ' / 30', ((i + 1) / 30 * 100), 'var(--accent)');
1574
+
1575
+ // Terminal output
1576
+ const catTag = step.category === 'wet' ? '<span class="t-ok">WET</span>'
1577
+ : step.category === 'comp' ? '<span class="t-label">CMP</span>'
1578
+ : '<span class="t-kw">META</span>';
1579
+ addLine(`<span class="t-dim">Step ${i + 1}</span> ${catTag} <span class="t-fn">${step.action}</span>(<span class="t-str">${step.params}</span>)`);
1580
+ await wait(300);
1581
+
1582
+ for (const line of step.output) {
1583
+ addLine(' <span class="t-sub">' + line + '</span>');
1584
+ await wait(80);
1585
+ }
1586
+
1587
+ // Reward
1588
+ cumReward += step.total;
1589
+ document.getElementById('stepRewardLabel').textContent = 'Step ' + (i + 1) + ': ' + step.action;
1590
+ setRewardBars(step.reward);
1591
+ document.getElementById('cumReward').textContent = cumReward.toFixed(2);
1592
+ addRewardHistory(step, i);
1593
+
1594
+ const rewardStr = step.total >= 0
1595
+ ? '<span class="t-ok">+' + step.total.toFixed(2) + '</span>'
1596
+ : '<span class="t-err">' + step.total.toFixed(2) + '</span>';
1597
+ addLine(` <span class="t-dim">reward: ${rewardStr} <span class="t-dim">(cumulative: ${cumReward.toFixed(2)})</span></span>`);
1598
+ addLine('');
1599
+
1600
+ if (step.discovery) addDiscovery(step.discovery);
1601
+
1602
+ // Done working
1603
+ agent.working = false;
1604
+ spawnParticles(agent.targetX, agent.targetY, '#34d399', 6);
1605
+
1606
+ if (step.terminal) {
1607
+ pipeEl.classList.remove('active');
1608
+ pipeEl.classList.add('done');
1609
+ pipeEl.querySelector('.step-icon').innerHTML = '&#10003;';
1610
+ }
1611
+ }
1612
+
1613
+ function resetDemo() {
1614
+ if (running) return;
1615
+ terminalEl.innerHTML = '';
1616
+ cumReward = 0;
1617
+ document.getElementById('pipelineSteps').innerHTML = '';
1618
+ document.getElementById('discoveries').innerHTML = '<div class="empty-state">No discoveries yet</div>';
1619
+ document.getElementById('rewardHistory').innerHTML = '<div class="empty-state">No steps yet</div>';
1620
+ document.getElementById('violations').innerHTML = '<div class="empty-state">No violations</div>';
1621
+ clearRewardBars();
1622
+ document.getElementById('cumReward').textContent = '0.00';
1623
+ document.getElementById('stepRewardLabel').textContent = '--';
1624
+ setGauge('budget', '$100,000', 100, 'var(--green)');
1625
+ setGauge('time', '180 / 180 days', 100, 'var(--cyan)');
1626
+ setGauge('step', '0 / 30', 0, 'var(--accent)');
1627
+ statusDot.classList.remove('live');
1628
+ statusText.textContent = 'Ready';
1629
+ labActionLabel.classList.remove('visible');
1630
+ initAgent();
1631
+ addLine('<span class="t-dim">Environment reset. Click "Run Episode" to start.</span>');
1632
+ }
1633
+
1634
+ // Init
1635
+ addLine('<span class="t-dim">BioEnv v1.0 | biomarker_validation_lung</span>');
1636
+ addLine('<span class="t-dim">Click "Run Episode" to start the demo.</span>');
1637
+ </script>
1638
+ </body>
1639
+ </html>
eval_compare.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compare base vs trained model on the same prompts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import random
8
+ from typing import Dict, List
9
+
10
+ import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ from training_script import (
14
+ SYSTEM_PROMPT,
15
+ OpenEnvReward,
16
+ build_prompt_examples,
17
+ completion_to_text,
18
+ parse_action_completion,
19
+ selected_scenarios,
20
+ )
21
+
22
+
23
+ def generate_completions(
24
+ model,
25
+ tokenizer,
26
+ prompts: List[str],
27
+ max_new_tokens: int = 220,
28
+ ) -> List[str]:
29
+ completions = []
30
+ for prompt in prompts:
31
+ messages = [
32
+ {"role": "system", "content": SYSTEM_PROMPT},
33
+ {"role": "user", "content": prompt},
34
+ ]
35
+ input_text = tokenizer.apply_chat_template(
36
+ messages, tokenize=False, add_generation_prompt=True
37
+ )
38
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
39
+ with torch.no_grad():
40
+ output = model.generate(
41
+ **inputs,
42
+ max_new_tokens=max_new_tokens,
43
+ do_sample=True,
44
+ temperature=0.7,
45
+ top_p=0.9,
46
+ )
47
+ generated = output[0][inputs["input_ids"].shape[1]:]
48
+ completions.append(tokenizer.decode(generated, skip_special_tokens=True))
49
+ return completions
50
+
51
+
52
+ def evaluate_model(
53
+ model,
54
+ tokenizer,
55
+ examples: List[Dict[str, str]],
56
+ reward_fn: OpenEnvReward,
57
+ label: str,
58
+ ) -> Dict[str, float]:
59
+ prompts = [ex["prompt"] for ex in examples]
60
+ completions = generate_completions(model, tokenizer, prompts)
61
+
62
+ rewards = []
63
+ valid_actions = 0
64
+ for comp, ex in zip(completions, examples):
65
+ reward = reward_fn(
66
+ completions=[comp],
67
+ scenario_name=[ex.get("scenario_name")],
68
+ history_actions=[ex.get("history_actions")],
69
+ )[0]
70
+ rewards.append(reward)
71
+ if parse_action_completion(comp) is not None:
72
+ valid_actions += 1
73
+
74
+ avg_reward = sum(rewards) / len(rewards) if rewards else 0
75
+ valid_pct = valid_actions / len(completions) * 100 if completions else 0
76
+
77
+ print(f"\n{'='*50}")
78
+ print(f" {label}")
79
+ print(f"{'='*50}")
80
+ print(f" Samples: {len(completions)}")
81
+ print(f" Avg reward: {avg_reward:.4f}")
82
+ print(f" Min reward: {min(rewards):.4f}")
83
+ print(f" Max reward: {max(rewards):.4f}")
84
+ print(f" Valid actions: {valid_actions}/{len(completions)} ({valid_pct:.1f}%)")
85
+ print()
86
+
87
+ # Show a few example completions
88
+ for i, (comp, r) in enumerate(zip(completions[:3], rewards[:3])):
89
+ print(f" Example {i+1} (reward={r:.2f}):")
90
+ print(f" {comp[:200]}")
91
+ print()
92
+
93
+ return {"avg_reward": avg_reward, "valid_pct": valid_pct, "rewards": rewards}
94
+
95
+
96
+ def main():
97
+ parser = argparse.ArgumentParser(description="Compare base vs trained model")
98
+ parser.add_argument("--base-model", default="Qwen/Qwen3.5-0.8B",
99
+ help="Base model ID from HuggingFace")
100
+ parser.add_argument("--trained-model", default="./grpo-output",
101
+ help="Path to trained model (local dir or HF repo)")
102
+ parser.add_argument("--num-samples", type=int, default=16,
103
+ help="Number of eval prompts")
104
+ parser.add_argument("--seed", type=int, default=42)
105
+ parser.add_argument("--trust-remote-code", action="store_true")
106
+ args = parser.parse_args()
107
+
108
+ random.seed(args.seed)
109
+
110
+ # Build eval prompts
111
+ scenarios = selected_scenarios(None)
112
+ examples = build_prompt_examples(
113
+ dataset_episodes=args.num_samples,
114
+ rollout_steps=1, # one prompt per episode
115
+ collection_policy="heuristic",
116
+ scenario_names=scenarios,
117
+ seed=args.seed,
118
+ domain_randomise=False,
119
+ )
120
+ print(f"Built {len(examples)} eval prompts across {len(scenarios)} scenarios")
121
+
122
+ reward_fn = OpenEnvReward(reward_backend="local", base_url="")
123
+
124
+ # Evaluate base model
125
+ print(f"\nLoading base model: {args.base_model}")
126
+ base_tokenizer = AutoTokenizer.from_pretrained(
127
+ args.base_model, trust_remote_code=args.trust_remote_code
128
+ )
129
+ if base_tokenizer.pad_token is None:
130
+ base_tokenizer.pad_token = base_tokenizer.eos_token
131
+ base_model = AutoModelForCausalLM.from_pretrained(
132
+ args.base_model,
133
+ trust_remote_code=args.trust_remote_code,
134
+ torch_dtype=torch.bfloat16,
135
+ device_map="auto",
136
+ )
137
+ base_results = evaluate_model(
138
+ base_model, base_tokenizer, examples, reward_fn, "BASE MODEL"
139
+ )
140
+ del base_model
141
+ torch.cuda.empty_cache()
142
+
143
+ # Evaluate trained model
144
+ print(f"\nLoading trained model: {args.trained_model}")
145
+ trained_tokenizer = AutoTokenizer.from_pretrained(
146
+ args.trained_model, trust_remote_code=args.trust_remote_code
147
+ )
148
+ if trained_tokenizer.pad_token is None:
149
+ trained_tokenizer.pad_token = trained_tokenizer.eos_token
150
+ trained_model = AutoModelForCausalLM.from_pretrained(
151
+ args.trained_model,
152
+ trust_remote_code=args.trust_remote_code,
153
+ torch_dtype=torch.bfloat16,
154
+ device_map="auto",
155
+ )
156
+ trained_results = evaluate_model(
157
+ trained_model, trained_tokenizer, examples, reward_fn, "TRAINED MODEL"
158
+ )
159
+
160
+ # Summary
161
+ delta = trained_results["avg_reward"] - base_results["avg_reward"]
162
+ print(f"{'='*50}")
163
+ print(f" COMPARISON SUMMARY")
164
+ print(f"{'='*50}")
165
+ print(f" Base avg reward: {base_results['avg_reward']:.4f}")
166
+ print(f" Trained avg reward: {trained_results['avg_reward']:.4f}")
167
+ print(f" Delta: {delta:+.4f}")
168
+ print(f" Base valid actions: {base_results['valid_pct']:.1f}%")
169
+ print(f" Trained valid: {trained_results['valid_pct']:.1f}%")
170
+ print()
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()
models.py CHANGED
@@ -73,6 +73,1025 @@ META_ACTIONS = frozenset({
73
  })
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class SubagentType(str, Enum):
77
  WET_LAB_PLANNER = "wet_lab_planner"
78
  COMPUTATIONAL_ANALYST = "computational_analyst"
@@ -96,29 +1115,61 @@ class ExperimentAction(Action):
96
  """
97
 
98
  action_type: ActionType = Field(
99
- ..., description="Discrete experiment or analysis step type"
 
 
 
 
 
100
  )
101
  input_targets: List[str] = Field(
102
  default_factory=list,
103
- description="References to prior outputs, samples, or artifacts",
 
 
 
104
  )
105
  method: Optional[str] = Field(
106
- None, description="Specific method or tool (e.g. 'Seurat', 'CellRanger')"
 
 
 
 
 
 
107
  )
108
  parameters: Dict[str, Any] = Field(
109
- default_factory=dict, description="Method-specific parameters"
 
 
 
 
 
110
  )
111
  expected_output_type: Optional[str] = Field(
112
- None, description="What the agent expects this step to produce"
 
 
 
 
 
113
  )
114
  justification: Optional[str] = Field(
115
- None, description="Scientific rationale for this step"
 
 
 
 
116
  )
117
  invoked_subagent: Optional[SubagentType] = Field(
118
  None, description="Sub-agent to delegate to, if any"
119
  )
120
  tool_call_spec: Optional[Dict[str, Any]] = Field(
121
- None, description="Structured tool invocation specification"
 
 
 
 
122
  )
123
  confidence: float = Field(
124
  0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
@@ -216,14 +1267,22 @@ class TaskSpec(BaseModel):
216
  organism: str = "human"
217
  tissue: str = "blood"
218
  conditions: List[str] = Field(default_factory=list)
219
- available_assays: List[str] = Field(default_factory=lambda: [
220
- "10x_chromium", "smart-seq2", "bulk_rna_seq",
221
- "atac-seq", "cite-seq", "spatial_transcriptomics",
222
- ])
223
- available_tools: List[str] = Field(default_factory=lambda: [
224
- "CellRanger", "Seurat", "Scanpy", "DESeq2", "GSEA",
225
- "Monocle", "scVelo", "CellChat", "SCENIC",
226
- ])
 
 
 
 
 
 
 
 
227
  budget_limit: float = 100_000.0
228
  time_limit_days: float = 180.0
229
  prior_observations: List[str] = Field(default_factory=list)
@@ -234,7 +1293,10 @@ class TaskSpec(BaseModel):
234
 
235
 
236
  class ConclusionClaim(BaseModel):
237
- claim: str
 
 
 
238
  evidence_steps: List[int] = Field(default_factory=list)
239
  confidence: float = Field(0.5, ge=0.0, le=1.0)
240
  claim_type: str = "correlational"
@@ -254,9 +1316,26 @@ class ExperimentObservation(Observation):
254
  task: TaskSpec = Field(default_factory=TaskSpec)
255
  step_index: int = 0
256
  pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
257
- available_assays: List[str] = Field(default_factory=list)
258
- available_tools: List[str] = Field(default_factory=list)
259
- resource_usage: ResourceUsage = Field(default_factory=ResourceUsage)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  latest_output: Optional[IntermediateOutput] = None
261
  all_outputs: List[IntermediateOutput] = Field(default_factory=list)
262
  discovered_markers: List[str] = Field(default_factory=list)
@@ -266,3 +1345,313 @@ class ExperimentObservation(Observation):
266
  conclusions: List[ConclusionClaim] = Field(default_factory=list)
267
  rule_violations: List[str] = Field(default_factory=list)
268
  step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  })
74
 
75
 
76
+ # ── Tool, Assay & Modality Registries ──────────────────────────────────────
77
+
78
+
79
+ class ToolCategory(str, Enum):
80
+ ALIGNMENT = "alignment"
81
+ PREPROCESSING = "preprocessing"
82
+ NORMALIZATION = "normalization"
83
+ DIMENSIONALITY_REDUCTION = "dimensionality_reduction"
84
+ CLUSTERING = "clustering"
85
+ DIFFERENTIAL_EXPRESSION = "differential_expression"
86
+ TRAJECTORY = "trajectory"
87
+ GENE_REGULATORY_NETWORK = "gene_regulatory_network"
88
+ CELL_COMMUNICATION = "cell_communication"
89
+ SPATIAL = "spatial"
90
+ MULTIMODAL_INTEGRATION = "multimodal_integration"
91
+ GENE_SET_ANALYSIS = "gene_set_analysis"
92
+ VARIANT_CALLING = "variant_calling"
93
+ PEAK_CALLING = "peak_calling"
94
+ IMPUTATION = "imputation"
95
+ BATCH_CORRECTION = "batch_correction"
96
+ CELL_TYPE_ANNOTATION = "cell_type_annotation"
97
+ SIMULATION = "simulation"
98
+ VISUALIZATION = "visualization"
99
+ QUALITY_CONTROL = "quality_control"
100
+ PERTURBATION_ANALYSIS = "perturbation_analysis"
101
+
102
+
103
+ class ToolSpec(BaseModel):
104
+ """Registry entry describing a bioinformatics tool."""
105
+
106
+ name: str
107
+ category: ToolCategory
108
+ modalities: List[str] = Field(default_factory=list)
109
+ description: str = ""
110
+ input_types: List[str] = Field(default_factory=list)
111
+ output_types: List[str] = Field(default_factory=list)
112
+ typical_runtime_hours: float = 0.1
113
+ typical_cost_usd: float = 0.0
114
+ requires_gpu: bool = False
115
+ open_source: bool = True
116
+
117
+
118
+ TOOL_REGISTRY: Dict[str, ToolSpec] = {
119
+ # ── Alignment & quantification ──
120
+ "CellRanger": ToolSpec(
121
+ name="CellRanger",
122
+ category=ToolCategory.ALIGNMENT,
123
+ modalities=["scRNA-seq", "scATAC-seq", "CITE-seq", "scMultiome"],
124
+ description="10x Genomics pipeline for alignment, barcode demux, and counting",
125
+ input_types=["fastq"],
126
+ output_types=["count_matrix", "bam"],
127
+ typical_runtime_hours=4.0,
128
+ open_source=False,
129
+ ),
130
+ "STARsolo": ToolSpec(
131
+ name="STARsolo",
132
+ category=ToolCategory.ALIGNMENT,
133
+ modalities=["scRNA-seq", "scATAC-seq"],
134
+ description="Drop-seq / 10x-compatible aligner built into STAR",
135
+ input_types=["fastq"],
136
+ output_types=["count_matrix", "bam"],
137
+ typical_runtime_hours=3.0,
138
+ ),
139
+ "kallisto_bustools": ToolSpec(
140
+ name="kallisto_bustools",
141
+ category=ToolCategory.ALIGNMENT,
142
+ modalities=["scRNA-seq"],
143
+ description="Pseudoalignment-based lightweight quantification",
144
+ input_types=["fastq"],
145
+ output_types=["count_matrix"],
146
+ typical_runtime_hours=1.0,
147
+ ),
148
+ "Salmon_alevin": ToolSpec(
149
+ name="Salmon_alevin",
150
+ category=ToolCategory.ALIGNMENT,
151
+ modalities=["scRNA-seq"],
152
+ description="Quasi-mapping quantification for single-cell RNA-seq",
153
+ input_types=["fastq"],
154
+ output_types=["count_matrix"],
155
+ typical_runtime_hours=1.5,
156
+ ),
157
+ "spaceranger": ToolSpec(
158
+ name="spaceranger",
159
+ category=ToolCategory.ALIGNMENT,
160
+ modalities=["spatial_transcriptomics"],
161
+ description="10x Visium spatial alignment and quantification",
162
+ input_types=["fastq", "image"],
163
+ output_types=["count_matrix", "spatial_coords"],
164
+ typical_runtime_hours=3.0,
165
+ open_source=False,
166
+ ),
167
+ # ── Preprocessing / analysis frameworks ──
168
+ "Scanpy": ToolSpec(
169
+ name="Scanpy",
170
+ category=ToolCategory.PREPROCESSING,
171
+ modalities=["scRNA-seq", "scATAC-seq", "spatial_transcriptomics"],
172
+ description="Python single-cell analysis framework",
173
+ input_types=["count_matrix", "h5ad"],
174
+ output_types=["h5ad", "embedding", "cluster_result"],
175
+ typical_runtime_hours=0.5,
176
+ ),
177
+ "Seurat": ToolSpec(
178
+ name="Seurat",
179
+ category=ToolCategory.PREPROCESSING,
180
+ modalities=["scRNA-seq", "CITE-seq", "spatial_transcriptomics", "scATAC-seq"],
181
+ description="R single-cell analysis toolkit with multimodal support",
182
+ input_types=["count_matrix", "h5seurat"],
183
+ output_types=["h5seurat", "embedding", "cluster_result"],
184
+ typical_runtime_hours=0.5,
185
+ ),
186
+ "Bioconductor_SingleCellExperiment": ToolSpec(
187
+ name="Bioconductor_SingleCellExperiment",
188
+ category=ToolCategory.PREPROCESSING,
189
+ modalities=["scRNA-seq"],
190
+ description="R/Bioconductor framework for single-cell experiments",
191
+ input_types=["count_matrix"],
192
+ output_types=["sce_object"],
193
+ typical_runtime_hours=0.3,
194
+ ),
195
+ # ── Normalization ──
196
+ "scran": ToolSpec(
197
+ name="scran",
198
+ category=ToolCategory.NORMALIZATION,
199
+ modalities=["scRNA-seq"],
200
+ description="Pool-based size-factor normalization",
201
+ input_types=["count_matrix"],
202
+ output_types=["normalized_matrix"],
203
+ ),
204
+ "sctransform": ToolSpec(
205
+ name="sctransform",
206
+ category=ToolCategory.NORMALIZATION,
207
+ modalities=["scRNA-seq"],
208
+ description="Variance-stabilizing transformation via regularized NB regression",
209
+ input_types=["count_matrix"],
210
+ output_types=["normalized_matrix"],
211
+ ),
212
+ # ── Dimensionality reduction ──
213
+ "scVI": ToolSpec(
214
+ name="scVI",
215
+ category=ToolCategory.DIMENSIONALITY_REDUCTION,
216
+ modalities=["scRNA-seq", "CITE-seq", "scATAC-seq"],
217
+ description="Deep generative model for scRNA-seq (variational inference)",
218
+ input_types=["count_matrix"],
219
+ output_types=["latent_embedding"],
220
+ requires_gpu=True,
221
+ ),
222
+ "UMAP": ToolSpec(
223
+ name="UMAP",
224
+ category=ToolCategory.DIMENSIONALITY_REDUCTION,
225
+ modalities=["scRNA-seq", "scATAC-seq", "CITE-seq", "spatial_transcriptomics"],
226
+ description="Uniform manifold approximation for 2D/3D visualization",
227
+ input_types=["pca_embedding", "latent_embedding"],
228
+ output_types=["2d_embedding"],
229
+ ),
230
+ # ── Clustering ──
231
+ "Leiden": ToolSpec(
232
+ name="Leiden",
233
+ category=ToolCategory.CLUSTERING,
234
+ modalities=["scRNA-seq", "scATAC-seq", "CITE-seq"],
235
+ description="Community detection via the Leiden algorithm",
236
+ input_types=["knn_graph"],
237
+ output_types=["cluster_result"],
238
+ ),
239
+ "Louvain": ToolSpec(
240
+ name="Louvain",
241
+ category=ToolCategory.CLUSTERING,
242
+ modalities=["scRNA-seq", "scATAC-seq"],
243
+ description="Community detection via Louvain modularity optimization",
244
+ input_types=["knn_graph"],
245
+ output_types=["cluster_result"],
246
+ ),
247
+ # ── Differential expression ──
248
+ "DESeq2": ToolSpec(
249
+ name="DESeq2",
250
+ category=ToolCategory.DIFFERENTIAL_EXPRESSION,
251
+ modalities=["bulk_rna_seq", "scRNA-seq"],
252
+ description="Negative binomial GLM-based differential expression",
253
+ input_types=["count_matrix"],
254
+ output_types=["de_result"],
255
+ ),
256
+ "MAST": ToolSpec(
257
+ name="MAST",
258
+ category=ToolCategory.DIFFERENTIAL_EXPRESSION,
259
+ modalities=["scRNA-seq"],
260
+ description="Two-part hurdle model for scRNA-seq DE testing",
261
+ input_types=["count_matrix"],
262
+ output_types=["de_result"],
263
+ ),
264
+ "edgeR": ToolSpec(
265
+ name="edgeR",
266
+ category=ToolCategory.DIFFERENTIAL_EXPRESSION,
267
+ modalities=["bulk_rna_seq", "scRNA-seq"],
268
+ description="Empirical Bayes quasi-likelihood DE testing",
269
+ input_types=["count_matrix"],
270
+ output_types=["de_result"],
271
+ ),
272
+ "Wilcoxon": ToolSpec(
273
+ name="Wilcoxon",
274
+ category=ToolCategory.DIFFERENTIAL_EXPRESSION,
275
+ modalities=["scRNA-seq"],
276
+ description="Rank-sum test for marker gene detection",
277
+ input_types=["count_matrix"],
278
+ output_types=["de_result"],
279
+ ),
280
+ # ── Trajectory & RNA velocity ──
281
+ "Monocle3": ToolSpec(
282
+ name="Monocle3",
283
+ category=ToolCategory.TRAJECTORY,
284
+ modalities=["scRNA-seq"],
285
+ description="Reversed graph embedding for pseudotime trajectories",
286
+ input_types=["count_matrix", "embedding"],
287
+ output_types=["trajectory_result", "pseudotime"],
288
+ ),
289
+ "scVelo": ToolSpec(
290
+ name="scVelo",
291
+ category=ToolCategory.TRAJECTORY,
292
+ modalities=["scRNA-seq"],
293
+ description="RNA velocity estimation via spliced/unspliced dynamics",
294
+ input_types=["count_matrix"],
295
+ output_types=["velocity_result"],
296
+ ),
297
+ "CellRank": ToolSpec(
298
+ name="CellRank",
299
+ category=ToolCategory.TRAJECTORY,
300
+ modalities=["scRNA-seq"],
301
+ description="Fate probability estimation combining velocity and transcriptomics",
302
+ input_types=["velocity_result", "count_matrix"],
303
+ output_types=["fate_probabilities"],
304
+ ),
305
+ "Slingshot": ToolSpec(
306
+ name="Slingshot",
307
+ category=ToolCategory.TRAJECTORY,
308
+ modalities=["scRNA-seq"],
309
+ description="Minimum spanning tree-based trajectory inference",
310
+ input_types=["embedding", "cluster_result"],
311
+ output_types=["trajectory_result", "pseudotime"],
312
+ ),
313
+ "PAGA": ToolSpec(
314
+ name="PAGA",
315
+ category=ToolCategory.TRAJECTORY,
316
+ modalities=["scRNA-seq"],
317
+ description="Partition-based graph abstraction for topology estimation",
318
+ input_types=["knn_graph", "cluster_result"],
319
+ output_types=["trajectory_result"],
320
+ ),
321
+ # ── Gene regulatory networks ──
322
+ "SCENIC": ToolSpec(
323
+ name="SCENIC",
324
+ category=ToolCategory.GENE_REGULATORY_NETWORK,
325
+ modalities=["scRNA-seq"],
326
+ description="Single-cell regulatory network inference and clustering",
327
+ input_types=["count_matrix"],
328
+ output_types=["regulon_result", "network_result"],
329
+ typical_runtime_hours=6.0,
330
+ ),
331
+ "CellOracle": ToolSpec(
332
+ name="CellOracle",
333
+ category=ToolCategory.GENE_REGULATORY_NETWORK,
334
+ modalities=["scRNA-seq", "scATAC-seq", "scMultiome"],
335
+ description="GRN-based in-silico perturbation prediction",
336
+ input_types=["count_matrix", "peak_matrix"],
337
+ output_types=["network_result", "perturbation_prediction"],
338
+ typical_runtime_hours=4.0,
339
+ ),
340
+ # ── Cell-cell communication ──
341
+ "CellChat": ToolSpec(
342
+ name="CellChat",
343
+ category=ToolCategory.CELL_COMMUNICATION,
344
+ modalities=["scRNA-seq", "spatial_transcriptomics"],
345
+ description="Ligand-receptor interaction inference with communication patterns",
346
+ input_types=["count_matrix", "cluster_result"],
347
+ output_types=["communication_result"],
348
+ ),
349
+ "NicheNet": ToolSpec(
350
+ name="NicheNet",
351
+ category=ToolCategory.CELL_COMMUNICATION,
352
+ modalities=["scRNA-seq"],
353
+ description="Ligand-target link prediction using prior knowledge",
354
+ input_types=["count_matrix", "de_result"],
355
+ output_types=["communication_result"],
356
+ ),
357
+ "LIANA": ToolSpec(
358
+ name="LIANA",
359
+ category=ToolCategory.CELL_COMMUNICATION,
360
+ modalities=["scRNA-seq", "spatial_transcriptomics"],
361
+ description="Framework unifying multiple ligand-receptor methods",
362
+ input_types=["count_matrix", "cluster_result"],
363
+ output_types=["communication_result"],
364
+ ),
365
+ # ── Spatial analysis ──
366
+ "squidpy": ToolSpec(
367
+ name="squidpy",
368
+ category=ToolCategory.SPATIAL,
369
+ modalities=["spatial_transcriptomics"],
370
+ description="Spatial omics analysis (neighborhood, co-occurrence, image features)",
371
+ input_types=["count_matrix", "spatial_coords"],
372
+ output_types=["spatial_result"],
373
+ ),
374
+ "cell2location": ToolSpec(
375
+ name="cell2location",
376
+ category=ToolCategory.SPATIAL,
377
+ modalities=["spatial_transcriptomics"],
378
+ description="Spatial deconvolution mapping cell types to tissue locations",
379
+ input_types=["count_matrix", "spatial_coords", "reference_h5ad"],
380
+ output_types=["deconvolution_result"],
381
+ requires_gpu=True,
382
+ ),
383
+ "BANKSY": ToolSpec(
384
+ name="BANKSY",
385
+ category=ToolCategory.SPATIAL,
386
+ modalities=["spatial_transcriptomics"],
387
+ description="Spatially-aware clustering combining cell and neighbor features",
388
+ input_types=["count_matrix", "spatial_coords"],
389
+ output_types=["cluster_result"],
390
+ ),
391
+ # ── Multimodal integration ──
392
+ "Harmony": ToolSpec(
393
+ name="Harmony",
394
+ category=ToolCategory.BATCH_CORRECTION,
395
+ modalities=["scRNA-seq", "scATAC-seq", "CITE-seq"],
396
+ description="Fast iterative batch correction on PCA embeddings",
397
+ input_types=["pca_embedding"],
398
+ output_types=["corrected_embedding"],
399
+ ),
400
+ "scanorama": ToolSpec(
401
+ name="scanorama",
402
+ category=ToolCategory.BATCH_CORRECTION,
403
+ modalities=["scRNA-seq"],
404
+ description="Panoramic stitching of scRNA-seq batches",
405
+ input_types=["count_matrix"],
406
+ output_types=["corrected_embedding", "corrected_matrix"],
407
+ ),
408
+ "BBKNN": ToolSpec(
409
+ name="BBKNN",
410
+ category=ToolCategory.BATCH_CORRECTION,
411
+ modalities=["scRNA-seq"],
412
+ description="Batch-balanced KNN graph construction",
413
+ input_types=["pca_embedding"],
414
+ output_types=["knn_graph"],
415
+ ),
416
+ "WNN": ToolSpec(
417
+ name="WNN",
418
+ category=ToolCategory.MULTIMODAL_INTEGRATION,
419
+ modalities=["CITE-seq", "scMultiome"],
420
+ description="Weighted nearest neighbors for multimodal integration (Seurat v4+)",
421
+ input_types=["rna_embedding", "protein_embedding"],
422
+ output_types=["multimodal_embedding"],
423
+ ),
424
+ "MOFA+": ToolSpec(
425
+ name="MOFA+",
426
+ category=ToolCategory.MULTIMODAL_INTEGRATION,
427
+ modalities=["scMultiome", "CITE-seq"],
428
+ description="Multi-omics factor analysis for unsupervised integration",
429
+ input_types=["count_matrix", "peak_matrix"],
430
+ output_types=["factor_result"],
431
+ ),
432
+ "ArchR": ToolSpec(
433
+ name="ArchR",
434
+ category=ToolCategory.PREPROCESSING,
435
+ modalities=["scATAC-seq", "scMultiome"],
436
+ description="Full-featured scATAC-seq analysis framework in R",
437
+ input_types=["fragments", "bam"],
438
+ output_types=["peak_matrix", "gene_activity_matrix"],
439
+ typical_runtime_hours=2.0,
440
+ ),
441
+ "Signac": ToolSpec(
442
+ name="Signac",
443
+ category=ToolCategory.PREPROCESSING,
444
+ modalities=["scATAC-seq", "scMultiome"],
445
+ description="Seurat extension for chromatin accessibility analysis",
446
+ input_types=["fragments", "peak_matrix"],
447
+ output_types=["peak_matrix", "motif_result"],
448
+ ),
449
+ "chromVAR": ToolSpec(
450
+ name="chromVAR",
451
+ category=ToolCategory.PEAK_CALLING,
452
+ modalities=["scATAC-seq", "scMultiome"],
453
+ description="TF motif accessibility deviation scoring",
454
+ input_types=["peak_matrix"],
455
+ output_types=["motif_deviation_scores"],
456
+ ),
457
+ # ── Gene set / pathway analysis ──
458
+ "GSEA": ToolSpec(
459
+ name="GSEA",
460
+ category=ToolCategory.GENE_SET_ANALYSIS,
461
+ modalities=["bulk_rna_seq", "scRNA-seq"],
462
+ description="Gene Set Enrichment Analysis (preranked or phenotype-based)",
463
+ input_types=["de_result", "ranked_gene_list"],
464
+ output_types=["pathway_result"],
465
+ ),
466
+ "clusterProfiler": ToolSpec(
467
+ name="clusterProfiler",
468
+ category=ToolCategory.GENE_SET_ANALYSIS,
469
+ modalities=["bulk_rna_seq", "scRNA-seq"],
470
+ description="ORA & GSEA with GO, KEGG, Reactome, and custom gene sets",
471
+ input_types=["de_result", "gene_list"],
472
+ output_types=["pathway_result"],
473
+ ),
474
+ "decoupleR": ToolSpec(
475
+ name="decoupleR",
476
+ category=ToolCategory.GENE_SET_ANALYSIS,
477
+ modalities=["scRNA-seq", "bulk_rna_seq", "spatial_transcriptomics"],
478
+ description="Unified framework for functional activity inference (TF, pathway)",
479
+ input_types=["count_matrix", "de_result"],
480
+ output_types=["activity_scores"],
481
+ ),
482
+ # ── Cell type annotation ──
483
+ "celltypist": ToolSpec(
484
+ name="celltypist",
485
+ category=ToolCategory.CELL_TYPE_ANNOTATION,
486
+ modalities=["scRNA-seq"],
487
+ description="Automated cell type classification with pre-trained models",
488
+ input_types=["count_matrix"],
489
+ output_types=["annotation_result"],
490
+ ),
491
+ "SingleR": ToolSpec(
492
+ name="SingleR",
493
+ category=ToolCategory.CELL_TYPE_ANNOTATION,
494
+ modalities=["scRNA-seq"],
495
+ description="Reference-based cell type annotation using correlation",
496
+ input_types=["count_matrix", "reference_dataset"],
497
+ output_types=["annotation_result"],
498
+ ),
499
+ "scArches": ToolSpec(
500
+ name="scArches",
501
+ category=ToolCategory.CELL_TYPE_ANNOTATION,
502
+ modalities=["scRNA-seq", "scATAC-seq", "CITE-seq"],
503
+ description="Reference mapping and label transfer via deep learning",
504
+ input_types=["count_matrix", "reference_model"],
505
+ output_types=["annotation_result", "latent_embedding"],
506
+ requires_gpu=True,
507
+ ),
508
+ # ── Imputation ──
509
+ "MAGIC": ToolSpec(
510
+ name="MAGIC",
511
+ category=ToolCategory.IMPUTATION,
512
+ modalities=["scRNA-seq"],
513
+ description="Markov affinity-based graph imputation of dropout zeros",
514
+ input_types=["count_matrix"],
515
+ output_types=["imputed_matrix"],
516
+ ),
517
+ # ── Perturbation analysis ──
518
+ "MILO": ToolSpec(
519
+ name="MILO",
520
+ category=ToolCategory.PERTURBATION_ANALYSIS,
521
+ modalities=["scRNA-seq"],
522
+ description="Differential abundance testing on KNN graph neighborhoods",
523
+ input_types=["count_matrix", "knn_graph"],
524
+ output_types=["da_result"],
525
+ ),
526
+ "Mixscape": ToolSpec(
527
+ name="Mixscape",
528
+ category=ToolCategory.PERTURBATION_ANALYSIS,
529
+ modalities=["Perturb-seq", "CROP-seq"],
530
+ description="Seurat extension for CRISPR screen perturbation analysis",
531
+ input_types=["count_matrix", "guide_assignments"],
532
+ output_types=["perturbation_result"],
533
+ ),
534
+ "MIMOSCA": ToolSpec(
535
+ name="MIMOSCA",
536
+ category=ToolCategory.PERTURBATION_ANALYSIS,
537
+ modalities=["Perturb-seq", "CROP-seq"],
538
+ description="Multi-input multi-output single-cell analysis for screens",
539
+ input_types=["count_matrix", "guide_assignments"],
540
+ output_types=["perturbation_result"],
541
+ ),
542
+ # ── Quality control ──
543
+ "scrublet": ToolSpec(
544
+ name="scrublet",
545
+ category=ToolCategory.QUALITY_CONTROL,
546
+ modalities=["scRNA-seq"],
547
+ description="Computational doublet detection via synthetic doublets",
548
+ input_types=["count_matrix"],
549
+ output_types=["doublet_scores"],
550
+ ),
551
+ "DoubletFinder": ToolSpec(
552
+ name="DoubletFinder",
553
+ category=ToolCategory.QUALITY_CONTROL,
554
+ modalities=["scRNA-seq"],
555
+ description="Artificial nearest-neighbor doublet detection",
556
+ input_types=["count_matrix"],
557
+ output_types=["doublet_scores"],
558
+ ),
559
+ "SoupX": ToolSpec(
560
+ name="SoupX",
561
+ category=ToolCategory.QUALITY_CONTROL,
562
+ modalities=["scRNA-seq"],
563
+ description="Ambient RNA contamination estimation and removal",
564
+ input_types=["count_matrix", "raw_count_matrix"],
565
+ output_types=["corrected_matrix"],
566
+ ),
567
+ "DecontX": ToolSpec(
568
+ name="DecontX",
569
+ category=ToolCategory.QUALITY_CONTROL,
570
+ modalities=["scRNA-seq"],
571
+ description="Bayesian ambient RNA decontamination",
572
+ input_types=["count_matrix"],
573
+ output_types=["corrected_matrix"],
574
+ ),
575
+ # ── Simulation ──
576
+ "Splatter": ToolSpec(
577
+ name="Splatter",
578
+ category=ToolCategory.SIMULATION,
579
+ modalities=["scRNA-seq"],
580
+ description="Flexible scRNA-seq data simulation framework",
581
+ input_types=["simulation_params"],
582
+ output_types=["simulated_count_matrix"],
583
+ ),
584
+ }
585
+
586
+
587
+ class Modality(str, Enum):
588
+ SCRNA_SEQ = "scRNA-seq"
589
+ SCATAC_SEQ = "scATAC-seq"
590
+ CITE_SEQ = "CITE-seq"
591
+ SPATIAL_TRANSCRIPTOMICS = "spatial_transcriptomics"
592
+ BULK_RNA_SEQ = "bulk_rna_seq"
593
+ SCRNA_MULTIOME = "scMultiome"
594
+ PERTURB_SEQ = "Perturb-seq"
595
+ CROP_SEQ = "CROP-seq"
596
+ SMART_SEQ2 = "Smart-seq2"
597
+ SLIDE_SEQ = "Slide-seq"
598
+ MERFISH = "MERFISH"
599
+ SEQFISH = "seqFISH"
600
+ PATCH_SEQ = "Patch-seq"
601
+ SHARE_SEQ = "SHARE-seq"
602
+ SNARE_SEQ = "SNARE-seq"
603
+ SC_HI_C = "scHi-C"
604
+ SCBS_SEQ = "scBS-seq"
605
+ SCNMT_SEQ = "scNMT-seq"
606
+
607
+
608
+ class ModalitySpec(BaseModel):
609
+ """Registry entry for a single-cell or bulk assay modality."""
610
+
611
+ name: str
612
+ modality: Modality
613
+ measurement: str = ""
614
+ resolution: str = "single-cell"
615
+ multiplexable: bool = False
616
+ typical_cells: str = "1k-20k"
617
+ typical_cost_per_sample_usd: float = 5000.0
618
+ compatible_tools: List[str] = Field(default_factory=list)
619
+ description: str = ""
620
+
621
+
622
+ MODALITY_REGISTRY: Dict[str, ModalitySpec] = {
623
+ "scRNA-seq": ModalitySpec(
624
+ name="scRNA-seq",
625
+ modality=Modality.SCRNA_SEQ,
626
+ measurement="mRNA transcripts",
627
+ typical_cells="5k-20k",
628
+ typical_cost_per_sample_usd=5000.0,
629
+ compatible_tools=[
630
+ "CellRanger", "STARsolo", "kallisto_bustools", "Scanpy", "Seurat",
631
+ "scVI", "Leiden", "DESeq2", "MAST", "Monocle3", "scVelo", "SCENIC",
632
+ "CellChat", "GSEA", "celltypist", "scrublet",
633
+ ],
634
+ description="Droplet-based single-cell RNA sequencing (e.g. 10x Chromium)",
635
+ ),
636
+ "scATAC-seq": ModalitySpec(
637
+ name="scATAC-seq",
638
+ modality=Modality.SCATAC_SEQ,
639
+ measurement="open chromatin regions",
640
+ typical_cells="5k-15k",
641
+ typical_cost_per_sample_usd=6000.0,
642
+ compatible_tools=[
643
+ "CellRanger", "ArchR", "Signac", "chromVAR", "Scanpy", "Leiden",
644
+ ],
645
+ description="Single-cell Assay for Transposase-Accessible Chromatin",
646
+ ),
647
+ "CITE-seq": ModalitySpec(
648
+ name="CITE-seq",
649
+ modality=Modality.CITE_SEQ,
650
+ measurement="mRNA + surface proteins (ADT)",
651
+ multiplexable=True,
652
+ typical_cells="5k-20k",
653
+ typical_cost_per_sample_usd=8000.0,
654
+ compatible_tools=[
655
+ "CellRanger", "Seurat", "WNN", "MOFA+", "Scanpy", "Leiden",
656
+ ],
657
+ description="Cellular Indexing of Transcriptomes and Epitopes by Sequencing",
658
+ ),
659
+ "spatial_transcriptomics": ModalitySpec(
660
+ name="spatial_transcriptomics",
661
+ modality=Modality.SPATIAL_TRANSCRIPTOMICS,
662
+ measurement="spatially resolved transcripts",
663
+ resolution="spot (55µm) or subcellular",
664
+ typical_cells="1k-10k spots",
665
+ typical_cost_per_sample_usd=7000.0,
666
+ compatible_tools=[
667
+ "spaceranger", "squidpy", "cell2location", "BANKSY", "Scanpy", "Seurat",
668
+ ],
669
+ description="Spatially resolved transcriptomics (Visium, MERFISH, Slide-seq, etc.)",
670
+ ),
671
+ "bulk_rna_seq": ModalitySpec(
672
+ name="bulk_rna_seq",
673
+ modality=Modality.BULK_RNA_SEQ,
674
+ measurement="aggregate mRNA across cells",
675
+ resolution="bulk",
676
+ typical_cells="N/A",
677
+ typical_cost_per_sample_usd=500.0,
678
+ compatible_tools=["DESeq2", "edgeR", "GSEA", "clusterProfiler"],
679
+ description="Standard bulk RNA sequencing",
680
+ ),
681
+ "scMultiome": ModalitySpec(
682
+ name="scMultiome",
683
+ modality=Modality.SCRNA_MULTIOME,
684
+ measurement="mRNA + open chromatin (joint)",
685
+ typical_cells="5k-15k",
686
+ typical_cost_per_sample_usd=10000.0,
687
+ compatible_tools=[
688
+ "CellRanger", "ArchR", "Signac", "Seurat", "MOFA+", "CellOracle",
689
+ ],
690
+ description="10x Multiome (joint scRNA + scATAC from same cell)",
691
+ ),
692
+ "Perturb-seq": ModalitySpec(
693
+ name="Perturb-seq",
694
+ modality=Modality.PERTURB_SEQ,
695
+ measurement="mRNA + CRISPR guide assignment",
696
+ multiplexable=True,
697
+ typical_cells="10k-100k",
698
+ typical_cost_per_sample_usd=15000.0,
699
+ compatible_tools=[
700
+ "CellRanger", "Scanpy", "Seurat", "Mixscape", "MIMOSCA",
701
+ ],
702
+ description="Pooled CRISPR screens with single-cell RNA readout",
703
+ ),
704
+ "CROP-seq": ModalitySpec(
705
+ name="CROP-seq",
706
+ modality=Modality.CROP_SEQ,
707
+ measurement="mRNA + CRISPR guide assignment",
708
+ multiplexable=True,
709
+ typical_cells="10k-50k",
710
+ typical_cost_per_sample_usd=12000.0,
711
+ compatible_tools=[
712
+ "CellRanger", "Scanpy", "Seurat", "Mixscape", "MIMOSCA",
713
+ ],
714
+ description="CRISPR dropout screen with single-cell RNA readout",
715
+ ),
716
+ "Smart-seq2": ModalitySpec(
717
+ name="Smart-seq2",
718
+ modality=Modality.SMART_SEQ2,
719
+ measurement="full-length mRNA transcripts",
720
+ typical_cells="100-1000",
721
+ typical_cost_per_sample_usd=10000.0,
722
+ compatible_tools=["Scanpy", "Seurat", "DESeq2", "MAST", "Monocle3"],
723
+ description="Plate-based full-length scRNA-seq with high sensitivity",
724
+ ),
725
+ "MERFISH": ModalitySpec(
726
+ name="MERFISH",
727
+ modality=Modality.MERFISH,
728
+ measurement="in situ mRNA (imaging-based)",
729
+ resolution="subcellular",
730
+ typical_cells="10k-1M",
731
+ typical_cost_per_sample_usd=20000.0,
732
+ compatible_tools=["squidpy", "Scanpy", "BANKSY"],
733
+ description="Multiplexed Error-Robust FISH for spatial transcriptomics",
734
+ ),
735
+ "Slide-seq": ModalitySpec(
736
+ name="Slide-seq",
737
+ modality=Modality.SLIDE_SEQ,
738
+ measurement="spatially resolved mRNA (bead array)",
739
+ resolution="10µm",
740
+ typical_cells="10k-50k beads",
741
+ typical_cost_per_sample_usd=8000.0,
742
+ compatible_tools=["squidpy", "cell2location", "Scanpy"],
743
+ description="Near-cellular spatial transcriptomics on bead arrays",
744
+ ),
745
+ "Patch-seq": ModalitySpec(
746
+ name="Patch-seq",
747
+ modality=Modality.PATCH_SEQ,
748
+ measurement="mRNA + electrophysiology + morphology",
749
+ typical_cells="10-500",
750
+ typical_cost_per_sample_usd=50000.0,
751
+ compatible_tools=["Scanpy", "Seurat"],
752
+ description="Combined patch-clamp electrophysiology and scRNA-seq",
753
+ ),
754
+ "scHi-C": ModalitySpec(
755
+ name="scHi-C",
756
+ modality=Modality.SC_HI_C,
757
+ measurement="3D chromatin contacts",
758
+ typical_cells="1k-10k",
759
+ typical_cost_per_sample_usd=15000.0,
760
+ compatible_tools=["Scanpy"],
761
+ description="Single-cell chromosome conformation capture",
762
+ ),
763
+ "scBS-seq": ModalitySpec(
764
+ name="scBS-seq",
765
+ modality=Modality.SCBS_SEQ,
766
+ measurement="DNA methylation (CpG)",
767
+ typical_cells="100-5k",
768
+ typical_cost_per_sample_usd=12000.0,
769
+ compatible_tools=["Scanpy"],
770
+ description="Single-cell bisulfite sequencing for DNA methylation",
771
+ ),
772
+ "scNMT-seq": ModalitySpec(
773
+ name="scNMT-seq",
774
+ modality=Modality.SCNMT_SEQ,
775
+ measurement="nucleosome + methylation + transcription (joint)",
776
+ typical_cells="100-1k",
777
+ typical_cost_per_sample_usd=25000.0,
778
+ compatible_tools=["MOFA+", "Scanpy"],
779
+ description="Joint single-cell nucleosome, methylation, and transcription",
780
+ ),
781
+ }
782
+
783
+
784
+ class AssayCategory(str, Enum):
785
+ SEQUENCING = "sequencing"
786
+ IMAGING = "imaging"
787
+ PERTURBATION = "perturbation"
788
+ FUNCTIONAL = "functional"
789
+ EPIGENOMICS = "epigenomics"
790
+ PROTEOMICS = "proteomics"
791
+ METABOLOMICS = "metabolomics"
792
+
793
+
794
+ class AssaySpec(BaseModel):
795
+ """Registry entry for a laboratory assay or protocol."""
796
+
797
+ name: str
798
+ category: AssayCategory
799
+ modalities: List[str] = Field(default_factory=list)
800
+ description: str = ""
801
+ typical_duration_days: float = 1.0
802
+ typical_cost_usd: float = 1000.0
803
+ requires_live_cells: bool = False
804
+ requires_fresh_tissue: bool = False
805
+ throughput: str = "medium"
806
+ outputs: List[str] = Field(default_factory=list)
807
+
808
+
809
+ ASSAY_REGISTRY: Dict[str, AssaySpec] = {
810
+ "10x_chromium": AssaySpec(
811
+ name="10x_chromium",
812
+ category=AssayCategory.SEQUENCING,
813
+ modalities=["scRNA-seq", "scATAC-seq", "CITE-seq", "scMultiome"],
814
+ description="10x Genomics Chromium droplet-based single-cell partitioning",
815
+ typical_duration_days=2.0,
816
+ typical_cost_usd=5000.0,
817
+ requires_live_cells=True,
818
+ throughput="high (500-20k cells)",
819
+ outputs=["fastq", "count_matrix"],
820
+ ),
821
+ "smart-seq2": AssaySpec(
822
+ name="smart-seq2",
823
+ category=AssayCategory.SEQUENCING,
824
+ modalities=["Smart-seq2"],
825
+ description="Plate-based full-length cDNA scRNA-seq",
826
+ typical_duration_days=3.0,
827
+ typical_cost_usd=10000.0,
828
+ requires_live_cells=True,
829
+ throughput="low (96-384 cells)",
830
+ outputs=["fastq", "count_matrix"],
831
+ ),
832
+ "smart-seq3": AssaySpec(
833
+ name="smart-seq3",
834
+ category=AssayCategory.SEQUENCING,
835
+ modalities=["Smart-seq2"],
836
+ description="Improved full-length scRNA-seq with UMIs",
837
+ typical_duration_days=3.0,
838
+ typical_cost_usd=10000.0,
839
+ requires_live_cells=True,
840
+ throughput="low (96-384 cells)",
841
+ outputs=["fastq", "count_matrix"],
842
+ ),
843
+ "bulk_rna_seq": AssaySpec(
844
+ name="bulk_rna_seq",
845
+ category=AssayCategory.SEQUENCING,
846
+ modalities=["bulk_rna_seq"],
847
+ description="Standard bulk RNA sequencing with poly-A or ribo-depletion",
848
+ typical_duration_days=3.0,
849
+ typical_cost_usd=500.0,
850
+ throughput="high",
851
+ outputs=["fastq", "count_matrix"],
852
+ ),
853
+ "atac-seq": AssaySpec(
854
+ name="atac-seq",
855
+ category=AssayCategory.EPIGENOMICS,
856
+ modalities=["scATAC-seq"],
857
+ description="Assay for Transposase-Accessible Chromatin using sequencing",
858
+ typical_duration_days=2.0,
859
+ typical_cost_usd=6000.0,
860
+ requires_live_cells=True,
861
+ outputs=["fastq", "fragments", "peak_matrix"],
862
+ ),
863
+ "cite-seq": AssaySpec(
864
+ name="cite-seq",
865
+ category=AssayCategory.PROTEOMICS,
866
+ modalities=["CITE-seq"],
867
+ description="Simultaneous RNA + surface protein via DNA-barcoded antibodies",
868
+ typical_duration_days=2.0,
869
+ typical_cost_usd=8000.0,
870
+ requires_live_cells=True,
871
+ throughput="high (5k-20k cells)",
872
+ outputs=["fastq", "count_matrix", "adt_matrix"],
873
+ ),
874
+ "10x_multiome": AssaySpec(
875
+ name="10x_multiome",
876
+ category=AssayCategory.SEQUENCING,
877
+ modalities=["scMultiome"],
878
+ description="Joint scRNA-seq + scATAC-seq from the same cell",
879
+ typical_duration_days=2.0,
880
+ typical_cost_usd=10000.0,
881
+ requires_live_cells=True,
882
+ throughput="high (5k-15k cells)",
883
+ outputs=["fastq", "count_matrix", "fragments"],
884
+ ),
885
+ "visium": AssaySpec(
886
+ name="visium",
887
+ category=AssayCategory.SEQUENCING,
888
+ modalities=["spatial_transcriptomics"],
889
+ description="10x Visium spatially barcoded capture on tissue sections",
890
+ typical_duration_days=3.0,
891
+ typical_cost_usd=7000.0,
892
+ requires_fresh_tissue=True,
893
+ throughput="medium (1k-5k spots)",
894
+ outputs=["fastq", "count_matrix", "spatial_coords", "image"],
895
+ ),
896
+ "visium_hd": AssaySpec(
897
+ name="visium_hd",
898
+ category=AssayCategory.SEQUENCING,
899
+ modalities=["spatial_transcriptomics"],
900
+ description="High-definition Visium with 2µm bin resolution",
901
+ typical_duration_days=3.0,
902
+ typical_cost_usd=10000.0,
903
+ requires_fresh_tissue=True,
904
+ throughput="high",
905
+ outputs=["fastq", "count_matrix", "spatial_coords", "image"],
906
+ ),
907
+ "merfish": AssaySpec(
908
+ name="merfish",
909
+ category=AssayCategory.IMAGING,
910
+ modalities=["MERFISH"],
911
+ description="Multiplexed Error-Robust FISH imaging-based spatial",
912
+ typical_duration_days=5.0,
913
+ typical_cost_usd=20000.0,
914
+ requires_fresh_tissue=True,
915
+ throughput="high (100-1000 genes, millions of transcripts)",
916
+ outputs=["transcript_coords", "cell_segmentation"],
917
+ ),
918
+ "seqfish_plus": AssaySpec(
919
+ name="seqfish_plus",
920
+ category=AssayCategory.IMAGING,
921
+ modalities=["seqFISH"],
922
+ description="Sequential FISH for imaging-based spatial transcriptomics",
923
+ typical_duration_days=5.0,
924
+ typical_cost_usd=15000.0,
925
+ requires_fresh_tissue=True,
926
+ outputs=["transcript_coords"],
927
+ ),
928
+ "slide-seq": AssaySpec(
929
+ name="slide-seq",
930
+ category=AssayCategory.SEQUENCING,
931
+ modalities=["Slide-seq"],
932
+ description="Near-cellular spatial transcriptomics on bead arrays",
933
+ typical_duration_days=3.0,
934
+ typical_cost_usd=8000.0,
935
+ requires_fresh_tissue=True,
936
+ outputs=["count_matrix", "spatial_coords"],
937
+ ),
938
+ "perturb-seq": AssaySpec(
939
+ name="perturb-seq",
940
+ category=AssayCategory.PERTURBATION,
941
+ modalities=["Perturb-seq"],
942
+ description="Pooled CRISPR screen + scRNA-seq readout",
943
+ typical_duration_days=14.0,
944
+ typical_cost_usd=15000.0,
945
+ requires_live_cells=True,
946
+ throughput="high (10k-100k cells)",
947
+ outputs=["fastq", "count_matrix", "guide_assignments"],
948
+ ),
949
+ "crop-seq": AssaySpec(
950
+ name="crop-seq",
951
+ category=AssayCategory.PERTURBATION,
952
+ modalities=["CROP-seq"],
953
+ description="CRISPR dropout screening with scRNA-seq readout",
954
+ typical_duration_days=14.0,
955
+ typical_cost_usd=12000.0,
956
+ requires_live_cells=True,
957
+ throughput="high (10k-50k cells)",
958
+ outputs=["fastq", "count_matrix", "guide_assignments"],
959
+ ),
960
+ "patch-seq": AssaySpec(
961
+ name="patch-seq",
962
+ category=AssayCategory.FUNCTIONAL,
963
+ modalities=["Patch-seq"],
964
+ description="Patch-clamp electrophysiology + scRNA-seq on same neuron",
965
+ typical_duration_days=7.0,
966
+ typical_cost_usd=50000.0,
967
+ requires_live_cells=True,
968
+ throughput="very low (10-100 cells)",
969
+ outputs=["fastq", "count_matrix", "ephys_trace", "morphology"],
970
+ ),
971
+ "sc_hi_c": AssaySpec(
972
+ name="sc_hi_c",
973
+ category=AssayCategory.EPIGENOMICS,
974
+ modalities=["scHi-C"],
975
+ description="Single-cell chromosome conformation capture",
976
+ typical_duration_days=5.0,
977
+ typical_cost_usd=15000.0,
978
+ outputs=["contact_matrix"],
979
+ ),
980
+ "sc_bisulfite": AssaySpec(
981
+ name="sc_bisulfite",
982
+ category=AssayCategory.EPIGENOMICS,
983
+ modalities=["scBS-seq"],
984
+ description="Single-cell bisulfite sequencing for DNA methylation profiling",
985
+ typical_duration_days=5.0,
986
+ typical_cost_usd=12000.0,
987
+ outputs=["methylation_matrix"],
988
+ ),
989
+ "sc_nmt_seq": AssaySpec(
990
+ name="sc_nmt_seq",
991
+ category=AssayCategory.EPIGENOMICS,
992
+ modalities=["scNMT-seq"],
993
+ description="Joint nucleosome occupancy, methylation, and transcription",
994
+ typical_duration_days=7.0,
995
+ typical_cost_usd=25000.0,
996
+ requires_live_cells=True,
997
+ throughput="low (100-1k cells)",
998
+ outputs=["count_matrix", "methylation_matrix", "accessibility_matrix"],
999
+ ),
1000
+ "flow_cytometry": AssaySpec(
1001
+ name="flow_cytometry",
1002
+ category=AssayCategory.FUNCTIONAL,
1003
+ modalities=[],
1004
+ description="Fluorescence-based cell sorting and phenotyping",
1005
+ typical_duration_days=1.0,
1006
+ typical_cost_usd=500.0,
1007
+ requires_live_cells=True,
1008
+ throughput="very high (millions of cells)",
1009
+ outputs=["cell_counts", "sorted_cells"],
1010
+ ),
1011
+ "mass_cytometry_CyTOF": AssaySpec(
1012
+ name="mass_cytometry_CyTOF",
1013
+ category=AssayCategory.PROTEOMICS,
1014
+ modalities=[],
1015
+ description="Mass-tag cytometry for 40+ protein markers per cell",
1016
+ typical_duration_days=2.0,
1017
+ typical_cost_usd=3000.0,
1018
+ requires_live_cells=True,
1019
+ throughput="high (100k-1M cells)",
1020
+ outputs=["protein_expression_matrix"],
1021
+ ),
1022
+ "western_blot": AssaySpec(
1023
+ name="western_blot",
1024
+ category=AssayCategory.PROTEOMICS,
1025
+ modalities=[],
1026
+ description="Protein detection and semi-quantification by size separation",
1027
+ typical_duration_days=2.0,
1028
+ typical_cost_usd=200.0,
1029
+ outputs=["band_image", "relative_quantification"],
1030
+ ),
1031
+ "qPCR": AssaySpec(
1032
+ name="qPCR",
1033
+ category=AssayCategory.FUNCTIONAL,
1034
+ modalities=[],
1035
+ description="Quantitative PCR for targeted gene expression validation",
1036
+ typical_duration_days=1.0,
1037
+ typical_cost_usd=100.0,
1038
+ throughput="low (target genes)",
1039
+ outputs=["ct_values", "fold_change"],
1040
+ ),
1041
+ "immunofluorescence": AssaySpec(
1042
+ name="immunofluorescence",
1043
+ category=AssayCategory.IMAGING,
1044
+ modalities=[],
1045
+ description="Antibody-based fluorescence imaging of proteins in situ",
1046
+ typical_duration_days=2.0,
1047
+ typical_cost_usd=500.0,
1048
+ outputs=["fluorescence_image"],
1049
+ ),
1050
+ "elisa": AssaySpec(
1051
+ name="elisa",
1052
+ category=AssayCategory.PROTEOMICS,
1053
+ modalities=[],
1054
+ description="Enzyme-linked immunosorbent assay for secreted protein quantification",
1055
+ typical_duration_days=1.0,
1056
+ typical_cost_usd=300.0,
1057
+ throughput="medium (96-384 well)",
1058
+ outputs=["protein_concentration"],
1059
+ ),
1060
+ "cell_viability_assay": AssaySpec(
1061
+ name="cell_viability_assay",
1062
+ category=AssayCategory.FUNCTIONAL,
1063
+ modalities=[],
1064
+ description="MTT/CellTiter-Glo viability and proliferation measurement",
1065
+ typical_duration_days=1.0,
1066
+ typical_cost_usd=200.0,
1067
+ requires_live_cells=True,
1068
+ throughput="high (96-384 well)",
1069
+ outputs=["viability_scores"],
1070
+ ),
1071
+ }
1072
+
1073
+
1074
+ # ── Registry helper functions ──────────────────────────────────────────────
1075
+
1076
+
1077
+ def tools_for_modality(modality: str) -> List[ToolSpec]:
1078
+ """Return all registered tools compatible with a given modality."""
1079
+ return [t for t in TOOL_REGISTRY.values() if modality in t.modalities]
1080
+
1081
+
1082
+ def assays_for_modality(modality: str) -> List[AssaySpec]:
1083
+ """Return all registered assays that produce a given modality."""
1084
+ return [a for a in ASSAY_REGISTRY.values() if modality in a.modalities]
1085
+
1086
+
1087
+ def tools_by_category(category: ToolCategory) -> List[ToolSpec]:
1088
+ """Return all registered tools in a given category."""
1089
+ return [t for t in TOOL_REGISTRY.values() if t.category == category]
1090
+
1091
+
1092
+ # ── Sub-agents ───────────────────────────────────────────────────���─────────
1093
+
1094
+
1095
  class SubagentType(str, Enum):
1096
  WET_LAB_PLANNER = "wet_lab_planner"
1097
  COMPUTATIONAL_ANALYST = "computational_analyst"
 
1115
  """
1116
 
1117
  action_type: ActionType = Field(
1118
+ ...,
1119
+ description=(
1120
+ "Discrete simulator step type. The environment enforces scientific "
1121
+ "prerequisites between steps, so actions should follow a valid "
1122
+ "pipeline order."
1123
+ ),
1124
  )
1125
  input_targets: List[str] = Field(
1126
  default_factory=list,
1127
+ description=(
1128
+ "Optional references to prior samples, outputs, or artifacts that "
1129
+ "this step consumes."
1130
+ ),
1131
  )
1132
  method: Optional[str] = Field(
1133
+ None,
1134
+ description=(
1135
+ "Optional named tool or protocol (for example 'Seurat' or "
1136
+ "'CellRanger'). Prefer methods compatible with the current "
1137
+ "modality and available tool list because tool choice can change "
1138
+ "runtime, cost, and scientific fit."
1139
+ ),
1140
  )
1141
  parameters: Dict[str, Any] = Field(
1142
+ default_factory=dict,
1143
+ description=(
1144
+ "Action-specific settings such as comparison labels, perturbation "
1145
+ "targets, or analysis options. Use only parameters that materially "
1146
+ "change the scientific step."
1147
+ ),
1148
  )
1149
  expected_output_type: Optional[str] = Field(
1150
+ None,
1151
+ description=(
1152
+ "Optional expected artifact or summary that should result from the "
1153
+ "step, such as a count matrix, QC report, DE table, or validation "
1154
+ "result."
1155
+ ),
1156
  )
1157
  justification: Optional[str] = Field(
1158
+ None,
1159
+ description=(
1160
+ "Short scientific rationale explaining why this is the right next "
1161
+ "step in the current environment state."
1162
+ ),
1163
  )
1164
  invoked_subagent: Optional[SubagentType] = Field(
1165
  None, description="Sub-agent to delegate to, if any"
1166
  )
1167
  tool_call_spec: Optional[Dict[str, Any]] = Field(
1168
+ None,
1169
+ description=(
1170
+ "Optional structured tool invocation payload when the action needs "
1171
+ "a more explicit tool execution plan."
1172
+ ),
1173
  )
1174
  confidence: float = Field(
1175
  0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
 
1267
  organism: str = "human"
1268
  tissue: str = "blood"
1269
  conditions: List[str] = Field(default_factory=list)
1270
+ available_assays: List[str] = Field(
1271
+ default_factory=lambda: list(ASSAY_REGISTRY.keys()),
1272
+ description=(
1273
+ "Assays that are scientifically compatible with this task's "
1274
+ "modality. These are the relevant assay choices for the episode, "
1275
+ "not an unrestricted catalog."
1276
+ ),
1277
+ )
1278
+ available_tools: List[str] = Field(
1279
+ default_factory=lambda: list(TOOL_REGISTRY.keys()),
1280
+ description=(
1281
+ "Tools filtered to those compatible with the current task "
1282
+ "modality. The agent should treat this list as the preferred tool "
1283
+ "set for the episode."
1284
+ ),
1285
+ )
1286
  budget_limit: float = 100_000.0
1287
  time_limit_days: float = 180.0
1288
  prior_observations: List[str] = Field(default_factory=list)
 
1293
 
1294
 
1295
  class ConclusionClaim(BaseModel):
1296
+ claim: str = ""
1297
+ top_markers: List[str] = Field(default_factory=list)
1298
+ causal_mechanisms: List[str] = Field(default_factory=list)
1299
+ predicted_pathways: Dict[str, float] = Field(default_factory=dict)
1300
  evidence_steps: List[int] = Field(default_factory=list)
1301
  confidence: float = Field(0.5, ge=0.0, le=1.0)
1302
  claim_type: str = "correlational"
 
1316
  task: TaskSpec = Field(default_factory=TaskSpec)
1317
  step_index: int = 0
1318
  pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
1319
+ available_assays: List[str] = Field(
1320
+ default_factory=list,
1321
+ description=(
1322
+ "Episode-specific assay choices already filtered to the current "
1323
+ "modality and task context."
1324
+ ),
1325
+ )
1326
+ available_tools: List[str] = Field(
1327
+ default_factory=list,
1328
+ description=(
1329
+ "Episode-specific compatible tools. These are the methods the "
1330
+ "agent should prefer instead of inventing incompatible tools."
1331
+ ),
1332
+ )
1333
+ resource_usage: ResourceUsage = Field(
1334
+ default_factory=ResourceUsage,
1335
+ description=(
1336
+ "Running budget, time, and compute usage after previous actions."
1337
+ ),
1338
+ )
1339
  latest_output: Optional[IntermediateOutput] = None
1340
  all_outputs: List[IntermediateOutput] = Field(default_factory=list)
1341
  discovered_markers: List[str] = Field(default_factory=list)
 
1345
  conclusions: List[ConclusionClaim] = Field(default_factory=list)
1346
  rule_violations: List[str] = Field(default_factory=list)
1347
  step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
1348
+
1349
+
1350
+ AGENT_ACTION_GUIDANCE: Dict[ActionType, str] = {
1351
+ ActionType.COLLECT_SAMPLE: (
1352
+ "Wet-lab entry point. One successful collection usually provides enough "
1353
+ "material to continue unless the output shows poor yield or quality."
1354
+ ),
1355
+ ActionType.SELECT_COHORT: (
1356
+ "Use when subject stratification is part of the scientific question "
1357
+ "before downstream experimental work."
1358
+ ),
1359
+ ActionType.PREPARE_LIBRARY: (
1360
+ "Requires collected samples and converts biological material into "
1361
+ "sequence-ready libraries."
1362
+ ),
1363
+ ActionType.CULTURE_CELLS: (
1364
+ "Requires collected samples and adds substantial time; use only when "
1365
+ "live-cell expansion or later perturbation is needed."
1366
+ ),
1367
+ ActionType.PERTURB_GENE: (
1368
+ "Requires samples. Use for causal tests, not as a default discovery "
1369
+ "step."
1370
+ ),
1371
+ ActionType.PERTURB_COMPOUND: (
1372
+ "Requires samples. Best for mechanistic follow-up or treatment "
1373
+ "response questions."
1374
+ ),
1375
+ ActionType.SEQUENCE_CELLS: (
1376
+ "Requires prepared libraries and produces the raw sequencing-derived "
1377
+ "artifacts used by downstream QC and analysis."
1378
+ ),
1379
+ ActionType.RUN_QC: (
1380
+ "Requires sequencing and returns summarized quality metrics such as "
1381
+ "doublets, mitochondrial fraction, and ambient RNA."
1382
+ ),
1383
+ ActionType.FILTER_DATA: (
1384
+ "Requires QC and removes poor-quality cells, changing downstream cell "
1385
+ "counts and data retention."
1386
+ ),
1387
+ ActionType.NORMALIZE_DATA: (
1388
+ "Requires filtered data and unlocks clustering, differential "
1389
+ "expression, trajectory, and network analyses."
1390
+ ),
1391
+ ActionType.INTEGRATE_BATCHES: (
1392
+ "Requires normalized data. Use when batch effects are likely to "
1393
+ "confound interpretation; it is not always necessary."
1394
+ ),
1395
+ ActionType.CLUSTER_CELLS: (
1396
+ "Requires normalized data and identifies cell populations or states "
1397
+ "for downstream interpretation."
1398
+ ),
1399
+ ActionType.DIFFERENTIAL_EXPRESSION: (
1400
+ "Requires normalized data and is the main route to candidate genes "
1401
+ "for pathway analysis and marker selection."
1402
+ ),
1403
+ ActionType.TRAJECTORY_ANALYSIS: (
1404
+ "Requires normalized data and is most useful when lineage progression "
1405
+ "or pseudotime is central to the task."
1406
+ ),
1407
+ ActionType.PATHWAY_ENRICHMENT: (
1408
+ "Requires differential expression. Results are less reliable without a "
1409
+ "strong DE gene list."
1410
+ ),
1411
+ ActionType.REGULATORY_NETWORK_INFERENCE: (
1412
+ "Requires normalized data and is most helpful once cell states or "
1413
+ "trajectories are already characterized."
1414
+ ),
1415
+ ActionType.MARKER_SELECTION: (
1416
+ "Requires differential expression and turns candidate genes into a "
1417
+ "short list for validation."
1418
+ ),
1419
+ ActionType.VALIDATE_MARKER: (
1420
+ "Requires discovered markers and is an expensive wet-lab confirmation "
1421
+ "step that should follow strong computational evidence."
1422
+ ),
1423
+ ActionType.DESIGN_FOLLOWUP: (
1424
+ "Use to propose targeted next experiments once remaining uncertainty "
1425
+ "is clear."
1426
+ ),
1427
+ ActionType.REQUEST_SUBAGENT_REVIEW: (
1428
+ "Use for critique or planning support, not as a substitute for "
1429
+ "missing experimental evidence."
1430
+ ),
1431
+ ActionType.SYNTHESIZE_CONCLUSION: (
1432
+ "Use once the evidence is sufficient. Do not spend budget on redundant "
1433
+ "steps just because more actions are possible."
1434
+ ),
1435
+ }
1436
+
1437
+ AGENT_ENVIRONMENT_RULES: List[str] = [
1438
+ (
1439
+ "Each successful action already returns summarized scientific evidence, "
1440
+ "so repeated sampling or repeated analysis is not the default."
1441
+ ),
1442
+ (
1443
+ "Repeat a step only when the task demands it or when prior outputs show "
1444
+ "poor quality, insufficient yield, unresolved batch effects, or another "
1445
+ "clear failure mode."
1446
+ ),
1447
+ (
1448
+ "The available tool and assay lists are already filtered to the current "
1449
+ "task modality, so prefer them over inventing incompatible methods."
1450
+ ),
1451
+ (
1452
+ "Hard scientific prerequisites are enforced by the environment, so "
1453
+ "invalid pipeline orderings will be blocked."
1454
+ ),
1455
+ ]
1456
+
1457
+ _TOOL_CATEGORY_AGENT_NOTES: Dict[ToolCategory, str] = {
1458
+ ToolCategory.ALIGNMENT: (
1459
+ "Best immediately after sequencing to turn FASTQ-like inputs into "
1460
+ "count-style matrices for downstream analysis."
1461
+ ),
1462
+ ToolCategory.PREPROCESSING: (
1463
+ "Useful for general single-cell data handling before specialized "
1464
+ "downstream analyses."
1465
+ ),
1466
+ ToolCategory.NORMALIZATION: (
1467
+ "Applies after filtering to produce normalized matrices for downstream "
1468
+ "modeling."
1469
+ ),
1470
+ ToolCategory.DIMENSIONALITY_REDUCTION: (
1471
+ "Builds latent embeddings that support clustering or trajectory work."
1472
+ ),
1473
+ ToolCategory.CLUSTERING: (
1474
+ "Best once data are normalized and the goal is to resolve cell states "
1475
+ "or populations."
1476
+ ),
1477
+ ToolCategory.DIFFERENTIAL_EXPRESSION: (
1478
+ "Tests contrasts and produces ranked genes for biological "
1479
+ "interpretation."
1480
+ ),
1481
+ ToolCategory.TRAJECTORY: (
1482
+ "Useful when the task asks about developmental progression, state "
1483
+ "transitions, or pseudotime."
1484
+ ),
1485
+ ToolCategory.GENE_REGULATORY_NETWORK: (
1486
+ "Most useful after normalized data and some cell-state structure are "
1487
+ "already established."
1488
+ ),
1489
+ ToolCategory.GENE_SET_ANALYSIS: (
1490
+ "Best after differential expression to interpret gene lists at the "
1491
+ "pathway level."
1492
+ ),
1493
+ ToolCategory.BATCH_CORRECTION: (
1494
+ "Use when batch effects would confound interpretation; unnecessary use "
1495
+ "adds extra steps."
1496
+ ),
1497
+ ToolCategory.MULTIMODAL_INTEGRATION: (
1498
+ "Useful only when combining modalities or batches is part of the "
1499
+ "scientific question."
1500
+ ),
1501
+ ToolCategory.QUALITY_CONTROL: (
1502
+ "Helps identify low-quality cells or technical artifacts before "
1503
+ "filtering."
1504
+ ),
1505
+ ToolCategory.CELL_TYPE_ANNOTATION: (
1506
+ "Best after clustering when assigning biological identities to groups."
1507
+ ),
1508
+ ToolCategory.PERTURBATION_ANALYSIS: (
1509
+ "Use when perturbations were actually applied and the goal is to model "
1510
+ "their transcriptional effects."
1511
+ ),
1512
+ ToolCategory.SPATIAL: (
1513
+ "Only useful when the modality includes spatial coordinates or tissue "
1514
+ "context."
1515
+ ),
1516
+ }
1517
+
1518
+
1519
+ def _format_currency(value: float) -> str:
1520
+ return f"${value:,.0f}"
1521
+
1522
+
1523
+ def _format_runtime_hours(hours: float) -> str:
1524
+ if hours < 1.0:
1525
+ return f"{int(round(hours * 60))}m"
1526
+ if float(hours).is_integer():
1527
+ return f"{int(hours)}h"
1528
+ return f"{hours:.1f}h"
1529
+
1530
+
1531
+ def describe_tool_for_agent(tool_name: str) -> str:
1532
+ """Return a compact environment-aware tool description for prompts."""
1533
+ tool = TOOL_REGISTRY.get(tool_name)
1534
+ if tool is None:
1535
+ return tool_name
1536
+
1537
+ parts = [f"{tool.name}: {tool.description}."]
1538
+ if tool.input_types or tool.output_types:
1539
+ inputs = ", ".join(tool.input_types) or "upstream artifacts"
1540
+ outputs = ", ".join(tool.output_types) or "analysis artifacts"
1541
+ parts.append(f"Consumes {inputs}; yields {outputs}.")
1542
+
1543
+ category_note = _TOOL_CATEGORY_AGENT_NOTES.get(tool.category)
1544
+ if category_note:
1545
+ parts.append(category_note)
1546
+
1547
+ resource_bits: List[str] = []
1548
+ if tool.typical_cost_usd > 0:
1549
+ resource_bits.append(_format_currency(tool.typical_cost_usd))
1550
+ if tool.typical_runtime_hours > 0:
1551
+ resource_bits.append(_format_runtime_hours(tool.typical_runtime_hours))
1552
+ if tool.requires_gpu:
1553
+ resource_bits.append("GPU")
1554
+ if resource_bits:
1555
+ parts.append(f"Typical resources: {', '.join(resource_bits)}.")
1556
+
1557
+ return " ".join(parts)
1558
+
1559
+
1560
+ def describe_assay_for_agent(assay_name: str) -> str:
1561
+ """Return a compact environment-aware assay description for prompts."""
1562
+ assay = ASSAY_REGISTRY.get(assay_name)
1563
+ if assay is None:
1564
+ return assay_name
1565
+
1566
+ parts = [f"{assay.name}: {assay.description}."]
1567
+ if assay.outputs:
1568
+ parts.append(f"Produces {', '.join(assay.outputs)}.")
1569
+
1570
+ requirements: List[str] = []
1571
+ if assay.requires_live_cells:
1572
+ requirements.append("live cells")
1573
+ if assay.requires_fresh_tissue:
1574
+ requirements.append("fresh tissue")
1575
+ if requirements:
1576
+ parts.append(f"Requires {' and '.join(requirements)}.")
1577
+
1578
+ parts.append(
1579
+ "Typical resources: "
1580
+ f"{_format_currency(assay.typical_cost_usd)}, "
1581
+ f"{assay.typical_duration_days:.1f}d."
1582
+ )
1583
+ return " ".join(parts)
1584
+
1585
+
1586
+ def build_agent_system_prompt() -> str:
1587
+ """Build the shared agent system prompt for training and inference."""
1588
+ lines = [
1589
+ "You are an expert biologist planning a single-cell experiment pipeline.",
1590
+ "",
1591
+ "At each turn you see the experiment state and must pick the next scientifically justified step.",
1592
+ "",
1593
+ "Environment-specific reasoning rules:",
1594
+ ]
1595
+ lines.extend(f" - {rule}" for rule in AGENT_ENVIRONMENT_RULES)
1596
+ lines.append("")
1597
+ lines.append("Action guidance:")
1598
+ lines.extend(
1599
+ f" - {action_type.value}: {AGENT_ACTION_GUIDANCE[action_type]}"
1600
+ for action_type in ActionType
1601
+ )
1602
+ lines.extend([
1603
+ "",
1604
+ "Respond with ONLY valid JSON, nothing else:",
1605
+ '{"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}',
1606
+ "",
1607
+ "For synthesize_conclusion, use structured claims:",
1608
+ '{"action_type": "synthesize_conclusion", "parameters": {"claims": [{"top_markers": ["GENE1", "GENE2"], "causal_mechanisms": ["mechanism description"], "predicted_pathways": {"pathway_name": 0.8}, "confidence": 0.8, "claim_type": "causal", "claim": "optional free text"}]}, "justification": "...", "confidence": 0.8}',
1609
+ ])
1610
+ return "\n".join(lines)
1611
+
1612
+
1613
+ def build_agent_observation_context(
1614
+ obs: ExperimentObservation,
1615
+ *,
1616
+ max_tools: int = 6,
1617
+ max_assays: int = 3,
1618
+ ) -> str:
1619
+ """Summarize modality-specific tool and assay context for the agent."""
1620
+ sections: List[str] = []
1621
+
1622
+ modality_spec = MODALITY_REGISTRY.get(obs.task.modality)
1623
+ if modality_spec is not None:
1624
+ sections.append(
1625
+ "Modality context: "
1626
+ f"{modality_spec.name} measures {modality_spec.measurement} at "
1627
+ f"{modality_spec.resolution} resolution; typical scale "
1628
+ f"{modality_spec.typical_cells}."
1629
+ )
1630
+ else:
1631
+ sections.append(f"Modality context: {obs.task.modality}.")
1632
+
1633
+ tool_names = list(dict.fromkeys(obs.available_tools or obs.task.available_tools))
1634
+ if tool_names:
1635
+ sections.append("Available tools (already filtered to this modality):")
1636
+ for tool_name in tool_names[:max_tools]:
1637
+ sections.append(f" - {describe_tool_for_agent(tool_name)}")
1638
+ if len(tool_names) > max_tools:
1639
+ remainder = ", ".join(tool_names[max_tools:max_tools + 6])
1640
+ sections.append(
1641
+ " - Additional compatible tools not shown in full: "
1642
+ f"{remainder}"
1643
+ )
1644
+
1645
+ assay_names = list(dict.fromkeys(obs.available_assays or obs.task.available_assays))
1646
+ if assay_names:
1647
+ sections.append("Available assays:")
1648
+ for assay_name in assay_names[:max_assays]:
1649
+ sections.append(f" - {describe_assay_for_agent(assay_name)}")
1650
+ if len(assay_names) > max_assays:
1651
+ remainder = ", ".join(assay_names[max_assays:max_assays + 4])
1652
+ sections.append(
1653
+ " - Additional compatible assays not shown in full: "
1654
+ f"{remainder}"
1655
+ )
1656
+
1657
+ return "\n".join(sections)
my_env/README.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: My Env Environment Server
3
+ emoji: 🎤
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
+
14
+ # My Env Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+
20
+ The simplest way to use the My Env environment is through the `MyEnv` class:
21
+
22
+ ```python
23
+ from my_env import MyAction, MyEnv
24
+
25
+ try:
26
+ # Create environment from Docker image
27
+ my_envenv = MyEnv.from_docker_image("my_env-env:latest")
28
+
29
+ # Reset
30
+ result = my_envenv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
32
+
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
+
36
+ for msg in messages:
37
+ result = my_envenv.step(MyAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" → Echoed: '{result.observation.echoed_message}'")
40
+ print(f" → Length: {result.observation.message_length}")
41
+ print(f" → Reward: {result.reward}")
42
+
43
+ finally:
44
+ # Always clean up
45
+ my_envenv.close()
46
+ ```
47
+
48
+ That's it! The `MyEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
53
+
54
+ ## Building the Docker Image
55
+
56
+ Before using the environment, you need to build the Docker image:
57
+
58
+ ```bash
59
+ # From project root
60
+ docker build -t my_env-env:latest -f server/Dockerfile .
61
+ ```
62
+
63
+ ## Deploying to Hugging Face Spaces
64
+
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
+
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
+
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
73
+ ```
74
+
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
+
80
+ ### Prerequisites
81
+
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
+
84
+ ### Options
85
+
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
+
91
+ ### Examples
92
+
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
96
+
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
+
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
+
103
+ # Push as a private space
104
+ openenv push --private
105
+
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
+
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
112
+
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
+
119
+ ## Environment Details
120
+
121
+ ### Action
122
+ **MyAction**: Contains a single field
123
+ - `message` (str) - The message to echo back
124
+
125
+ ### Observation
126
+ **MyObservation**: Contains the echo response and metadata
127
+ - `echoed_message` (str) - The message echoed back
128
+ - `message_length` (int) - Length of the message
129
+ - `reward` (float) - Reward based on message length (length × 0.1)
130
+ - `done` (bool) - Always False for echo environment
131
+ - `metadata` (dict) - Additional info like step count
132
+
133
+ ### Reward
134
+ The reward is calculated as: `message_length × 0.1`
135
+ - "Hi" → reward: 0.2
136
+ - "Hello, World!" → reward: 1.3
137
+ - Empty message → reward: 0.0
138
+
139
+ ## Advanced Usage
140
+
141
+ ### Connecting to an Existing Server
142
+
143
+ If you already have a My Env environment server running, you can connect directly:
144
+
145
+ ```python
146
+ from my_env import MyEnv
147
+
148
+ # Connect to existing server
149
+ my_envenv = MyEnv(base_url="<ENV_HTTP_URL_HERE>")
150
+
151
+ # Use as normal
152
+ result = my_envenv.reset()
153
+ result = my_envenv.step(MyAction(message="Hello!"))
154
+ ```
155
+
156
+ Note: When connecting to an existing server, `my_envenv.close()` will NOT stop the server.
157
+
158
+ ### Using the Context Manager
159
+
160
+ The client supports context manager usage for automatic connection management:
161
+
162
+ ```python
163
+ from my_env import MyAction, MyEnv
164
+
165
+ # Connect with context manager (auto-connects and closes)
166
+ with MyEnv(base_url="http://localhost:8000") as env:
167
+ result = env.reset()
168
+ print(f"Reset: {result.observation.echoed_message}")
169
+ # Multiple steps with low latency
170
+ for msg in ["Hello", "World", "!"]:
171
+ result = env.step(MyAction(message=msg))
172
+ print(f"Echoed: {result.observation.echoed_message}")
173
+ ```
174
+
175
+ The client uses WebSocket connections for:
176
+ - **Lower latency**: No HTTP connection overhead per request
177
+ - **Persistent session**: Server maintains your environment state
178
+ - **Efficient for episodes**: Better for many sequential steps
179
+
180
+ ### Concurrent WebSocket Sessions
181
+
182
+ The server supports multiple concurrent WebSocket connections. To enable this,
183
+ modify `server/app.py` to use factory mode:
184
+
185
+ ```python
186
+ # In server/app.py - use factory mode for concurrent sessions
187
+ app = create_app(
188
+ MyEnvironment, # Pass class, not instance
189
+ MyAction,
190
+ MyObservation,
191
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
192
+ )
193
+ ```
194
+
195
+ Then multiple clients can connect simultaneously:
196
+
197
+ ```python
198
+ from my_env import MyAction, MyEnv
199
+ from concurrent.futures import ThreadPoolExecutor
200
+
201
+ def run_episode(client_id: int):
202
+ with MyEnv(base_url="http://localhost:8000") as env:
203
+ result = env.reset()
204
+ for i in range(10):
205
+ result = env.step(MyAction(message=f"Client {client_id}, step {i}"))
206
+ return client_id, result.observation.message_length
207
+
208
+ # Run 4 episodes concurrently
209
+ with ThreadPoolExecutor(max_workers=4) as executor:
210
+ results = list(executor.map(run_episode, range(4)))
211
+ ```
212
+
213
+ ## Development & Testing
214
+
215
+ ### Direct Environment Testing
216
+
217
+ Test the environment logic directly without starting the HTTP server:
218
+
219
+ ```bash
220
+ # From the server directory
221
+ python3 server/my_env_environment.py
222
+ ```
223
+
224
+ This verifies that:
225
+ - Environment resets correctly
226
+ - Step executes actions properly
227
+ - State tracking works
228
+ - Rewards are calculated correctly
229
+
230
+ ### Running Locally
231
+
232
+ Run the server locally for development:
233
+
234
+ ```bash
235
+ uvicorn server.app:app --reload
236
+ ```
237
+
238
+ ## Project Structure
239
+
240
+ ```
241
+ my_env/
242
+ ├── .dockerignore # Docker build exclusions
243
+ ├── __init__.py # Module exports
244
+ ├── README.md # This file
245
+ ├── openenv.yaml # OpenEnv manifest
246
+ ├── pyproject.toml # Project metadata and dependencies
247
+ ├── uv.lock # Locked dependencies (generated)
248
+ ├── client.py # MyEnv client
249
+ ├── models.py # Action and Observation models
250
+ └── server/
251
+ ├── __init__.py # Server module exports
252
+ ├── my_env_environment.py # Core environment logic
253
+ ├── app.py # FastAPI application (HTTP + WebSocket endpoints)
254
+ └── Dockerfile # Container image definition
255
+ ```
my_env/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """My Env Environment."""
8
+
9
+ from .client import MyEnv
10
+ from .models import MyAction, MyObservation
11
+
12
+ __all__ = [
13
+ "MyAction",
14
+ "MyObservation",
15
+ "MyEnv",
16
+ ]
my_env/client.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """My Env Environment Client."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core.client_types import StepResult
12
+ from openenv.core.env_server.types import State
13
+ from openenv.core import EnvClient
14
+
15
+ from .models import MyAction, MyObservation
16
+
17
+
18
+ class MyEnv(
19
+ EnvClient[MyAction, MyObservation]
20
+ ):
21
+ """
22
+ Client for the My Env Environment.
23
+
24
+ This client maintains a persistent WebSocket connection to the environment server,
25
+ enabling efficient multi-step interactions with lower latency.
26
+ Each client instance has its own dedicated environment session on the server.
27
+
28
+ Example:
29
+ >>> # Connect to a running server
30
+ >>> with MyEnv(base_url="http://localhost:8000") as client:
31
+ ... result = client.reset()
32
+ ... print(result.observation.echoed_message)
33
+ ...
34
+ ... result = client.step(MyAction(message="Hello!"))
35
+ ... print(result.observation.echoed_message)
36
+
37
+ Example with Docker:
38
+ >>> # Automatically start container and connect
39
+ >>> client = MyEnv.from_docker_image("my_env-env:latest")
40
+ >>> try:
41
+ ... result = client.reset()
42
+ ... result = client.step(MyAction(message="Test"))
43
+ ... finally:
44
+ ... client.close()
45
+ """
46
+
47
+ def _step_payload(self, action: MyAction) -> Dict:
48
+ """
49
+ Convert MyAction to JSON payload for step message.
50
+
51
+ Args:
52
+ action: MyAction instance
53
+
54
+ Returns:
55
+ Dictionary representation suitable for JSON encoding
56
+ """
57
+ return {
58
+ "message": action.message,
59
+ }
60
+
61
+ def _parse_result(self, payload: Dict) -> StepResult[MyObservation]:
62
+ """
63
+ Parse server response into StepResult[MyObservation].
64
+
65
+ Args:
66
+ payload: JSON response data from server
67
+
68
+ Returns:
69
+ StepResult with MyObservation
70
+ """
71
+ obs_data = payload.get("observation", {})
72
+ observation = MyObservation(
73
+ echoed_message=obs_data.get("echoed_message", ""),
74
+ message_length=obs_data.get("message_length", 0),
75
+ done=payload.get("done", False),
76
+ reward=payload.get("reward"),
77
+ metadata=obs_data.get("metadata", {}),
78
+ )
79
+
80
+ return StepResult(
81
+ observation=observation,
82
+ reward=payload.get("reward"),
83
+ done=payload.get("done", False),
84
+ )
85
+
86
+ def _parse_state(self, payload: Dict) -> State:
87
+ """
88
+ Parse server response into State object.
89
+
90
+ Args:
91
+ payload: JSON response from state request
92
+
93
+ Returns:
94
+ State object with episode_id and step_count
95
+ """
96
+ return State(
97
+ episode_id=payload.get("episode_id"),
98
+ step_count=payload.get("step_count", 0),
99
+ )
my_env/models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the My Env Environment.
9
+
10
+ The my_env environment is a simple test environment that echoes back messages.
11
+ """
12
+
13
+ from pydantic import Field
14
+
15
+ from openenv.core.env_server.types import Action, Observation
16
+
17
+
18
+ class MyAction(Action):
19
+ """Action for the My Env environment - just a message to echo."""
20
+
21
+ message: str = Field(..., description="Message to echo back")
22
+
23
+
24
+ class MyObservation(Observation):
25
+ """Observation from the My Env environment - the echoed message."""
26
+
27
+ echoed_message: str = Field(default="", description="The echoed message")
28
+ message_length: int = Field(default=0, description="Length of the echoed message")
my_env/openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: my_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
my_env/pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-my_env"
13
+ version = "0.1.0"
14
+ description = "My Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.0",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m my_env.server.app
40
+ server = "my_env.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["my_env", "my_env.server"]
45
+ package-dir = { "my_env" = ".", "my_env.server" = "server" }
my_env/server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=my_env
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
my_env/server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """My Env environment server components."""
8
+
9
+ from .my_env_environment import MyEnvironment
10
+
11
+ __all__ = ["MyEnvironment"]
my_env/server/app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the My Env Environment.
9
+
10
+ This module creates an HTTP server that exposes the MyEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ # Import from local models.py (PYTHONPATH includes /app/env in Docker)
39
+ from models import MyAction, MyObservation
40
+ from .my_env_environment import MyEnvironment
41
+
42
+
43
+ # Create the app with web interface and README integration
44
+ app = create_app(
45
+ MyEnvironment,
46
+ MyAction,
47
+ MyObservation,
48
+ env_name="my_env",
49
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
50
+ )
51
+
52
+
53
+ def main(host: str = "0.0.0.0", port: int = 8000):
54
+ """
55
+ Entry point for direct execution via uv run or python -m.
56
+
57
+ This function enables running the server without Docker:
58
+ uv run --project . server
59
+ uv run --project . server --port 8001
60
+ python -m my_env.server.app
61
+
62
+ Args:
63
+ host: Host address to bind to (default: "0.0.0.0")
64
+ port: Port number to listen on (default: 8000)
65
+
66
+ For production deployments, consider using uvicorn directly with
67
+ multiple workers:
68
+ uvicorn my_env.server.app:app --workers 4
69
+ """
70
+ import uvicorn
71
+
72
+ uvicorn.run(app, host=host, port=port)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ import argparse
77
+
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--port", type=int, default=8000)
80
+ args = parser.parse_args()
81
+ main(port=args.port)
my_env/server/my_env_environment.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ My Env Environment Implementation.
9
+
10
+ A simple test environment that echoes back messages sent to it.
11
+ Perfect for testing HTTP server infrastructure.
12
+ """
13
+
14
+ from uuid import uuid4
15
+
16
+ from openenv.core.env_server.interfaces import Environment
17
+ from openenv.core.env_server.types import State
18
+
19
+ from models import MyAction, MyObservation
20
+
21
+
22
+ class MyEnvironment(Environment):
23
+ """
24
+ A simple echo environment that echoes back messages.
25
+
26
+ This environment is designed for testing the HTTP server infrastructure.
27
+ It maintains minimal state and simply echoes back whatever message it receives.
28
+
29
+ Example:
30
+ >>> env = MyEnvironment()
31
+ >>> obs = env.reset()
32
+ >>> print(obs.echoed_message) # "My Env environment ready!"
33
+ >>>
34
+ >>> obs = env.step(MyAction(message="Hello"))
35
+ >>> print(obs.echoed_message) # "Hello"
36
+ >>> print(obs.message_length) # 5
37
+ """
38
+
39
+ # Enable concurrent WebSocket sessions.
40
+ # Set to True if your environment isolates state between instances.
41
+ # When True, multiple WebSocket clients can connect simultaneously, each
42
+ # getting their own environment instance (when using factory mode in app.py).
43
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
44
+
45
+ def __init__(self):
46
+ """Initialize the my_env environment."""
47
+ self._state = State(episode_id=str(uuid4()), step_count=0)
48
+ self._reset_count = 0
49
+
50
+ def reset(self) -> MyObservation:
51
+ """
52
+ Reset the environment.
53
+
54
+ Returns:
55
+ MyObservation with a ready message
56
+ """
57
+ self._state = State(episode_id=str(uuid4()), step_count=0)
58
+ self._reset_count += 1
59
+
60
+ return MyObservation(
61
+ echoed_message="My Env environment ready!",
62
+ message_length=0,
63
+ done=False,
64
+ reward=0.0,
65
+ )
66
+
67
+ def step(self, action: MyAction) -> MyObservation: # type: ignore[override]
68
+ """
69
+ Execute a step in the environment by echoing the message.
70
+
71
+ Args:
72
+ action: MyAction containing the message to echo
73
+
74
+ Returns:
75
+ MyObservation with the echoed message and its length
76
+ """
77
+ self._state.step_count += 1
78
+
79
+ message = action.message
80
+ length = len(message)
81
+
82
+ # Simple reward: longer messages get higher rewards
83
+ reward = length * 0.1
84
+
85
+ return MyObservation(
86
+ echoed_message=message,
87
+ message_length=length,
88
+ done=False,
89
+ reward=reward,
90
+ metadata={"original_message": message, "step": self._state.step_count},
91
+ )
92
+
93
+ @property
94
+ def state(self) -> State:
95
+ """
96
+ Get the current environment state.
97
+
98
+ Returns:
99
+ Current State with episode_id and step_count
100
+ """
101
+ return self._state
my_env/server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
pyproject.toml CHANGED
@@ -21,9 +21,6 @@ dependencies = [
21
  ]
22
 
23
  [project.optional-dependencies]
24
- train = [
25
- "gymnasium>=0.29.0",
26
- ]
27
  bio = [
28
  "biopython>=1.84",
29
  "gseapy>=1.1.3",
@@ -32,7 +29,16 @@ bio = [
32
  dev = [
33
  "pytest>=8.0.0",
34
  "pytest-cov>=4.0.0",
35
- "gymnasium>=0.29.0",
 
 
 
 
 
 
 
 
 
36
  ]
37
 
38
  [project.scripts]
 
21
  ]
22
 
23
  [project.optional-dependencies]
 
 
 
24
  bio = [
25
  "biopython>=1.84",
26
  "gseapy>=1.1.3",
 
29
  dev = [
30
  "pytest>=8.0.0",
31
  "pytest-cov>=4.0.0",
32
+ ]
33
+ train = [
34
+ "accelerate>=1.13.0",
35
+ "bitsandbytes>=0.45.0",
36
+ "datasets>=4.6.1",
37
+ "matplotlib>=3.10.8",
38
+ "peft>=0.15.0",
39
+ "torch>=2.10.0",
40
+ "transformers>=5.3.0",
41
+ "trl>=0.29.0",
42
  ]
43
 
44
  [project.scripts]
run_agent.py CHANGED
@@ -1,292 +1,978 @@
1
- """Run the bio-experiment environment with Qwen3.5-2B as the planning agent."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- import re
7
- import sys
8
- import time
9
- from typing import Any, Dict, List, Optional
10
-
11
- import torch
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
-
14
- from models import ActionType, ExperimentAction, ExperimentObservation
15
- from server.hackathon_environment import BioExperimentEnvironment
16
-
17
- MODEL_ID = "Qwen/Qwen3.5-0.8B"
18
- MAX_EPISODE_STEPS = 12
19
- PIPELINE_TASK = "image-text-to-text"
20
- USE_PIPELINE = True
21
-
22
- ACTION_TYPES = [a.value for a in ActionType]
23
-
24
- SYSTEM_PROMPT = """\
25
- You are an expert biologist planning a single-cell experiment pipeline.
26
-
27
- At each turn you see the experiment state and must pick the next step.
28
-
29
- Action types (in typical order):
30
- collect_sample, prepare_library, sequence_cells, run_qc, filter_data,
31
- normalize_data, cluster_cells, differential_expression,
32
- pathway_enrichment, marker_selection, validate_marker, synthesize_conclusion
33
-
34
- Other actions: select_cohort, culture_cells, perturb_gene, perturb_compound,
35
- integrate_batches, trajectory_analysis, regulatory_network_inference,
36
- design_followup_experiment, request_subagent_review
37
-
38
- Respond with ONLY valid JSON, nothing else:
39
- {"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}
40
- """
41
-
42
-
43
- def format_observation(obs: ExperimentObservation) -> str:
44
- parts = [
45
- f"TASK: {obs.task.problem_statement}",
46
- f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
47
- f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
48
- f"Step: {obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d",
49
- ]
50
- if obs.pipeline_history:
51
- last5 = obs.pipeline_history[-5:]
52
- parts.append("History:")
53
- for h in last5:
54
- tag = "OK" if h.success else "FAIL"
55
- parts.append(f" [{tag}] {h.action_type.value}: {h.output_summary[:80]}")
56
- if obs.rule_violations:
57
- parts.append(f"VIOLATIONS: {obs.rule_violations}")
58
- if obs.discovered_markers:
59
- parts.append(f"Markers: {obs.discovered_markers[:5]}")
60
- return "\n".join(parts)
61
-
62
-
63
- def parse_action(text: str) -> Optional[ExperimentAction]:
64
- match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
65
- if not match:
66
- return None
67
- try:
68
- d = json.loads(match.group())
69
- except json.JSONDecodeError:
70
- return None
71
-
72
- action_type = d.get("action_type")
73
- if action_type not in ACTION_TYPES:
74
- return None
75
-
76
- return ExperimentAction(
77
- action_type=ActionType(action_type),
78
- method=d.get("method"),
79
- parameters=d.get("parameters") or {},
80
- justification=d.get("justification"),
81
- confidence=min(1.0, max(0.0, float(d.get("confidence", 0.5)))),
82
- )
83
-
84
-
85
- FALLBACK_SEQUENCE = [
86
- ActionType.COLLECT_SAMPLE,
87
- ActionType.PREPARE_LIBRARY,
88
- ActionType.SEQUENCE_CELLS,
89
- ActionType.RUN_QC,
90
- ActionType.FILTER_DATA,
91
- ActionType.NORMALIZE_DATA,
92
- ActionType.CLUSTER_CELLS,
93
- ActionType.DIFFERENTIAL_EXPRESSION,
94
- ActionType.PATHWAY_ENRICHMENT,
95
- ActionType.MARKER_SELECTION,
96
- ActionType.SYNTHESIZE_CONCLUSION,
97
- ]
98
-
99
-
100
- def fallback_action(step: int) -> ExperimentAction:
101
- idx = min(step, len(FALLBACK_SEQUENCE) - 1)
102
- return ExperimentAction(
103
- action_type=FALLBACK_SEQUENCE[idx],
104
- justification="fallback",
105
- confidence=0.3,
106
- )
107
-
108
-
109
- def log(msg: str) -> None:
110
- print(msg, flush=True)
111
-
112
-
113
- def build_observation_prompt(obs: ExperimentObservation) -> str:
114
- return format_observation(obs)
115
-
116
-
117
- def run_with_pipeline(pipe, prompt: str) -> str:
118
- attempts = [
119
- {"text": prompt},
120
- {"text": prompt, "image": None},
121
- {"image": prompt},
122
- ]
123
-
124
- for payload in attempts:
125
- try:
126
- result = pipe(payload, max_new_tokens=220)
127
- if isinstance(result, list) and result:
128
- result = result[0]
129
- if isinstance(result, dict):
130
- text = result.get("generated_text") or result.get("text") or result.get("answer")
131
- elif isinstance(result, str):
132
- text = result
133
- else:
134
- text = ""
135
- if isinstance(text, str) and text.strip():
136
- return text.strip()
137
- except Exception:
138
- continue
139
-
140
- return ""
141
-
142
-
143
- def main():
144
- tokenizer = None
145
- model = None
146
- eos_ids: List[int] = []
147
- active_pipeline = None
148
-
149
- if USE_PIPELINE:
150
- log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...")
151
- try:
152
- active_pipeline = pipeline(
153
- PIPELINE_TASK,
154
- model=MODEL_ID,
155
- trust_remote_code=True,
156
- torch_dtype=torch.bfloat16,
157
- )
158
- log("Pipeline loaded.")
159
- except Exception as exc:
160
- log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.")
161
-
162
- if active_pipeline is None:
163
- log(f"Loading tokenizer for {MODEL_ID} ...")
164
- tokenizer = AutoTokenizer.from_pretrained(
165
- MODEL_ID, trust_remote_code=True,
166
- )
167
- log("Tokenizer loaded. Loading model (this downloads ~4 GB on first run) ...")
168
-
169
- model = AutoModelForCausalLM.from_pretrained(
170
- MODEL_ID,
171
- torch_dtype=torch.bfloat16,
172
- device_map="auto",
173
- trust_remote_code=True,
174
- )
175
- log(f"Model loaded. Device: {model.device}")
176
-
177
- if tokenizer.eos_token_id is not None:
178
- eos_ids.append(tokenizer.eos_token_id)
179
- extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
180
- for tid in extra:
181
- if isinstance(tid, int) and tid not in eos_ids:
182
- eos_ids.append(tid)
183
- log(f"EOS token ids: {eos_ids}")
184
-
185
- env = BioExperimentEnvironment()
186
- obs = env.reset()
187
-
188
- log("\n" + "=" * 70)
189
- log(f"TASK: {obs.task.problem_statement}")
190
- log(f"Conditions: {obs.task.conditions}")
191
- log(f"Budget: ${obs.task.budget_limit:,.0f} | Time: {obs.task.time_limit_days:.0f} days")
192
- log("=" * 70)
193
-
194
- cumulative_reward = 0.0
195
-
196
- for step in range(MAX_EPISODE_STEPS):
197
- user_msg = build_observation_prompt(obs)
198
-
199
- messages = [
200
- {"role": "system", "content": SYSTEM_PROMPT},
201
- {"role": "user", "content": user_msg},
202
- ]
203
-
204
- if tokenizer is None:
205
- # Pipeline path usually ignores chat templates.
206
- prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}"
207
- else:
208
- try:
209
- prompt = tokenizer.apply_chat_template(
210
- messages,
211
- tokenize=False,
212
- add_generation_prompt=True,
213
- enable_thinking=False,
214
- )
215
- except TypeError:
216
- prompt = tokenizer.apply_chat_template(
217
- messages,
218
- tokenize=False,
219
- add_generation_prompt=True,
220
- )
221
-
222
- t0 = time.time()
223
- if active_pipeline is not None:
224
- response = run_with_pipeline(active_pipeline, prompt)
225
- if not response:
226
- response = format_observation(obs)
227
- else:
228
- assert tokenizer is not None and model is not None
229
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
230
- n_input = inputs["input_ids"].shape[1]
231
- with torch.no_grad():
232
- output_ids = model.generate(
233
- **inputs,
234
- max_new_tokens=200,
235
- do_sample=True,
236
- temperature=0.7,
237
- top_p=0.8,
238
- top_k=20,
239
- repetition_penalty=1.3,
240
- eos_token_id=eos_ids if eos_ids else None,
241
- )
242
- new_tokens = output_ids[0][n_input:]
243
- response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
244
- gen_time = time.time() - t0
245
-
246
- action = parse_action(response)
247
- used_fallback = False
248
- if action is None:
249
- log(f"\n [!] Parse failed, using fallback. Raw: {response[:150]}")
250
- action = fallback_action(step)
251
- used_fallback = True
252
-
253
- tag = " [FALLBACK]" if used_fallback else ""
254
- log(f"\nStep {step + 1}: {action.action_type.value}{tag} ({gen_time:.1f}s)")
255
- if action.justification:
256
- log(f" Rationale: {action.justification}")
257
-
258
- obs = env.step(action)
259
-
260
- if obs.latest_output:
261
- lo = obs.latest_output
262
- status = "OK" if lo.success else "FAIL"
263
- log(f" [{status}] {lo.summary}")
264
- if lo.warnings:
265
- log(f" Warnings: {lo.warnings}")
266
-
267
- step_reward = obs.reward
268
- cumulative_reward += step_reward
269
- log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
270
- log(f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d")
271
-
272
- if obs.rule_violations:
273
- log(f" Violations: {obs.rule_violations}")
274
-
275
- if obs.done:
276
- break
277
-
278
- log(f"\n{'=' * 70}")
279
- log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
280
- log(f" Steps: {obs.step_index}")
281
- log(f" Total reward: {cumulative_reward:+.3f}")
282
- log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
283
- log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
284
- if obs.conclusions:
285
- log(" Conclusions:")
286
- for c in obs.conclusions:
287
- log(f" [{c.claim_type}, conf={c.confidence:.2f}] {c.claim}")
288
- log("=" * 70)
289
-
290
-
291
- if __name__ == "__main__":
292
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the bio-experiment environment with Qwen3.5-0.8B as the planning agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import re
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
+
15
+ from models import (
16
+ ActionType,
17
+ ExperimentAction,
18
+ ExperimentObservation,
19
+ build_agent_observation_context,
20
+ build_agent_system_prompt,
21
+ )
22
+ from server.hackathon_environment import BioExperimentEnvironment
23
+
24
+ DASHBOARD_STATE_PATH = Path(__file__).parent / "_dashboard_state.json"
25
+ DASHBOARD_CMD_PATH = Path(__file__).parent / "_dashboard_cmd.json"
26
+
27
+ USE_PIPELINE = os.getenv("RUN_AGENT_USE_PIPELINE", "0").strip().lower() not in {"0", "false", "off"}
28
+
29
+ def _parse_thinking_flag() -> bool:
30
+ import sys
31
+ if "--no-thinking" in sys.argv:
32
+ return False
33
+ if "--thinking" in sys.argv:
34
+ return True
35
+ return os.getenv("RUN_AGENT_ENABLE_THINKING", "1").strip().lower() not in {"0", "false", "off"}
36
+
37
+ ENABLE_THINKING = _parse_thinking_flag()
38
+
39
+ MODEL_ID = "Qwen/Qwen3.5-2B"
40
+ MAX_EPISODE_STEPS = int(os.getenv("RUN_AGENT_MAX_EPISODE_STEPS", "20"))
41
+ PIPELINE_TASK = "text-generation"
42
+
43
+ ACTION_TYPES = [a.value for a in ActionType]
44
+ ACTION_TYPE_ALIASES = {
45
+ "collect_samples": ActionType.COLLECT_SAMPLE.value,
46
+ "collect_sample_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
47
+ "collect_samples_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
48
+ "prepare_sc_library": ActionType.PREPARE_LIBRARY.value,
49
+ "sequence_single_cells": ActionType.SEQUENCE_CELLS.value,
50
+ "qc": ActionType.RUN_QC.value,
51
+ "run_quality_control": ActionType.RUN_QC.value,
52
+ "cluster": ActionType.CLUSTER_CELLS.value,
53
+ "de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
54
+ "differential_expression_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
55
+ "trajectory_inference": ActionType.TRAJECTORY_ANALYSIS.value,
56
+ "infer_trajectory": ActionType.TRAJECTORY_ANALYSIS.value,
57
+ "network_inference": ActionType.REGULATORY_NETWORK_INFERENCE.value,
58
+ "select_markers": ActionType.MARKER_SELECTION.value,
59
+ "final_conclusion": ActionType.SYNTHESIZE_CONCLUSION.value,
60
+ }
61
+
62
+ SYSTEM_PROMPT = build_agent_system_prompt()
63
+
64
+ STANDARD_PIPELINE_ORDER = [
65
+ ActionType.COLLECT_SAMPLE,
66
+ ActionType.SELECT_COHORT,
67
+ ActionType.PREPARE_LIBRARY,
68
+ ActionType.SEQUENCE_CELLS,
69
+ ActionType.RUN_QC,
70
+ ActionType.FILTER_DATA,
71
+ ActionType.NORMALIZE_DATA,
72
+ ActionType.INTEGRATE_BATCHES,
73
+ ActionType.CLUSTER_CELLS,
74
+ ActionType.DIFFERENTIAL_EXPRESSION,
75
+ ActionType.PATHWAY_ENRICHMENT,
76
+ ActionType.MARKER_SELECTION,
77
+ ActionType.TRAJECTORY_ANALYSIS,
78
+ ActionType.REGULATORY_NETWORK_INFERENCE,
79
+ ActionType.SYNTHESIZE_CONCLUSION,
80
+ ]
81
+
82
+ MODEL_RESPONSE_PREVIEW_CHARS = int(
83
+ os.getenv("RUN_AGENT_MODEL_RESPONSE_PREVIEW_CHARS", "240")
84
+ )
85
+
86
+
87
+ def compact_preview(value: Any, max_chars: int = 160) -> str:
88
+ try:
89
+ text = json.dumps(value, ensure_ascii=True, sort_keys=True)
90
+ except TypeError:
91
+ text = str(value)
92
+ text = re.sub(r"\s+", " ", text).strip()
93
+ if len(text) <= max_chars:
94
+ return text
95
+ return text[: max_chars - 3] + "..."
96
+
97
+
98
+ def format_observation(obs: ExperimentObservation) -> str:
99
+ parts = [
100
+ f"TASK: {obs.task.problem_statement}",
101
+ f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
102
+ f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
103
+ f"Step: {obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d",
104
+ ]
105
+ context = build_agent_observation_context(obs, max_tools=5, max_assays=2)
106
+ if context:
107
+ parts.append(context)
108
+ if obs.pipeline_history:
109
+ last5 = obs.pipeline_history[-5:]
110
+ parts.append("Recent history:")
111
+ for h in last5:
112
+ tag = "OK" if h.success else "FAIL"
113
+ line = f" [{tag}] {h.action_type.value}"
114
+ if h.method:
115
+ line += f" ({h.method})"
116
+ line += f": {h.output_summary[:80]}"
117
+ parts.append(line)
118
+
119
+ completed = {h.action_type for h in obs.pipeline_history if h.success}
120
+ if completed:
121
+ parts.append(f"Completed steps (do NOT repeat): {', '.join(sorted(a.value for a in completed))}")
122
+ remaining = [a.value for a in STANDARD_PIPELINE_ORDER if a not in completed]
123
+ if remaining:
124
+ parts.append(f"Remaining steps (choose one): {', '.join(remaining)}")
125
+
126
+ if obs.latest_output and obs.latest_output.data:
127
+ parts.append(
128
+ f"Latest data: {compact_preview(obs.latest_output.data, 200)}"
129
+ )
130
+ if obs.rule_violations:
131
+ parts.append(f"VIOLATIONS: {obs.rule_violations}")
132
+ if obs.discovered_markers:
133
+ parts.append(f"Markers found so far: {obs.discovered_markers[:5]}")
134
+
135
+ parts.append(
136
+ 'Output ONLY a single JSON object with these exact keys, no comments, no extra text:\n'
137
+ '{"action_type": "<one of the remaining steps>", "method": null, "parameters": {}, "justification": "<why>", "confidence": 0.8}'
138
+ )
139
+ return "\n".join(parts)
140
+
141
+
142
+ def _repair_truncated_json(text: str) -> Optional[str]:
143
+ """Try to repair JSON truncated mid-value (common with small LLMs)."""
144
+ s = text.strip()
145
+ if not s.startswith("{"):
146
+ return None
147
+
148
+ # Drop dangling partial keys or empty key/value stubs at the tail.
149
+ s = re.sub(r',\s*"[^"\n]*$', '', s)
150
+ s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
151
+
152
+ in_string = False
153
+ escape = False
154
+ for ch in s:
155
+ if escape:
156
+ escape = False
157
+ continue
158
+ if ch == "\\":
159
+ escape = True
160
+ continue
161
+ if ch == '"':
162
+ in_string = not in_string
163
+
164
+ if in_string:
165
+ s += '"'
166
+
167
+ open_braces = s.count("{") - s.count("}")
168
+ open_brackets = s.count("[") - s.count("]")
169
+ s += "]" * max(0, open_brackets)
170
+ s += "}" * max(0, open_braces)
171
+
172
+ try:
173
+ obj = json.loads(s)
174
+ if isinstance(obj, dict):
175
+ return s
176
+ except json.JSONDecodeError:
177
+ pass
178
+
179
+ s = re.sub(r',\s*([}\]])', r'\1', s)
180
+ try:
181
+ obj = json.loads(s)
182
+ if isinstance(obj, dict):
183
+ return s
184
+ except json.JSONDecodeError:
185
+ pass
186
+ return None
187
+
188
+
189
+ def _normalize_jsonish_text(text: str) -> str:
190
+ """Normalize common near-JSON artifacts emitted by small local models."""
191
+ text = _strip_js_comments(text)
192
+ text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
193
+ text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
194
+ text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
195
+ text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
196
+ return text
197
+
198
+
199
+ def _strip_js_comments(text: str) -> str:
200
+ """Remove // and /* */ comments that small LLMs inject into JSON."""
201
+ text = re.sub(r'//[^\n]*', '', text)
202
+ text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
203
+ return text
204
+
205
+
206
+ def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
207
+ stripped = _normalize_jsonish_text(text).strip()
208
+ fence_prefix = "```"
209
+ if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
210
+ lines = stripped.splitlines()
211
+ if len(lines) >= 3:
212
+ stripped = "\n".join(lines[1:-1]).strip()
213
+
214
+ candidates: List[str] = [stripped]
215
+ start = stripped.find("{")
216
+ while start != -1:
217
+ depth = 0
218
+ for idx in range(start, len(stripped)):
219
+ char = stripped[idx]
220
+ if char == "{":
221
+ depth += 1
222
+ elif char == "}":
223
+ depth -= 1
224
+ if depth == 0:
225
+ candidates.append(stripped[start:idx + 1])
226
+ break
227
+ start = stripped.find("{", start + 1)
228
+
229
+ first_brace = stripped.find("{")
230
+ if first_brace != -1:
231
+ repaired = _repair_truncated_json(stripped[first_brace:])
232
+ if repaired is not None:
233
+ candidates.append(repaired)
234
+
235
+ candidates.sort(key=len, reverse=True)
236
+
237
+ for candidate in candidates:
238
+ try:
239
+ parsed = json.loads(candidate)
240
+ except json.JSONDecodeError:
241
+ continue
242
+ if isinstance(parsed, dict):
243
+ return parsed
244
+
245
+ return None
246
+
247
+
248
+ def _edit_distance(a: str, b: str) -> int:
249
+ if len(a) < len(b):
250
+ return _edit_distance(b, a)
251
+ if not b:
252
+ return len(a)
253
+ prev = list(range(len(b) + 1))
254
+ for i, ca in enumerate(a):
255
+ curr = [i + 1]
256
+ for j, cb in enumerate(b):
257
+ curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
258
+ prev = curr
259
+ return prev[-1]
260
+
261
+
262
+ def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
263
+ for name in names:
264
+ if name in payload:
265
+ return payload[name]
266
+
267
+ lowered = {
268
+ str(key).lower(): value
269
+ for key, value in payload.items()
270
+ }
271
+ for name in names:
272
+ if name.lower() in lowered:
273
+ return lowered[name.lower()]
274
+
275
+ for key, value in lowered.items():
276
+ for name in names:
277
+ threshold = max(2, len(name) // 3)
278
+ if _edit_distance(key, name.lower()) <= threshold:
279
+ return value
280
+ return None
281
+
282
+
283
+ def normalize_optional_string(value: Any) -> Optional[str]:
284
+ if value is None or isinstance(value, bool):
285
+ return None
286
+ if isinstance(value, str):
287
+ value = value.strip()
288
+ return value or None
289
+ if isinstance(value, (int, float)):
290
+ return str(value)
291
+ return compact_preview(value, 80)
292
+
293
+
294
+ def normalize_action_type(raw_action_type: Any) -> Optional[str]:
295
+ if not isinstance(raw_action_type, str):
296
+ return None
297
+
298
+ candidate = raw_action_type.strip().lower()
299
+ if candidate in ACTION_TYPES:
300
+ return candidate
301
+ if candidate in ACTION_TYPE_ALIASES:
302
+ return ACTION_TYPE_ALIASES[candidate]
303
+
304
+ candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
305
+ if candidate in ACTION_TYPES:
306
+ return candidate
307
+ if candidate in ACTION_TYPE_ALIASES:
308
+ return ACTION_TYPE_ALIASES[candidate]
309
+
310
+ heuristics = [
311
+ (("collect", "sample"), ActionType.COLLECT_SAMPLE.value),
312
+ (("library",), ActionType.PREPARE_LIBRARY.value),
313
+ (("sequence",), ActionType.SEQUENCE_CELLS.value),
314
+ (("qc",), ActionType.RUN_QC.value),
315
+ (("quality", "control"), ActionType.RUN_QC.value),
316
+ (("filter",), ActionType.FILTER_DATA.value),
317
+ (("normal",), ActionType.NORMALIZE_DATA.value),
318
+ (("integrat", "batch"), ActionType.INTEGRATE_BATCHES.value),
319
+ (("cluster",), ActionType.CLUSTER_CELLS.value),
320
+ (("differential", "expression"), ActionType.DIFFERENTIAL_EXPRESSION.value),
321
+ (("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
322
+ (("trajectory",), ActionType.TRAJECTORY_ANALYSIS.value),
323
+ (("network",), ActionType.REGULATORY_NETWORK_INFERENCE.value),
324
+ (("marker",), ActionType.MARKER_SELECTION.value),
325
+ (("validat", "marker"), ActionType.VALIDATE_MARKER.value),
326
+ (("followup",), ActionType.DESIGN_FOLLOWUP.value),
327
+ (("review",), ActionType.REQUEST_SUBAGENT_REVIEW.value),
328
+ (("conclusion",), ActionType.SYNTHESIZE_CONCLUSION.value),
329
+ ]
330
+ for fragments, normalized in heuristics:
331
+ if all(fragment in candidate for fragment in fragments):
332
+ return normalized
333
+ return None
334
+
335
+
336
+ def should_block_failed_reattempt(
337
+ history: List[Any], action_type: ActionType
338
+ ) -> bool:
339
+ last_failed_idx = None
340
+ last_success_idx = None
341
+
342
+ for idx, record in enumerate(history):
343
+ if record.action_type != action_type:
344
+ continue
345
+ if record.success:
346
+ last_success_idx = idx
347
+ else:
348
+ last_failed_idx = idx
349
+
350
+ if last_failed_idx is None:
351
+ return False
352
+
353
+ # Allow retry after the same action has already succeeded once, or after the
354
+ # pipeline made progress with a different successful step since the failure.
355
+ if last_success_idx is not None and last_success_idx > last_failed_idx:
356
+ return False
357
+ for record in history[last_failed_idx + 1:]:
358
+ if record.success and record.action_type != action_type:
359
+ return False
360
+ return True
361
+
362
+
363
+ def parse_action(text: str) -> Optional[ExperimentAction]:
364
+ d = extract_json_object(text)
365
+ if d is not None:
366
+ action_type = normalize_action_type(get_payload_value(d, "action_type"))
367
+ if action_type is None:
368
+ return None
369
+
370
+ parameters = get_payload_value(d, "parameters", "params") or {}
371
+ if not isinstance(parameters, dict):
372
+ parameters = {}
373
+
374
+ confidence = get_payload_value(d, "confidence")
375
+ if confidence is None:
376
+ confidence = 0.5
377
+ try:
378
+ confidence = float(confidence)
379
+ except (TypeError, ValueError):
380
+ confidence = 0.5
381
+
382
+ justification = get_payload_value(
383
+ d, "justification", "reasoning", "rationale", "reason"
384
+ )
385
+ if justification is not None and not isinstance(justification, str):
386
+ justification = compact_preview(justification, 200)
387
+ method = normalize_optional_string(get_payload_value(d, "method"))
388
+
389
+ return ExperimentAction(
390
+ action_type=ActionType(action_type),
391
+ method=method,
392
+ parameters=parameters,
393
+ justification=justification,
394
+ confidence=min(1.0, max(0.0, confidence)),
395
+ )
396
+
397
+ action_match = re.search(
398
+ r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
399
+ text,
400
+ re.IGNORECASE,
401
+ )
402
+ if not action_match:
403
+ return None
404
+
405
+ action_type = normalize_action_type(action_match.group(1))
406
+ if action_type is None:
407
+ return None
408
+
409
+ method_match = re.search(
410
+ r'["\']method["\']\s*:\s*("((?:[^"\\]|\\.)*)"|null|none|true|false|-?\d+(?:\.\d+)?)',
411
+ text,
412
+ re.IGNORECASE,
413
+ )
414
+ confidence_match = re.search(
415
+ r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)',
416
+ text,
417
+ re.IGNORECASE,
418
+ )
419
+ justification_match = re.search(
420
+ r'["\'](?:justif\w*|reasoning|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)',
421
+ text,
422
+ re.DOTALL | re.IGNORECASE,
423
+ )
424
+
425
+ confidence = 0.5
426
+ if confidence_match:
427
+ try:
428
+ confidence = float(confidence_match.group(1))
429
+ except ValueError:
430
+ confidence = 0.5
431
+
432
+ justification = None
433
+ if justification_match:
434
+ try:
435
+ justification = json.loads(f'"{justification_match.group(1)}"')
436
+ except json.JSONDecodeError:
437
+ justification = justification_match.group(1)
438
+
439
+ method = None
440
+ if method_match:
441
+ raw_method = method_match.group(1)
442
+ if raw_method.startswith('"') and raw_method.endswith('"'):
443
+ try:
444
+ method = json.loads(raw_method)
445
+ except json.JSONDecodeError:
446
+ method = raw_method.strip('"')
447
+ elif raw_method.lower() not in {"null", "none", "true", "false"}:
448
+ method = raw_method
449
+ method = normalize_optional_string(method)
450
+
451
+ return ExperimentAction(
452
+ action_type=ActionType(action_type),
453
+ method=method,
454
+ parameters={},
455
+ justification=justification,
456
+ confidence=min(1.0, max(0.0, confidence)),
457
+ )
458
+
459
+
460
+ def should_force_terminal_conclusion(
461
+ action: ExperimentAction,
462
+ completed_types: set[ActionType],
463
+ ) -> bool:
464
+ meta_repeatables = {
465
+ ActionType.DESIGN_FOLLOWUP,
466
+ ActionType.REQUEST_SUBAGENT_REVIEW,
467
+ }
468
+ return (
469
+ action.action_type in meta_repeatables
470
+ and action.action_type in completed_types
471
+ and ActionType.SYNTHESIZE_CONCLUSION not in completed_types
472
+ )
473
+
474
+
475
+
476
+ def write_dashboard_state(
477
+ env: BioExperimentEnvironment,
478
+ obs: ExperimentObservation,
479
+ *,
480
+ step: int,
481
+ cumulative_reward: float,
482
+ model_response: str = "",
483
+ model_thinking: str = "",
484
+ action: Optional[ExperimentAction] = None,
485
+ gen_time: float = 0.0,
486
+ episode_done: bool = False,
487
+ ) -> None:
488
+ """Serialise the full world state (observable + latent) for the dashboard."""
489
+ latent = env._latent
490
+ snapshot: Dict[str, Any] = {
491
+ "timestamp": time.time(),
492
+ "step": step,
493
+ "episode_done": episode_done,
494
+ "cumulative_reward": cumulative_reward,
495
+ "gen_time_s": round(gen_time, 2),
496
+ "model_response_raw": model_response[:600],
497
+ "model_thinking": model_thinking[:800],
498
+ "thinking_enabled": ENABLE_THINKING,
499
+ }
500
+
501
+ snapshot["task"] = {
502
+ "problem_statement": obs.task.problem_statement,
503
+ "organism": obs.task.organism,
504
+ "tissue": obs.task.tissue,
505
+ "modality": obs.task.modality,
506
+ "conditions": obs.task.conditions,
507
+ "budget_limit": obs.task.budget_limit,
508
+ "time_limit_days": obs.task.time_limit_days,
509
+ }
510
+
511
+ snapshot["resources"] = {
512
+ "budget_used": round(obs.resource_usage.budget_used, 2),
513
+ "budget_remaining": round(obs.resource_usage.budget_remaining, 2),
514
+ "time_used_days": round(obs.resource_usage.time_used_days, 1),
515
+ "time_remaining_days": round(obs.resource_usage.time_remaining_days, 1),
516
+ "samples_consumed": obs.resource_usage.samples_consumed,
517
+ "compute_hours_used": round(obs.resource_usage.compute_hours_used, 2),
518
+ }
519
+
520
+ snapshot["pipeline_history"] = [
521
+ {
522
+ "step_index": h.step_index,
523
+ "action_type": h.action_type.value,
524
+ "method": h.method,
525
+ "output_summary": h.output_summary[:120],
526
+ "success": h.success,
527
+ "quality_score": round(h.quality_score, 3),
528
+ "resource_cost": round(h.resource_cost, 2),
529
+ "time_cost_days": round(h.time_cost_days, 1),
530
+ }
531
+ for h in obs.pipeline_history
532
+ ]
533
+
534
+ if action:
535
+ snapshot["current_action"] = {
536
+ "action_type": action.action_type.value,
537
+ "method": action.method,
538
+ "parameters": action.parameters,
539
+ "justification": action.justification,
540
+ "confidence": action.confidence,
541
+ }
542
+
543
+ if obs.latest_output:
544
+ lo = obs.latest_output
545
+ snapshot["latest_output"] = {
546
+ "summary": lo.summary,
547
+ "success": lo.success,
548
+ "quality_score": round(lo.quality_score, 3),
549
+ "uncertainty": round(lo.uncertainty, 3),
550
+ "warnings": lo.warnings,
551
+ "data_preview": compact_preview(lo.data, 300) if lo.data else None,
552
+ }
553
+
554
+ snapshot["discovered_markers"] = obs.discovered_markers[:20]
555
+ snapshot["candidate_mechanisms"] = obs.candidate_mechanisms[:20]
556
+ snapshot["rule_violations"] = obs.rule_violations
557
+ snapshot["uncertainty_summary"] = {
558
+ k: round(v, 3) for k, v in obs.uncertainty_summary.items()
559
+ }
560
+ snapshot["reward_breakdown"] = {
561
+ k: round(v, 4) for k, v in obs.step_reward_breakdown.items()
562
+ }
563
+
564
+ if obs.conclusions:
565
+ snapshot["conclusions"] = [
566
+ {
567
+ "claim": c.claim,
568
+ "claim_type": c.claim_type,
569
+ "confidence": c.confidence,
570
+ "top_markers": c.top_markers,
571
+ "causal_mechanisms": c.causal_mechanisms,
572
+ "predicted_pathways": c.predicted_pathways,
573
+ }
574
+ for c in obs.conclusions
575
+ ]
576
+
577
+ if latent:
578
+ bio = latent.biology
579
+ snapshot["latent"] = {
580
+ "cell_populations": [
581
+ {
582
+ "name": cp.name,
583
+ "proportion": round(cp.proportion, 3),
584
+ "marker_genes": cp.marker_genes[:8],
585
+ "state": cp.state,
586
+ }
587
+ for cp in bio.cell_populations
588
+ ],
589
+ "true_markers": bio.true_markers,
590
+ "causal_mechanisms": bio.causal_mechanisms,
591
+ "true_pathways": {
592
+ k: round(v, 3) for k, v in list(bio.true_pathways.items())[:15]
593
+ },
594
+ "true_de_genes_count": sum(
595
+ len(genes) for genes in bio.true_de_genes.values()
596
+ ),
597
+ "true_regulatory_network_size": sum(
598
+ len(targets) for targets in bio.true_regulatory_network.values()
599
+ ),
600
+ "confounders": bio.confounders,
601
+ "n_true_cells": bio.n_true_cells,
602
+ "technical": {
603
+ "ambient_rna_fraction": latent.technical.ambient_rna_fraction,
604
+ "doublet_rate": latent.technical.doublet_rate,
605
+ "dropout_rate": latent.technical.dropout_rate,
606
+ "sample_quality": latent.technical.sample_quality,
607
+ "library_complexity": latent.technical.library_complexity,
608
+ "capture_efficiency": latent.technical.capture_efficiency,
609
+ },
610
+ "progress": latent.progress.model_dump(),
611
+ "hidden_failure_conditions": latent.hidden_failure_conditions,
612
+ }
613
+
614
+ try:
615
+ DASHBOARD_STATE_PATH.write_text(
616
+ json.dumps(snapshot, indent=2, default=str), encoding="utf-8"
617
+ )
618
+ except Exception:
619
+ pass
620
+
621
+
622
+ def log(msg: str) -> None:
623
+ print(msg, flush=True)
624
+
625
+
626
+ def build_observation_prompt(obs: ExperimentObservation) -> str:
627
+ return format_observation(obs)
628
+
629
+
630
+ def run_with_pipeline(pipe, prompt: str) -> str:
631
+ try:
632
+ _pipe_max = 2048 if ENABLE_THINKING else 300
633
+ result = pipe(prompt, max_new_tokens=_pipe_max, return_full_text=False)
634
+ except Exception:
635
+ return ""
636
+
637
+ if isinstance(result, list) and result:
638
+ result = result[0]
639
+ if isinstance(result, dict):
640
+ text = result.get("generated_text") or result.get("text") or result.get("answer")
641
+ elif isinstance(result, str):
642
+ text = result
643
+ else:
644
+ text = ""
645
+ return text.strip() if isinstance(text, str) else ""
646
+
647
+
648
+ def resolve_torch_runtime() -> Dict[str, Any]:
649
+ use_cuda = torch.cuda.is_available()
650
+ bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
651
+ dtype = torch.bfloat16 if bf16 else (
652
+ torch.float16 if use_cuda else torch.float32
653
+ )
654
+ return {
655
+ "use_cuda": use_cuda,
656
+ "device": "cuda:0" if use_cuda else "cpu",
657
+ "dtype": dtype,
658
+ "device_map": "auto" if use_cuda else None,
659
+ "device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
660
+ }
661
+
662
+
663
+ def main():
664
+ tokenizer = None
665
+ model = None
666
+ eos_ids: List[int] = []
667
+ active_pipeline = None
668
+
669
+ runtime = resolve_torch_runtime()
670
+ log(
671
+ f"Using local model runtime: device={runtime['device']} "
672
+ f"name={runtime['device_name']} dtype={runtime['dtype']}"
673
+ )
674
+
675
+ if USE_PIPELINE:
676
+ log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...")
677
+ try:
678
+ active_pipeline = pipeline(
679
+ PIPELINE_TASK,
680
+ model=MODEL_ID,
681
+ trust_remote_code=True,
682
+ dtype=runtime["dtype"],
683
+ device=0 if runtime["use_cuda"] else -1,
684
+ )
685
+ log("Pipeline loaded.")
686
+ except Exception as exc:
687
+ log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.")
688
+
689
+ if active_pipeline is None:
690
+ log(f"Loading tokenizer for {MODEL_ID} ...")
691
+ tokenizer = AutoTokenizer.from_pretrained(
692
+ MODEL_ID, trust_remote_code=True,
693
+ )
694
+ log("Tokenizer loaded. Loading model (this may download files on first run) ...")
695
+
696
+ model = AutoModelForCausalLM.from_pretrained(
697
+ MODEL_ID,
698
+ dtype=runtime["dtype"],
699
+ device_map=runtime["device_map"],
700
+ trust_remote_code=True,
701
+ )
702
+ log(f"Model loaded. Device: {model.device}")
703
+
704
+ if tokenizer.eos_token_id is not None:
705
+ eos_ids.append(tokenizer.eos_token_id)
706
+ extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
707
+ for tid in extra:
708
+ if isinstance(tid, int) and tid not in eos_ids:
709
+ eos_ids.append(tid)
710
+ log(f"EOS token ids: {eos_ids}")
711
+
712
+ def check_dashboard_command() -> Optional[Dict[str, Any]]:
713
+ """Read and consume a command file written by the dashboard."""
714
+ try:
715
+ raw = DASHBOARD_CMD_PATH.read_text(encoding="utf-8")
716
+ DASHBOARD_CMD_PATH.unlink(missing_ok=True)
717
+ return json.loads(raw)
718
+ except (FileNotFoundError, json.JSONDecodeError):
719
+ return None
720
+
721
+ def run_episode(
722
+ scenario_name: Optional[str] = None,
723
+ custom_ground_truth: Optional[Dict[str, Any]] = None,
724
+ ):
725
+ env = BioExperimentEnvironment(scenario_name=scenario_name)
726
+ obs = env.reset()
727
+
728
+ if custom_ground_truth and env._latent:
729
+ gt = custom_ground_truth
730
+ bio = env._latent.biology
731
+ if gt.get("true_markers"):
732
+ bio.true_markers = gt["true_markers"]
733
+ if gt.get("causal_mechanisms"):
734
+ bio.causal_mechanisms = gt["causal_mechanisms"]
735
+ if gt.get("true_pathways"):
736
+ bio.true_pathways = {
737
+ k: float(v) for k, v in gt["true_pathways"].items()
738
+ }
739
+
740
+ log("\n" + "=" * 70)
741
+ log(f"TASK: {obs.task.problem_statement}")
742
+ log(f"Conditions: {obs.task.conditions}")
743
+ log(f"Budget: ${obs.task.budget_limit:,.0f} | Time: {obs.task.time_limit_days:.0f} days")
744
+ if ENABLE_THINKING:
745
+ log("Reasoning mode: ENABLED")
746
+ log("=" * 70)
747
+
748
+ cumulative_reward = 0.0
749
+ write_dashboard_state(env, obs, step=0, cumulative_reward=0.0)
750
+
751
+ for step in range(MAX_EPISODE_STEPS):
752
+ cmd = check_dashboard_command()
753
+ if cmd and cmd.get("action") == "restart":
754
+ log("\n[DASHBOARD] Restart requested — ending episode early.")
755
+ break
756
+
757
+ user_msg = build_observation_prompt(obs)
758
+
759
+ messages = [
760
+ {"role": "system", "content": SYSTEM_PROMPT},
761
+ {"role": "user", "content": user_msg},
762
+ ]
763
+
764
+ if active_pipeline is not None:
765
+ prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}"
766
+ else:
767
+ try:
768
+ prompt = tokenizer.apply_chat_template(
769
+ messages,
770
+ tokenize=False,
771
+ add_generation_prompt=True,
772
+ enable_thinking=ENABLE_THINKING,
773
+ )
774
+ except TypeError:
775
+ prompt = tokenizer.apply_chat_template(
776
+ messages,
777
+ tokenize=False,
778
+ add_generation_prompt=True,
779
+ )
780
+
781
+ t0 = time.time()
782
+ if active_pipeline is not None:
783
+ response = run_with_pipeline(active_pipeline, prompt)
784
+ if not response:
785
+ response = format_observation(obs)
786
+ else:
787
+ assert tokenizer is not None and model is not None
788
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
789
+ n_input = inputs["input_ids"].shape[1]
790
+ max_new = 2048 if ENABLE_THINKING else 300
791
+ with torch.no_grad():
792
+ output_ids = model.generate(
793
+ **inputs,
794
+ max_new_tokens=max_new,
795
+ do_sample=True,
796
+ temperature=0.7,
797
+ top_p=0.8,
798
+ top_k=20,
799
+ repetition_penalty=1.3,
800
+ eos_token_id=eos_ids if eos_ids else None,
801
+ )
802
+ new_tokens = output_ids[0][n_input:]
803
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
804
+ gen_time = time.time() - t0
805
+
806
+ thinking = ""
807
+ if ENABLE_THINKING:
808
+ think_match = re.search(
809
+ r"<think>(.*?)</think>", response, re.DOTALL
810
+ )
811
+ if think_match:
812
+ thinking = think_match.group(1).strip()
813
+ response = response[think_match.end():].strip()
814
+ elif response.startswith("<think>"):
815
+ parts = response.split("</think>", 1)
816
+ if len(parts) == 2:
817
+ thinking = parts[0].replace("<think>", "").strip()
818
+ response = parts[1].strip()
819
+
820
+ is_last_step = (step == MAX_EPISODE_STEPS - 1)
821
+
822
+ action = parse_action(response)
823
+ if action is None:
824
+ if is_last_step:
825
+ log(f"\n [!] Parse failed on final step — forcing synthesize_conclusion.")
826
+ action = ExperimentAction(
827
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
828
+ justification="forced terminal conclusion",
829
+ confidence=0.5,
830
+ )
831
+ else:
832
+ log(f"\n [!] Parse failed, skipping step. Raw: {response[:150]}")
833
+ continue
834
+
835
+ completed_types = {
836
+ r.action_type for r in obs.pipeline_history if r.success
837
+ }
838
+ failed_types = {
839
+ r.action_type
840
+ for r in obs.pipeline_history
841
+ if not r.success
842
+ }
843
+
844
+ if should_force_terminal_conclusion(action, completed_types):
845
+ log(
846
+ f"\n [!] repeated completed meta step {action.action_type.value} "
847
+ f"— forcing synthesize_conclusion."
848
+ )
849
+ action = ExperimentAction(
850
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
851
+ justification="repeated completed meta step forced terminal conclusion",
852
+ confidence=action.confidence,
853
+ )
854
+ completed_types = {
855
+ r.action_type for r in obs.pipeline_history if r.success
856
+ }
857
+
858
+ skip_reason = None
859
+ if action.action_type in completed_types:
860
+ skip_reason = (
861
+ f"blocked repeat of completed step {action.action_type.value}"
862
+ )
863
+ elif action.action_type in failed_types:
864
+ if should_block_failed_reattempt(
865
+ obs.pipeline_history, action.action_type
866
+ ):
867
+ skip_reason = (
868
+ f"blocked re-attempt of failed step {action.action_type.value}"
869
+ )
870
+
871
+ if skip_reason:
872
+ if is_last_step:
873
+ log(f"\n [!] {skip_reason} on final step — forcing synthesize_conclusion.")
874
+ action = ExperimentAction(
875
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
876
+ justification="forced terminal conclusion",
877
+ confidence=0.5,
878
+ )
879
+ else:
880
+ log(f"\n [!] {skip_reason}, skipping step.")
881
+ continue
882
+
883
+ if is_last_step and action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
884
+ log(f"\n [!] Final step — overriding {action.action_type.value} with synthesize_conclusion.")
885
+ action = ExperimentAction(
886
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
887
+ justification="forced terminal conclusion",
888
+ confidence=action.confidence,
889
+ )
890
+
891
+ log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
892
+ if thinking:
893
+ log(f" Thinking: {thinking[:200]}")
894
+ if action.justification:
895
+ log(f" Rationale: {action.justification}")
896
+ else:
897
+ log(" Rationale: [model did not provide one]")
898
+ if action.parameters:
899
+ log(f" Parameters: {compact_preview(action.parameters, 200)}")
900
+ elif not action.justification and response:
901
+ log(
902
+ f" Model response: "
903
+ f"{compact_preview(response, MODEL_RESPONSE_PREVIEW_CHARS)}"
904
+ )
905
+
906
+ obs = env.step(action)
907
+
908
+ if obs.latest_output:
909
+ lo = obs.latest_output
910
+ status = "OK" if lo.success else "FAIL"
911
+ log(f" [{status}] {lo.summary}")
912
+ if lo.warnings:
913
+ log(f" Warnings: {lo.warnings}")
914
+
915
+ step_reward = obs.reward
916
+ cumulative_reward += step_reward
917
+ log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
918
+ log(f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d")
919
+
920
+ write_dashboard_state(
921
+ env, obs,
922
+ step=step + 1,
923
+ cumulative_reward=cumulative_reward,
924
+ model_response=response,
925
+ model_thinking=thinking,
926
+ action=action,
927
+ gen_time=gen_time,
928
+ episode_done=obs.done,
929
+ )
930
+
931
+ if obs.rule_violations:
932
+ log(f" Violations: {obs.rule_violations}")
933
+
934
+ if obs.done:
935
+ break
936
+
937
+ log(f"\n{'=' * 70}")
938
+ log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
939
+ log(f" Steps: {obs.step_index}")
940
+ log(f" Total reward: {cumulative_reward:+.3f}")
941
+ log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
942
+ log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
943
+ if obs.conclusions:
944
+ log(" Conclusions:")
945
+ for c in obs.conclusions:
946
+ log(f" [{c.claim_type}, conf={c.confidence:.2f}] {c.claim}")
947
+ if c.top_markers:
948
+ log(f" Markers: {c.top_markers}")
949
+ if c.causal_mechanisms:
950
+ log(f" Mechanisms: {c.causal_mechanisms}")
951
+ if c.predicted_pathways:
952
+ log(f" Pathways: {c.predicted_pathways}")
953
+ log("=" * 70)
954
+
955
+ DASHBOARD_CMD_PATH.unlink(missing_ok=True)
956
+ run_episode()
957
+
958
+ while True:
959
+ log("\nWaiting for dashboard command (restart / new task) ...")
960
+ while True:
961
+ cmd = check_dashboard_command()
962
+ if cmd:
963
+ break
964
+ time.sleep(1.0)
965
+
966
+ action_type = cmd.get("action", "restart")
967
+ if action_type == "quit":
968
+ log("Quit requested.")
969
+ break
970
+
971
+ scenario = cmd.get("scenario_name")
972
+ ground_truth = cmd.get("ground_truth")
973
+ log(f"\n[DASHBOARD] {action_type} — scenario={scenario}")
974
+ run_episode(scenario_name=scenario, custom_ground_truth=ground_truth)
975
+
976
+
977
+ if __name__ == "__main__":
978
+ main()
server/app.py CHANGED
@@ -6,8 +6,12 @@ Endpoints:
6
  - GET /state: Get current environment state
7
  - GET /schema: Get action/observation schemas
8
  - WS /ws: WebSocket endpoint for persistent sessions
 
9
  """
10
 
 
 
 
11
  try:
12
  from openenv.core.env_server.http_server import create_app
13
  except Exception as e: # pragma: no cover
@@ -16,6 +20,7 @@ except Exception as e: # pragma: no cover
16
  "Install dependencies with 'uv sync'"
17
  ) from e
18
 
 
19
  from models import ExperimentAction, ExperimentObservation
20
  from .hackathon_environment import BioExperimentEnvironment
21
 
@@ -24,12 +29,24 @@ app = create_app(
24
  ExperimentAction,
25
  ExperimentObservation,
26
  env_name="bio_experiment",
27
- max_concurrent_envs=1,
28
  )
29
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def main(host: str = "0.0.0.0", port: int = 8000):
32
  import uvicorn
 
 
33
  uvicorn.run(app, host=host, port=port)
34
 
35
 
@@ -37,9 +54,6 @@ if __name__ == "__main__":
37
  import argparse
38
  parser = argparse.ArgumentParser()
39
  parser.add_argument("--host", default="0.0.0.0")
40
- parser.add_argument("--port", type=int, default=8000)
41
  args = parser.parse_args()
42
- if args.host == "0.0.0.0" and args.port == 8000:
43
- main()
44
- else:
45
- main(host=args.host, port=args.port)
 
6
  - GET /state: Get current environment state
7
  - GET /schema: Get action/observation schemas
8
  - WS /ws: WebSocket endpoint for persistent sessions
9
+ - GET / Demo UI
10
  """
11
 
12
+ import os
13
+ from pathlib import Path
14
+
15
  try:
16
  from openenv.core.env_server.http_server import create_app
17
  except Exception as e: # pragma: no cover
 
20
  "Install dependencies with 'uv sync'"
21
  ) from e
22
 
23
+ from fastapi.responses import HTMLResponse
24
  from models import ExperimentAction, ExperimentObservation
25
  from .hackathon_environment import BioExperimentEnvironment
26
 
 
29
  ExperimentAction,
30
  ExperimentObservation,
31
  env_name="bio_experiment",
32
+ max_concurrent_envs=int(os.environ.get("MAX_ENVS", "4")),
33
  )
34
 
35
+ # Serve demo UI at root
36
+ DEMO_HTML = Path(__file__).resolve().parent.parent / "demo.html"
37
+
38
+
39
+ @app.get("/", response_class=HTMLResponse)
40
+ async def demo_ui():
41
+ if DEMO_HTML.exists():
42
+ return HTMLResponse(content=DEMO_HTML.read_text(), status_code=200)
43
+ return HTMLResponse(content="<h1>BioEnv API</h1><p>Visit /docs for API documentation.</p>", status_code=200)
44
+
45
 
46
+ def main(host: str = "0.0.0.0", port: int = None):
47
  import uvicorn
48
+ if port is None:
49
+ port = int(os.environ.get("PORT", "8000"))
50
  uvicorn.run(app, host=host, port=port)
51
 
52
 
 
54
  import argparse
55
  parser = argparse.ArgumentParser()
56
  parser.add_argument("--host", default="0.0.0.0")
57
+ parser.add_argument("--port", type=int, default=None)
58
  args = parser.parse_args()
59
+ main(host=args.host, port=args.port)
 
 
 
server/biology/__init__.py ADDED
File without changes
server/biology/gene_index.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pathway-aware gene similarity index for structured reward scoring.
2
+
3
+ Uses gseapy pathway libraries (KEGG + Reactome) to build binary pathway
4
+ membership vectors per gene, enabling cosine-similarity-based set scoring
5
+ instead of substring matching.
6
+
7
+ Mechanism comparison uses sentence-transformers for semantic similarity.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from functools import lru_cache
14
+ from typing import Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _PATHWAY_SETS: Optional[Dict[str, List[str]]] = None
21
+ _PATHWAY_NAMES: Optional[List[str]] = None
22
+ _GENE_TO_PATHWAY_IDX: Optional[Dict[str, List[int]]] = None
23
+ _N_PATHWAYS: int = 0
24
+
25
+ _SENTENCE_MODEL = None
26
+
27
+
28
+ def _ensure_pathway_index() -> None:
29
+ """Lazily build the inverted gene→pathway index on first use."""
30
+ global _PATHWAY_SETS, _PATHWAY_NAMES, _GENE_TO_PATHWAY_IDX, _N_PATHWAYS
31
+
32
+ if _PATHWAY_NAMES is not None:
33
+ return
34
+
35
+ try:
36
+ import gseapy as gp
37
+ except ImportError:
38
+ logger.warning("gseapy not installed; pathway scoring will use fallback.")
39
+ _PATHWAY_SETS = {}
40
+ _PATHWAY_NAMES = []
41
+ _GENE_TO_PATHWAY_IDX = {}
42
+ _N_PATHWAYS = 0
43
+ return
44
+
45
+ combined: Dict[str, List[str]] = {}
46
+ for lib_name in ("KEGG_2021_Human", "Reactome_2022"):
47
+ try:
48
+ combined.update(gp.get_library(lib_name))
49
+ except Exception as exc:
50
+ logger.warning("Failed to load %s: %s", lib_name, exc)
51
+
52
+ _PATHWAY_SETS = combined
53
+ _PATHWAY_NAMES = sorted(combined.keys())
54
+ _N_PATHWAYS = len(_PATHWAY_NAMES)
55
+
56
+ inv: Dict[str, List[int]] = {}
57
+ for idx, pw_name in enumerate(_PATHWAY_NAMES):
58
+ for gene in combined[pw_name]:
59
+ gene_upper = gene.upper().strip()
60
+ inv.setdefault(gene_upper, []).append(idx)
61
+
62
+ _GENE_TO_PATHWAY_IDX = inv
63
+ logger.info(
64
+ "Pathway index built: %d pathways, %d genes indexed.",
65
+ _N_PATHWAYS, len(inv),
66
+ )
67
+
68
+
69
+ def _ensure_sentence_model():
70
+ """Lazily load the sentence-transformer model."""
71
+ global _SENTENCE_MODEL
72
+ if _SENTENCE_MODEL is not None:
73
+ return
74
+
75
+ try:
76
+ from sentence_transformers import SentenceTransformer
77
+ _SENTENCE_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
78
+ except ImportError:
79
+ logger.warning(
80
+ "sentence-transformers not installed; mechanism scoring will use fallback."
81
+ )
82
+ _SENTENCE_MODEL = None
83
+
84
+
85
+ def gene_vector(gene: str) -> np.ndarray:
86
+ """L2-normalised binary pathway membership vector for *gene*."""
87
+ _ensure_pathway_index()
88
+ vec = np.zeros(_N_PATHWAYS, dtype=np.float32)
89
+ indices = _GENE_TO_PATHWAY_IDX.get(gene.upper().strip(), [])
90
+ if indices:
91
+ vec[indices] = 1.0
92
+ norm = np.linalg.norm(vec)
93
+ if norm > 0:
94
+ vec /= norm
95
+ return vec
96
+
97
+
98
+ def pathway_similarity(g1: str, g2: str) -> float:
99
+ """Cosine similarity between two genes in pathway space."""
100
+ v1 = gene_vector(g1)
101
+ v2 = gene_vector(g2)
102
+ dot = float(np.dot(v1, v2))
103
+ return max(0.0, min(1.0, dot))
104
+
105
+
106
+ def marker_set_score(
107
+ predicted: List[str],
108
+ truth: List[str],
109
+ sigma: float = 0.3,
110
+ ) -> float:
111
+ """Pathway-weighted Gaussian set similarity for marker genes.
112
+
113
+ For each true marker, finds the best-matching predicted gene by
114
+ pathway cosine similarity, then applies a Gaussian kernel:
115
+ score_i = exp(-d^2 / (2 * sigma^2)) where d = 1 - sim
116
+ Returns the mean score over all true markers.
117
+ """
118
+ if not truth:
119
+ return 0.0
120
+ if not predicted:
121
+ return 0.0
122
+
123
+ _ensure_pathway_index()
124
+
125
+ if _N_PATHWAYS == 0:
126
+ return _fallback_marker_score(predicted, truth)
127
+
128
+ pred_vecs = [gene_vector(g) for g in predicted]
129
+ scores: List[float] = []
130
+
131
+ for true_gene in truth:
132
+ tv = gene_vector(true_gene)
133
+ best_sim = 0.0
134
+ for pv in pred_vecs:
135
+ sim = float(np.dot(tv, pv))
136
+ if sim > best_sim:
137
+ best_sim = sim
138
+ d = 1.0 - best_sim
139
+ scores.append(float(np.exp(-(d ** 2) / (2.0 * sigma ** 2))))
140
+
141
+ return sum(scores) / len(scores)
142
+
143
+
144
+ def _fallback_marker_score(predicted: List[str], truth: List[str]) -> float:
145
+ """Exact-match fallback when pathway data is unavailable."""
146
+ pred_set = {g.upper().strip() for g in predicted}
147
+ hits = sum(1 for g in truth if g.upper().strip() in pred_set)
148
+ return hits / len(truth) if truth else 0.0
149
+
150
+
151
+ def mechanism_set_score(predicted: List[str], truth: List[str]) -> float:
152
+ """Sentence-transformer semantic similarity for mechanism strings.
153
+
154
+ For each truth mechanism, finds the best-matching predicted mechanism
155
+ by cosine similarity and returns the mean of best matches.
156
+ """
157
+ if not truth:
158
+ return 0.0
159
+ if not predicted:
160
+ return 0.0
161
+
162
+ _ensure_sentence_model()
163
+
164
+ if _SENTENCE_MODEL is None:
165
+ return _fallback_mechanism_score(predicted, truth)
166
+
167
+ pred_embs = _SENTENCE_MODEL.encode(predicted, convert_to_numpy=True)
168
+ truth_embs = _SENTENCE_MODEL.encode(truth, convert_to_numpy=True)
169
+
170
+ pred_norms = pred_embs / (
171
+ np.linalg.norm(pred_embs, axis=1, keepdims=True) + 1e-9
172
+ )
173
+ truth_norms = truth_embs / (
174
+ np.linalg.norm(truth_embs, axis=1, keepdims=True) + 1e-9
175
+ )
176
+
177
+ sim_matrix = truth_norms @ pred_norms.T
178
+ best_per_truth = sim_matrix.max(axis=1)
179
+ return float(np.mean(np.clip(best_per_truth, 0.0, 1.0)))
180
+
181
+
182
+ def _fallback_mechanism_score(predicted: List[str], truth: List[str]) -> float:
183
+ """Token-overlap fallback when sentence-transformers is unavailable."""
184
+ scores: List[float] = []
185
+ for t in truth:
186
+ t_tokens = set(t.lower().split())
187
+ best = 0.0
188
+ for p in predicted:
189
+ p_tokens = set(p.lower().split())
190
+ union = t_tokens | p_tokens
191
+ if union:
192
+ overlap = len(t_tokens & p_tokens) / len(union)
193
+ best = max(best, overlap)
194
+ scores.append(best)
195
+ return sum(scores) / len(scores) if scores else 0.0
196
+
197
+
198
+ def score_pathways(
199
+ predicted: Dict[str, float],
200
+ truth: Dict[str, float],
201
+ ) -> float:
202
+ """Score predicted pathway activations against ground truth.
203
+
204
+ Uses normalised key matching with activity-level weighting.
205
+ """
206
+ if not truth:
207
+ return 0.0
208
+ if not predicted:
209
+ return 0.0
210
+
211
+ pred_norm = {k.lower().strip(): v for k, v in predicted.items()}
212
+ total_weight = 0.0
213
+ weighted_score = 0.0
214
+
215
+ for pw, true_activity in truth.items():
216
+ pw_key = pw.lower().strip()
217
+ weight = true_activity
218
+ total_weight += weight
219
+ if pw_key in pred_norm:
220
+ pred_activity = pred_norm[pw_key]
221
+ diff = abs(pred_activity - true_activity)
222
+ match_score = max(0.0, 1.0 - diff)
223
+ weighted_score += weight * match_score
224
+
225
+ return weighted_score / total_weight if total_weight > 0 else 0.0
server/hackathon_environment.py CHANGED
@@ -28,7 +28,7 @@ from server.rules.engine import RuleEngine
28
  from server.rewards.reward import RewardBreakdown, RewardComputer
29
  from server.simulator.latent_state import FullLatentState
30
  from server.simulator.noise import NoiseModel
31
- from server.simulator.transition import ACTION_COSTS, TransitionEngine
32
  from server.tasks.generator import TaskGenerator
33
 
34
 
@@ -70,8 +70,8 @@ class BioExperimentEnvironment(Environment):
70
 
71
  # ── Environment interface ───────────────────────────────────────────
72
 
73
- def reset(self) -> ExperimentObservation:
74
- seed = hash(uuid4()) % (2**31)
75
  self._noise.reseed(seed)
76
  self._state = State(episode_id=str(uuid4()), step_count=0)
77
 
@@ -116,7 +116,7 @@ class BioExperimentEnvironment(Environment):
116
  action, prev_state, self._latent, result.output, hard_v, soft_v,
117
  )
118
 
119
- cost_budget, cost_time = ACTION_COSTS.get(action.action_type, (0, 0))
120
  self._history.append(PipelineStepRecord(
121
  step_index=self._state.step_count,
122
  action_type=action.action_type,
@@ -143,7 +143,11 @@ class BioExperimentEnvironment(Environment):
143
  terminal_rb = RewardBreakdown()
144
  if done:
145
  terminal_rb = self._rewards.terminal_reward(
146
- self._latent, self._conclusions, self._task.success_criteria,
 
 
 
 
147
  )
148
 
149
  total_reward = step_rb.total + terminal_rb.total
@@ -158,6 +162,7 @@ class BioExperimentEnvironment(Environment):
158
  latest_output=result.output,
159
  rule_violations=hard_v + soft_v,
160
  reward_breakdown=breakdown,
 
161
  )
162
 
163
  @property
@@ -179,10 +184,18 @@ class BioExperimentEnvironment(Environment):
179
  latest_output: Optional[IntermediateOutput] = None,
180
  rule_violations: Optional[List[str]] = None,
181
  reward_breakdown: Optional[Dict[str, float]] = None,
 
182
  ) -> ExperimentObservation:
183
  assert self._task is not None
184
  assert self._latent is not None
185
  res = self._latent.resources
 
 
 
 
 
 
 
186
  return ExperimentObservation(
187
  task=self._task,
188
  step_index=self._state.step_count,
@@ -205,14 +218,10 @@ class BioExperimentEnvironment(Environment):
205
  subagent_outputs=list(self._subagent_outputs),
206
  conclusions=list(self._conclusions),
207
  rule_violations=rule_violations or [],
208
- step_reward_breakdown=reward_breakdown or {},
209
  done=done,
210
  reward=reward,
211
- metadata={
212
- "episode_id": self._state.episode_id,
213
- "step": self._state.step_count,
214
- "cumulative_reward": self._cumulative_reward,
215
- },
216
  )
217
 
218
  def _compute_uncertainty_summary(self) -> Dict[str, float]:
@@ -228,12 +237,22 @@ class BioExperimentEnvironment(Environment):
228
  ) -> None:
229
  if action.action_type == ActionType.MARKER_SELECTION:
230
  markers = output.data.get("markers", [])
231
- self._discovered_markers.extend(markers)
 
 
 
 
232
  if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
233
  regs = output.data.get("top_regulators", [])
234
- self._candidate_mechanisms.extend(regs)
 
 
 
 
235
  if action.action_type == ActionType.PATHWAY_ENRICHMENT:
236
  pathways = output.data.get("top_pathways", [])
237
- self._candidate_mechanisms.extend(
238
- [p["pathway"] for p in pathways if isinstance(p, dict)]
239
- )
 
 
 
28
  from server.rewards.reward import RewardBreakdown, RewardComputer
29
  from server.simulator.latent_state import FullLatentState
30
  from server.simulator.noise import NoiseModel
31
+ from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost
32
  from server.tasks.generator import TaskGenerator
33
 
34
 
 
70
 
71
  # ── Environment interface ───────────────────────────────────────────
72
 
73
+ def reset(self, seed: Optional[int] = None) -> ExperimentObservation:
74
+ seed = seed if seed is not None else hash(uuid4()) % (2**31)
75
  self._noise.reseed(seed)
76
  self._state = State(episode_id=str(uuid4()), step_count=0)
77
 
 
116
  action, prev_state, self._latent, result.output, hard_v, soft_v,
117
  )
118
 
119
+ cost_budget, cost_time = compute_action_cost(action)
120
  self._history.append(PipelineStepRecord(
121
  step_index=self._state.step_count,
122
  action_type=action.action_type,
 
143
  terminal_rb = RewardBreakdown()
144
  if done:
145
  terminal_rb = self._rewards.terminal_reward(
146
+ self._latent,
147
+ self._conclusions,
148
+ self._task.success_criteria,
149
+ discovered_markers=self._discovered_markers,
150
+ candidate_mechanisms=self._candidate_mechanisms,
151
  )
152
 
153
  total_reward = step_rb.total + terminal_rb.total
 
162
  latest_output=result.output,
163
  rule_violations=hard_v + soft_v,
164
  reward_breakdown=breakdown,
165
+ metadata_extra={"reward_breakdown": breakdown},
166
  )
167
 
168
  @property
 
184
  latest_output: Optional[IntermediateOutput] = None,
185
  rule_violations: Optional[List[str]] = None,
186
  reward_breakdown: Optional[Dict[str, float]] = None,
187
+ metadata_extra: Optional[Dict[str, Any]] = None,
188
  ) -> ExperimentObservation:
189
  assert self._task is not None
190
  assert self._latent is not None
191
  res = self._latent.resources
192
+ meta: Dict[str, Any] = {
193
+ "episode_id": self._state.episode_id,
194
+ "step": self._state.step_count,
195
+ "cumulative_reward": self._cumulative_reward,
196
+ }
197
+ if metadata_extra:
198
+ meta.update(metadata_extra)
199
  return ExperimentObservation(
200
  task=self._task,
201
  step_index=self._state.step_count,
 
218
  subagent_outputs=list(self._subagent_outputs),
219
  conclusions=list(self._conclusions),
220
  rule_violations=rule_violations or [],
221
+ step_reward_breakdown={},
222
  done=done,
223
  reward=reward,
224
+ metadata=meta,
 
 
 
 
225
  )
226
 
227
  def _compute_uncertainty_summary(self) -> Dict[str, float]:
 
237
  ) -> None:
238
  if action.action_type == ActionType.MARKER_SELECTION:
239
  markers = output.data.get("markers", [])
240
+ existing = set(self._discovered_markers)
241
+ for m in markers:
242
+ if m not in existing:
243
+ self._discovered_markers.append(m)
244
+ existing.add(m)
245
  if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
246
  regs = output.data.get("top_regulators", [])
247
+ existing = set(self._candidate_mechanisms)
248
+ for r in regs:
249
+ if r not in existing:
250
+ self._candidate_mechanisms.append(r)
251
+ existing.add(r)
252
  if action.action_type == ActionType.PATHWAY_ENRICHMENT:
253
  pathways = output.data.get("top_pathways", [])
254
+ existing = set(self._candidate_mechanisms)
255
+ for p in pathways:
256
+ if isinstance(p, dict) and p["pathway"] not in existing:
257
+ self._candidate_mechanisms.append(p["pathway"])
258
+ existing.add(p["pathway"])
server/requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- openenv[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
-
5
-
6
-
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
+ numpy>=1.24.0
5
+ scipy>=1.10.0
6
+ pydantic>=2.0.0
7
+ gseapy>=1.1.0
8
+ sentence-transformers>=3.0.0
9
+ scikit-learn>=1.4.0
server/rewards/reward.py CHANGED
@@ -15,7 +15,7 @@ Potential-based shaping
15
 
16
  The final step reward is:
17
  R_t = r_validity + r_ordering + r_info_gain + r_efficiency
18
- + r_novelty + r_penalty + γ[φ(s_{t+1}) − φ(s_t)]
19
 
20
  The terminal reward adds:
21
  R_T += r_terminal
@@ -32,9 +32,15 @@ from models import (
32
  ExperimentAction,
33
  IntermediateOutput,
34
  META_ACTIONS,
 
35
  WET_LAB_ACTIONS,
36
  )
37
 
 
 
 
 
 
38
  from server.simulator.latent_state import FullLatentState
39
 
40
 
@@ -84,20 +90,16 @@ class RewardComputer:
84
 
85
  Parameters
86
  ----------
87
- gamma : float
88
- Discount factor for potential-based shaping (default 0.99).
89
  efficiency_weight : float
90
  Relative importance of resource efficiency.
91
  """
92
 
93
  def __init__(
94
  self,
95
- gamma: float = 0.99,
96
  efficiency_weight: float = 0.3,
97
  info_gain_weight: float = 0.4,
98
  validity_weight: float = 0.3,
99
  ):
100
- self.gamma = gamma
101
  self.w_eff = efficiency_weight
102
  self.w_ig = info_gain_weight
103
  self.w_val = validity_weight
@@ -124,11 +126,19 @@ class RewardComputer:
124
 
125
  rb.validity = self.w_val * (1.0 if output.success else 0.0)
126
 
127
- # ordering bonus: +0.2 if the step was a natural next step
128
- rb.ordering = 0.2 * self._ordering_score(action, prev_state)
 
 
129
 
130
  # information gain proxy: quality × (1 - uncertainty)
131
  rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
 
 
 
 
 
 
132
 
133
  # efficiency: normalised cost relative to budget
134
  budget_frac = (
@@ -141,13 +151,25 @@ class RewardComputer:
141
  if not soft_violations:
142
  rb.novelty = 0.1
143
 
 
 
 
 
 
144
  # penalties
145
  rb.penalty = -0.15 * len(soft_violations)
146
-
147
- # potential-based shaping
 
 
 
 
 
 
 
148
  phi_prev = self._potential(prev_state)
149
  phi_next = self._potential(next_state)
150
- rb.shaping = self.gamma * phi_next - phi_prev
151
 
152
  return rb
153
 
@@ -158,8 +180,12 @@ class RewardComputer:
158
  state: FullLatentState,
159
  conclusions: List[ConclusionClaim],
160
  task_success_criteria: List[str],
 
 
161
  ) -> RewardBreakdown:
162
  rb = RewardBreakdown()
 
 
163
 
164
  # pipeline completeness (0-1)
165
  completeness = self._completeness(state)
@@ -183,11 +209,22 @@ class RewardComputer:
183
  overconf = self._overconfidence_penalty(state, conclusions)
184
  rb.components["overconfidence_penalty"] = overconf
185
 
 
 
 
 
 
 
 
 
 
 
186
  rb.terminal = (
187
  3.0 * completeness
188
  + 4.0 * calibration
189
- + 1.0 * (budget_eff + time_eff) / 2.0
190
  + overconf
 
191
  )
192
  return rb
193
 
@@ -196,7 +233,7 @@ class RewardComputer:
196
  def _ordering_score(
197
  self, action: ExperimentAction, s: FullLatentState
198
  ) -> float:
199
- """Heuristic: 1.0 if this step naturally follows the current progress."""
200
  at = action.action_type
201
  p = s.progress
202
  NATURAL_NEXT = {
@@ -215,10 +252,26 @@ class RewardComputer:
215
  p.de_performed or p.cells_clustered
216
  ) and not p.conclusion_reached,
217
  }
218
- return 1.0 if NATURAL_NEXT.get(at, False) else 0.3
 
 
 
 
 
 
 
 
 
 
219
 
220
  def _potential(self, s: FullLatentState) -> float:
221
- """Progress potential φ(s) — counts completed milestones."""
 
 
 
 
 
 
222
  p = s.progress
223
  milestones = [
224
  p.samples_collected,
@@ -252,9 +305,38 @@ class RewardComputer:
252
  def _calibration(
253
  self, s: FullLatentState, conclusions: List[ConclusionClaim]
254
  ) -> float:
 
 
 
 
 
 
255
  if not conclusions:
256
  return 0.0
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  true_mechanisms = set(s.biology.causal_mechanisms)
259
  true_markers = set(s.biology.true_markers)
260
  score = 0.0
@@ -270,16 +352,121 @@ class RewardComputer:
270
  score -= 0.3
271
  return max(0.0, min(1.0, score / max(n, 1)))
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  def _overconfidence_penalty(
274
  self, s: FullLatentState, conclusions: List[ConclusionClaim]
275
  ) -> float:
276
- """Penalise high-confidence claims that disagree with ground truth."""
 
 
 
 
277
  penalty = 0.0
278
- true_set = set(
279
- m.lower() for m in s.biology.causal_mechanisms + s.biology.true_markers
280
- )
 
281
  for c in conclusions:
282
- is_correct = any(t in c.claim.lower() for t in true_set)
283
- if c.confidence > 0.8 and not is_correct:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  penalty -= 0.5 * c.confidence
 
285
  return penalty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  The final step reward is:
17
  R_t = r_validity + r_ordering + r_info_gain + r_efficiency
18
+ + r_novelty + r_penalty + [φ(s_{t+1}) − φ(s_t)]
19
 
20
  The terminal reward adds:
21
  R_T += r_terminal
 
32
  ExperimentAction,
33
  IntermediateOutput,
34
  META_ACTIONS,
35
+ TOOL_REGISTRY,
36
  WET_LAB_ACTIONS,
37
  )
38
 
39
+ from server.biology.gene_index import (
40
+ marker_set_score,
41
+ mechanism_set_score,
42
+ score_pathways,
43
+ )
44
  from server.simulator.latent_state import FullLatentState
45
 
46
 
 
90
 
91
  Parameters
92
  ----------
 
 
93
  efficiency_weight : float
94
  Relative importance of resource efficiency.
95
  """
96
 
97
  def __init__(
98
  self,
 
99
  efficiency_weight: float = 0.3,
100
  info_gain_weight: float = 0.4,
101
  validity_weight: float = 0.3,
102
  ):
 
103
  self.w_eff = efficiency_weight
104
  self.w_ig = info_gain_weight
105
  self.w_val = validity_weight
 
126
 
127
  rb.validity = self.w_val * (1.0 if output.success else 0.0)
128
 
129
+ ordering_score = self._ordering_score(action, prev_state)
130
+ rb.ordering = 0.2 * ordering_score
131
+ if ordering_score < 0:
132
+ rb.penalty += ordering_score * 0.3
133
 
134
  # information gain proxy: quality × (1 - uncertainty)
135
  rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
136
+ if action.action_type in META_ACTIONS and not (
137
+ prev_state.progress.de_performed
138
+ or prev_state.progress.cells_clustered
139
+ ):
140
+ # Meta actions before substantive analysis should not dominate reward.
141
+ rb.info_gain *= 0.2
142
 
143
  # efficiency: normalised cost relative to budget
144
  budget_frac = (
 
151
  if not soft_violations:
152
  rb.novelty = 0.1
153
 
154
+ # tool-modality fit bonus/penalty
155
+ tool_fit = self._tool_fit_score(action, prev_state)
156
+ rb.components["tool_fit"] = tool_fit
157
+ rb.validity += 0.15 * tool_fit
158
+
159
  # penalties
160
  rb.penalty = -0.15 * len(soft_violations)
161
+ if action.action_type in META_ACTIONS and not (
162
+ prev_state.progress.de_performed
163
+ or prev_state.progress.cells_clustered
164
+ ):
165
+ rb.penalty -= 0.25
166
+ rb.components["premature_meta_action_penalty"] = -0.25
167
+
168
+ # potential-based shaping (γ=1 so it doesn't depend on the
169
+ # training algorithm's discount factor)
170
  phi_prev = self._potential(prev_state)
171
  phi_next = self._potential(next_state)
172
+ rb.shaping = phi_next - phi_prev
173
 
174
  return rb
175
 
 
180
  state: FullLatentState,
181
  conclusions: List[ConclusionClaim],
182
  task_success_criteria: List[str],
183
+ discovered_markers: Optional[List[str]] = None,
184
+ candidate_mechanisms: Optional[List[str]] = None,
185
  ) -> RewardBreakdown:
186
  rb = RewardBreakdown()
187
+ discovered_markers = discovered_markers or []
188
+ candidate_mechanisms = candidate_mechanisms or []
189
 
190
  # pipeline completeness (0-1)
191
  completeness = self._completeness(state)
 
209
  overconf = self._overconfidence_penalty(state, conclusions)
210
  rb.components["overconfidence_penalty"] = overconf
211
 
212
+ discovery_alignment = self._discovery_alignment(
213
+ state,
214
+ discovered_markers,
215
+ candidate_mechanisms,
216
+ )
217
+ discovery_error_penalty = -2.5 * (1.0 - discovery_alignment)
218
+ rb.components["discovery_alignment"] = discovery_alignment
219
+ rb.components["discovery_error_penalty"] = discovery_error_penalty
220
+
221
+ eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
222
  rb.terminal = (
223
  3.0 * completeness
224
  + 4.0 * calibration
225
+ + 1.0 * eff_bonus
226
  + overconf
227
+ + discovery_error_penalty
228
  )
229
  return rb
230
 
 
233
  def _ordering_score(
234
  self, action: ExperimentAction, s: FullLatentState
235
  ) -> float:
236
+ """Heuristic: 1.0 if natural next, 0.3 if acceptable, -1.0 if premature."""
237
  at = action.action_type
238
  p = s.progress
239
  NATURAL_NEXT = {
 
252
  p.de_performed or p.cells_clustered
253
  ) and not p.conclusion_reached,
254
  }
255
+ if NATURAL_NEXT.get(at, False):
256
+ return 1.0
257
+
258
+ has_evidence = any([
259
+ p.cells_clustered, p.de_performed, p.trajectories_inferred,
260
+ p.pathways_analyzed, p.networks_inferred, p.markers_discovered,
261
+ ])
262
+ if at in META_ACTIONS and not has_evidence:
263
+ return -1.0
264
+
265
+ return 0.3
266
 
267
  def _potential(self, s: FullLatentState) -> float:
268
+ """Progress potential φ(s) — counts completed milestones.
269
+
270
+ Returns 0.0 at terminal states so that the shaping signal
271
+ telescopes correctly over the episode.
272
+ """
273
+ if s.progress.conclusion_reached:
274
+ return 0.0
275
  p = s.progress
276
  milestones = [
277
  p.samples_collected,
 
305
  def _calibration(
306
  self, s: FullLatentState, conclusions: List[ConclusionClaim]
307
  ) -> float:
308
+ """Structured set-similarity calibration against hidden ground truth.
309
+
310
+ Uses pathway-weighted Gaussian similarity for markers, semantic
311
+ similarity for mechanisms, and activity-weighted matching for pathways.
312
+ Falls back to legacy substring matching when structured fields are empty.
313
+ """
314
  if not conclusions:
315
  return 0.0
316
 
317
+ pred_markers = [g for c in conclusions for g in c.top_markers]
318
+ pred_mechs = [m for c in conclusions for m in c.causal_mechanisms]
319
+ pred_pathways = {
320
+ p: v for c in conclusions for p, v in c.predicted_pathways.items()
321
+ }
322
+
323
+ has_structured = bool(pred_markers or pred_mechs or pred_pathways)
324
+
325
+ if has_structured:
326
+ m_score = marker_set_score(pred_markers, s.biology.true_markers)
327
+ mech_score = mechanism_set_score(
328
+ pred_mechs, s.biology.causal_mechanisms
329
+ )
330
+ pw_score = score_pathways(pred_pathways, s.biology.true_pathways)
331
+ return 0.50 * m_score + 0.35 * mech_score + 0.15 * pw_score
332
+
333
+ return self._legacy_calibration(s, conclusions)
334
+
335
+ @staticmethod
336
+ def _legacy_calibration(
337
+ s: FullLatentState, conclusions: List[ConclusionClaim]
338
+ ) -> float:
339
+ """Substring-based calibration kept for backward compatibility."""
340
  true_mechanisms = set(s.biology.causal_mechanisms)
341
  true_markers = set(s.biology.true_markers)
342
  score = 0.0
 
352
  score -= 0.3
353
  return max(0.0, min(1.0, score / max(n, 1)))
354
 
355
+ _METHOD_TO_TOOL: Dict[str, str] = {
356
+ "scanpy.pp.calculate_qc_metrics": "Scanpy",
357
+ "scanpy.pp.filter_cells": "Scanpy",
358
+ "scanpy.pp.filter_genes": "Scanpy",
359
+ "scanpy.pp.normalize_total": "Scanpy",
360
+ "scanpy.pp.log1p": "Scanpy",
361
+ "scanpy.pp.highly_variable_genes": "Scanpy",
362
+ "scanpy.pp.neighbors": "Scanpy",
363
+ "scanpy.tl.leiden": "Leiden",
364
+ "scanpy.tl.louvain": "Louvain",
365
+ "scanpy.tl.rank_genes_groups": "Scanpy",
366
+ "scanpy.tl.paga": "PAGA",
367
+ "scanpy.tl.umap": "UMAP",
368
+ "gseapy.prerank": "Scanpy",
369
+ "gseapy.gsea": "Scanpy",
370
+ "10x_chromium": "CellRanger",
371
+ "NovaSeq": "CellRanger",
372
+ }
373
+
374
+ @staticmethod
375
+ def _tool_fit_score(
376
+ action: ExperimentAction, s: FullLatentState
377
+ ) -> float:
378
+ """Score how well the chosen tool matches the task modality.
379
+
380
+ Returns +1.0 for a perfect match, 0.0 if no tool specified,
381
+ -1.0 for a known tool used on an incompatible modality.
382
+ """
383
+ method = action.method
384
+ if not method:
385
+ return 0.0
386
+ resolved = RewardComputer._METHOD_TO_TOOL.get(method, method)
387
+ tool_spec = TOOL_REGISTRY.get(resolved)
388
+ if tool_spec is None:
389
+ return -0.5
390
+ modality = getattr(s, "task_modality", None)
391
+ if not modality or not tool_spec.modalities:
392
+ return 0.0
393
+ if modality in tool_spec.modalities:
394
+ return 1.0
395
+ return -1.0
396
+
397
  def _overconfidence_penalty(
398
  self, s: FullLatentState, conclusions: List[ConclusionClaim]
399
  ) -> float:
400
+ """Penalise high-confidence claims that disagree with ground truth.
401
+
402
+ Checks structured fields (top_markers, causal_mechanisms) first;
403
+ falls back to claim substring matching for backward compatibility.
404
+ """
405
  penalty = 0.0
406
+ true_markers_lower = {m.lower() for m in s.biology.true_markers}
407
+ true_mechs_lower = {m.lower() for m in s.biology.causal_mechanisms}
408
+ true_set = true_markers_lower | true_mechs_lower
409
+
410
  for c in conclusions:
411
+ if c.confidence <= 0.8:
412
+ continue
413
+
414
+ has_structured = bool(c.top_markers or c.causal_mechanisms)
415
+ if has_structured:
416
+ marker_hit = any(
417
+ g.upper().strip() in {m.upper() for m in s.biology.true_markers}
418
+ for g in c.top_markers
419
+ )
420
+ mech_hit = any(
421
+ any(kw in m.lower() for kw in t.lower().split())
422
+ for m in c.causal_mechanisms
423
+ for t in s.biology.causal_mechanisms
424
+ )
425
+ is_correct = marker_hit or mech_hit
426
+ else:
427
+ is_correct = any(t in c.claim.lower() for t in true_set)
428
+
429
+ if not is_correct:
430
  penalty -= 0.5 * c.confidence
431
+
432
  return penalty
433
+
434
+ def _discovery_alignment(
435
+ self,
436
+ s: FullLatentState,
437
+ discovered_markers: List[str],
438
+ candidate_mechanisms: List[str],
439
+ ) -> float:
440
+ """Symmetric end-of-episode similarity for discovered biology.
441
+
442
+ Forward scoring measures recall against hidden truth. Reverse scoring
443
+ measures how well the agent's discoveries map back onto real biology,
444
+ which penalizes extra hallucinated markers or mechanisms.
445
+ """
446
+ components: List[float] = []
447
+
448
+ if s.biology.true_markers or discovered_markers:
449
+ marker_recall = marker_set_score(
450
+ discovered_markers,
451
+ s.biology.true_markers,
452
+ )
453
+ marker_precision = marker_set_score(
454
+ s.biology.true_markers,
455
+ discovered_markers,
456
+ )
457
+ components.append((marker_recall + marker_precision) / 2.0)
458
+
459
+ if s.biology.causal_mechanisms or candidate_mechanisms:
460
+ mechanism_recall = mechanism_set_score(
461
+ candidate_mechanisms,
462
+ s.biology.causal_mechanisms,
463
+ )
464
+ mechanism_precision = mechanism_set_score(
465
+ s.biology.causal_mechanisms,
466
+ candidate_mechanisms,
467
+ )
468
+ components.append((mechanism_recall + mechanism_precision) / 2.0)
469
+
470
+ if not components:
471
+ return 1.0
472
+ return sum(components) / len(components)
server/rules/engine.py CHANGED
@@ -10,7 +10,7 @@ from dataclasses import dataclass
10
  from enum import Enum
11
  from typing import List
12
 
13
- from models import ActionType, ExperimentAction
14
 
15
  from server.simulator.latent_state import FullLatentState
16
 
@@ -32,6 +32,19 @@ class RuleEngine:
32
  latent state before each action is applied.
33
  """
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def check(
36
  self, action: ExperimentAction, state: FullLatentState
37
  ) -> List[RuleViolation]:
@@ -40,6 +53,7 @@ class RuleEngine:
40
  violations.extend(self._check_resource_constraints(action, state))
41
  violations.extend(self._check_redundancy(action, state))
42
  violations.extend(self._check_causal_validity(action, state))
 
43
  return violations
44
 
45
  def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
@@ -106,6 +120,9 @@ class RuleEngine:
106
  ActionType.CULTURE_CELLS: [
107
  ("samples_collected", "Cannot culture without samples"),
108
  ],
 
 
 
109
  }
110
 
111
  for flag, msg in REQUIRES.get(at, []):
@@ -127,22 +144,22 @@ class RuleEngine:
127
  vs.append(RuleViolation(
128
  rule_id="budget_exhausted",
129
  severity=Severity.HARD,
130
- message="Budget exhausted no further actions possible",
131
  ))
132
  if s.resources.time_exhausted:
133
  vs.append(RuleViolation(
134
  rule_id="time_exhausted",
135
  severity=Severity.HARD,
136
- message="Time limit reached no further actions possible",
137
  ))
138
 
139
  remaining = s.resources.budget_remaining
140
- from server.simulator.transition import ACTION_COSTS
141
- cost, _ = ACTION_COSTS.get(action.action_type, (0, 0))
142
  if cost > remaining and remaining > 0:
143
  vs.append(RuleViolation(
144
  rule_id="budget_insufficient",
145
- severity=Severity.SOFT,
146
  message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
147
  ))
148
  return vs
@@ -163,13 +180,23 @@ class RuleEngine:
163
  ActionType.RUN_QC: "qc_performed",
164
  ActionType.FILTER_DATA: "data_filtered",
165
  ActionType.NORMALIZE_DATA: "data_normalized",
 
 
 
 
 
 
 
 
 
 
166
  }
167
  flag = REDUNDANT.get(at)
168
  if flag and getattr(p, flag, False):
169
  vs.append(RuleViolation(
170
  rule_id=f"redundant_{at.value}",
171
- severity=Severity.SOFT,
172
- message=f"Step '{at.value}' already completed — redundant action",
173
  ))
174
  return vs
175
 
@@ -179,12 +206,36 @@ class RuleEngine:
179
  self, action: ExperimentAction, s: FullLatentState
180
  ) -> List[RuleViolation]:
181
  vs: List[RuleViolation] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
183
  if not s.progress.de_performed and not s.progress.cells_clustered:
184
  vs.append(RuleViolation(
185
  rule_id="premature_conclusion",
186
- severity=Severity.SOFT,
187
- message="Synthesising conclusion without substantive analysis",
188
  ))
189
 
190
  claims = action.parameters.get("claims", [])
@@ -206,3 +257,72 @@ class RuleEngine:
206
  message="Pathway enrichment without DE may yield unreliable results",
207
  ))
208
  return vs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from enum import Enum
11
  from typing import List
12
 
13
+ from models import ActionType, ExperimentAction, TOOL_REGISTRY
14
 
15
  from server.simulator.latent_state import FullLatentState
16
 
 
32
  latent state before each action is applied.
33
  """
34
 
35
+ @staticmethod
36
+ def _has_analysis_evidence(s: FullLatentState) -> bool:
37
+ p = s.progress
38
+ return any([
39
+ p.cells_clustered,
40
+ p.de_performed,
41
+ p.trajectories_inferred,
42
+ p.pathways_analyzed,
43
+ p.networks_inferred,
44
+ p.markers_discovered,
45
+ p.markers_validated,
46
+ ])
47
+
48
  def check(
49
  self, action: ExperimentAction, state: FullLatentState
50
  ) -> List[RuleViolation]:
 
53
  violations.extend(self._check_resource_constraints(action, state))
54
  violations.extend(self._check_redundancy(action, state))
55
  violations.extend(self._check_causal_validity(action, state))
56
+ violations.extend(self._check_tool_compatibility(action, state))
57
  return violations
58
 
59
  def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
 
120
  ActionType.CULTURE_CELLS: [
121
  ("samples_collected", "Cannot culture without samples"),
122
  ],
123
+ ActionType.SYNTHESIZE_CONCLUSION: [
124
+ ("data_normalized", "Cannot synthesize conclusions before data normalization"),
125
+ ],
126
  }
127
 
128
  for flag, msg in REQUIRES.get(at, []):
 
144
  vs.append(RuleViolation(
145
  rule_id="budget_exhausted",
146
  severity=Severity.HARD,
147
+ message="Budget exhausted - no further actions possible",
148
  ))
149
  if s.resources.time_exhausted:
150
  vs.append(RuleViolation(
151
  rule_id="time_exhausted",
152
  severity=Severity.HARD,
153
+ message="Time limit reached - no further actions possible",
154
  ))
155
 
156
  remaining = s.resources.budget_remaining
157
+ from server.simulator.transition import compute_action_cost
158
+ cost, _ = compute_action_cost(action)
159
  if cost > remaining and remaining > 0:
160
  vs.append(RuleViolation(
161
  rule_id="budget_insufficient",
162
+ severity=Severity.HARD,
163
  message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
164
  ))
165
  return vs
 
180
  ActionType.RUN_QC: "qc_performed",
181
  ActionType.FILTER_DATA: "data_filtered",
182
  ActionType.NORMALIZE_DATA: "data_normalized",
183
+ ActionType.CLUSTER_CELLS: "cells_clustered",
184
+ ActionType.DIFFERENTIAL_EXPRESSION: "de_performed",
185
+ ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred",
186
+ ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed",
187
+ ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
188
+ ActionType.MARKER_SELECTION: "markers_discovered",
189
+ ActionType.VALIDATE_MARKER: "markers_validated",
190
+ ActionType.DESIGN_FOLLOWUP: "followup_designed",
191
+ ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested",
192
+ ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
193
  }
194
  flag = REDUNDANT.get(at)
195
  if flag and getattr(p, flag, False):
196
  vs.append(RuleViolation(
197
  rule_id=f"redundant_{at.value}",
198
+ severity=Severity.HARD,
199
+ message=f"Step '{at.value}' already completed — redundant action blocked",
200
  ))
201
  return vs
202
 
 
206
  self, action: ExperimentAction, s: FullLatentState
207
  ) -> List[RuleViolation]:
208
  vs: List[RuleViolation] = []
209
+ has_analysis_evidence = self._has_analysis_evidence(s)
210
+
211
+ if action.action_type == ActionType.DESIGN_FOLLOWUP:
212
+ if not has_analysis_evidence:
213
+ vs.append(RuleViolation(
214
+ rule_id="premature_followup_design",
215
+ severity=Severity.HARD,
216
+ message=(
217
+ "Follow-up design without prior analysis is blocked; "
218
+ "complete wet-lab and computational steps first"
219
+ ),
220
+ ))
221
+
222
+ if action.action_type == ActionType.REQUEST_SUBAGENT_REVIEW:
223
+ if not has_analysis_evidence:
224
+ vs.append(RuleViolation(
225
+ rule_id="premature_subagent_review",
226
+ severity=Severity.HARD,
227
+ message=(
228
+ "Subagent review without prior analysis is blocked; "
229
+ "generate evidence first"
230
+ ),
231
+ ))
232
+
233
  if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
234
  if not s.progress.de_performed and not s.progress.cells_clustered:
235
  vs.append(RuleViolation(
236
  rule_id="premature_conclusion",
237
+ severity=Severity.HARD,
238
+ message="Cannot synthesise conclusion without substantive analysis",
239
  ))
240
 
241
  claims = action.parameters.get("claims", [])
 
257
  message="Pathway enrichment without DE may yield unreliable results",
258
  ))
259
  return vs
260
+
261
+ # ── tool / modality compatibility ────────────────────────────────────
262
+
263
+ _KNOWN_METHODS = {
264
+ "scanpy.pp.calculate_qc_metrics", "scanpy.pp.filter_cells",
265
+ "scanpy.pp.filter_genes", "scanpy.pp.normalize_total",
266
+ "scanpy.pp.log1p", "scanpy.pp.highly_variable_genes",
267
+ "scanpy.pp.neighbors", "scanpy.tl.leiden", "scanpy.tl.louvain",
268
+ "scanpy.tl.rank_genes_groups", "scanpy.tl.paga", "scanpy.tl.umap",
269
+ "gseapy.prerank", "gseapy.gsea", "10x_chromium", "NovaSeq",
270
+ }
271
+ _METHOD_TO_TOOL = {
272
+ "scanpy.pp.calculate_qc_metrics": "Scanpy",
273
+ "scanpy.pp.filter_cells": "Scanpy",
274
+ "scanpy.pp.filter_genes": "Scanpy",
275
+ "scanpy.pp.normalize_total": "Scanpy",
276
+ "scanpy.pp.log1p": "Scanpy",
277
+ "scanpy.pp.highly_variable_genes": "Scanpy",
278
+ "scanpy.pp.neighbors": "Scanpy",
279
+ "scanpy.tl.leiden": "Leiden",
280
+ "scanpy.tl.louvain": "Louvain",
281
+ "scanpy.tl.rank_genes_groups": "Scanpy",
282
+ "scanpy.tl.paga": "PAGA",
283
+ "scanpy.tl.umap": "UMAP",
284
+ "gseapy.prerank": "Scanpy",
285
+ "gseapy.gsea": "Scanpy",
286
+ "10x_chromium": "CellRanger",
287
+ "NovaSeq": "CellRanger",
288
+ }
289
+
290
+ def _check_tool_compatibility(
291
+ self, action: ExperimentAction, s: FullLatentState
292
+ ) -> List[RuleViolation]:
293
+ """Warn when the chosen tool is incompatible with the task modality."""
294
+ vs: List[RuleViolation] = []
295
+ method = action.method
296
+ if not method:
297
+ return vs
298
+
299
+ resolved = self._METHOD_TO_TOOL.get(method, method)
300
+ tool_spec = TOOL_REGISTRY.get(resolved)
301
+ if tool_spec is None and method not in self._KNOWN_METHODS:
302
+ vs.append(RuleViolation(
303
+ rule_id="unknown_tool",
304
+ severity=Severity.SOFT,
305
+ message=f"Tool '{method}' is not in the registry — results may be unreliable",
306
+ ))
307
+ return vs
308
+ if tool_spec is None:
309
+ return vs
310
+
311
+ # Check modality compatibility (modality lives on the task, which is
312
+ # stored in the latent state's associated TaskSpec — but the latent
313
+ # state doesn't carry the TaskSpec directly. We can still check via
314
+ # the action's own context or fall back gracefully).
315
+ task_modality = getattr(s, "task_modality", None)
316
+ if task_modality and tool_spec.modalities:
317
+ if task_modality not in tool_spec.modalities:
318
+ vs.append(RuleViolation(
319
+ rule_id="tool_modality_mismatch",
320
+ severity=Severity.SOFT,
321
+ message=(
322
+ f"Tool '{method}' is designed for "
323
+ f"{', '.join(tool_spec.modalities)} but task modality "
324
+ f"is '{task_modality}'"
325
+ ),
326
+ ))
327
+
328
+ return vs
server/simulator/latent_state.py CHANGED
@@ -88,8 +88,11 @@ class ExperimentProgress(BaseModel):
88
  networks_inferred: bool = False
89
  markers_discovered: bool = False
90
  markers_validated: bool = False
 
 
91
  conclusion_reached: bool = False
92
 
 
93
  n_cells_after_filter: Optional[int] = None
94
  n_clusters_found: Optional[int] = None
95
  n_de_genes_found: Optional[int] = None
@@ -139,5 +142,12 @@ class FullLatentState(BaseModel):
139
  mechanism_confidence: Dict[str, float] = Field(default_factory=dict)
140
  discovered_de_genes: List[str] = Field(default_factory=list)
141
  discovered_clusters: List[str] = Field(default_factory=list)
 
142
  step_count: int = 0
143
  rng_seed: int = 42
 
 
 
 
 
 
 
88
  networks_inferred: bool = False
89
  markers_discovered: bool = False
90
  markers_validated: bool = False
91
+ followup_designed: bool = False
92
+ subagent_review_requested: bool = False
93
  conclusion_reached: bool = False
94
 
95
+ n_cells_sequenced: Optional[int] = None
96
  n_cells_after_filter: Optional[int] = None
97
  n_clusters_found: Optional[int] = None
98
  n_de_genes_found: Optional[int] = None
 
142
  mechanism_confidence: Dict[str, float] = Field(default_factory=dict)
143
  discovered_de_genes: List[str] = Field(default_factory=list)
144
  discovered_clusters: List[str] = Field(default_factory=list)
145
+ task_modality: str = "scRNA-seq"
146
  step_count: int = 0
147
  rng_seed: int = 42
148
+
149
+ # Transient fields for passing sampled values from the transition engine
150
+ # to the output generator within a single step (not serialized).
151
+ last_retain_frac: Optional[float] = Field(None, exclude=True)
152
+ last_n_clusters: Optional[int] = Field(None, exclude=True)
153
+ last_perturbation_efficiency: Optional[float] = Field(None, exclude=True)
server/simulator/noise.py CHANGED
@@ -30,7 +30,11 @@ class NoiseModel:
30
  ) -> Dict[str, float]:
31
  noisy: Dict[str, float] = {}
32
  for gene, value in true_values.items():
33
- if self.rng.random() < dropout_rate:
 
 
 
 
34
  noisy[gene] = 0.0
35
  else:
36
  sigma = noise_level * abs(value) + 0.1
 
30
  ) -> Dict[str, float]:
31
  noisy: Dict[str, float] = {}
32
  for gene, value in true_values.items():
33
+ # Dropout probability is inversely proportional to expression
34
+ # magnitude: lowly expressed genes drop out much more readily,
35
+ # matching the zero-inflation pattern in real scRNA-seq data.
36
+ p_drop = dropout_rate / (1.0 + abs(value))
37
+ if self.rng.random() < p_drop:
38
  noisy[gene] = 0.0
39
  else:
40
  sigma = noise_level * abs(value) + 0.1
server/simulator/output_generator.py CHANGED
@@ -14,6 +14,15 @@ from models import (
14
  from .latent_state import FullLatentState
15
  from .noise import NoiseModel
16
 
 
 
 
 
 
 
 
 
 
17
 
18
  class OutputGenerator:
19
  """Creates structured ``IntermediateOutput`` objects conditioned on the
@@ -91,7 +100,13 @@ class OutputGenerator:
91
  self, action: ExperimentAction, s: FullLatentState, idx: int
92
  ) -> IntermediateOutput:
93
  days = action.parameters.get("days", 7)
94
- viability = self.noise.sample_qc_metric(0.92, 0.05, 0.5, 1.0)
 
 
 
 
 
 
95
  return IntermediateOutput(
96
  output_type=OutputType.CULTURE_RESULT,
97
  step_index=idx,
@@ -101,20 +116,54 @@ class OutputGenerator:
101
  artifacts_available=["cultured_cells"],
102
  )
103
 
104
- def _perturb(
105
  self, action: ExperimentAction, s: FullLatentState, idx: int
106
  ) -> IntermediateOutput:
 
 
107
  target = action.parameters.get("target", "unknown")
108
- efficiency = self.noise.sample_qc_metric(0.75, 0.15, 0.0, 1.0)
 
109
  return IntermediateOutput(
110
  output_type=OutputType.PERTURBATION_RESULT,
111
  step_index=idx,
112
  quality_score=efficiency,
113
- summary=f"Perturbation of {target} (efficiency={efficiency:.2f})",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  data={
115
  "target": target,
116
  "efficiency": efficiency,
117
  "type": action.action_type.value,
 
 
118
  },
119
  artifacts_available=["perturbed_cells"],
120
  )
@@ -122,11 +171,18 @@ class OutputGenerator:
122
  def _sequence_cells(
123
  self, action: ExperimentAction, s: FullLatentState, idx: int
124
  ) -> IntermediateOutput:
 
125
  depth = s.technical.sequencing_depth_factor
126
- n_cells = self.noise.sample_count(
127
  s.biology.n_true_cells * s.technical.capture_efficiency
128
  )
129
- n_genes = self.noise.sample_count(18_000)
 
 
 
 
 
 
130
  median_umi = self.noise.sample_count(int(3000 * depth))
131
  quality = self.noise.quality_degradation(
132
  s.technical.sample_quality,
@@ -157,7 +213,18 @@ class OutputGenerator:
157
  doublet_frac = self.noise.sample_qc_metric(
158
  s.technical.doublet_rate, 0.01, 0.0, 0.2
159
  )
160
- mito_frac = self.noise.sample_qc_metric(0.05, 0.02, 0.0, 0.3)
 
 
 
 
 
 
 
 
 
 
 
161
  ambient_frac = self.noise.sample_qc_metric(
162
  s.technical.ambient_rna_fraction, 0.01, 0.0, 0.2
163
  )
@@ -186,9 +253,9 @@ class OutputGenerator:
186
  def _filter_data(
187
  self, action: ExperimentAction, s: FullLatentState, idx: int
188
  ) -> IntermediateOutput:
189
- retain_frac = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
190
- n_before = s.biology.n_true_cells
191
- n_after = max(100, int(n_before * retain_frac))
192
  return IntermediateOutput(
193
  output_type=OutputType.COUNT_MATRIX_SUMMARY,
194
  step_index=idx,
@@ -238,14 +305,15 @@ class OutputGenerator:
238
  ) -> IntermediateOutput:
239
  n_true = len(s.biology.cell_populations) or 5
240
  quality = self.noise.quality_degradation(0.8, [0.95])
241
- n_clusters = self.noise.sample_cluster_count(n_true, quality)
242
  cluster_names = [f"cluster_{i}" for i in range(n_clusters)]
243
- sizes = self._random_partition(s.biology.n_true_cells, n_clusters)
 
244
  return IntermediateOutput(
245
  output_type=OutputType.CLUSTER_RESULT,
246
  step_index=idx,
247
  quality_score=quality,
248
- summary=f"Found {n_clusters} clusters (ground-truth populations: {n_true})",
249
  data={
250
  "n_clusters": n_clusters,
251
  "cluster_names": cluster_names,
@@ -260,10 +328,22 @@ class OutputGenerator:
260
  self, action: ExperimentAction, s: FullLatentState, idx: int
261
  ) -> IntermediateOutput:
262
  comparison = action.parameters.get("comparison", "disease_vs_healthy")
 
 
 
 
263
  true_effects = s.biology.true_de_genes.get(comparison, {})
264
 
265
  n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
266
- noise_level = s.technical.dropout_rate + 0.1 * (1.0 - s.technical.sample_quality)
 
 
 
 
 
 
 
 
267
  observed = self.noise.sample_effect_sizes(true_effects, n_cells, noise_level)
268
 
269
  fp_genes = self.noise.generate_false_positives(5000, 0.002 + noise_level * 0.01)
@@ -299,10 +379,16 @@ class OutputGenerator:
299
  quality = self.noise.quality_degradation(0.7 if has_trajectory else 0.3, [0.9])
300
  summary_data: Dict[str, Any] = {"method": action.method or "monocle3"}
301
  if has_trajectory:
 
 
 
 
 
 
302
  summary_data.update({
303
- "n_lineages": s.biology.true_trajectory.get("n_lineages", 1),
304
  "pseudotime_range": [0.0, 1.0],
305
- "branching_detected": s.biology.true_trajectory.get("branching", False),
306
  })
307
  else:
308
  summary_data["n_lineages"] = self.noise.sample_count(1) + 1
@@ -323,24 +409,38 @@ class OutputGenerator:
323
  self, action: ExperimentAction, s: FullLatentState, idx: int
324
  ) -> IntermediateOutput:
325
  true_pathways = s.biology.true_pathways
326
- noise_level = 0.15
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  observed: Dict[str, float] = {}
328
  for pw, activity in true_pathways.items():
329
  observed[pw] = activity + float(self.noise.rng.normal(0, noise_level))
330
 
331
- for i in range(self.noise.sample_count(2)):
332
  observed[f"FP_PATHWAY_{i}"] = float(self.noise.rng.uniform(0.3, 0.6))
333
 
334
  top = sorted(observed.items(), key=lambda kv: kv[1], reverse=True)[:15]
 
335
  return IntermediateOutput(
336
  output_type=OutputType.PATHWAY_RESULT,
337
  step_index=idx,
338
- quality_score=self.noise.quality_degradation(0.8, [0.95]),
339
  summary=f"Pathway enrichment: {len(top)} significant pathways",
340
  data={
341
  "method": action.method or "GSEA",
342
  "top_pathways": [
343
- {"pathway": p, "score": round(s, 3)} for p, s in top
344
  ],
345
  },
346
  uncertainty=noise_level,
@@ -353,6 +453,25 @@ class OutputGenerator:
353
  true_net = s.biology.true_regulatory_network
354
  n_edges_true = sum(len(v) for v in true_net.values())
355
  noise_edges = self.noise.sample_count(max(5, int(n_edges_true * 0.3)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  return IntermediateOutput(
357
  output_type=OutputType.NETWORK_RESULT,
358
  step_index=idx,
@@ -362,7 +481,7 @@ class OutputGenerator:
362
  "method": action.method or "SCENIC",
363
  "n_regulons": len(true_net) + self.noise.sample_count(3),
364
  "n_edges": n_edges_true + noise_edges,
365
- "top_regulators": list(true_net.keys())[:10],
366
  },
367
  uncertainty=0.35,
368
  artifacts_available=["regulon_table", "grn_adjacency"],
@@ -407,8 +526,11 @@ class OutputGenerator:
407
  "marker": marker,
408
  "validated": validated,
409
  "assay": action.method or "qPCR",
 
 
 
410
  "effect_size": self.noise.sample_qc_metric(
411
- 1.5 if is_true else 0.2, 0.3, -0.5, 5.0
412
  ),
413
  },
414
  artifacts_available=["validation_data"],
@@ -417,22 +539,54 @@ class OutputGenerator:
417
  def _design_followup(
418
  self, action: ExperimentAction, s: FullLatentState, idx: int
419
  ) -> IntermediateOutput:
 
 
 
 
 
 
 
 
 
420
  return IntermediateOutput(
421
  output_type=OutputType.FOLLOWUP_DESIGN,
422
  step_index=idx,
423
- summary="Follow-up experiment design proposed",
424
- data={"proposal": action.parameters},
 
 
 
 
 
 
 
 
425
  artifacts_available=["followup_proposal"],
426
  )
427
 
428
  def _subagent_review(
429
  self, action: ExperimentAction, s: FullLatentState, idx: int
430
  ) -> IntermediateOutput:
 
 
 
 
 
 
 
 
 
431
  return IntermediateOutput(
432
  output_type=OutputType.SUBAGENT_REPORT,
433
  step_index=idx,
 
434
  summary=f"Subagent review ({action.invoked_subagent or 'general'})",
435
- data={"subagent": action.invoked_subagent, "notes": "Review complete."},
 
 
 
 
 
436
  artifacts_available=["subagent_report"],
437
  )
438
 
@@ -469,14 +623,46 @@ class OutputGenerator:
469
  sizes[0] += diff
470
  return sizes
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  _HANDLERS = {
474
  ActionType.COLLECT_SAMPLE: OutputGenerator._collect_sample,
475
  ActionType.SELECT_COHORT: OutputGenerator._select_cohort,
476
  ActionType.PREPARE_LIBRARY: OutputGenerator._prepare_library,
477
  ActionType.CULTURE_CELLS: OutputGenerator._culture_cells,
478
- ActionType.PERTURB_GENE: OutputGenerator._perturb,
479
- ActionType.PERTURB_COMPOUND: OutputGenerator._perturb,
480
  ActionType.SEQUENCE_CELLS: OutputGenerator._sequence_cells,
481
  ActionType.RUN_QC: OutputGenerator._run_qc,
482
  ActionType.FILTER_DATA: OutputGenerator._filter_data,
 
14
  from .latent_state import FullLatentState
15
  from .noise import NoiseModel
16
 
17
+ # Pool of common transcription factors used to generate realistic false-positive
18
+ # regulators, so the agent cannot trivially distinguish true vs. false hits by
19
+ # gene-name format alone.
20
+ _NOISE_TFS: List[str] = [
21
+ "NR3C1", "KLF4", "EGR1", "IRF1", "FOSL2", "JUN", "FOS", "ATF3",
22
+ "NFKB1", "RELA", "SP1", "MYC", "MAX", "E2F1", "CTCF", "YY1",
23
+ "TP53", "STAT5A", "SMAD3", "TCF7L2", "NFE2L2", "HIF1A", "CREB1",
24
+ ]
25
+
26
 
27
  class OutputGenerator:
28
  """Creates structured ``IntermediateOutput`` objects conditioned on the
 
100
  self, action: ExperimentAction, s: FullLatentState, idx: int
101
  ) -> IntermediateOutput:
102
  days = action.parameters.get("days", 7)
103
+ # Viability decays with culture duration: each day adds ~0.5%
104
+ # cumulative stress, reflecting senescence, media depletion, and
105
+ # passaging artefacts common in primary cell cultures.
106
+ decay = 0.005 * days
107
+ viability = self.noise.sample_qc_metric(
108
+ max(0.50, 0.95 - decay), 0.05, 0.30, 1.0
109
+ )
110
  return IntermediateOutput(
111
  output_type=OutputType.CULTURE_RESULT,
112
  step_index=idx,
 
116
  artifacts_available=["cultured_cells"],
117
  )
118
 
119
+ def _perturb_gene(
120
  self, action: ExperimentAction, s: FullLatentState, idx: int
121
  ) -> IntermediateOutput:
122
+ """Genetic perturbation (CRISPR/RNAi): high on-target efficiency,
123
+ binary effect, non-trivial off-target risk."""
124
  target = action.parameters.get("target", "unknown")
125
+ efficiency = s.last_perturbation_efficiency if s.last_perturbation_efficiency is not None else self.noise.sample_qc_metric(0.80, 0.12, 0.0, 1.0)
126
+ off_target_risk = self.noise.sample_qc_metric(0.10, 0.05, 0.0, 0.5)
127
  return IntermediateOutput(
128
  output_type=OutputType.PERTURBATION_RESULT,
129
  step_index=idx,
130
  quality_score=efficiency,
131
+ summary=(
132
+ f"Genetic perturbation of {target} "
133
+ f"(efficiency={efficiency:.2f}, off-target risk={off_target_risk:.2f})"
134
+ ),
135
+ data={
136
+ "target": target,
137
+ "efficiency": efficiency,
138
+ "type": action.action_type.value,
139
+ "off_target_risk": off_target_risk,
140
+ },
141
+ artifacts_available=["perturbed_cells"],
142
+ )
143
+
144
+ def _perturb_compound(
145
+ self, action: ExperimentAction, s: FullLatentState, idx: int
146
+ ) -> IntermediateOutput:
147
+ """Small-molecule perturbation: dose-dependent, partial on-target
148
+ activity, systemic effects possible."""
149
+ target = action.parameters.get("target", "unknown")
150
+ dose_um = action.parameters.get("dose_uM", 1.0)
151
+ efficiency = s.last_perturbation_efficiency if s.last_perturbation_efficiency is not None else self.noise.sample_qc_metric(0.70, 0.15, 0.0, 1.0)
152
+ on_target_frac = self.noise.sample_qc_metric(0.75, 0.10, 0.0, 1.0)
153
+ return IntermediateOutput(
154
+ output_type=OutputType.PERTURBATION_RESULT,
155
+ step_index=idx,
156
+ quality_score=efficiency * on_target_frac,
157
+ summary=(
158
+ f"Compound perturbation targeting {target} at {dose_um} µM "
159
+ f"(efficiency={efficiency:.2f}, on-target={on_target_frac:.2f})"
160
+ ),
161
  data={
162
  "target": target,
163
  "efficiency": efficiency,
164
  "type": action.action_type.value,
165
+ "dose_uM": dose_um,
166
+ "on_target_fraction": on_target_frac,
167
  },
168
  artifacts_available=["perturbed_cells"],
169
  )
 
171
  def _sequence_cells(
172
  self, action: ExperimentAction, s: FullLatentState, idx: int
173
  ) -> IntermediateOutput:
174
+ import math
175
  depth = s.technical.sequencing_depth_factor
176
+ n_cells = s.progress.n_cells_sequenced or self.noise.sample_count(
177
  s.biology.n_true_cells * s.technical.capture_efficiency
178
  )
179
+ # Gene detection saturates with sequencing depth: follows a
180
+ # 1 - exp(-k) saturation curve, scaled by library complexity.
181
+ max_genes = 20_000
182
+ saturation_arg = depth * s.technical.library_complexity * 0.8
183
+ n_genes = self.noise.sample_count(
184
+ int(max_genes * (1.0 - math.exp(-saturation_arg)))
185
+ )
186
  median_umi = self.noise.sample_count(int(3000 * depth))
187
  quality = self.noise.quality_degradation(
188
  s.technical.sample_quality,
 
213
  doublet_frac = self.noise.sample_qc_metric(
214
  s.technical.doublet_rate, 0.01, 0.0, 0.2
215
  )
216
+ # Mitochondrial fraction reflects cellular stress: activated,
217
+ # inflammatory, or pro-fibrotic populations have elevated mito
218
+ # transcription compared to quiescent/resting cells.
219
+ _stressed_states = {"activated", "stressed", "pro-fibrotic", "inflammatory"}
220
+ has_stressed_cells = any(
221
+ p.state in _stressed_states for p in s.biology.cell_populations
222
+ )
223
+ # Means are kept close (0.09 vs 0.06) with a wider SD (0.03) so the
224
+ # mito fraction is informative but not a near-perfect oracle for
225
+ # stressed-cell presence.
226
+ mito_mean = 0.09 if has_stressed_cells else 0.06
227
+ mito_frac = self.noise.sample_qc_metric(mito_mean, 0.03, 0.0, 0.3)
228
  ambient_frac = self.noise.sample_qc_metric(
229
  s.technical.ambient_rna_fraction, 0.01, 0.0, 0.2
230
  )
 
253
  def _filter_data(
254
  self, action: ExperimentAction, s: FullLatentState, idx: int
255
  ) -> IntermediateOutput:
256
+ retain_frac = s.last_retain_frac if s.last_retain_frac is not None else self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
257
+ n_before = s.progress.n_cells_sequenced or s.biology.n_true_cells
258
+ n_after = s.progress.n_cells_after_filter or max(100, int(n_before * retain_frac))
259
  return IntermediateOutput(
260
  output_type=OutputType.COUNT_MATRIX_SUMMARY,
261
  step_index=idx,
 
305
  ) -> IntermediateOutput:
306
  n_true = len(s.biology.cell_populations) or 5
307
  quality = self.noise.quality_degradation(0.8, [0.95])
308
+ n_clusters = s.last_n_clusters if s.last_n_clusters is not None else self.noise.sample_cluster_count(n_true, quality)
309
  cluster_names = [f"cluster_{i}" for i in range(n_clusters)]
310
+ n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
311
+ sizes = self._partition_by_population(n_cells, n_clusters, s.biology.cell_populations)
312
  return IntermediateOutput(
313
  output_type=OutputType.CLUSTER_RESULT,
314
  step_index=idx,
315
  quality_score=quality,
316
+ summary=f"Found {n_clusters} clusters",
317
  data={
318
  "n_clusters": n_clusters,
319
  "cluster_names": cluster_names,
 
328
  self, action: ExperimentAction, s: FullLatentState, idx: int
329
  ) -> IntermediateOutput:
330
  comparison = action.parameters.get("comparison", "disease_vs_healthy")
331
+ # Fall back to the first available comparison key if the requested one
332
+ # is absent, rather than silently returning an empty effect dict.
333
+ if comparison not in s.biology.true_de_genes and s.biology.true_de_genes:
334
+ comparison = next(iter(s.biology.true_de_genes))
335
  true_effects = s.biology.true_de_genes.get(comparison, {})
336
 
337
  n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
338
+ batch_noise = (
339
+ sum(s.technical.batch_effects.values())
340
+ / max(len(s.technical.batch_effects), 1)
341
+ )
342
+ noise_level = (
343
+ s.technical.dropout_rate
344
+ + 0.1 * (1.0 - s.technical.sample_quality)
345
+ + 0.5 * batch_noise
346
+ )
347
  observed = self.noise.sample_effect_sizes(true_effects, n_cells, noise_level)
348
 
349
  fp_genes = self.noise.generate_false_positives(5000, 0.002 + noise_level * 0.01)
 
379
  quality = self.noise.quality_degradation(0.7 if has_trajectory else 0.3, [0.9])
380
  summary_data: Dict[str, Any] = {"method": action.method or "monocle3"}
381
  if has_trajectory:
382
+ true_n_lineages = s.biology.true_trajectory.get("n_lineages", 1)
383
+ true_branching = s.biology.true_trajectory.get("branching", False)
384
+ # Perturb lineage count by ±1 and flip the branching flag with 20%
385
+ # probability so the output is informative but not an exact oracle.
386
+ noisy_n_lineages = max(1, true_n_lineages + int(self.noise.rng.choice([-1, 0, 0, 1])))
387
+ noisy_branching = true_branching if not self.noise.coin_flip(0.20) else not true_branching
388
  summary_data.update({
389
+ "n_lineages": noisy_n_lineages,
390
  "pseudotime_range": [0.0, 1.0],
391
+ "branching_detected": noisy_branching,
392
  })
393
  else:
394
  summary_data["n_lineages"] = self.noise.sample_count(1) + 1
 
409
  self, action: ExperimentAction, s: FullLatentState, idx: int
410
  ) -> IntermediateOutput:
411
  true_pathways = s.biology.true_pathways
412
+ # Pathway enrichment quality is tightly coupled to the quality of the
413
+ # preceding DE step: more DE genes found → better gene-set coverage →
414
+ # lower noise and fewer spurious pathway hits.
415
+ de_genes_found = s.progress.n_de_genes_found or 0
416
+ de_was_run = s.progress.de_performed
417
+ if de_was_run and de_genes_found > 0:
418
+ # Noise shrinks as the DE gene list grows (more signal in input).
419
+ noise_level = max(0.05, 0.25 - 0.001 * min(de_genes_found, 200))
420
+ n_fp_mean = max(1, int(5 - de_genes_found / 50))
421
+ else:
422
+ # Without a DE step, enrichment is unreliable.
423
+ noise_level = 0.40
424
+ n_fp_mean = 8
425
+
426
  observed: Dict[str, float] = {}
427
  for pw, activity in true_pathways.items():
428
  observed[pw] = activity + float(self.noise.rng.normal(0, noise_level))
429
 
430
+ for i in range(self.noise.sample_count(n_fp_mean)):
431
  observed[f"FP_PATHWAY_{i}"] = float(self.noise.rng.uniform(0.3, 0.6))
432
 
433
  top = sorted(observed.items(), key=lambda kv: kv[1], reverse=True)[:15]
434
+ base_quality = 0.80 if de_was_run else 0.45
435
  return IntermediateOutput(
436
  output_type=OutputType.PATHWAY_RESULT,
437
  step_index=idx,
438
+ quality_score=self.noise.quality_degradation(base_quality, [0.95]),
439
  summary=f"Pathway enrichment: {len(top)} significant pathways",
440
  data={
441
  "method": action.method or "GSEA",
442
  "top_pathways": [
443
+ {"pathway": p, "score": round(sc, 3)} for p, sc in top
444
  ],
445
  },
446
  uncertainty=noise_level,
 
453
  true_net = s.biology.true_regulatory_network
454
  n_edges_true = sum(len(v) for v in true_net.values())
455
  noise_edges = self.noise.sample_count(max(5, int(n_edges_true * 0.3)))
456
+
457
+ true_tfs = list(true_net.keys())
458
+ # Drop ~25% of true regulators (false-negative rate).
459
+ fn_set = set(self.noise.generate_false_negatives(true_tfs, 0.25))
460
+ observed_tfs = [tf for tf in true_tfs if tf not in fn_set]
461
+ # Inject realistic false-positive TFs drawn from a background pool so
462
+ # the agent cannot distinguish true from false hits by name format.
463
+ fp_candidates = [t for t in _NOISE_TFS if t not in set(true_tfs)]
464
+ n_fp = self.noise.sample_count(max(2, int(len(true_tfs) * 0.5) + 2))
465
+ if fp_candidates and n_fp > 0:
466
+ chosen = self.noise.rng.choice(
467
+ fp_candidates,
468
+ size=min(n_fp, len(fp_candidates)),
469
+ replace=False,
470
+ )
471
+ observed_tfs.extend(chosen.tolist())
472
+ # Shuffle so rank order does not reveal true-vs-false identity.
473
+ observed_tfs = self.noise.shuffle_ranking(observed_tfs, 0.5)
474
+
475
  return IntermediateOutput(
476
  output_type=OutputType.NETWORK_RESULT,
477
  step_index=idx,
 
481
  "method": action.method or "SCENIC",
482
  "n_regulons": len(true_net) + self.noise.sample_count(3),
483
  "n_edges": n_edges_true + noise_edges,
484
+ "top_regulators": observed_tfs[:10],
485
  },
486
  uncertainty=0.35,
487
  artifacts_available=["regulon_table", "grn_adjacency"],
 
526
  "marker": marker,
527
  "validated": validated,
528
  "assay": action.method or "qPCR",
529
+ # Means are kept close (0.85 vs 0.45) with a wide SD (0.4)
530
+ # so the effect size is correlated with, but not a near-perfect
531
+ # oracle for, true marker membership.
532
  "effect_size": self.noise.sample_qc_metric(
533
+ 0.85 if is_true else 0.45, 0.4, -0.5, 5.0
534
  ),
535
  },
536
  artifacts_available=["validation_data"],
 
539
  def _design_followup(
540
  self, action: ExperimentAction, s: FullLatentState, idx: int
541
  ) -> IntermediateOutput:
542
+ evidence_signals = sum([
543
+ int(s.progress.cells_clustered),
544
+ int(s.progress.de_performed),
545
+ int(s.progress.trajectories_inferred),
546
+ int(s.progress.pathways_analyzed),
547
+ int(s.progress.networks_inferred),
548
+ int(s.progress.markers_discovered),
549
+ int(s.progress.markers_validated),
550
+ ])
551
  return IntermediateOutput(
552
  output_type=OutputType.FOLLOWUP_DESIGN,
553
  step_index=idx,
554
+ quality_score=min(0.75, 0.2 + 0.08 * evidence_signals),
555
+ summary=(
556
+ f"Follow-up experiment design proposed "
557
+ f"(evidence_signals={evidence_signals})"
558
+ ),
559
+ data={
560
+ "proposal": action.parameters,
561
+ "evidence_signals": evidence_signals,
562
+ },
563
+ uncertainty=max(0.25, 0.8 - 0.08 * evidence_signals),
564
  artifacts_available=["followup_proposal"],
565
  )
566
 
567
  def _subagent_review(
568
  self, action: ExperimentAction, s: FullLatentState, idx: int
569
  ) -> IntermediateOutput:
570
+ evidence_signals = sum([
571
+ int(s.progress.cells_clustered),
572
+ int(s.progress.de_performed),
573
+ int(s.progress.trajectories_inferred),
574
+ int(s.progress.pathways_analyzed),
575
+ int(s.progress.networks_inferred),
576
+ int(s.progress.markers_discovered),
577
+ int(s.progress.markers_validated),
578
+ ])
579
  return IntermediateOutput(
580
  output_type=OutputType.SUBAGENT_REPORT,
581
  step_index=idx,
582
+ quality_score=min(0.7, 0.15 + 0.07 * evidence_signals),
583
  summary=f"Subagent review ({action.invoked_subagent or 'general'})",
584
+ data={
585
+ "subagent": action.invoked_subagent,
586
+ "notes": "Review complete.",
587
+ "evidence_signals": evidence_signals,
588
+ },
589
+ uncertainty=max(0.3, 0.85 - 0.08 * evidence_signals),
590
  artifacts_available=["subagent_report"],
591
  )
592
 
 
623
  sizes[0] += diff
624
  return sizes
625
 
626
+ def _partition_by_population(
627
+ self,
628
+ total: int,
629
+ k: int,
630
+ populations: list,
631
+ ) -> List[int]:
632
+ """Partition cells into k clusters using true population proportions
633
+ as Dirichlet concentration parameters, so majority cell types produce
634
+ larger clusters rather than uniformly random sizes."""
635
+ if k <= 0:
636
+ return []
637
+ if populations:
638
+ # Use true proportions as Dirichlet alpha — larger proportions
639
+ # concentrate probability mass, yielding realistic size ratios.
640
+ raw = [max(p.proportion, 1e-3) for p in populations]
641
+ # Align alpha length to k: repeat/truncate as needed.
642
+ if len(raw) >= k:
643
+ alpha = raw[:k]
644
+ else:
645
+ alpha = raw + [sum(raw) / len(raw)] * (k - len(raw))
646
+ # Scale alpha so the total magnitude is proportional to k,
647
+ # giving reasonable Dirichlet variance.
648
+ scale = k / max(sum(alpha), 1e-6)
649
+ alpha = [a * scale for a in alpha]
650
+ else:
651
+ alpha = [1.0] * k
652
+ fracs = self.noise.rng.dirichlet(alpha=alpha)
653
+ sizes = [max(1, int(total * f)) for f in fracs]
654
+ diff = total - sum(sizes)
655
+ sizes[0] += diff
656
+ return sizes
657
+
658
 
659
  _HANDLERS = {
660
  ActionType.COLLECT_SAMPLE: OutputGenerator._collect_sample,
661
  ActionType.SELECT_COHORT: OutputGenerator._select_cohort,
662
  ActionType.PREPARE_LIBRARY: OutputGenerator._prepare_library,
663
  ActionType.CULTURE_CELLS: OutputGenerator._culture_cells,
664
+ ActionType.PERTURB_GENE: OutputGenerator._perturb_gene,
665
+ ActionType.PERTURB_COMPOUND: OutputGenerator._perturb_compound,
666
  ActionType.SEQUENCE_CELLS: OutputGenerator._sequence_cells,
667
  ActionType.RUN_QC: OutputGenerator._run_qc,
668
  ActionType.FILTER_DATA: OutputGenerator._filter_data,
server/simulator/transition.py CHANGED
@@ -15,6 +15,7 @@ from models import (
15
  ExperimentAction,
16
  IntermediateOutput,
17
  OutputType,
 
18
  )
19
 
20
  from .latent_state import FullLatentState
@@ -22,7 +23,8 @@ from .noise import NoiseModel
22
  from .output_generator import OutputGenerator
23
 
24
 
25
- ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
 
26
  ActionType.COLLECT_SAMPLE: (5_000, 7.0),
27
  ActionType.SELECT_COHORT: ( 500, 1.0),
28
  ActionType.PREPARE_LIBRARY: (8_000, 3.0),
@@ -41,11 +43,30 @@ ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
41
  ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
42
  ActionType.MARKER_SELECTION: ( 100, 0.5),
43
  ActionType.VALIDATE_MARKER: (5_000, 14.0),
44
- ActionType.DESIGN_FOLLOWUP: ( 0, 0.5),
45
- ActionType.REQUEST_SUBAGENT_REVIEW: ( 0, 0.25),
46
  ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @dataclass
51
  class TransitionResult:
@@ -138,9 +159,7 @@ class TransitionEngine:
138
  def _apply_resource_cost(
139
  self, s: FullLatentState, action: ExperimentAction
140
  ) -> None:
141
- budget_cost, time_cost = ACTION_COSTS.get(
142
- action.action_type, (0.0, 0.0)
143
- )
144
  s.resources.budget_used += budget_cost
145
  s.resources.time_used_days += time_cost
146
  if action.action_type in {
@@ -176,6 +195,8 @@ class TransitionEngine:
176
  ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
177
  ActionType.MARKER_SELECTION: "markers_discovered",
178
  ActionType.VALIDATE_MARKER: "markers_validated",
 
 
179
  ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
180
  }
181
  flag = _MAP.get(at)
@@ -188,16 +209,43 @@ class TransitionEngine:
188
 
189
  if at == ActionType.SEQUENCE_CELLS:
190
  s.resources.sequencing_lanes_used += 1
 
 
 
 
 
 
191
 
192
  if at == ActionType.FILTER_DATA:
193
  retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
194
- p.n_cells_after_filter = max(
195
- 100, int(s.biology.n_true_cells * retain)
196
- )
197
 
198
  if at == ActionType.CLUSTER_CELLS:
199
  n_true = len(s.biology.cell_populations) or 5
200
  p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  def _propagate_artifacts(
203
  self,
@@ -208,6 +256,7 @@ class TransitionEngine:
208
  if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
209
  top = output.data.get("top_genes", [])
210
  s.discovered_de_genes = [g["gene"] for g in top[:20]]
 
211
 
212
  if action.action_type == ActionType.CLUSTER_CELLS:
213
  s.discovered_clusters = output.data.get("cluster_names", [])
 
15
  ExperimentAction,
16
  IntermediateOutput,
17
  OutputType,
18
+ TOOL_REGISTRY,
19
  )
20
 
21
  from .latent_state import FullLatentState
 
23
  from .output_generator import OutputGenerator
24
 
25
 
26
+ # Fallback costs per ActionType when the agent doesn't specify a known tool.
27
+ _BASE_ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
28
  ActionType.COLLECT_SAMPLE: (5_000, 7.0),
29
  ActionType.SELECT_COHORT: ( 500, 1.0),
30
  ActionType.PREPARE_LIBRARY: (8_000, 3.0),
 
43
  ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
44
  ActionType.MARKER_SELECTION: ( 100, 0.5),
45
  ActionType.VALIDATE_MARKER: (5_000, 14.0),
46
+ ActionType.DESIGN_FOLLOWUP: ( 100, 0.5),
47
+ ActionType.REQUEST_SUBAGENT_REVIEW: ( 50, 0.25),
48
  ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
49
  }
50
 
51
+ # Kept as public alias so existing imports (e.g. hackathon_environment) still work.
52
+ ACTION_COSTS = _BASE_ACTION_COSTS
53
+
54
+
55
+ def compute_action_cost(action: ExperimentAction) -> Tuple[float, float]:
56
+ """Return (budget_cost, time_cost_days) for an action.
57
+
58
+ If the action specifies a ``method`` that exists in ``TOOL_REGISTRY``,
59
+ the tool's ``typical_cost_usd`` and ``typical_runtime_hours`` are used
60
+ (converted to days). Otherwise we fall back to the per-ActionType base
61
+ cost table.
62
+ """
63
+ tool_spec = TOOL_REGISTRY.get(action.method or "")
64
+ if tool_spec is not None:
65
+ budget = tool_spec.typical_cost_usd
66
+ time_days = tool_spec.typical_runtime_hours / 24.0
67
+ return (budget, time_days)
68
+ return _BASE_ACTION_COSTS.get(action.action_type, (0.0, 0.0))
69
+
70
 
71
  @dataclass
72
  class TransitionResult:
 
159
  def _apply_resource_cost(
160
  self, s: FullLatentState, action: ExperimentAction
161
  ) -> None:
162
+ budget_cost, time_cost = compute_action_cost(action)
 
 
163
  s.resources.budget_used += budget_cost
164
  s.resources.time_used_days += time_cost
165
  if action.action_type in {
 
195
  ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
196
  ActionType.MARKER_SELECTION: "markers_discovered",
197
  ActionType.VALIDATE_MARKER: "markers_validated",
198
+ ActionType.DESIGN_FOLLOWUP: "followup_designed",
199
+ ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested",
200
  ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
201
  }
202
  flag = _MAP.get(at)
 
209
 
210
  if at == ActionType.SEQUENCE_CELLS:
211
  s.resources.sequencing_lanes_used += 1
212
+ p.n_cells_sequenced = self.noise.sample_count(
213
+ s.biology.n_true_cells * s.technical.capture_efficiency
214
+ )
215
+
216
+ if at in {ActionType.PERTURB_GENE, ActionType.PERTURB_COMPOUND}:
217
+ self._apply_perturbation_effects(s, action)
218
 
219
  if at == ActionType.FILTER_DATA:
220
  retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
221
+ base = p.n_cells_sequenced or s.biology.n_true_cells
222
+ p.n_cells_after_filter = max(100, int(base * retain))
223
+ s.last_retain_frac = retain
224
 
225
  if at == ActionType.CLUSTER_CELLS:
226
  n_true = len(s.biology.cell_populations) or 5
227
  p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
228
+ s.last_n_clusters = p.n_clusters_found
229
+
230
+ def _apply_perturbation_effects(
231
+ self, s: FullLatentState, action: ExperimentAction
232
+ ) -> None:
233
+ """Fold perturbation-specific gene effects into true_de_genes so
234
+ downstream DE analysis reflects the perturbed biology."""
235
+ target = action.parameters.get("target", "")
236
+ effects = s.biology.perturbation_effects.get(target, {})
237
+ if not effects:
238
+ return
239
+ # Efficiency drawn from the same distribution as the output handler
240
+ # so latent state and observable output are coherent.
241
+ if action.action_type == ActionType.PERTURB_GENE:
242
+ efficiency = self.noise.sample_qc_metric(0.80, 0.12, 0.0, 1.0)
243
+ else:
244
+ efficiency = self.noise.sample_qc_metric(0.70, 0.15, 0.0, 1.0)
245
+ s.last_perturbation_efficiency = efficiency
246
+ for gene_map in s.biology.true_de_genes.values():
247
+ for gene, delta in effects.items():
248
+ gene_map[gene] = gene_map.get(gene, 0.0) + delta * efficiency
249
 
250
  def _propagate_artifacts(
251
  self,
 
256
  if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
257
  top = output.data.get("top_genes", [])
258
  s.discovered_de_genes = [g["gene"] for g in top[:20]]
259
+ s.progress.n_de_genes_found = output.data.get("n_significant", 0)
260
 
261
  if action.action_type == ActionType.CLUSTER_CELLS:
262
  s.discovered_clusters = output.data.get("cluster_names", [])
server/tasks/bio_palette.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curated biological building blocks for procedural scenario generation.
2
+
3
+ Provides tissue-specific cell types, disease profiles, pathway libraries,
4
+ regulatory network templates, and perturbation effect profiles. The
5
+ procedural generator composes these into complete ``Scenario`` objects
6
+ with fully populated ``LatentBiologicalState``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Dict, List, Optional, Tuple
13
+
14
+
15
+ # ── Cell type templates ─────────────────────────────────────────────────────
16
+
17
+
18
+ @dataclass
19
+ class CellTypeTemplate:
20
+ name: str
21
+ marker_genes: List[str]
22
+ proportion_range: Tuple[float, float] = (0.05, 0.30)
23
+ states: List[str] = field(default_factory=lambda: ["quiescent"])
24
+ disease_responsive: bool = False
25
+ response_range: Tuple[float, float] = (0.5, 1.5)
26
+
27
+
28
+ TISSUE_CELL_TYPES: Dict[str, List[CellTypeTemplate]] = {
29
+ "heart": [
30
+ CellTypeTemplate("cardiomyocyte", ["TNNT2", "MYH7", "ACTC1"], (0.25, 0.40), ["contractile", "stressed"]),
31
+ CellTypeTemplate("cardiac_fibroblast", ["COL1A1", "DCN", "LUM"], (0.15, 0.30), ["quiescent", "activated"], True, (1.1, 1.8)),
32
+ CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.10, 0.20), ["quiescent"]),
33
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.15), ["quiescent", "activated", "inflammatory"], True, (1.2, 2.0)),
34
+ CellTypeTemplate("smooth_muscle", ["ACTA2", "MYH11", "TAGLN"], (0.08, 0.18), ["quiescent"]),
35
+ CellTypeTemplate("pericyte", ["PDGFRB", "RGS5", "NOTCH3"], (0.03, 0.10), ["quiescent"]),
36
+ ],
37
+ "lung": [
38
+ CellTypeTemplate("AT2", ["SFTPC", "SFTPB", "ABCA3"], (0.15, 0.25), ["normal", "stressed"]),
39
+ CellTypeTemplate("AT1", ["AGER", "PDPN", "CAV1"], (0.10, 0.18), ["normal"]),
40
+ CellTypeTemplate("alveolar_macrophage", ["MARCO", "FABP4", "MCEMP1"], (0.10, 0.20), ["resident", "activated"]),
41
+ CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "POSTN"], (0.12, 0.25), ["quiescent", "activated"], True, (1.2, 2.0)),
42
+ CellTypeTemplate("endothelial", ["PECAM1", "CLDN5", "VWF"], (0.08, 0.15), ["quiescent"]),
43
+ CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.08, 0.18), ["quiescent", "activated"]),
44
+ CellTypeTemplate("ciliated", ["FOXJ1", "DNAH5", "TPPP3"], (0.05, 0.12), ["normal"]),
45
+ ],
46
+ "brain": [
47
+ CellTypeTemplate("excitatory_neuron", ["SLC17A7", "CAMK2A", "NRGN"], (0.25, 0.40), ["normal", "stressed"]),
48
+ CellTypeTemplate("inhibitory_neuron", ["GAD1", "GAD2", "SLC32A1"], (0.12, 0.22), ["normal"]),
49
+ CellTypeTemplate("astrocyte", ["GFAP", "AQP4", "SLC1A3"], (0.10, 0.20), ["quiescent", "activated"], True, (1.2, 1.8)),
50
+ CellTypeTemplate("microglia", ["CX3CR1", "P2RY12", "TMEM119"], (0.05, 0.12), ["homeostatic", "activated", "inflammatory"], True, (1.3, 2.5)),
51
+ CellTypeTemplate("oligodendrocyte", ["MBP", "PLP1", "MOG"], (0.10, 0.18), ["myelinating"]),
52
+ CellTypeTemplate("OPC", ["PDGFRA", "CSPG4", "OLIG2"], (0.03, 0.08), ["progenitor"]),
53
+ CellTypeTemplate("endothelial", ["CLDN5", "FLT1", "PECAM1"], (0.03, 0.08), ["quiescent"]),
54
+ ],
55
+ "liver": [
56
+ CellTypeTemplate("hepatocyte", ["ALB", "APOB", "CYP3A4"], (0.55, 0.70), ["normal", "stressed"]),
57
+ CellTypeTemplate("cholangiocyte", ["KRT19", "KRT7", "EPCAM"], (0.05, 0.10), ["normal"]),
58
+ CellTypeTemplate("kupffer_cell", ["CD68", "CLEC4F", "MARCO"], (0.08, 0.15), ["quiescent", "activated", "inflammatory"], True, (1.2, 2.0)),
59
+ CellTypeTemplate("stellate_cell", ["ACTA2", "LRAT", "PDGFRB"], (0.05, 0.12), ["quiescent", "activated"], True, (1.3, 2.0)),
60
+ CellTypeTemplate("endothelial", ["PECAM1", "LYVE1", "STAB2"], (0.05, 0.10), ["quiescent"]),
61
+ CellTypeTemplate("NK_cell", ["NKG7", "GNLY", "KLRD1"], (0.03, 0.08), ["quiescent", "activated"]),
62
+ ],
63
+ "bone_marrow": [
64
+ CellTypeTemplate("HSC", ["CD34", "KIT", "THY1"], (0.03, 0.08), ["stem"]),
65
+ CellTypeTemplate("CMP", ["CD34", "FLT3"], (0.08, 0.15), ["progenitor"]),
66
+ CellTypeTemplate("GMP", ["CSF3R", "CEBPA"], (0.08, 0.15), ["progenitor"]),
67
+ CellTypeTemplate("MEP", ["GATA1", "KLF1"], (0.06, 0.12), ["progenitor"]),
68
+ CellTypeTemplate("erythrocyte", ["HBA1", "HBB", "GYPA"], (0.15, 0.25), ["mature"]),
69
+ CellTypeTemplate("neutrophil", ["ELANE", "MPO", "CTSG"], (0.12, 0.22), ["mature"]),
70
+ CellTypeTemplate("monocyte", ["CD14", "CSF1R", "FCGR3A"], (0.10, 0.18), ["mature"]),
71
+ CellTypeTemplate("megakaryocyte", ["ITGA2B", "GP1BA", "PF4"], (0.05, 0.12), ["mature"]),
72
+ ],
73
+ "kidney": [
74
+ CellTypeTemplate("proximal_tubule", ["SLC34A1", "LRP2", "CUBN"], (0.30, 0.45), ["normal", "stressed"]),
75
+ CellTypeTemplate("distal_tubule", ["SLC12A3", "CALB1"], (0.10, 0.18), ["normal"]),
76
+ CellTypeTemplate("collecting_duct", ["AQP2", "FXYD4"], (0.08, 0.15), ["normal"]),
77
+ CellTypeTemplate("podocyte", ["NPHS1", "NPHS2", "WT1"], (0.05, 0.10), ["normal", "stressed"]),
78
+ CellTypeTemplate("endothelial", ["PECAM1", "EMCN", "FLT1"], (0.05, 0.12), ["quiescent"]),
79
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.10), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
80
+ CellTypeTemplate("fibroblast", ["COL1A1", "PDGFRA", "DCN"], (0.05, 0.12), ["quiescent", "activated"], True, (1.2, 1.8)),
81
+ ],
82
+ "colon": [
83
+ CellTypeTemplate("colonocyte", ["CA2", "AQP8", "SLC26A3"], (0.25, 0.40), ["normal", "stressed"]),
84
+ CellTypeTemplate("goblet_cell", ["MUC2", "TFF3", "FCGBP"], (0.10, 0.18), ["secretory"]),
85
+ CellTypeTemplate("stem_cell", ["LGR5", "ASCL2", "OLFM4"], (0.05, 0.10), ["stem"]),
86
+ CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.10, 0.18), ["quiescent", "activated"]),
87
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.12), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
88
+ CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "VIM"], (0.08, 0.15), ["quiescent", "activated"], True, (1.2, 1.8)),
89
+ CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.05, 0.10), ["quiescent"]),
90
+ ],
91
+ "pancreas": [
92
+ CellTypeTemplate("beta_cell", ["INS", "MAFA", "NKX6-1"], (0.25, 0.40), ["normal", "stressed"], True, (0.4, 0.8)),
93
+ CellTypeTemplate("alpha_cell", ["GCG", "ARX", "IRX2"], (0.15, 0.25), ["normal"]),
94
+ CellTypeTemplate("delta_cell", ["SST", "HHEX"], (0.05, 0.10), ["normal"]),
95
+ CellTypeTemplate("ductal", ["KRT19", "SOX9", "CFTR"], (0.10, 0.18), ["normal"]),
96
+ CellTypeTemplate("acinar", ["PRSS1", "CPA1", "CELA3A"], (0.10, 0.20), ["normal"]),
97
+ CellTypeTemplate("stellate", ["ACTA2", "PDGFRA", "COL1A1"], (0.05, 0.10), ["quiescent", "activated"], True, (1.2, 1.8)),
98
+ CellTypeTemplate("macrophage", ["CD68", "CD163"], (0.03, 0.08), ["quiescent", "inflammatory"]),
99
+ ],
100
+ "skin": [
101
+ CellTypeTemplate("keratinocyte", ["KRT14", "KRT5", "KRT1"], (0.40, 0.55), ["basal", "differentiated"]),
102
+ CellTypeTemplate("melanocyte", ["MLANA", "PMEL", "TYR"], (0.05, 0.10), ["normal", "activated"]),
103
+ CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "DCN"], (0.10, 0.20), ["quiescent", "activated"]),
104
+ CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.08, 0.15), ["quiescent", "activated"]),
105
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.10), ["quiescent", "inflammatory"]),
106
+ CellTypeTemplate("endothelial", ["PECAM1", "VWF"], (0.05, 0.10), ["quiescent"]),
107
+ ],
108
+ "breast": [
109
+ CellTypeTemplate("luminal_epithelial", ["KRT8", "KRT18", "EPCAM"], (0.25, 0.40), ["normal", "stressed"]),
110
+ CellTypeTemplate("basal_epithelial", ["KRT14", "KRT5", "TP63"], (0.10, 0.20), ["normal"]),
111
+ CellTypeTemplate("fibroblast", ["COL1A1", "COL3A1", "FAP"], (0.10, 0.20), ["quiescent", "activated"], True, (1.2, 1.8)),
112
+ CellTypeTemplate("T_cell", ["CD3D", "CD3E", "CD8A"], (0.08, 0.15), ["quiescent", "activated", "exhausted"]),
113
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.05, 0.12), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
114
+ CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.05, 0.10), ["quiescent"]),
115
+ ],
116
+ "synovium": [
117
+ CellTypeTemplate("fibroblast", ["COL1A1", "FAP", "THY1"], (0.20, 0.30), ["quiescent", "activated"], True, (1.2, 1.8)),
118
+ CellTypeTemplate("CD4_T_cell", ["CD3D", "CD4", "IL7R"], (0.12, 0.22), ["quiescent", "activated"]),
119
+ CellTypeTemplate("CD8_T_cell", ["CD3D", "CD8A", "GZMB"], (0.08, 0.15), ["quiescent", "activated"]),
120
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "MARCO"], (0.10, 0.18), ["quiescent", "inflammatory"], True, (1.3, 2.0)),
121
+ CellTypeTemplate("B_cell", ["CD19", "MS4A1", "CD79A"], (0.05, 0.12), ["quiescent"]),
122
+ CellTypeTemplate("endothelial", ["PECAM1", "VWF"], (0.05, 0.10), ["quiescent"]),
123
+ CellTypeTemplate("mast_cell", ["KIT", "TPSAB1", "CPA3"], (0.03, 0.08), ["quiescent"]),
124
+ ],
125
+ "aorta": [
126
+ CellTypeTemplate("smooth_muscle", ["ACTA2", "MYH11", "TAGLN"], (0.30, 0.45), ["contractile", "synthetic"], True, (0.6, 0.9)),
127
+ CellTypeTemplate("endothelial", ["PECAM1", "VWF", "CDH5"], (0.15, 0.25), ["quiescent", "activated"]),
128
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "TREM2"], (0.08, 0.15), ["quiescent", "inflammatory"], True, (1.5, 2.5)),
129
+ CellTypeTemplate("fibroblast", ["COL1A1", "LUM", "DCN"], (0.08, 0.15), ["quiescent", "activated"]),
130
+ CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.05, 0.12), ["quiescent", "activated"]),
131
+ CellTypeTemplate("dendritic_cell", ["FCER1A", "CD1C", "CLEC10A"], (0.03, 0.08), ["quiescent"]),
132
+ ],
133
+ "blood": [
134
+ CellTypeTemplate("CD4_T_cell", ["CD3D", "CD4", "IL7R"], (0.15, 0.25), ["quiescent", "activated"]),
135
+ CellTypeTemplate("CD8_T_cell", ["CD3D", "CD8A", "GZMB"], (0.10, 0.18), ["quiescent", "activated"]),
136
+ CellTypeTemplate("B_cell", ["CD19", "MS4A1", "CD79A"], (0.08, 0.15), ["quiescent"]),
137
+ CellTypeTemplate("NK_cell", ["NKG7", "GNLY", "KLRD1"], (0.05, 0.12), ["quiescent", "activated"]),
138
+ CellTypeTemplate("monocyte", ["CD14", "CSF1R", "FCGR3A"], (0.15, 0.25), ["classical", "non_classical"]),
139
+ CellTypeTemplate("neutrophil", ["ELANE", "MPO", "CTSG"], (0.10, 0.20), ["mature"]),
140
+ CellTypeTemplate("platelet", ["ITGA2B", "GP1BA", "PF4"], (0.03, 0.08), ["normal"]),
141
+ ],
142
+ "spleen": [
143
+ CellTypeTemplate("B_cell", ["CD19", "MS4A1", "CD79A"], (0.20, 0.35), ["quiescent", "activated"]),
144
+ CellTypeTemplate("T_cell", ["CD3D", "CD3E", "IL7R"], (0.15, 0.25), ["quiescent", "activated"]),
145
+ CellTypeTemplate("macrophage", ["CD68", "CD163", "CSF1R"], (0.10, 0.18), ["quiescent", "inflammatory"]),
146
+ CellTypeTemplate("dendritic_cell", ["FCER1A", "CD1C", "CLEC10A"], (0.05, 0.10), ["quiescent"]),
147
+ CellTypeTemplate("NK_cell", ["NKG7", "GNLY", "KLRD1"], (0.05, 0.12), ["quiescent"]),
148
+ CellTypeTemplate("endothelial", ["PECAM1", "STAB2"], (0.05, 0.10), ["quiescent"]),
149
+ ],
150
+ "thymus": [
151
+ CellTypeTemplate("double_negative_T", ["CD3D", "PTCRA"], (0.10, 0.18), ["progenitor"]),
152
+ CellTypeTemplate("double_positive_T", ["CD3D", "CD4", "CD8A"], (0.30, 0.45), ["progenitor"]),
153
+ CellTypeTemplate("CD4_SP", ["CD3D", "CD4", "IL7R"], (0.10, 0.18), ["mature"]),
154
+ CellTypeTemplate("CD8_SP", ["CD3D", "CD8A", "CD8B"], (0.08, 0.15), ["mature"]),
155
+ CellTypeTemplate("thymic_epithelial", ["FOXN1", "KRT5", "KRT8"], (0.05, 0.12), ["cortical", "medullary"]),
156
+ CellTypeTemplate("dendritic_cell", ["FCER1A", "CD1C"], (0.03, 0.08), ["quiescent"]),
157
+ CellTypeTemplate("macrophage", ["CD68", "CD163"], (0.03, 0.08), ["quiescent"]),
158
+ ],
159
+ }
160
+
161
+
162
+ # ── Disease profiles ────────────────────────────────────────────────────────
163
+
164
+
165
+ @dataclass
166
+ class DiseaseProfile:
167
+ name: str
168
+ display_name: str
169
+ tissue: str
170
+ condition_name: str
171
+ de_genes: Dict[str, Tuple[float, float]]
172
+ pathways: Dict[str, Tuple[float, float]]
173
+ markers: List[str]
174
+ mechanism_templates: List[str]
175
+ responding_cell_types: List[str] = field(default_factory=list)
176
+ hidden_failure_templates: List[str] = field(default_factory=list)
177
+
178
+
179
+ DISEASE_PROFILES: Dict[str, DiseaseProfile] = {
180
+ "dilated_cardiomyopathy": DiseaseProfile(
181
+ name="dilated_cardiomyopathy",
182
+ display_name="dilated cardiomyopathy",
183
+ tissue="heart",
184
+ condition_name="dilated_cardiomyopathy",
185
+ de_genes={
186
+ "NPPA": (1.8, 3.5), "NPPB": (2.0, 4.0), "MYH7": (1.0, 2.5),
187
+ "COL1A1": (1.0, 2.2), "COL3A1": (0.8, 1.8), "POSTN": (1.5, 3.0),
188
+ "CCL2": (0.8, 1.8), "IL6": (0.5, 1.5), "TGFB1": (0.7, 1.6),
189
+ "ANKRD1": (1.5, 3.0), "XIRP2": (-2.0, -0.8), "MYL2": (-1.5, -0.5),
190
+ },
191
+ pathways={
192
+ "cardiac_muscle_contraction": (0.3, 0.6),
193
+ "extracellular_matrix_organisation": (0.7, 0.95),
194
+ "inflammatory_response": (0.5, 0.8),
195
+ "TGF_beta_signalling": (0.6, 0.85),
196
+ "apoptosis": (0.4, 0.65),
197
+ },
198
+ markers=["NPPA", "NPPB", "POSTN", "COL1A1"],
199
+ mechanism_templates=[
200
+ "TGF-beta-driven fibrosis",
201
+ "inflammatory macrophage infiltration",
202
+ ],
203
+ responding_cell_types=["cardiac_fibroblast", "macrophage"],
204
+ ),
205
+ "IPF": DiseaseProfile(
206
+ name="IPF",
207
+ display_name="idiopathic pulmonary fibrosis",
208
+ tissue="lung",
209
+ condition_name="IPF",
210
+ de_genes={
211
+ "SPP1": (2.0, 4.0), "MERTK": (0.8, 2.0), "MMP9": (1.0, 2.5),
212
+ "TREM2": (0.8, 2.0), "COL1A1": (1.5, 3.0), "COL3A1": (1.2, 2.5),
213
+ "POSTN": (1.5, 3.5), "SFTPC": (-2.0, -0.8), "AGER": (-2.5, -1.0),
214
+ },
215
+ pathways={
216
+ "extracellular_matrix_organisation": (0.75, 0.95),
217
+ "integrin_signalling": (0.6, 0.85),
218
+ "macrophage_activation": (0.65, 0.9),
219
+ "Wnt_signalling": (0.4, 0.7),
220
+ },
221
+ markers=["SPP1", "MERTK", "POSTN", "MMP9"],
222
+ mechanism_templates=[
223
+ "SPP1+ macrophage-driven fibroblast activation",
224
+ "integrin-mediated SPP1 signalling in fibrosis",
225
+ ],
226
+ responding_cell_types=["fibroblast", "alveolar_macrophage"],
227
+ ),
228
+ "Alzheimer": DiseaseProfile(
229
+ name="Alzheimer",
230
+ display_name="Alzheimer's disease",
231
+ tissue="brain",
232
+ condition_name="Alzheimer",
233
+ de_genes={
234
+ "TREM2": (1.0, 2.5), "APOE": (1.2, 2.8), "CLU": (0.8, 2.0),
235
+ "C1QA": (1.0, 2.2), "C1QB": (0.9, 2.0), "GFAP": (1.5, 3.0),
236
+ "AQP4": (0.6, 1.5), "SLC17A7": (-1.5, -0.5), "NRGN": (-2.0, -0.8),
237
+ "SNAP25": (-1.2, -0.4),
238
+ },
239
+ pathways={
240
+ "complement_cascade": (0.7, 0.9),
241
+ "neuroinflammation": (0.65, 0.85),
242
+ "amyloid_processing": (0.6, 0.8),
243
+ "synaptic_signalling": (0.3, 0.5),
244
+ "lipid_metabolism": (0.5, 0.7),
245
+ },
246
+ markers=["TREM2", "APOE", "GFAP", "C1QA"],
247
+ mechanism_templates=[
248
+ "TREM2-mediated microglial activation in amyloid clearance",
249
+ "complement-driven synaptic pruning",
250
+ "reactive astrogliosis amplifying neuroinflammation",
251
+ ],
252
+ responding_cell_types=["microglia", "astrocyte"],
253
+ ),
254
+ "colorectal_cancer": DiseaseProfile(
255
+ name="colorectal_cancer",
256
+ display_name="colorectal cancer",
257
+ tissue="colon",
258
+ condition_name="colorectal_cancer",
259
+ de_genes={
260
+ "MYC": (1.5, 3.0), "KRAS": (0.8, 1.8), "TP53": (-1.5, -0.5),
261
+ "APC": (-1.8, -0.8), "CDH1": (-1.2, -0.4), "VIM": (1.0, 2.5),
262
+ "MKI67": (1.5, 3.0), "CD44": (1.0, 2.2), "LGR5": (0.8, 2.0),
263
+ },
264
+ pathways={
265
+ "Wnt_signalling": (0.75, 0.95),
266
+ "cell_cycle": (0.7, 0.9),
267
+ "EMT": (0.6, 0.85),
268
+ "p53_signalling": (0.3, 0.5),
269
+ "MAPK_signalling": (0.55, 0.75),
270
+ },
271
+ markers=["MYC", "CD44", "VIM", "MKI67"],
272
+ mechanism_templates=[
273
+ "Wnt/beta-catenin-driven tumour stem cell expansion",
274
+ "epithelial-mesenchymal transition promoting invasion",
275
+ ],
276
+ responding_cell_types=["stem_cell", "macrophage", "fibroblast"],
277
+ ),
278
+ "type2_diabetes": DiseaseProfile(
279
+ name="type2_diabetes",
280
+ display_name="type 2 diabetes",
281
+ tissue="pancreas",
282
+ condition_name="type2_diabetes",
283
+ de_genes={
284
+ "INS": (-2.0, -0.8), "MAFA": (-1.5, -0.5), "PDX1": (-1.2, -0.4),
285
+ "UCN3": (-1.8, -0.6), "GCG": (0.8, 2.0), "ARX": (0.5, 1.5),
286
+ "IAPP": (0.6, 1.8), "TXNIP": (1.0, 2.5), "DDIT3": (0.8, 2.0),
287
+ },
288
+ pathways={
289
+ "insulin_signalling": (0.3, 0.5),
290
+ "ER_stress_response": (0.7, 0.9),
291
+ "oxidative_stress": (0.6, 0.8),
292
+ "glucagon_signalling": (0.6, 0.8),
293
+ "apoptosis": (0.5, 0.7),
294
+ },
295
+ markers=["INS", "TXNIP", "IAPP", "DDIT3"],
296
+ mechanism_templates=[
297
+ "ER stress-induced beta cell apoptosis",
298
+ "glucotoxicity-driven beta cell dedifferentiation",
299
+ ],
300
+ responding_cell_types=["beta_cell", "stellate"],
301
+ ),
302
+ "AML": DiseaseProfile(
303
+ name="AML",
304
+ display_name="acute myeloid leukemia",
305
+ tissue="bone_marrow",
306
+ condition_name="AML",
307
+ de_genes={
308
+ "FLT3": (1.5, 3.0), "NPM1": (0.8, 2.0), "IDH2": (0.6, 1.5),
309
+ "RUNX1": (-1.5, -0.5), "CEBPA": (-1.2, -0.4), "KIT": (1.0, 2.5),
310
+ "WT1": (1.2, 2.8), "MYC": (0.8, 2.0),
311
+ },
312
+ pathways={
313
+ "hematopoietic_cell_lineage": (0.3, 0.5),
314
+ "MAPK_signalling": (0.7, 0.9),
315
+ "PI3K_AKT_signalling": (0.65, 0.85),
316
+ "cell_cycle": (0.7, 0.9),
317
+ "apoptosis": (0.3, 0.5),
318
+ },
319
+ markers=["FLT3", "NPM1", "WT1", "KIT"],
320
+ mechanism_templates=[
321
+ "FLT3-ITD-driven proliferative advantage",
322
+ "myeloid differentiation block via RUNX1 loss",
323
+ ],
324
+ responding_cell_types=["HSC", "GMP"],
325
+ ),
326
+ "rheumatoid_arthritis": DiseaseProfile(
327
+ name="rheumatoid_arthritis",
328
+ display_name="rheumatoid arthritis",
329
+ tissue="synovium",
330
+ condition_name="rheumatoid_arthritis",
331
+ de_genes={
332
+ "IFNG": (1.0, 2.5), "TBX21": (0.8, 1.8), "IL17A": (1.0, 2.2),
333
+ "RORC": (0.6, 1.5), "TNF": (1.2, 2.5), "IL6": (1.0, 2.2),
334
+ "MMP3": (1.5, 3.0), "MMP1": (1.2, 2.5), "CXCL13": (1.0, 2.5),
335
+ },
336
+ pathways={
337
+ "JAK_STAT_signalling": (0.7, 0.9),
338
+ "TNF_signalling": (0.7, 0.9),
339
+ "Th17_differentiation": (0.6, 0.8),
340
+ "NF_kB_signalling": (0.65, 0.85),
341
+ "matrix_metalloproteinase_activity": (0.7, 0.9),
342
+ },
343
+ markers=["TNF", "IL6", "MMP3", "CXCL13"],
344
+ mechanism_templates=[
345
+ "TNF/NF-kB-driven synovial inflammation",
346
+ "Th17-mediated cartilage destruction via MMPs",
347
+ ],
348
+ responding_cell_types=["fibroblast", "macrophage", "CD4_T_cell"],
349
+ ),
350
+ "hepatocellular_carcinoma": DiseaseProfile(
351
+ name="hepatocellular_carcinoma",
352
+ display_name="hepatocellular carcinoma",
353
+ tissue="liver",
354
+ condition_name="HCC",
355
+ de_genes={
356
+ "GPC3": (2.0, 4.0), "AFP": (1.5, 3.5), "EPCAM": (1.0, 2.5),
357
+ "MYC": (1.0, 2.5), "VEGFA": (1.2, 2.8), "MKI67": (1.5, 3.0),
358
+ "ALB": (-2.0, -0.8), "CYP3A4": (-1.8, -0.6), "APOB": (-1.5, -0.5),
359
+ },
360
+ pathways={
361
+ "Wnt_signalling": (0.7, 0.9),
362
+ "cell_cycle": (0.75, 0.95),
363
+ "angiogenesis": (0.6, 0.8),
364
+ "PI3K_AKT_signalling": (0.65, 0.85),
365
+ "p53_signalling": (0.3, 0.5),
366
+ },
367
+ markers=["GPC3", "AFP", "VEGFA", "MKI67"],
368
+ mechanism_templates=[
369
+ "Wnt/beta-catenin-driven hepatocyte dedifferentiation",
370
+ "VEGF-mediated tumour angiogenesis",
371
+ ],
372
+ responding_cell_types=["kupffer_cell", "stellate_cell"],
373
+ hidden_failure_templates=[
374
+ "Tumour heterogeneity may confound DE in mixed biopsies",
375
+ ],
376
+ ),
377
+ "atherosclerosis": DiseaseProfile(
378
+ name="atherosclerosis",
379
+ display_name="atherosclerosis",
380
+ tissue="aorta",
381
+ condition_name="atherosclerosis",
382
+ de_genes={
383
+ "TREM2": (1.5, 3.0), "CD9": (1.0, 2.2), "LGALS3": (1.2, 2.5),
384
+ "APOE": (0.8, 2.0), "MMP9": (1.0, 2.5), "IL1B": (0.8, 2.0),
385
+ "ACTA2": (-1.5, -0.5), "MYH11": (-2.0, -0.8), "CNN1": (-1.5, -0.5),
386
+ },
387
+ pathways={
388
+ "lipid_metabolism": (0.7, 0.9),
389
+ "inflammatory_response": (0.65, 0.85),
390
+ "foam_cell_formation": (0.7, 0.9),
391
+ "smooth_muscle_contraction": (0.3, 0.5),
392
+ "complement_cascade": (0.5, 0.7),
393
+ },
394
+ markers=["TREM2", "LGALS3", "MMP9", "CD9"],
395
+ mechanism_templates=[
396
+ "TREM2+ macrophage-driven foam cell formation",
397
+ "smooth muscle cell phenotypic switching in plaque",
398
+ ],
399
+ responding_cell_types=["macrophage", "smooth_muscle"],
400
+ ),
401
+ "breast_cancer": DiseaseProfile(
402
+ name="breast_cancer",
403
+ display_name="breast cancer",
404
+ tissue="breast",
405
+ condition_name="breast_cancer",
406
+ de_genes={
407
+ "ERBB2": (1.5, 3.5), "ESR1": (-1.5, 1.5), "MKI67": (1.5, 3.0),
408
+ "MYC": (1.0, 2.5), "CDH1": (-1.5, -0.3), "VIM": (0.8, 2.2),
409
+ "CD274": (0.8, 2.0), "FOXP3": (0.6, 1.5), "GZMB": (0.8, 2.0),
410
+ },
411
+ pathways={
412
+ "cell_cycle": (0.7, 0.9),
413
+ "PI3K_AKT_signalling": (0.65, 0.85),
414
+ "EMT": (0.55, 0.8),
415
+ "immune_checkpoint": (0.5, 0.75),
416
+ "estrogen_signalling": (0.3, 0.7),
417
+ },
418
+ markers=["ERBB2", "MKI67", "CD274", "VIM"],
419
+ mechanism_templates=[
420
+ "ERBB2-driven proliferative signalling",
421
+ "immune evasion via PD-L1 upregulation",
422
+ ],
423
+ responding_cell_types=["macrophage", "T_cell", "fibroblast"],
424
+ ),
425
+ "multiple_sclerosis": DiseaseProfile(
426
+ name="multiple_sclerosis",
427
+ display_name="multiple sclerosis",
428
+ tissue="brain",
429
+ condition_name="multiple_sclerosis",
430
+ de_genes={
431
+ "CD68": (1.0, 2.5), "CXCL10": (1.2, 2.8), "STAT1": (0.8, 2.0),
432
+ "IRF1": (0.8, 1.8), "MBP": (-2.0, -0.8), "PLP1": (-1.8, -0.6),
433
+ "MOG": (-1.5, -0.5), "GFAP": (1.0, 2.5), "C3": (0.8, 2.0),
434
+ },
435
+ pathways={
436
+ "interferon_signalling": (0.7, 0.9),
437
+ "neuroinflammation": (0.7, 0.9),
438
+ "complement_cascade": (0.6, 0.8),
439
+ "myelination": (0.2, 0.4),
440
+ "T_cell_activation": (0.6, 0.8),
441
+ },
442
+ markers=["CXCL10", "STAT1", "GFAP", "C3"],
443
+ mechanism_templates=[
444
+ "interferon-driven microglial activation in demyelination",
445
+ "complement-mediated oligodendrocyte damage",
446
+ ],
447
+ responding_cell_types=["microglia", "astrocyte"],
448
+ ),
449
+ "diabetic_nephropathy": DiseaseProfile(
450
+ name="diabetic_nephropathy",
451
+ display_name="diabetic nephropathy",
452
+ tissue="kidney",
453
+ condition_name="diabetic_nephropathy",
454
+ de_genes={
455
+ "HAVCR1": (1.5, 3.0), "LCN2": (1.2, 2.8), "COL4A1": (1.0, 2.5),
456
+ "VEGFA": (0.8, 2.0), "NPHS1": (-1.8, -0.6), "NPHS2": (-1.5, -0.5),
457
+ "WT1": (-1.2, -0.4), "TGFB1": (1.0, 2.2), "FN1": (1.2, 2.5),
458
+ },
459
+ pathways={
460
+ "TGF_beta_signalling": (0.7, 0.9),
461
+ "extracellular_matrix_organisation": (0.7, 0.9),
462
+ "oxidative_stress": (0.6, 0.8),
463
+ "VEGF_signalling": (0.5, 0.7),
464
+ "apoptosis": (0.5, 0.7),
465
+ },
466
+ markers=["HAVCR1", "LCN2", "TGFB1", "FN1"],
467
+ mechanism_templates=[
468
+ "TGF-beta-driven glomerular fibrosis",
469
+ "podocyte loss via oxidative stress",
470
+ ],
471
+ responding_cell_types=["fibroblast", "macrophage"],
472
+ ),
473
+ "melanoma": DiseaseProfile(
474
+ name="melanoma",
475
+ display_name="melanoma",
476
+ tissue="skin",
477
+ condition_name="melanoma",
478
+ de_genes={
479
+ "MLANA": (1.5, 3.0), "PMEL": (1.2, 2.5), "SOX10": (1.0, 2.2),
480
+ "MKI67": (1.5, 3.0), "CD274": (0.8, 2.0), "PDCD1": (0.8, 2.0),
481
+ "GZMB": (0.8, 2.0), "HAVCR2": (0.6, 1.5), "LAG3": (0.6, 1.5),
482
+ },
483
+ pathways={
484
+ "MAPK_signalling": (0.7, 0.9),
485
+ "immune_checkpoint": (0.6, 0.85),
486
+ "cell_cycle": (0.7, 0.9),
487
+ "melanogenesis": (0.5, 0.7),
488
+ "T_cell_exhaustion": (0.55, 0.8),
489
+ },
490
+ markers=["MLANA", "CD274", "GZMB", "MKI67"],
491
+ mechanism_templates=[
492
+ "MAPK-driven melanocyte proliferation",
493
+ "T cell exhaustion via immune checkpoint upregulation",
494
+ ],
495
+ responding_cell_types=["T_cell", "macrophage"],
496
+ ),
497
+ }
498
+
499
+
500
+ # ── Pathway library ─────────────────────────────────────────────────────────
501
+
502
+ PATHWAY_LIBRARY: Dict[str, List[str]] = {
503
+ "TGF_beta_signalling": ["TGFB1", "TGFB2", "SMAD2", "SMAD3", "SMAD4", "ACVR1"],
504
+ "Wnt_signalling": ["WNT3A", "CTNNB1", "APC", "AXIN2", "LGR5", "TCF7L2"],
505
+ "MAPK_signalling": ["KRAS", "BRAF", "MAP2K1", "MAPK1", "MAPK3", "FOS", "JUN"],
506
+ "JAK_STAT_signalling": ["JAK1", "JAK2", "STAT1", "STAT3", "STAT5A", "SOCS1", "SOCS3"],
507
+ "PI3K_AKT_signalling": ["PIK3CA", "AKT1", "MTOR", "PTEN", "TSC2"],
508
+ "NF_kB_signalling": ["NFKB1", "RELA", "IKBKB", "TNF", "IL1B"],
509
+ "cell_cycle": ["CDK4", "CDK6", "CCND1", "CCNE1", "RB1", "E2F1", "MKI67"],
510
+ "apoptosis": ["BCL2", "BAX", "BAK1", "CASP3", "CASP9", "TP53", "BID"],
511
+ "inflammatory_response": ["TNF", "IL6", "IL1B", "CCL2", "CXCL8", "NFKB1"],
512
+ "extracellular_matrix_organisation": ["COL1A1", "COL3A1", "FN1", "POSTN", "MMP2", "MMP9", "TIMP1"],
513
+ "complement_cascade": ["C1QA", "C1QB", "C3", "C4A", "C5", "CFB"],
514
+ "neuroinflammation": ["TREM2", "CX3CR1", "P2RY12", "IL1B", "TNF", "C1QA"],
515
+ "synaptic_signalling": ["SLC17A7", "GRIA1", "GRIN1", "DLG4", "SNAP25", "SYP"],
516
+ "hematopoietic_cell_lineage": ["CD34", "KIT", "FLT3", "GATA1", "CEBPA", "SPI1"],
517
+ "insulin_signalling": ["INS", "INSR", "IRS1", "PIK3CA", "AKT1", "SLC2A4"],
518
+ "ER_stress_response": ["DDIT3", "ATF4", "XBP1", "HSPA5", "EIF2AK3"],
519
+ "oxidative_stress": ["SOD1", "SOD2", "CAT", "GPX1", "NFE2L2", "HMOX1"],
520
+ "angiogenesis": ["VEGFA", "VEGFB", "KDR", "FLT1", "ANGPT1", "ANGPT2"],
521
+ "EMT": ["CDH1", "CDH2", "VIM", "SNAI1", "SNAI2", "TWIST1", "ZEB1"],
522
+ "immune_checkpoint": ["CD274", "PDCD1", "CTLA4", "HAVCR2", "LAG3", "TIGIT"],
523
+ "T_cell_activation": ["CD3D", "CD28", "LCK", "ZAP70", "IL2", "IFNG"],
524
+ "T_cell_exhaustion": ["PDCD1", "HAVCR2", "LAG3", "TIGIT", "TOX", "ENTPD1"],
525
+ "TNF_signalling": ["TNF", "TNFRSF1A", "TRADD", "RIPK1", "NFKB1", "CASP8"],
526
+ "Th17_differentiation": ["IL17A", "IL17F", "RORC", "IL23R", "CCR6", "STAT3"],
527
+ "interferon_signalling": ["IFNG", "IFNB1", "STAT1", "IRF1", "IRF7", "MX1", "OAS1"],
528
+ "lipid_metabolism": ["APOE", "APOB", "LDLR", "HMGCR", "ABCA1", "PPARG"],
529
+ "myelination": ["MBP", "PLP1", "MOG", "MAG", "OLIG2", "SOX10"],
530
+ "foam_cell_formation": ["CD36", "MSR1", "ABCA1", "APOE", "LGALS3", "TREM2"],
531
+ "smooth_muscle_contraction": ["ACTA2", "MYH11", "TAGLN", "CNN1", "MYLK"],
532
+ "glucagon_signalling": ["GCG", "GCGR", "CREB1", "PCK1", "G6PC"],
533
+ "matrix_metalloproteinase_activity": ["MMP1", "MMP2", "MMP3", "MMP9", "TIMP1", "TIMP2"],
534
+ "estrogen_signalling": ["ESR1", "ESR2", "PGR", "GREB1", "TFF1"],
535
+ "melanogenesis": ["MITF", "TYR", "TYRP1", "DCT", "MLANA", "PMEL"],
536
+ "VEGF_signalling": ["VEGFA", "VEGFB", "KDR", "FLT1", "NRP1"],
537
+ }
538
+
539
+
540
+ # ── Regulatory network templates ────────────────────────────────────────────
541
+
542
+ REGULATORY_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
543
+ "erythroid": {
544
+ "GATA1": ["KLF1", "HBB", "HBA1", "GYPA", "ALAS2"],
545
+ "KLF1": ["HBB", "HBA1", "SLC4A1"],
546
+ },
547
+ "myeloid": {
548
+ "CEBPA": ["CSF3R", "ELANE", "MPO", "CTSG"],
549
+ "SPI1": ["CSF1R", "CD14", "FCGR3A", "CD68"],
550
+ },
551
+ "lymphoid": {
552
+ "TCF7": ["CD3D", "CD3E", "IL7R", "LEF1"],
553
+ "PAX5": ["CD19", "MS4A1", "CD79A"],
554
+ },
555
+ "fibrotic": {
556
+ "SMAD3": ["COL1A1", "COL3A1", "FN1", "POSTN"],
557
+ "TGFB1": ["ACTA2", "COL1A1", "CTGF"],
558
+ },
559
+ "inflammatory": {
560
+ "NFKB1": ["TNF", "IL6", "IL1B", "CCL2", "CXCL8"],
561
+ "STAT1": ["IRF1", "CXCL10", "MX1", "OAS1"],
562
+ },
563
+ "stem_cell": {
564
+ "RUNX1": ["CD34", "KIT", "FLT3"],
565
+ "MYC": ["CDK4", "CCND1", "E2F1"],
566
+ },
567
+ "neuronal": {
568
+ "NEUROD1": ["SLC17A7", "NRGN", "SNAP25"],
569
+ "DLX1": ["GAD1", "GAD2", "SLC32A1"],
570
+ },
571
+ }
572
+
573
+
574
+ # ── Perturbation templates ──────────────────────────────────────────────────
575
+
576
+ @dataclass
577
+ class PerturbationTemplate:
578
+ name: str
579
+ target_pathway: str
580
+ gene_effects: Dict[str, float]
581
+ description: str
582
+
583
+
584
+ PERTURBATION_TEMPLATES: Dict[str, PerturbationTemplate] = {
585
+ "JAK_inhibitor": PerturbationTemplate(
586
+ name="JAK_inhibitor",
587
+ target_pathway="JAK_STAT_signalling",
588
+ gene_effects={"STAT1": -0.8, "STAT3": -0.7, "IFNG": -1.5, "IL17A": -1.3, "SOCS1": 1.2},
589
+ description="JAK inhibitor treatment",
590
+ ),
591
+ "anti_TNF": PerturbationTemplate(
592
+ name="anti_TNF",
593
+ target_pathway="TNF_signalling",
594
+ gene_effects={"TNF": -1.5, "IL6": -1.0, "IL1B": -0.8, "MMP3": -1.2, "SOCS3": 0.8},
595
+ description="anti-TNF biologic therapy",
596
+ ),
597
+ "PD1_blockade": PerturbationTemplate(
598
+ name="PD1_blockade",
599
+ target_pathway="immune_checkpoint",
600
+ gene_effects={"PDCD1": -1.0, "GZMB": 1.5, "IFNG": 1.2, "PRF1": 1.0, "TNF": 0.8},
601
+ description="anti-PD-1 immune checkpoint blockade",
602
+ ),
603
+ "BRAF_inhibitor": PerturbationTemplate(
604
+ name="BRAF_inhibitor",
605
+ target_pathway="MAPK_signalling",
606
+ gene_effects={"BRAF": -0.5, "MAPK1": -1.0, "MKI67": -1.5, "CCND1": -1.2, "FOS": -0.8},
607
+ description="BRAF inhibitor treatment",
608
+ ),
609
+ "TGFb_inhibitor": PerturbationTemplate(
610
+ name="TGFb_inhibitor",
611
+ target_pathway="TGF_beta_signalling",
612
+ gene_effects={"TGFB1": -0.8, "COL1A1": -1.2, "COL3A1": -1.0, "POSTN": -1.5, "ACTA2": -0.8},
613
+ description="TGF-beta pathway inhibitor",
614
+ ),
615
+ "mTOR_inhibitor": PerturbationTemplate(
616
+ name="mTOR_inhibitor",
617
+ target_pathway="PI3K_AKT_signalling",
618
+ gene_effects={"MTOR": -0.8, "AKT1": -0.6, "MKI67": -1.2, "CCND1": -1.0, "HIF1A": -0.7},
619
+ description="mTOR inhibitor treatment",
620
+ ),
621
+ "CRISPR_TP53_KO": PerturbationTemplate(
622
+ name="CRISPR_TP53_KO",
623
+ target_pathway="p53_signalling",
624
+ gene_effects={"TP53": -2.0, "BAX": -1.0, "CDKN1A": -1.5, "MDM2": -0.8, "MKI67": 1.0},
625
+ description="CRISPR knockout of TP53",
626
+ ),
627
+ }
628
+
629
+
630
+ # ── Trajectory templates ────────────────────────────────────────────────────
631
+
632
+ @dataclass
633
+ class TrajectoryTemplate:
634
+ """Template for a developmental trajectory through cell populations."""
635
+ root_population: str
636
+ branches: List[List[str]]
637
+ n_lineages: int
638
+ tissue: str
639
+
640
+
641
+ TRAJECTORY_TEMPLATES: List[TrajectoryTemplate] = [
642
+ TrajectoryTemplate(
643
+ root_population="HSC",
644
+ branches=[
645
+ ["HSC", "CMP", "GMP", "neutrophil"],
646
+ ["HSC", "CMP", "GMP", "monocyte"],
647
+ ["HSC", "MEP", "erythrocyte"],
648
+ ["HSC", "MEP", "megakaryocyte"],
649
+ ],
650
+ n_lineages=3,
651
+ tissue="bone_marrow",
652
+ ),
653
+ TrajectoryTemplate(
654
+ root_population="double_negative_T",
655
+ branches=[
656
+ ["double_negative_T", "double_positive_T", "CD4_SP"],
657
+ ["double_negative_T", "double_positive_T", "CD8_SP"],
658
+ ],
659
+ n_lineages=2,
660
+ tissue="thymus",
661
+ ),
662
+ TrajectoryTemplate(
663
+ root_population="stem_cell",
664
+ branches=[
665
+ ["stem_cell", "colonocyte"],
666
+ ["stem_cell", "goblet_cell"],
667
+ ],
668
+ n_lineages=2,
669
+ tissue="colon",
670
+ ),
671
+ TrajectoryTemplate(
672
+ root_population="OPC",
673
+ branches=[
674
+ ["OPC", "oligodendrocyte"],
675
+ ],
676
+ n_lineages=1,
677
+ tissue="brain",
678
+ ),
679
+ ]
680
+
681
+
682
+ # ── Hidden failure condition templates ──────────────────────────────────────
683
+
684
+ HIDDEN_FAILURE_TEMPLATES: List[str] = [
685
+ "High ambient RNA may confound DE in low-abundance transcripts",
686
+ "Strong batch effects between conditions may inflate false positives",
687
+ "Low cell viability in disease samples reduces statistical power",
688
+ "Doublet contamination in dense populations obscures rare cell types",
689
+ "Sample degradation during processing introduces 3' bias artefacts",
690
+ "Dissociation-induced gene expression changes confound stress signatures",
691
+ "Unbalanced sample sizes between conditions reduce DE sensitivity",
692
+ ]
server/tasks/generator.py CHANGED
@@ -12,7 +12,7 @@ from typing import List, Optional, Tuple
12
 
13
  import numpy as np
14
 
15
- from models import TaskSpec
16
 
17
  from server.simulator.latent_state import (
18
  CellPopulation,
@@ -24,6 +24,7 @@ from server.simulator.latent_state import (
24
  TechnicalState,
25
  )
26
  from .scenarios import SCENARIO_LIBRARY, Scenario
 
27
 
28
 
29
  class TaskGenerator:
@@ -34,7 +35,10 @@ class TaskGenerator:
34
  scenarios: Optional[List[Scenario]] = None,
35
  domain_randomise: bool = True,
36
  ):
37
- self.scenarios = scenarios or SCENARIO_LIBRARY
 
 
 
38
  self.domain_randomise = domain_randomise
39
 
40
  def generate(
@@ -58,6 +62,14 @@ class TaskGenerator:
58
  if self.domain_randomise:
59
  self._randomise(rng, task, biology, technical)
60
 
 
 
 
 
 
 
 
 
61
  latent = FullLatentState(
62
  biology=biology,
63
  technical=technical,
@@ -67,6 +79,7 @@ class TaskGenerator:
67
  time_limit_days=task.time_limit_days,
68
  ),
69
  hidden_failure_conditions=list(scenario.hidden_failure_conditions),
 
70
  rng_seed=seed or 0,
71
  )
72
  return task, latent
 
12
 
13
  import numpy as np
14
 
15
+ from models import TaskSpec, tools_for_modality, assays_for_modality
16
 
17
  from server.simulator.latent_state import (
18
  CellPopulation,
 
24
  TechnicalState,
25
  )
26
  from .scenarios import SCENARIO_LIBRARY, Scenario
27
+ from .procedural_generator import generate_procedural_scenarios
28
 
29
 
30
  class TaskGenerator:
 
35
  scenarios: Optional[List[Scenario]] = None,
36
  domain_randomise: bool = True,
37
  ):
38
+ if scenarios is not None:
39
+ self.scenarios = scenarios
40
+ else:
41
+ self.scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(n=20, seed=42)
42
  self.domain_randomise = domain_randomise
43
 
44
  def generate(
 
62
  if self.domain_randomise:
63
  self._randomise(rng, task, biology, technical)
64
 
65
+ # Filter available tools/assays to those compatible with the modality.
66
+ compatible_tools = [t.name for t in tools_for_modality(task.modality)]
67
+ compatible_assays = [a.name for a in assays_for_modality(task.modality)]
68
+ if compatible_tools:
69
+ task.available_tools = compatible_tools
70
+ if compatible_assays:
71
+ task.available_assays = compatible_assays
72
+
73
  latent = FullLatentState(
74
  biology=biology,
75
  technical=technical,
 
79
  time_limit_days=task.time_limit_days,
80
  ),
81
  hidden_failure_conditions=list(scenario.hidden_failure_conditions),
82
+ task_modality=task.modality,
83
  rng_seed=seed or 0,
84
  )
85
  return task, latent
server/tasks/procedural_generator.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Procedural scenario generator.
2
+
3
+ Composes biologically coherent ``Scenario`` objects from the curated
4
+ palette in ``bio_palette``, producing fully populated
5
+ ``LatentBiologicalState`` instances that drive every simulator tool
6
+ (clustering, DE, pathway enrichment, trajectory, regulatory networks,
7
+ marker validation) with realistic intermediate outputs.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ import numpy as np
16
+
17
+ from models import TaskSpec
18
+
19
+ from server.simulator.latent_state import (
20
+ CellPopulation,
21
+ LatentBiologicalState,
22
+ TechnicalState,
23
+ )
24
+
25
+ from .bio_palette import (
26
+ DISEASE_PROFILES,
27
+ HIDDEN_FAILURE_TEMPLATES,
28
+ PATHWAY_LIBRARY,
29
+ PERTURBATION_TEMPLATES,
30
+ REGULATORY_TEMPLATES,
31
+ TISSUE_CELL_TYPES,
32
+ TRAJECTORY_TEMPLATES,
33
+ CellTypeTemplate,
34
+ DiseaseProfile,
35
+ )
36
+ from .scenarios import Scenario
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ SCENARIO_TYPES = ("de", "trajectory", "perturbation", "biomarker")
41
+
42
+ _DIFFICULTY_PARAMS = {
43
+ "easy": {
44
+ "n_pops": (4, 5),
45
+ "de_scale": (1.2, 1.6),
46
+ "noise_dropout": (0.05, 0.10),
47
+ "noise_doublet": (0.03, 0.06),
48
+ "noise_ambient": (0.02, 0.05),
49
+ "noise_batch_strength": (0.05, 0.12),
50
+ "n_batches": (1, 2),
51
+ "budget_range": (70_000, 100_000),
52
+ "time_range": (100, 150),
53
+ "sample_quality": (0.85, 0.95),
54
+ "include_trajectory": False,
55
+ "include_perturbation": False,
56
+ "include_network": False,
57
+ "include_failure_conditions": False,
58
+ },
59
+ "medium": {
60
+ "n_pops": (5, 7),
61
+ "de_scale": (0.9, 1.3),
62
+ "noise_dropout": (0.08, 0.14),
63
+ "noise_doublet": (0.04, 0.08),
64
+ "noise_ambient": (0.03, 0.07),
65
+ "noise_batch_strength": (0.08, 0.18),
66
+ "n_batches": (1, 3),
67
+ "budget_range": (80_000, 120_000),
68
+ "time_range": (120, 180),
69
+ "sample_quality": (0.78, 0.92),
70
+ "include_trajectory": True,
71
+ "include_perturbation": False,
72
+ "include_network": True,
73
+ "include_failure_conditions": False,
74
+ },
75
+ "hard": {
76
+ "n_pops": (6, 8),
77
+ "de_scale": (0.6, 1.0),
78
+ "noise_dropout": (0.10, 0.20),
79
+ "noise_doublet": (0.06, 0.12),
80
+ "noise_ambient": (0.05, 0.10),
81
+ "noise_batch_strength": (0.12, 0.25),
82
+ "n_batches": (2, 4),
83
+ "budget_range": (90_000, 140_000),
84
+ "time_range": (140, 200),
85
+ "sample_quality": (0.65, 0.85),
86
+ "include_trajectory": True,
87
+ "include_perturbation": True,
88
+ "include_network": True,
89
+ "include_failure_conditions": True,
90
+ },
91
+ }
92
+
93
+
94
+ def generate_scenario(
95
+ seed: int,
96
+ difficulty: str = "medium",
97
+ scenario_type: Optional[str] = None,
98
+ ) -> Scenario:
99
+ """Generate a single procedural scenario with complete latent state.
100
+
101
+ Parameters
102
+ ----------
103
+ seed
104
+ RNG seed for reproducibility.
105
+ difficulty
106
+ One of ``"easy"``, ``"medium"``, ``"hard"``.
107
+ scenario_type
108
+ One of ``"de"``, ``"trajectory"``, ``"perturbation"``,
109
+ ``"biomarker"``, or ``None`` for random selection.
110
+ """
111
+ rng = np.random.default_rng(seed)
112
+ params = _DIFFICULTY_PARAMS[difficulty]
113
+
114
+ if scenario_type is None:
115
+ scenario_type = rng.choice(SCENARIO_TYPES)
116
+
117
+ disease_key = rng.choice(list(DISEASE_PROFILES.keys()))
118
+ disease = DISEASE_PROFILES[disease_key]
119
+ tissue = disease.tissue
120
+
121
+ cell_templates = TISSUE_CELL_TYPES.get(tissue, [])
122
+ if not cell_templates:
123
+ tissue = rng.choice(list(TISSUE_CELL_TYPES.keys()))
124
+ cell_templates = TISSUE_CELL_TYPES[tissue]
125
+
126
+ populations = _sample_populations(rng, cell_templates, disease, params)
127
+ de_genes = _build_de_genes(rng, disease, params)
128
+ pathways = _build_pathways(rng, disease)
129
+ markers = _derive_markers(rng, de_genes, disease)
130
+ mechanisms = list(disease.mechanism_templates)
131
+ n_cells = int(rng.integers(8_000, 22_000))
132
+
133
+ trajectory = None
134
+ if scenario_type == "trajectory" or (
135
+ params["include_trajectory"] and rng.random() < 0.4
136
+ ):
137
+ trajectory = _build_trajectory(rng, tissue, populations)
138
+
139
+ reg_network: Dict[str, List[str]] = {}
140
+ if scenario_type == "trajectory" or (
141
+ params["include_network"] and rng.random() < 0.5
142
+ ):
143
+ reg_network = _build_regulatory_network(rng, tissue, populations)
144
+
145
+ perturbation_effects: Dict[str, Dict[str, float]] = {}
146
+ if scenario_type == "perturbation" or (
147
+ params["include_perturbation"] and rng.random() < 0.5
148
+ ):
149
+ perturbation_effects = _build_perturbation(rng, disease)
150
+
151
+ technical = _build_technical(rng, params)
152
+
153
+ hidden_failures: List[str] = []
154
+ if params["include_failure_conditions"] and rng.random() < 0.6:
155
+ n_failures = int(rng.integers(1, 3))
156
+ indices = rng.choice(
157
+ len(HIDDEN_FAILURE_TEMPLATES), size=min(n_failures, len(HIDDEN_FAILURE_TEMPLATES)), replace=False,
158
+ )
159
+ hidden_failures = [HIDDEN_FAILURE_TEMPLATES[i] for i in indices]
160
+
161
+ task = _build_task(rng, disease, tissue, scenario_type, params, perturbation_effects)
162
+
163
+ biology = LatentBiologicalState(
164
+ cell_populations=populations,
165
+ true_de_genes=de_genes,
166
+ true_pathways=pathways,
167
+ true_trajectory=trajectory,
168
+ true_regulatory_network=reg_network,
169
+ perturbation_effects=perturbation_effects,
170
+ true_markers=markers,
171
+ causal_mechanisms=mechanisms,
172
+ n_true_cells=n_cells,
173
+ )
174
+
175
+ name = f"proc_{disease.name}_{scenario_type}_{seed}"
176
+
177
+ tags = [scenario_type, "scRNA-seq", tissue, disease.name, difficulty]
178
+
179
+ return Scenario(
180
+ name=name,
181
+ task=task,
182
+ biology=biology,
183
+ technical=technical,
184
+ hidden_failure_conditions=hidden_failures,
185
+ difficulty=difficulty,
186
+ tags=tags,
187
+ )
188
+
189
+
190
+ def generate_procedural_scenarios(
191
+ n: int = 20,
192
+ seed: int = 42,
193
+ ) -> List[Scenario]:
194
+ """Pre-generate a pool of procedural scenarios across difficulties."""
195
+ rng = np.random.default_rng(seed)
196
+ scenarios: List[Scenario] = []
197
+ difficulties = ["easy", "medium", "hard"]
198
+
199
+ for i in range(n):
200
+ diff = difficulties[i % len(difficulties)]
201
+ child_seed = int(rng.integers(0, 2**31))
202
+ scenario = generate_scenario(
203
+ seed=child_seed,
204
+ difficulty=diff,
205
+ scenario_type=None,
206
+ )
207
+ scenarios.append(scenario)
208
+
209
+ logger.info("Generated %d procedural scenarios.", len(scenarios))
210
+ return scenarios
211
+
212
+
213
+ # ── Internal builders ───────────────────────────────────────────────────────
214
+
215
+
216
+ def _sample_populations(
217
+ rng: np.random.Generator,
218
+ templates: List[CellTypeTemplate],
219
+ disease: DiseaseProfile,
220
+ params: dict,
221
+ ) -> List[CellPopulation]:
222
+ lo, hi = params["n_pops"]
223
+ n_pops = int(rng.integers(lo, hi + 1))
224
+ n_pops = min(n_pops, len(templates))
225
+
226
+ indices = rng.choice(len(templates), size=n_pops, replace=False)
227
+ selected = [templates[i] for i in sorted(indices)]
228
+
229
+ responding_names = set(disease.responding_cell_types)
230
+
231
+ populations: List[CellPopulation] = []
232
+ for tmpl in selected:
233
+ prop = float(rng.uniform(*tmpl.proportion_range))
234
+ state = rng.choice(tmpl.states)
235
+
236
+ condition_response: Dict[str, float] = {}
237
+ if tmpl.disease_responsive and tmpl.name in responding_names:
238
+ condition_response[disease.condition_name] = float(
239
+ rng.uniform(*tmpl.response_range)
240
+ )
241
+
242
+ populations.append(CellPopulation(
243
+ name=tmpl.name,
244
+ proportion=prop,
245
+ marker_genes=list(tmpl.marker_genes),
246
+ state=state,
247
+ condition_response=condition_response,
248
+ ))
249
+
250
+ total = sum(p.proportion for p in populations)
251
+ if total > 0:
252
+ for p in populations:
253
+ p.proportion = round(p.proportion / total, 4)
254
+
255
+ return populations
256
+
257
+
258
+ def _build_de_genes(
259
+ rng: np.random.Generator,
260
+ disease: DiseaseProfile,
261
+ params: dict,
262
+ ) -> Dict[str, Dict[str, float]]:
263
+ comparison = f"{disease.condition_name}_vs_healthy"
264
+ scale_lo, scale_hi = params["de_scale"]
265
+
266
+ effects: Dict[str, float] = {}
267
+ for gene, (lo, hi) in disease.de_genes.items():
268
+ base = float(rng.uniform(lo, hi))
269
+ scale = float(rng.uniform(scale_lo, scale_hi))
270
+ if base > 0:
271
+ effects[gene] = round(base * scale, 3)
272
+ else:
273
+ effects[gene] = round(base * scale, 3)
274
+
275
+ return {comparison: effects}
276
+
277
+
278
+ def _build_pathways(
279
+ rng: np.random.Generator,
280
+ disease: DiseaseProfile,
281
+ ) -> Dict[str, float]:
282
+ pathways: Dict[str, float] = {}
283
+ for pw, (lo, hi) in disease.pathways.items():
284
+ pathways[pw] = round(float(rng.uniform(lo, hi)), 3)
285
+ return pathways
286
+
287
+
288
+ def _derive_markers(
289
+ rng: np.random.Generator,
290
+ de_genes: Dict[str, Dict[str, float]],
291
+ disease: DiseaseProfile,
292
+ ) -> List[str]:
293
+ markers = list(disease.markers)
294
+
295
+ all_effects: Dict[str, float] = {}
296
+ for effects in de_genes.values():
297
+ all_effects.update(effects)
298
+
299
+ for gene in markers:
300
+ if gene not in all_effects:
301
+ all_effects[gene] = float(rng.uniform(1.0, 2.5))
302
+ for comp_effects in de_genes.values():
303
+ comp_effects[gene] = all_effects[gene]
304
+
305
+ n_markers = min(len(markers), int(rng.integers(3, 7)))
306
+ return markers[:n_markers]
307
+
308
+
309
+ def _build_trajectory(
310
+ rng: np.random.Generator,
311
+ tissue: str,
312
+ populations: List[CellPopulation],
313
+ ) -> Optional[Dict[str, Any]]:
314
+ pop_names = {p.name for p in populations}
315
+
316
+ for tmpl in TRAJECTORY_TEMPLATES:
317
+ if tmpl.tissue == tissue:
318
+ valid_branches = [
319
+ branch for branch in tmpl.branches
320
+ if all(node in pop_names for node in branch)
321
+ ]
322
+ if valid_branches:
323
+ return {
324
+ "root": tmpl.root_population,
325
+ "n_lineages": len(valid_branches),
326
+ "branching": len(valid_branches) > 1,
327
+ "branches": valid_branches,
328
+ }
329
+
330
+ if len(populations) >= 3:
331
+ root = populations[0].name
332
+ branches = [[root, p.name] for p in populations[1:]]
333
+ selected = branches[:int(rng.integers(2, min(4, len(branches)) + 1))]
334
+ return {
335
+ "root": root,
336
+ "n_lineages": len(selected),
337
+ "branching": len(selected) > 1,
338
+ "branches": selected,
339
+ }
340
+
341
+ return None
342
+
343
+
344
+ def _build_regulatory_network(
345
+ rng: np.random.Generator,
346
+ tissue: str,
347
+ populations: List[CellPopulation],
348
+ ) -> Dict[str, List[str]]:
349
+ all_genes = set()
350
+ for p in populations:
351
+ all_genes.update(p.marker_genes)
352
+
353
+ network: Dict[str, List[str]] = {}
354
+
355
+ tissue_to_programs = {
356
+ "bone_marrow": ["erythroid", "myeloid", "stem_cell"],
357
+ "thymus": ["lymphoid"],
358
+ "blood": ["lymphoid", "myeloid"],
359
+ "spleen": ["lymphoid"],
360
+ "brain": ["neuronal", "inflammatory"],
361
+ "heart": ["fibrotic", "inflammatory"],
362
+ "lung": ["fibrotic", "inflammatory"],
363
+ "liver": ["fibrotic", "inflammatory"],
364
+ "kidney": ["fibrotic", "inflammatory"],
365
+ "colon": ["inflammatory", "stem_cell"],
366
+ "pancreas": ["inflammatory"],
367
+ "skin": ["inflammatory"],
368
+ "breast": ["inflammatory"],
369
+ "synovium": ["inflammatory", "lymphoid"],
370
+ "aorta": ["inflammatory"],
371
+ }
372
+
373
+ programs = tissue_to_programs.get(tissue, ["inflammatory"])
374
+ for prog_name in programs:
375
+ prog = REGULATORY_TEMPLATES.get(prog_name, {})
376
+ for tf, targets in prog.items():
377
+ network[tf] = list(targets)
378
+
379
+ if not network:
380
+ for p in populations[:2]:
381
+ if len(p.marker_genes) >= 2:
382
+ tf = p.marker_genes[0]
383
+ network[tf] = p.marker_genes[1:]
384
+
385
+ return network
386
+
387
+
388
+ def _build_perturbation(
389
+ rng: np.random.Generator,
390
+ disease: DiseaseProfile,
391
+ ) -> Dict[str, Dict[str, float]]:
392
+ disease_pathways = set(disease.pathways.keys())
393
+
394
+ matching = [
395
+ (name, tmpl) for name, tmpl in PERTURBATION_TEMPLATES.items()
396
+ if tmpl.target_pathway in disease_pathways
397
+ ]
398
+
399
+ if matching:
400
+ name, tmpl = matching[int(rng.integers(0, len(matching)))]
401
+ else:
402
+ name = rng.choice(list(PERTURBATION_TEMPLATES.keys()))
403
+ tmpl = PERTURBATION_TEMPLATES[name]
404
+
405
+ scaled: Dict[str, float] = {}
406
+ for gene, effect in tmpl.gene_effects.items():
407
+ scale = float(rng.uniform(0.7, 1.3))
408
+ scaled[gene] = round(effect * scale, 3)
409
+
410
+ return {name: scaled}
411
+
412
+
413
+ def _build_technical(
414
+ rng: np.random.Generator,
415
+ params: dict,
416
+ ) -> TechnicalState:
417
+ n_batches = int(rng.integers(*params["n_batches"]))
418
+ batch_effects: Dict[str, float] = {}
419
+ for i in range(max(1, n_batches)):
420
+ strength = float(rng.uniform(*params["noise_batch_strength"]))
421
+ batch_effects[f"batch_{i}"] = round(strength, 3)
422
+
423
+ return TechnicalState(
424
+ batch_effects=batch_effects,
425
+ dropout_rate=round(float(rng.uniform(*params["noise_dropout"])), 3),
426
+ doublet_rate=round(float(rng.uniform(*params["noise_doublet"])), 3),
427
+ ambient_rna_fraction=round(float(rng.uniform(*params["noise_ambient"])), 3),
428
+ sample_quality=round(float(rng.uniform(*params["sample_quality"])), 3),
429
+ )
430
+
431
+
432
+ def _build_task(
433
+ rng: np.random.Generator,
434
+ disease: DiseaseProfile,
435
+ tissue: str,
436
+ scenario_type: str,
437
+ params: dict,
438
+ perturbation_effects: Dict[str, Dict[str, float]],
439
+ ) -> TaskSpec:
440
+ budget = float(rng.integers(*params["budget_range"]))
441
+ time_days = float(rng.integers(*params["time_range"]))
442
+
443
+ if scenario_type == "de":
444
+ problem = (
445
+ f"Identify differentially expressed genes between "
446
+ f"{disease.display_name} and healthy {tissue} tissue "
447
+ f"using single-cell RNA sequencing."
448
+ )
449
+ criteria = [
450
+ f"Identify DE genes between {disease.condition_name} and healthy",
451
+ "Validate at least one candidate marker",
452
+ ]
453
+ elif scenario_type == "trajectory":
454
+ problem = (
455
+ f"Infer the developmental trajectory of cell populations "
456
+ f"in {tissue} tissue in the context of {disease.display_name}."
457
+ )
458
+ criteria = [
459
+ "Reconstruct branching lineage structure",
460
+ "Identify key transcription factors driving fate decisions",
461
+ ]
462
+ elif scenario_type == "perturbation":
463
+ pert_name = next(iter(perturbation_effects), "treatment")
464
+ pert_tmpl = PERTURBATION_TEMPLATES.get(pert_name)
465
+ pert_desc = pert_tmpl.description if pert_tmpl else pert_name
466
+ problem = (
467
+ f"Determine the effect of {pert_desc} on cell states "
468
+ f"in {tissue} tissue affected by {disease.display_name}."
469
+ )
470
+ criteria = [
471
+ "Quantify shift in cell activation states",
472
+ f"Identify pathways modulated by {pert_name}",
473
+ "Propose validation strategy",
474
+ ]
475
+ else:
476
+ top_marker = disease.markers[0] if disease.markers else "candidate"
477
+ problem = (
478
+ f"Validate candidate biomarker {top_marker} for "
479
+ f"{disease.display_name} in {tissue} tissue using "
480
+ f"single-cell RNA sequencing."
481
+ )
482
+ criteria = [
483
+ f"Validate {top_marker} as a disease marker",
484
+ "Confirm expression specificity across cell types",
485
+ ]
486
+
487
+ conditions = ["healthy", disease.condition_name]
488
+ if scenario_type == "perturbation" and perturbation_effects:
489
+ pert_name = next(iter(perturbation_effects))
490
+ conditions = [f"untreated_{disease.condition_name}", f"{pert_name}_treated"]
491
+
492
+ return TaskSpec(
493
+ problem_statement=problem,
494
+ modality="scRNA-seq",
495
+ organism="human",
496
+ tissue=tissue,
497
+ conditions=conditions,
498
+ budget_limit=budget,
499
+ time_limit_days=time_days,
500
+ success_criteria=criteria,
501
+ )
server/tasks/scenarios.py CHANGED
@@ -353,8 +353,8 @@ SCENARIO_LIBRARY: List[Scenario] = [
353
  budget_limit=90_000.0,
354
  time_limit_days=150.0,
355
  prior_observations=[
356
- "SPP1 identified as top DE gene in prior pilot study",
357
- "SPP1+ macrophages enriched in fibrotic regions",
358
  ],
359
  success_criteria=[
360
  "Validate SPP1 as a marker for pro-fibrotic macrophages",
@@ -452,3 +452,5 @@ SCENARIO_LIBRARY: List[Scenario] = [
452
  ),
453
  ),
454
  ]
 
 
 
353
  budget_limit=90_000.0,
354
  time_limit_days=150.0,
355
  prior_observations=[
356
+ "A macrophage subpopulation shows elevated expression in IPF tissue relative to controls",
357
+ "Pro-fibrotic macrophage enrichment has been observed in fibrotic regions by spatial profiling",
358
  ],
359
  success_criteria=[
360
  "Validate SPP1 as a marker for pro-fibrotic macrophages",
 
452
  ),
453
  ),
454
  ]
455
+
456
+
tests/test_environment.py CHANGED
@@ -56,6 +56,17 @@ class TestEnvironmentLifecycle:
56
  assert obs.latest_output is not None
57
  assert obs.latest_output.success is False
58
 
 
 
 
 
 
 
 
 
 
 
 
59
  def test_conclusion_ends_episode(self):
60
  env = BioExperimentEnvironment()
61
  env.reset()
 
56
  assert obs.latest_output is not None
57
  assert obs.latest_output.success is False
58
 
59
+ def test_premature_followup_design_is_flagged(self):
60
+ env = BioExperimentEnvironment()
61
+ env.reset()
62
+ obs = env.step(ExperimentAction(
63
+ action_type=ActionType.DESIGN_FOLLOWUP,
64
+ parameters={"assay": "qPCR"},
65
+ ))
66
+ assert obs.latest_output is not None
67
+ assert obs.latest_output.success is True
68
+ assert any("follow-up design" in msg.lower() for msg in obs.rule_violations)
69
+
70
  def test_conclusion_ends_episode(self):
71
  env = BioExperimentEnvironment()
72
  env.reset()
tests/test_rewards.py CHANGED
@@ -61,6 +61,29 @@ class TestStepReward:
61
  )
62
  assert rb.total < 0
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  class TestTerminalReward:
66
  def test_correct_conclusion_rewarded(self):
@@ -103,3 +126,42 @@ class TestTerminalReward:
103
  ]
104
  rb = rc.terminal_reward(state, claims, [])
105
  assert rb.components.get("overconfidence_penalty", 0) < 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
  assert rb.total < 0
63
 
64
+ def test_premature_meta_action_gets_penalized(self):
65
+ rc = RewardComputer()
66
+ prev, nxt = _states(
67
+ prev_flags={"data_normalized": True},
68
+ next_flags={"followup_designed": True},
69
+ budget_used=2_000,
70
+ )
71
+ output = IntermediateOutput(
72
+ output_type=OutputType.FOLLOWUP_DESIGN,
73
+ step_index=2,
74
+ quality_score=1.0,
75
+ uncertainty=0.0,
76
+ )
77
+ rb = rc.step_reward(
78
+ ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
79
+ prev,
80
+ nxt,
81
+ output,
82
+ [],
83
+ [],
84
+ )
85
+ assert rb.components.get("premature_meta_action_penalty", 0.0) < 0.0
86
+
87
 
88
  class TestTerminalReward:
89
  def test_correct_conclusion_rewarded(self):
 
126
  ]
127
  rb = rc.terminal_reward(state, claims, [])
128
  assert rb.components.get("overconfidence_penalty", 0) < 0
129
+
130
+ def test_discovery_error_penalizes_wrong_markers_and_mechanisms(self):
131
+ rc = RewardComputer()
132
+ state = FullLatentState(
133
+ biology=LatentBiologicalState(
134
+ true_markers=["NPPA", "NPPB"],
135
+ causal_mechanisms=["TGF-beta-driven fibrosis"],
136
+ ),
137
+ progress=ExperimentProgress(
138
+ samples_collected=True,
139
+ cells_sequenced=True,
140
+ qc_performed=True,
141
+ data_filtered=True,
142
+ data_normalized=True,
143
+ de_performed=True,
144
+ markers_discovered=True,
145
+ conclusion_reached=True,
146
+ ),
147
+ resources=ResourceState(budget_total=100_000, budget_used=40_000),
148
+ )
149
+
150
+ aligned = rc.terminal_reward(
151
+ state,
152
+ [],
153
+ [],
154
+ discovered_markers=["NPPA", "NPPB"],
155
+ candidate_mechanisms=["TGF-beta-driven fibrosis"],
156
+ )
157
+ misaligned = rc.terminal_reward(
158
+ state,
159
+ [],
160
+ [],
161
+ discovered_markers=["WRONG1", "WRONG2"],
162
+ candidate_mechanisms=["unrelated inflammatory process"],
163
+ )
164
+
165
+ assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
166
+ assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
167
+ assert aligned.terminal > misaligned.terminal
tests/test_rules.py CHANGED
@@ -66,6 +66,36 @@ class TestRedundancy:
66
  assert not hard
67
  assert any("redundant" in m.lower() for m in soft)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  class TestResourceConstraints:
71
  def test_exhausted_budget_blocked(self):
 
66
  assert not hard
67
  assert any("redundant" in m.lower() for m in soft)
68
 
69
+ def test_repeated_followup_design_is_soft(self):
70
+ engine = RuleEngine()
71
+ violations = engine.check(
72
+ ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
73
+ _state(followup_designed=True, de_performed=True),
74
+ )
75
+ hard = engine.hard_violations(violations)
76
+ soft = engine.soft_violations(violations)
77
+ assert not hard
78
+ assert any("redundant" in m.lower() for m in soft)
79
+
80
+
81
+ class TestMetaActionTiming:
82
+ def test_followup_design_without_analysis_is_soft(self):
83
+ engine = RuleEngine()
84
+ violations = engine.check(
85
+ ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
86
+ _state(),
87
+ )
88
+ soft = engine.soft_violations(violations)
89
+ assert any("follow-up design" in m.lower() for m in soft)
90
+
91
+ def test_subagent_review_without_analysis_is_soft(self):
92
+ engine = RuleEngine()
93
+ violations = engine.check(
94
+ ExperimentAction(action_type=ActionType.REQUEST_SUBAGENT_REVIEW),
95
+ _state(),
96
+ )
97
+ soft = engine.soft_violations(violations)
98
+ assert any("subagent review" in m.lower() for m in soft)
99
 
100
  class TestResourceConstraints:
101
  def test_exhausted_budget_blocked(self):
tests/test_run_agent.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for run_agent parser and fallback helpers."""
2
+
3
+ from models import ActionType, ExperimentAction
4
+ from run_agent import fallback_action, parse_action
5
+ from server.hackathon_environment import BioExperimentEnvironment
6
+
7
+
8
+ def test_parse_action_accepts_reasoning_variant():
9
+ action = parse_action(
10
+ '{"action_type":"run_qc","parameters":{},"Reasoning":"check quality","confidence":0.8}'
11
+ )
12
+ assert action is not None
13
+ assert action.action_type == ActionType.RUN_QC
14
+ assert action.justification == "check quality"
15
+
16
+
17
+ def test_parse_action_accepts_justifyement_typo():
18
+ action = parse_action(
19
+ '{"action_type":"collect_sample","parameters":{},"justifyement":"typo key","confidence":0.7}'
20
+ )
21
+ assert action is not None
22
+ assert action.action_type == ActionType.COLLECT_SAMPLE
23
+ assert action.justification == "typo key"
24
+
25
+
26
+ def test_fallback_uses_observation_progress_not_step_index():
27
+ env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
28
+ obs = env.reset(seed=0)
29
+ for action_type in (
30
+ ActionType.COLLECT_SAMPLE,
31
+ ActionType.PREPARE_LIBRARY,
32
+ ActionType.SEQUENCE_CELLS,
33
+ ):
34
+ obs = env.step(ExperimentAction(action_type=action_type))
35
+ action = fallback_action(obs)
36
+ assert action.action_type == ActionType.RUN_QC
tests/test_training_script.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for GRPO training helpers."""
2
+
3
+ from pathlib import Path
4
+
5
+ from models import ActionType
6
+ from training_script import (
7
+ INVALID_ACTION_PENALTY,
8
+ OpenEnvReward,
9
+ available_numeric_log_keys,
10
+ build_prompt_examples,
11
+ completion_to_text,
12
+ parse_action_completion,
13
+ save_training_plots,
14
+ select_metric_key,
15
+ select_reward_key,
16
+ )
17
+
18
+
19
+ def test_completion_to_text_from_chat_messages():
20
+ completion = [
21
+ {"role": "assistant", "content": '{"action_type":"collect_sample"}'}
22
+ ]
23
+ assert completion_to_text(completion) == '{"action_type":"collect_sample"}'
24
+
25
+
26
+ def test_parse_action_completion_roundtrip():
27
+ action = parse_action_completion(
28
+ '{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",'
29
+ '"parameters":{"min_genes":200},"confidence":0.8}'
30
+ )
31
+ assert action is not None
32
+ assert action.action_type == ActionType.RUN_QC
33
+ assert action.method == "scanpy.pp.calculate_qc_metrics"
34
+ assert action.parameters["min_genes"] == 200
35
+ assert action.confidence == 0.8
36
+
37
+
38
+ def test_parse_action_completion_accepts_reasoning_alias():
39
+ action = parse_action_completion(
40
+ '{"action_type":"run_qc","reasoning":"Measure quality before filtering."}'
41
+ )
42
+ assert action is not None
43
+ assert action.justification == "Measure quality before filtering."
44
+
45
+
46
+ def test_build_prompt_examples_contains_reference_action():
47
+ examples = build_prompt_examples(
48
+ dataset_episodes=1,
49
+ rollout_steps=2,
50
+ collection_policy="heuristic",
51
+ scenario_names=["cardiac_disease_de"],
52
+ seed=0,
53
+ domain_randomise=False,
54
+ )
55
+ assert len(examples) == 2
56
+ assert examples[0]["scenario_name"] == "cardiac_disease_de"
57
+ assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
58
+
59
+
60
+ def test_openenv_reward_penalizes_invalid_completion():
61
+ reward_fn = OpenEnvReward(
62
+ reward_backend="local",
63
+ base_url="http://localhost:8000",
64
+ )
65
+ rewards = reward_fn(
66
+ completions=[[{"role": "assistant", "content": "not valid json"}]],
67
+ scenario_name=["cardiac_disease_de"],
68
+ history_actions=["[]"],
69
+ )
70
+ assert rewards == [INVALID_ACTION_PENALTY]
71
+
72
+
73
+ def test_openenv_reward_scores_valid_completion_locally():
74
+ examples = build_prompt_examples(
75
+ dataset_episodes=1,
76
+ rollout_steps=1,
77
+ collection_policy="heuristic",
78
+ scenario_names=["cardiac_disease_de"],
79
+ seed=0,
80
+ domain_randomise=False,
81
+ )
82
+ reward_fn = OpenEnvReward(
83
+ reward_backend="local",
84
+ base_url="http://localhost:8000",
85
+ )
86
+ sample = examples[0]
87
+ rewards = reward_fn(
88
+ completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
89
+ scenario_name=[sample["scenario_name"]],
90
+ history_actions=[sample["history_actions"]],
91
+ )
92
+ assert len(rewards) == 1
93
+ assert rewards[0] > 0.0
94
+
95
+
96
+ def test_log_key_selection_prefers_reward_and_metric_keys():
97
+ log_history = [
98
+ {"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05},
99
+ {"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04},
100
+ ]
101
+ assert available_numeric_log_keys(log_history) == [
102
+ "loss",
103
+ "objective/kl",
104
+ "rewards/open_env_reward",
105
+ ]
106
+ reward_key = select_reward_key(log_history)
107
+ assert reward_key == "rewards/open_env_reward"
108
+ assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl"
109
+
110
+
111
+ def test_save_training_plots_writes_expected_files(tmp_path):
112
+ log_history = [
113
+ {"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8},
114
+ {"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5},
115
+ ]
116
+ plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm")
117
+
118
+ assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"}
119
+ for plot_path in plot_paths.values():
120
+ assert Path(plot_path).exists()
121
+
122
+ manifest_path = tmp_path / "training_plot_manifest.json"
123
+ assert manifest_path.exists()
training/__init__.py CHANGED
@@ -1,9 +1,7 @@
1
  from .evaluation import EvaluationSuite
2
- from .gym_wrapper import BioExperimentGymEnv
3
  from .trajectory import Trajectory, TrajectoryDataset
4
 
5
  __all__ = [
6
- "BioExperimentGymEnv",
7
  "EvaluationSuite",
8
  "PaperBenchmarkResult",
9
  "Trajectory",
 
1
  from .evaluation import EvaluationSuite
 
2
  from .trajectory import Trajectory, TrajectoryDataset
3
 
4
  __all__ = [
 
5
  "EvaluationSuite",
6
  "PaperBenchmarkResult",
7
  "Trajectory",
training/evaluation.py CHANGED
@@ -118,7 +118,7 @@ class EvaluationSuite:
118
  for t in ds.trajectories:
119
  violations = sum(
120
  1 for s in t.steps
121
- if not s.observation.get("rule_violations") == []
122
  and s.observation.get("rule_violations") is not None
123
  )
124
  if violations == 0:
@@ -146,7 +146,8 @@ class EvaluationSuite:
146
  at = s.action.get("action_type")
147
  if at:
148
  all_types.add(at)
149
- return len(all_types)
 
150
 
151
  @staticmethod
152
  def _mean_conclusion_confidence(ds: TrajectoryDataset) -> float:
 
118
  for t in ds.trajectories:
119
  violations = sum(
120
  1 for s in t.steps
121
+ if s.observation.get("rule_violations", []) != []
122
  and s.observation.get("rule_violations") is not None
123
  )
124
  if violations == 0:
 
146
  at = s.action.get("action_type")
147
  if at:
148
  all_types.add(at)
149
+ from models import ActionType
150
+ return len(all_types) / max(len(ActionType), 1)
151
 
152
  @staticmethod
153
  def _mean_conclusion_confidence(ds: TrajectoryDataset) -> float:
training/literature_benchmark.py CHANGED
@@ -148,7 +148,6 @@ def run_paper_benchmark(
148
  tool_call_spec=_tool_context(
149
  obs.task,
150
  libraries=["biopython"],
151
- include_expected_findings=True,
152
  ),
153
  )
154
  )
@@ -353,19 +352,7 @@ def infer_conclusion_claims(obs: ExperimentObservation) -> List[ConclusionClaim]
353
  evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
354
  ))
355
 
356
- if claims:
357
- return claims
358
-
359
- # Fallback: preserve the strongest expected findings verbatim if the
360
- # heuristic extractors do not recover enough signal from the episode.
361
- return [
362
- ConclusionClaim(
363
- claim=finding.finding,
364
- confidence=0.65,
365
- claim_type=finding.category,
366
- )
367
- for finding in obs.task.expected_findings[:3]
368
- ]
369
 
370
 
371
  def compare_expected_findings(
@@ -444,11 +431,11 @@ def _default_comparison_name(task: TaskSpec) -> str:
444
 
445
 
446
  def _preferred_marker(task: TaskSpec) -> str:
447
- for finding in task.expected_findings:
448
- for keyword in finding.keywords:
449
- if keyword.isupper():
450
- return keyword
451
- return "SPP1"
452
 
453
 
454
  def _latest_output_data(
 
148
  tool_call_spec=_tool_context(
149
  obs.task,
150
  libraries=["biopython"],
 
151
  ),
152
  )
153
  )
 
352
  evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
353
  ))
354
 
355
+ return claims
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
 
358
  def compare_expected_findings(
 
431
 
432
 
433
  def _preferred_marker(task: TaskSpec) -> str:
434
+ """Derive a candidate marker from the problem statement, not expected findings."""
435
+ tokens = [t for t in TOKEN_RE.findall(task.problem_statement) if t.isupper() and len(t) >= 3]
436
+ if tokens:
437
+ return tokens[0]
438
+ return "unknown"
439
 
440
 
441
  def _latest_output_data(
training/rollout_collection.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Collect trajectories with direct OpenEnv environment access."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import random
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Dict, List, Sequence
10
+
11
+ from models import ActionType, ExperimentAction
12
+ from server.hackathon_environment import BioExperimentEnvironment
13
+ from training.evaluation import EvaluationSuite
14
+ from training.trajectory import Trajectory, TrajectoryDataset
15
+
16
+
17
+ HEURISTIC_SEQUENCE = [
18
+ ActionType.COLLECT_SAMPLE,
19
+ ActionType.PREPARE_LIBRARY,
20
+ ActionType.SEQUENCE_CELLS,
21
+ ActionType.RUN_QC,
22
+ ActionType.FILTER_DATA,
23
+ ActionType.NORMALIZE_DATA,
24
+ ActionType.CLUSTER_CELLS,
25
+ ActionType.TRAJECTORY_ANALYSIS,
26
+ ActionType.MARKER_SELECTION,
27
+ ActionType.SYNTHESIZE_CONCLUSION,
28
+ ]
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser(
33
+ description="Run rollout episodes and persist trajectories."
34
+ )
35
+ parser.add_argument("--episodes", type=int, default=10, help="Number of episodes.")
36
+ parser.add_argument(
37
+ "--policy",
38
+ choices=["random", "heuristic"],
39
+ default="heuristic",
40
+ help="Policy to use for rollouts.",
41
+ )
42
+ parser.add_argument(
43
+ "--max-steps",
44
+ type=int,
45
+ default=None,
46
+ help="Optional hard cutoff per episode (defaults to env limit).",
47
+ )
48
+ parser.add_argument(
49
+ "--output-dir",
50
+ default="training/rollouts",
51
+ help="Directory for JSON trajectory outputs.",
52
+ )
53
+ parser.add_argument("--seed", type=int, default=None, help="RNG seed.")
54
+ return parser.parse_args()
55
+
56
+
57
+ def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType:
58
+ seen = set(history)
59
+ for action in HEURISTIC_SEQUENCE:
60
+ if action not in seen:
61
+ return action
62
+ if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen:
63
+ return ActionType.VALIDATE_MARKER
64
+ if ActionType.SYNTHESIZE_CONCLUSION in seen:
65
+ return ActionType.SYNTHESIZE_CONCLUSION
66
+ return ActionType.SYNTHESIZE_CONCLUSION
67
+
68
+
69
+ def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType:
70
+ if policy == "random":
71
+ return random.choice(list(ActionType))
72
+ return heuristic_next_action(history, step_index)
73
+
74
+
75
+ def default_comparison_name(conditions: Sequence[str]) -> str:
76
+ normalized = {condition.lower() for condition in conditions}
77
+ if {"healthy", "ipf"} <= normalized:
78
+ return "IPF_vs_healthy"
79
+ if any("treated" in condition for condition in normalized) and any(
80
+ "untreated" in condition for condition in normalized
81
+ ):
82
+ return "treated_vs_untreated"
83
+ if any("healthy" in condition for condition in normalized):
84
+ return "disease_vs_healthy"
85
+ return "disease_vs_healthy"
86
+
87
+
88
+ def build_experiment_action(
89
+ action_type: ActionType,
90
+ discovered_markers: Sequence[str],
91
+ conditions: Sequence[str],
92
+ ) -> ExperimentAction:
93
+ method = None
94
+ parameters: Dict[str, object] = {}
95
+
96
+ if action_type == ActionType.COLLECT_SAMPLE:
97
+ parameters = {"n_samples": 6}
98
+ elif action_type == ActionType.PREPARE_LIBRARY:
99
+ method = "10x_chromium"
100
+ elif action_type == ActionType.RUN_QC:
101
+ method = "scanpy.pp.calculate_qc_metrics"
102
+ elif action_type == ActionType.FILTER_DATA:
103
+ method = "scanpy.pp.filter_cells"
104
+ elif action_type == ActionType.NORMALIZE_DATA:
105
+ method = "scanpy.pp.normalize_total"
106
+ elif action_type == ActionType.CLUSTER_CELLS:
107
+ method = "scanpy.tl.leiden"
108
+ elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
109
+ method = "scanpy.tl.rank_genes_groups"
110
+ parameters = {"comparison": default_comparison_name(conditions)}
111
+ elif action_type == ActionType.TRAJECTORY_ANALYSIS:
112
+ method = "scanpy.tl.dpt"
113
+ elif action_type == ActionType.MARKER_SELECTION:
114
+ method = "scanpy.tl.rank_genes_groups"
115
+ elif action_type == ActionType.VALIDATE_MARKER:
116
+ method = "qPCR"
117
+ parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
118
+ elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
119
+ parameters = {"claims": []}
120
+
121
+ return ExperimentAction(
122
+ action_type=action_type,
123
+ method=method,
124
+ parameters=parameters,
125
+ confidence=0.75,
126
+ )
127
+
128
+
129
+ def run_episode(
130
+ env: BioExperimentEnvironment,
131
+ episode_id: str,
132
+ policy: str,
133
+ max_steps: int | None = None,
134
+ ) -> Trajectory:
135
+ structured_obs = env.reset()
136
+ traj = Trajectory(
137
+ episode_id=episode_id,
138
+ task=structured_obs.task.model_dump(),
139
+ metadata={
140
+ "task_problem": structured_obs.task.problem_statement,
141
+ "policy": policy,
142
+ },
143
+ )
144
+
145
+ done = structured_obs.done
146
+ step_num = 0
147
+ while not done:
148
+ if max_steps is not None and step_num >= max_steps:
149
+ break
150
+
151
+ history = [rec.action_type for rec in structured_obs.pipeline_history]
152
+ action_type = pick_action(policy, step_num, history)
153
+ experiment_action = build_experiment_action(
154
+ action_type=action_type,
155
+ discovered_markers=structured_obs.discovered_markers,
156
+ conditions=structured_obs.task.conditions,
157
+ )
158
+
159
+ structured_obs = env.step(experiment_action)
160
+ reward = structured_obs.reward
161
+ done = structured_obs.done
162
+ step_num += 1
163
+
164
+ traj.add_step(
165
+ action=experiment_action,
166
+ observation=structured_obs,
167
+ reward=reward,
168
+ done=done,
169
+ reward_breakdown=structured_obs.step_reward_breakdown,
170
+ )
171
+
172
+ print(
173
+ f" step={structured_obs.step_index:02d} "
174
+ f"action={action_type.value:>28} "
175
+ f"reward={reward:+.3f}"
176
+ )
177
+
178
+ return traj
179
+
180
+
181
+ def main() -> None:
182
+ args = parse_args()
183
+ if args.seed is not None:
184
+ random.seed(args.seed)
185
+
186
+ out_dir = Path(args.output_dir)
187
+ out_dir.mkdir(parents=True, exist_ok=True)
188
+
189
+ env = BioExperimentEnvironment()
190
+ trajectories: List[Trajectory] = []
191
+
192
+ print(
193
+ f"Starting rollout collection: episodes={args.episodes}, policy={args.policy}"
194
+ )
195
+ for ep in range(args.episodes):
196
+ print(f"Episode {ep + 1}/{args.episodes}")
197
+ traj = run_episode(
198
+ env=env,
199
+ episode_id=str(uuid.uuid4()),
200
+ policy=args.policy,
201
+ max_steps=args.max_steps,
202
+ )
203
+ traj.save(out_dir / f"{traj.episode_id}.json")
204
+ trajectories.append(traj)
205
+
206
+ dataset = TrajectoryDataset(trajectories)
207
+ stats = EvaluationSuite.online_metrics(trajectories)
208
+
209
+ print("\nRun complete.")
210
+ print(f"Saved trajectories to: {out_dir}")
211
+ print("Online metrics:")
212
+ for metric in stats:
213
+ print(f" - {metric.name}: {metric.value:.4f}")
214
+
215
+ print(f"Summary: {dataset.summary()}")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
training_script.py ADDED
@@ -0,0 +1,1250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train a planner with TRL GRPO and OpenEnv rewards."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import random
8
+ import re
9
+ from numbers import Real
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
12
+
13
+ from client import BioExperimentEnv
14
+ from models import (
15
+ ActionType,
16
+ ExperimentAction,
17
+ ExperimentObservation,
18
+ build_agent_observation_context,
19
+ build_agent_system_prompt,
20
+ )
21
+ from server.hackathon_environment import BioExperimentEnvironment
22
+ from server.tasks.scenarios import SCENARIO_LIBRARY
23
+
24
+ DEFAULT_MODEL_ID = "Qwen/Qwen3.5-0.8B"
25
+ DEFAULT_OUTPUT_DIR = "training/grpo-output"
26
+ DEFAULT_BASE_URL = "http://localhost:8000"
27
+ INVALID_ACTION_PENALTY = -2.0
28
+ ENVIRONMENT_ERROR_PENALTY = -4.0
29
+
30
+ SYSTEM_PROMPT = build_agent_system_prompt()
31
+
32
+ HEURISTIC_SEQUENCE = [
33
+ ActionType.COLLECT_SAMPLE,
34
+ ActionType.PREPARE_LIBRARY,
35
+ ActionType.SEQUENCE_CELLS,
36
+ ActionType.RUN_QC,
37
+ ActionType.FILTER_DATA,
38
+ ActionType.NORMALIZE_DATA,
39
+ ActionType.CLUSTER_CELLS,
40
+ ActionType.DIFFERENTIAL_EXPRESSION,
41
+ ActionType.PATHWAY_ENRICHMENT,
42
+ ActionType.MARKER_SELECTION,
43
+ ActionType.SYNTHESIZE_CONCLUSION,
44
+ ]
45
+
46
+ VALID_ACTION_TYPES = {action.value for action in ActionType}
47
+
48
+
49
+ def compact_preview(value: Any, max_chars: int = 160) -> str:
50
+ try:
51
+ text = json.dumps(value, ensure_ascii=True, sort_keys=True)
52
+ except TypeError:
53
+ text = str(value)
54
+ text = re.sub(r"\s+", " ", text).strip()
55
+ if len(text) <= max_chars:
56
+ return text
57
+ return text[: max_chars - 3] + "..."
58
+
59
+
60
+ def _edit_distance(a: str, b: str) -> int:
61
+ if len(a) < len(b):
62
+ return _edit_distance(b, a)
63
+ if not b:
64
+ return len(a)
65
+ prev = list(range(len(b) + 1))
66
+ for i, ca in enumerate(a):
67
+ curr = [i + 1]
68
+ for j, cb in enumerate(b):
69
+ curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
70
+ prev = curr
71
+ return prev[-1]
72
+
73
+
74
+ def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
75
+ for name in names:
76
+ if name in payload:
77
+ return payload[name]
78
+
79
+ lowered = {
80
+ str(key).lower(): value
81
+ for key, value in payload.items()
82
+ }
83
+ for name in names:
84
+ if name.lower() in lowered:
85
+ return lowered[name.lower()]
86
+
87
+ for key, value in lowered.items():
88
+ for name in names:
89
+ threshold = max(2, len(name) // 3)
90
+ if _edit_distance(key, name.lower()) <= threshold:
91
+ return value
92
+ return None
93
+
94
+
95
+ def build_argument_parser() -> argparse.ArgumentParser:
96
+ parser = argparse.ArgumentParser(
97
+ description="Train a GRPO policy against the OpenEnv bio experiment environment."
98
+ )
99
+ parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
100
+ parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
101
+ parser.add_argument("--dataset-episodes", type=int, default=8)
102
+ parser.add_argument("--rollout-steps", type=int, default=6)
103
+ parser.add_argument(
104
+ "--collection-policy",
105
+ choices=["random", "heuristic"],
106
+ default="heuristic",
107
+ help="Policy used to build prompt states for GRPO training.",
108
+ )
109
+ parser.add_argument(
110
+ "--reward-backend",
111
+ choices=["local", "remote"],
112
+ default="local",
113
+ help="Use local in-process scoring or a live OpenEnv server.",
114
+ )
115
+ parser.add_argument(
116
+ "--base-url",
117
+ default=DEFAULT_BASE_URL,
118
+ help="Base URL for the OpenEnv server when reward-backend=remote.",
119
+ )
120
+ parser.add_argument(
121
+ "--scenario-name",
122
+ action="append",
123
+ default=None,
124
+ help="Repeatable scenario selector. Defaults to all curated scenarios.",
125
+ )
126
+ parser.add_argument(
127
+ "--domain-randomise",
128
+ action="store_true",
129
+ help="Enable domain randomisation while building prompts and local rewards.",
130
+ )
131
+ parser.add_argument("--num-generations", type=int, default=2)
132
+ parser.add_argument("--max-completion-length", type=int, default=220)
133
+ parser.add_argument("--max-prompt-length", type=int, default=768)
134
+ parser.add_argument("--per-device-train-batch-size", type=int, default=2)
135
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
136
+ parser.add_argument("--learning-rate", type=float, default=5e-6)
137
+ parser.add_argument("--num-train-epochs", type=float, default=1.0)
138
+ parser.add_argument("--logging-steps", type=int, default=1)
139
+ parser.add_argument("--save-steps", type=int, default=50)
140
+ parser.add_argument(
141
+ "--plot-metric-key",
142
+ default=None,
143
+ help="Optional extra metric key from trainer log history to plot.",
144
+ )
145
+ parser.add_argument("--seed", type=int, default=0)
146
+ parser.add_argument(
147
+ "--load-model-only",
148
+ action="store_true",
149
+ help="Download and load the selected model and tokenizer, then exit.",
150
+ )
151
+ parser.add_argument(
152
+ "--trust-remote-code",
153
+ action="store_true",
154
+ help="Pass trust_remote_code=True to model/tokenizer loading.",
155
+ )
156
+ parser.add_argument(
157
+ "--dry-run",
158
+ action="store_true",
159
+ help="Build the prompt dataset and smoke-test the reward function without training.",
160
+ )
161
+ parser.add_argument(
162
+ "--push-to-hub",
163
+ type=str,
164
+ default=None,
165
+ help="HuggingFace Hub repo id to push the trained model to (e.g. 'myuser/my-model').",
166
+ )
167
+ return parser
168
+
169
+
170
+ def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
171
+ return build_argument_parser().parse_args(argv)
172
+
173
+
174
+ def make_training_args(**overrides: Any) -> argparse.Namespace:
175
+ """Build an argparse-style namespace for notebooks and scripts."""
176
+ parser = build_argument_parser()
177
+ defaults = vars(parser.parse_args([]))
178
+ unknown = sorted(set(overrides) - set(defaults))
179
+ if unknown:
180
+ raise ValueError(f"Unknown training args: {', '.join(unknown)}")
181
+ defaults.update(overrides)
182
+ return argparse.Namespace(**defaults)
183
+
184
+
185
+ def format_observation(obs: ExperimentObservation) -> str:
186
+ parts = [
187
+ f"TASK: {obs.task.problem_statement}",
188
+ f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
189
+ f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
190
+ (
191
+ "Step: "
192
+ f"{obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} "
193
+ f"| Time: {obs.resource_usage.time_remaining_days:.0f}d"
194
+ ),
195
+ ]
196
+ context = build_agent_observation_context(obs, max_tools=5, max_assays=2)
197
+ if context:
198
+ parts.append(context)
199
+ if obs.pipeline_history:
200
+ parts.append("History:")
201
+ for step in obs.pipeline_history[-5:]:
202
+ tag = "OK" if step.success else "FAIL"
203
+ line = f" [{tag}] {step.action_type.value}: {step.output_summary[:100]}"
204
+ if step.parameters:
205
+ line += f" | params={compact_preview(step.parameters, 120)}"
206
+ parts.append(line)
207
+ if obs.latest_output and obs.latest_output.data:
208
+ parts.append(
209
+ f"Latest output data: {compact_preview(obs.latest_output.data, 200)}"
210
+ )
211
+ if obs.rule_violations:
212
+ parts.append(f"Violations: {obs.rule_violations}")
213
+ if obs.discovered_markers:
214
+ parts.append(f"Markers: {obs.discovered_markers[:5]}")
215
+ if obs.candidate_mechanisms:
216
+ parts.append(f"Mechanisms: {obs.candidate_mechanisms[:5]}")
217
+ return "\n".join(parts)
218
+
219
+
220
+ def build_training_prompt(obs: ExperimentObservation) -> str:
221
+ return f"{SYSTEM_PROMPT}\n\n{format_observation(obs)}"
222
+
223
+
224
+ def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType:
225
+ seen = set(history)
226
+ for action in HEURISTIC_SEQUENCE:
227
+ if action not in seen:
228
+ return action
229
+ if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen:
230
+ return ActionType.VALIDATE_MARKER
231
+ return ActionType.SYNTHESIZE_CONCLUSION
232
+
233
+
234
+ def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType:
235
+ if policy == "random":
236
+ return random.choice(list(ActionType))
237
+ return heuristic_next_action(history, step_index)
238
+
239
+
240
+ def default_comparison_name(conditions: Sequence[str]) -> str:
241
+ normalized = {condition.lower() for condition in conditions}
242
+ if {"healthy", "ipf"} <= normalized:
243
+ return "IPF_vs_healthy"
244
+ if any("treated" in condition for condition in normalized) and any(
245
+ "untreated" in condition for condition in normalized
246
+ ):
247
+ return "treated_vs_untreated"
248
+ return "disease_vs_healthy"
249
+
250
+
251
+ def build_experiment_action(
252
+ action_type: ActionType,
253
+ discovered_markers: Sequence[str],
254
+ conditions: Sequence[str],
255
+ ) -> ExperimentAction:
256
+ method = None
257
+ parameters: Dict[str, object] = {}
258
+ justification = f"Advance the experiment with {action_type.value}."
259
+
260
+ if action_type == ActionType.COLLECT_SAMPLE:
261
+ parameters = {"n_samples": 6}
262
+ justification = "Collect enough samples to start the experiment."
263
+ elif action_type == ActionType.PREPARE_LIBRARY:
264
+ method = "10x_chromium"
265
+ justification = "Prepare a single-cell library for sequencing."
266
+ elif action_type == ActionType.SEQUENCE_CELLS:
267
+ method = "NovaSeq"
268
+ justification = "Generate reads for downstream single-cell analysis."
269
+ elif action_type == ActionType.RUN_QC:
270
+ method = "scanpy.pp.calculate_qc_metrics"
271
+ justification = "Measure technical quality before filtering."
272
+ elif action_type == ActionType.FILTER_DATA:
273
+ method = "scanpy.pp.filter_cells"
274
+ justification = "Remove low-quality cells and technical artifacts."
275
+ elif action_type == ActionType.NORMALIZE_DATA:
276
+ method = "scanpy.pp.normalize_total"
277
+ justification = "Normalize counts for comparable expression profiles."
278
+ elif action_type == ActionType.CLUSTER_CELLS:
279
+ method = "scanpy.tl.leiden"
280
+ justification = "Resolve cell states before interpretation."
281
+ elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
282
+ method = "scanpy.tl.rank_genes_groups"
283
+ parameters = {"comparison": default_comparison_name(conditions)}
284
+ justification = "Identify genes associated with the phenotype of interest."
285
+ elif action_type == ActionType.TRAJECTORY_ANALYSIS:
286
+ method = "scanpy.tl.dpt"
287
+ justification = "Recover pseudotime and lineage structure."
288
+ elif action_type == ActionType.PATHWAY_ENRICHMENT:
289
+ method = "gseapy.prerank"
290
+ justification = "Translate gene-level changes into pathway programs."
291
+ elif action_type == ActionType.MARKER_SELECTION:
292
+ method = "scanpy.tl.rank_genes_groups"
293
+ justification = "Nominate marker genes for validation."
294
+ elif action_type == ActionType.VALIDATE_MARKER:
295
+ method = "qPCR"
296
+ parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
297
+ justification = "Validate the strongest discovered marker."
298
+ elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
299
+ top = list(discovered_markers[:5]) if discovered_markers else []
300
+ parameters = {
301
+ "claims": [{
302
+ "top_markers": top,
303
+ "causal_mechanisms": [],
304
+ "predicted_pathways": {},
305
+ "confidence": 0.6,
306
+ "claim_type": "correlational",
307
+ "claim": "",
308
+ }],
309
+ }
310
+ justification = "Summarize the current evidence into a conclusion."
311
+
312
+ return ExperimentAction(
313
+ action_type=action_type,
314
+ method=method,
315
+ parameters=parameters,
316
+ justification=justification,
317
+ confidence=0.75,
318
+ )
319
+
320
+
321
+ def selected_scenarios(requested: Optional[Sequence[str]]) -> List[str]:
322
+ from server.tasks.procedural_generator import generate_procedural_scenarios
323
+ all_scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(n=20, seed=42)
324
+ available = [scenario.name for scenario in all_scenarios]
325
+ if not requested:
326
+ return available
327
+ unknown = sorted(set(requested) - set(available))
328
+ if unknown:
329
+ raise ValueError(f"Unknown scenarios requested: {', '.join(unknown)}")
330
+ return list(requested)
331
+
332
+
333
+ def action_completion_json(action: ExperimentAction) -> str:
334
+ payload = {
335
+ "action_type": action.action_type.value,
336
+ "method": action.method,
337
+ "parameters": action.parameters,
338
+ "justification": action.justification,
339
+ "confidence": action.confidence,
340
+ }
341
+ return json.dumps(payload)
342
+
343
+
344
+ def build_prompt_examples(
345
+ *,
346
+ dataset_episodes: int,
347
+ rollout_steps: int,
348
+ collection_policy: str,
349
+ scenario_names: Sequence[str],
350
+ seed: int,
351
+ domain_randomise: bool,
352
+ ) -> List[Dict[str, str]]:
353
+ rng = random.Random(seed)
354
+ examples: List[Dict[str, str]] = []
355
+ scenario_cycle = list(scenario_names)
356
+ rng.shuffle(scenario_cycle)
357
+
358
+ for episode_idx in range(dataset_episodes):
359
+ scenario_name = scenario_cycle[episode_idx % len(scenario_cycle)]
360
+ env = BioExperimentEnvironment(
361
+ scenario_name=scenario_name,
362
+ domain_randomise=domain_randomise,
363
+ )
364
+ obs = env.reset()
365
+ history_actions: List[ExperimentAction] = []
366
+
367
+ for step_idx in range(rollout_steps):
368
+ if obs.done:
369
+ break
370
+
371
+ next_action = build_experiment_action(
372
+ action_type=pick_action(
373
+ collection_policy,
374
+ step_idx,
375
+ [action.action_type for action in history_actions],
376
+ ),
377
+ discovered_markers=obs.discovered_markers,
378
+ conditions=obs.task.conditions,
379
+ )
380
+ examples.append({
381
+ "prompt": build_training_prompt(obs),
382
+ "scenario_name": scenario_name,
383
+ "history_actions": json.dumps(
384
+ [action.model_dump() for action in history_actions]
385
+ ),
386
+ "rng_seed": str(env._latent.rng_seed),
387
+ "reference_action": action_completion_json(next_action),
388
+ "problem_statement": obs.task.problem_statement,
389
+ })
390
+
391
+ history_actions.append(next_action)
392
+ obs = env.step(next_action)
393
+
394
+ return examples
395
+
396
+
397
+ def completion_to_text(completion: Any) -> str:
398
+ if isinstance(completion, str):
399
+ return completion.strip()
400
+ if isinstance(completion, dict):
401
+ return content_to_text(completion.get("content", ""))
402
+ if isinstance(completion, list):
403
+ for item in reversed(completion):
404
+ if isinstance(item, dict) and "content" in item:
405
+ text = content_to_text(item["content"])
406
+ if text:
407
+ return text
408
+ if isinstance(item, str) and item.strip():
409
+ return item.strip()
410
+ return str(completion).strip()
411
+
412
+
413
+ def content_to_text(content: Any) -> str:
414
+ if isinstance(content, str):
415
+ return content.strip()
416
+ if isinstance(content, list):
417
+ parts: List[str] = []
418
+ for part in content:
419
+ if isinstance(part, str):
420
+ parts.append(part)
421
+ elif isinstance(part, dict):
422
+ if isinstance(part.get("text"), str):
423
+ parts.append(part["text"])
424
+ elif isinstance(part.get("content"), str):
425
+ parts.append(part["content"])
426
+ return "".join(parts).strip()
427
+ return str(content).strip()
428
+
429
+
430
+ def _repair_truncated_json(text: str) -> Optional[str]:
431
+ """Try to repair JSON truncated mid-value (common with small LLMs)."""
432
+ s = text.strip()
433
+ if not s.startswith("{"):
434
+ return None
435
+
436
+ s = re.sub(r',\s*"[^"\n]*$', '', s)
437
+ s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
438
+
439
+ in_string = False
440
+ escape = False
441
+ for ch in s:
442
+ if escape:
443
+ escape = False
444
+ continue
445
+ if ch == "\\":
446
+ escape = True
447
+ continue
448
+ if ch == '"':
449
+ in_string = not in_string
450
+
451
+ if in_string:
452
+ s += '"'
453
+
454
+ open_braces = s.count("{") - s.count("}")
455
+ open_brackets = s.count("[") - s.count("]")
456
+ s += "]" * max(0, open_brackets)
457
+ s += "}" * max(0, open_braces)
458
+
459
+ try:
460
+ obj = json.loads(s)
461
+ if isinstance(obj, dict):
462
+ return s
463
+ except json.JSONDecodeError:
464
+ pass
465
+
466
+ s = re.sub(r',\s*([}\]])', r'\1', s)
467
+ try:
468
+ obj = json.loads(s)
469
+ if isinstance(obj, dict):
470
+ return s
471
+ except json.JSONDecodeError:
472
+ pass
473
+ return None
474
+
475
+
476
+ def _normalize_jsonish_text(text: str) -> str:
477
+ """Normalize common near-JSON artifacts emitted by small local models."""
478
+ text = _strip_js_comments(text)
479
+ text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
480
+ text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
481
+ text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
482
+ text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
483
+ return text
484
+
485
+
486
+ def _strip_js_comments(text: str) -> str:
487
+ """Remove // and /* */ comments that small LLMs inject into JSON."""
488
+ text = re.sub(r'//[^\n]*', '', text)
489
+ text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
490
+ return text
491
+
492
+
493
+ def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
494
+ stripped = _normalize_jsonish_text(text).strip()
495
+ fence_prefix = "```"
496
+ if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
497
+ lines = stripped.splitlines()
498
+ if len(lines) >= 3:
499
+ stripped = "\n".join(lines[1:-1]).strip()
500
+
501
+ candidates: List[str] = [stripped]
502
+ start = stripped.find("{")
503
+ while start != -1:
504
+ depth = 0
505
+ for idx in range(start, len(stripped)):
506
+ char = stripped[idx]
507
+ if char == "{":
508
+ depth += 1
509
+ elif char == "}":
510
+ depth -= 1
511
+ if depth == 0:
512
+ candidates.append(stripped[start:idx + 1])
513
+ break
514
+ start = stripped.find("{", start + 1)
515
+
516
+ first_brace = stripped.find("{")
517
+ if first_brace != -1:
518
+ repaired = _repair_truncated_json(stripped[first_brace:])
519
+ if repaired is not None:
520
+ candidates.append(repaired)
521
+
522
+ candidates.sort(key=len, reverse=True)
523
+
524
+ for candidate in candidates:
525
+ try:
526
+ parsed = json.loads(candidate)
527
+ except json.JSONDecodeError:
528
+ continue
529
+ if isinstance(parsed, dict):
530
+ return parsed
531
+
532
+ return None
533
+
534
+
535
+ def normalize_optional_string(value: Any) -> Optional[str]:
536
+ if value is None or isinstance(value, bool):
537
+ return None
538
+ if isinstance(value, str):
539
+ value = value.strip()
540
+ return value or None
541
+ if isinstance(value, (int, float)):
542
+ return str(value)
543
+ return compact_preview(value, 80)
544
+
545
+
546
+ def parse_action_completion(text: str) -> Optional[ExperimentAction]:
547
+ payload = extract_json_object(text)
548
+ if payload is not None:
549
+ action_type = get_payload_value(payload, "action_type")
550
+ if action_type not in VALID_ACTION_TYPES:
551
+ return None
552
+
553
+ parameters = get_payload_value(payload, "parameters", "params") or {}
554
+ if not isinstance(parameters, dict):
555
+ parameters = {}
556
+
557
+ confidence = get_payload_value(payload, "confidence")
558
+ if confidence is None:
559
+ confidence = 0.5
560
+ try:
561
+ confidence = float(confidence)
562
+ except (TypeError, ValueError):
563
+ confidence = 0.5
564
+
565
+ justification = get_payload_value(
566
+ payload, "justification", "reasoning", "rationale", "reason"
567
+ )
568
+ if justification is not None and not isinstance(justification, str):
569
+ justification = compact_preview(justification, 200)
570
+
571
+ return ExperimentAction(
572
+ action_type=ActionType(action_type),
573
+ method=normalize_optional_string(get_payload_value(payload, "method")),
574
+ parameters=parameters,
575
+ justification=justification,
576
+ confidence=min(1.0, max(0.0, confidence)),
577
+ )
578
+
579
+ action_match = re.search(
580
+ r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
581
+ text,
582
+ re.IGNORECASE,
583
+ )
584
+ if not action_match:
585
+ return None
586
+
587
+ action_type = action_match.group(1).strip()
588
+ if action_type not in VALID_ACTION_TYPES:
589
+ return None
590
+
591
+ method_match = re.search(
592
+ r'["\']method["\']\s*:\s*("((?:[^"\\]|\\.)*)"|null|none|true|false|-?\d+(?:\.\d+)?)',
593
+ text,
594
+ re.IGNORECASE,
595
+ )
596
+ confidence_match = re.search(
597
+ r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)',
598
+ text,
599
+ re.IGNORECASE,
600
+ )
601
+ justification_match = re.search(
602
+ r'["\'](?:justif\w*|reasoning|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)',
603
+ text,
604
+ re.DOTALL | re.IGNORECASE,
605
+ )
606
+
607
+ confidence = 0.5
608
+ if confidence_match:
609
+ try:
610
+ confidence = float(confidence_match.group(1))
611
+ except ValueError:
612
+ confidence = 0.5
613
+
614
+ justification = None
615
+ if justification_match:
616
+ try:
617
+ justification = json.loads(f'"{justification_match.group(1)}"')
618
+ except json.JSONDecodeError:
619
+ justification = justification_match.group(1)
620
+
621
+ method = None
622
+ if method_match:
623
+ raw_method = method_match.group(1)
624
+ if raw_method.startswith('"') and raw_method.endswith('"'):
625
+ try:
626
+ method = json.loads(raw_method)
627
+ except json.JSONDecodeError:
628
+ method = raw_method.strip('"')
629
+ elif raw_method.lower() not in {"null", "none", "true", "false"}:
630
+ method = raw_method
631
+
632
+ return ExperimentAction(
633
+ action_type=ActionType(action_type),
634
+ method=normalize_optional_string(method),
635
+ parameters={},
636
+ justification=justification,
637
+ confidence=min(1.0, max(0.0, confidence)),
638
+ )
639
+
640
+
641
+ def decode_history_actions(history_actions: Optional[str]) -> List[ExperimentAction]:
642
+ if not history_actions:
643
+ return []
644
+ raw_actions = json.loads(history_actions)
645
+ return [
646
+ ExperimentAction(**action_payload)
647
+ for action_payload in raw_actions
648
+ if isinstance(action_payload, dict)
649
+ ]
650
+
651
+
652
+ def normalise_column(values: Any, length: int) -> List[Any]:
653
+ if values is None:
654
+ return [None] * length
655
+ if isinstance(values, list):
656
+ if len(values) == length:
657
+ return values
658
+ if len(values) == 1:
659
+ return values * length
660
+ return values[:length] + [None] * max(0, length - len(values))
661
+ return [values] * length
662
+
663
+
664
+ class OpenEnvReward:
665
+ """Reward function compatible with TRL GRPOTrainer."""
666
+
667
+ def __init__(
668
+ self,
669
+ *,
670
+ reward_backend: str,
671
+ base_url: str,
672
+ invalid_action_penalty: float = INVALID_ACTION_PENALTY,
673
+ environment_error_penalty: float = ENVIRONMENT_ERROR_PENALTY,
674
+ domain_randomise: bool = False,
675
+ ) -> None:
676
+ self.__name__ = "openenv_reward"
677
+ self.reward_backend = reward_backend
678
+ self.base_url = base_url
679
+ self.invalid_action_penalty = invalid_action_penalty
680
+ self.environment_error_penalty = environment_error_penalty
681
+ self.domain_randomise = domain_randomise
682
+
683
+ def __call__(
684
+ self,
685
+ completions: List[Any],
686
+ scenario_name: Optional[List[str]] = None,
687
+ history_actions: Optional[List[str]] = None,
688
+ rng_seed: Optional[List[str]] = None,
689
+ **_: Any,
690
+ ) -> List[float]:
691
+ scenario_names = normalise_column(scenario_name, len(completions))
692
+ history_columns = normalise_column(history_actions, len(completions))
693
+ seed_columns = normalise_column(rng_seed, len(completions))
694
+ rewards: List[float] = []
695
+
696
+ for completion, current_scenario, current_history, current_seed in zip(
697
+ completions,
698
+ scenario_names,
699
+ history_columns,
700
+ seed_columns,
701
+ ):
702
+ action = parse_action_completion(completion_to_text(completion))
703
+ if action is None:
704
+ rewards.append(self.invalid_action_penalty)
705
+ continue
706
+
707
+ try:
708
+ if self.reward_backend == "remote":
709
+ reward = self._score_remote(action, current_scenario, current_history)
710
+ else:
711
+ reward = self._score_local(action, current_scenario, current_history, current_seed)
712
+ except Exception:
713
+ reward = self.environment_error_penalty
714
+
715
+ rewards.append(float(reward))
716
+
717
+ return rewards
718
+
719
+ def _score_local(
720
+ self,
721
+ action: ExperimentAction,
722
+ scenario_name: Optional[str],
723
+ history_actions: Optional[str],
724
+ rng_seed: Optional[str] = None,
725
+ ) -> float:
726
+ env = BioExperimentEnvironment(
727
+ scenario_name=scenario_name,
728
+ domain_randomise=self.domain_randomise,
729
+ )
730
+ seed = int(rng_seed) if rng_seed else None
731
+ obs = env.reset(seed=seed)
732
+ for previous_action in decode_history_actions(history_actions):
733
+ obs = env.step(previous_action)
734
+ if obs.done:
735
+ return float(obs.reward)
736
+ obs = env.step(action)
737
+ return float(obs.reward)
738
+
739
+ def _score_remote(
740
+ self,
741
+ action: ExperimentAction,
742
+ scenario_name: Optional[str],
743
+ history_actions: Optional[str],
744
+ ) -> float:
745
+ with BioExperimentEnv(base_url=self.base_url) as env:
746
+ # NOTE: scenario_name is accepted for API parity with _score_local
747
+ # but the OpenEnv HTTP protocol does not yet support passing it
748
+ # through reset(). The server will use its configured default.
749
+ result = env.reset()
750
+ for previous_action in decode_history_actions(history_actions):
751
+ result = env.step(previous_action)
752
+ if result.done:
753
+ return float(result.reward or 0.0)
754
+ result = env.step(action)
755
+ if result.reward is not None:
756
+ return float(result.reward)
757
+ return float(result.observation.reward)
758
+
759
+
760
+ def is_numeric_log_value(value: Any) -> bool:
761
+ return isinstance(value, Real) and not isinstance(value, bool)
762
+
763
+
764
+ def available_numeric_log_keys(log_history: Sequence[Dict[str, Any]]) -> List[str]:
765
+ keys = {
766
+ key
767
+ for entry in log_history
768
+ if isinstance(entry, dict)
769
+ for key, value in entry.items()
770
+ if key != "step" and is_numeric_log_value(value)
771
+ }
772
+ return sorted(keys)
773
+
774
+
775
+ def extract_log_series(
776
+ log_history: Sequence[Dict[str, Any]],
777
+ key: Optional[str],
778
+ ) -> List[Tuple[float, float]]:
779
+ if not key:
780
+ return []
781
+
782
+ series: List[Tuple[float, float]] = []
783
+ synthetic_step = 0
784
+ for entry in log_history:
785
+ if not isinstance(entry, dict) or key not in entry:
786
+ continue
787
+ value = entry.get(key)
788
+ if not is_numeric_log_value(value):
789
+ continue
790
+
791
+ raw_step = entry.get("step")
792
+ if is_numeric_log_value(raw_step):
793
+ step = float(raw_step)
794
+ else:
795
+ synthetic_step += 1
796
+ step = float(synthetic_step)
797
+
798
+ series.append((step, float(value)))
799
+
800
+ return series
801
+
802
+
803
+ def select_reward_key(log_history: Sequence[Dict[str, Any]]) -> Optional[str]:
804
+ numeric_keys = available_numeric_log_keys(log_history)
805
+ reward_keys = [key for key in numeric_keys if "reward" in key.lower()]
806
+ if not reward_keys:
807
+ return None
808
+
809
+ preferred = [
810
+ "reward",
811
+ "mean_reward",
812
+ "reward_mean",
813
+ "rewards/open_env_reward",
814
+ ]
815
+ lowered = {key.lower(): key for key in reward_keys}
816
+ for key in preferred:
817
+ if key in lowered:
818
+ return lowered[key]
819
+
820
+ reward_keys.sort(key=lambda key: ("/" in key, len(key), key))
821
+ return reward_keys[0]
822
+
823
+
824
+ def select_metric_key(
825
+ log_history: Sequence[Dict[str, Any]],
826
+ *,
827
+ reward_key: Optional[str],
828
+ requested_key: Optional[str] = None,
829
+ ) -> Optional[str]:
830
+ numeric_keys = available_numeric_log_keys(log_history)
831
+ if requested_key:
832
+ if requested_key not in numeric_keys:
833
+ available = ", ".join(numeric_keys) or "none"
834
+ raise ValueError(
835
+ f"Requested plot metric '{requested_key}' was not logged. "
836
+ f"Available numeric keys: {available}"
837
+ )
838
+ return requested_key
839
+
840
+ excluded = {
841
+ "epoch",
842
+ "loss",
843
+ "learning_rate",
844
+ "step",
845
+ "total_flos",
846
+ "train_loss",
847
+ "train_runtime",
848
+ "train_samples_per_second",
849
+ "train_steps_per_second",
850
+ }
851
+ if reward_key:
852
+ excluded.add(reward_key)
853
+
854
+ preferred = [
855
+ "kl",
856
+ "objective/kl",
857
+ "completion_length",
858
+ "mean_completion_length",
859
+ "grad_norm",
860
+ "entropy",
861
+ "accuracy",
862
+ "learning_rate",
863
+ "epoch",
864
+ ]
865
+ numeric_set = set(numeric_keys)
866
+ for key in preferred:
867
+ if key in numeric_set and key not in excluded:
868
+ return key
869
+
870
+ candidates = [
871
+ key for key in numeric_keys
872
+ if key not in excluded and "reward" not in key.lower()
873
+ ]
874
+ if candidates:
875
+ return candidates[0]
876
+
877
+ for fallback in ("learning_rate", "epoch"):
878
+ if fallback in numeric_set:
879
+ return fallback
880
+
881
+ return None
882
+
883
+
884
+ def save_plot(
885
+ path: Path,
886
+ *,
887
+ series: Sequence[Tuple[float, float]],
888
+ title: str,
889
+ ylabel: str,
890
+ ) -> None:
891
+ import matplotlib
892
+
893
+ matplotlib.use("Agg")
894
+ import matplotlib.pyplot as plt
895
+
896
+ fig, ax = plt.subplots(figsize=(8, 4.5))
897
+ if series:
898
+ x_values, y_values = zip(*series)
899
+ ax.plot(x_values, y_values, marker="o", linewidth=1.8)
900
+ else:
901
+ ax.text(
902
+ 0.5,
903
+ 0.5,
904
+ "No logged data available",
905
+ ha="center",
906
+ va="center",
907
+ transform=ax.transAxes,
908
+ )
909
+ ax.set_title(title)
910
+ ax.set_xlabel("Step")
911
+ ax.set_ylabel(ylabel)
912
+ ax.grid(True, alpha=0.3)
913
+ fig.tight_layout()
914
+ fig.savefig(path, dpi=150)
915
+ plt.close(fig)
916
+
917
+
918
+ def save_training_plots(
919
+ log_history: Sequence[Dict[str, Any]],
920
+ output_dir: str | Path,
921
+ metric_key: Optional[str] = None,
922
+ ) -> Dict[str, str]:
923
+ import matplotlib
924
+
925
+ matplotlib.use("Agg")
926
+ import matplotlib.pyplot as plt
927
+
928
+ output_path = Path(output_dir)
929
+ output_path.mkdir(parents=True, exist_ok=True)
930
+
931
+ reward_key = select_reward_key(log_history)
932
+ selected_metric_key = select_metric_key(
933
+ log_history,
934
+ reward_key=reward_key,
935
+ requested_key=metric_key,
936
+ )
937
+
938
+ loss_series = extract_log_series(log_history, "loss")
939
+ reward_series = extract_log_series(log_history, reward_key)
940
+ metric_series = extract_log_series(log_history, selected_metric_key)
941
+
942
+ loss_path = output_path / "training_loss.png"
943
+ reward_path = output_path / "training_reward.png"
944
+ metric_path = output_path / "training_metric.png"
945
+ dashboard_path = output_path / "training_dashboard.png"
946
+ manifest_path = output_path / "training_plot_manifest.json"
947
+
948
+ save_plot(loss_path, series=loss_series, title="Training Loss", ylabel="Loss")
949
+ save_plot(
950
+ reward_path,
951
+ series=reward_series,
952
+ title=f"Training Reward ({reward_key or 'not logged'})",
953
+ ylabel="Reward",
954
+ )
955
+ save_plot(
956
+ metric_path,
957
+ series=metric_series,
958
+ title=f"Training Metric ({selected_metric_key or 'not logged'})",
959
+ ylabel=selected_metric_key or "Metric",
960
+ )
961
+
962
+ fig, axes = plt.subplots(3, 1, figsize=(10, 12))
963
+ plot_specs = [
964
+ (axes[0], loss_series, "Training Loss", "Loss"),
965
+ (axes[1], reward_series, f"Training Reward ({reward_key or 'not logged'})", "Reward"),
966
+ (
967
+ axes[2],
968
+ metric_series,
969
+ f"Training Metric ({selected_metric_key or 'not logged'})",
970
+ selected_metric_key or "Metric",
971
+ ),
972
+ ]
973
+ for axis, series, title, ylabel in plot_specs:
974
+ if series:
975
+ x_values, y_values = zip(*series)
976
+ axis.plot(x_values, y_values, marker="o", linewidth=1.8)
977
+ else:
978
+ axis.text(
979
+ 0.5,
980
+ 0.5,
981
+ "No logged data available",
982
+ ha="center",
983
+ va="center",
984
+ transform=axis.transAxes,
985
+ )
986
+ axis.set_title(title)
987
+ axis.set_xlabel("Step")
988
+ axis.set_ylabel(ylabel)
989
+ axis.grid(True, alpha=0.3)
990
+ fig.tight_layout()
991
+ fig.savefig(dashboard_path, dpi=150)
992
+ plt.close(fig)
993
+
994
+ manifest = {
995
+ "available_numeric_keys": available_numeric_log_keys(log_history),
996
+ "reward_key": reward_key,
997
+ "metric_key": selected_metric_key,
998
+ "plots": {
999
+ "loss": str(loss_path),
1000
+ "reward": str(reward_path),
1001
+ "metric": str(metric_path),
1002
+ "dashboard": str(dashboard_path),
1003
+ },
1004
+ }
1005
+ manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
1006
+ return manifest["plots"]
1007
+
1008
+
1009
+ def run_dry_run_preview(
1010
+ examples: Sequence[Dict[str, str]],
1011
+ reward_fn: OpenEnvReward,
1012
+ output_dir: str,
1013
+ ) -> None:
1014
+ if not examples:
1015
+ raise ValueError("No training prompts were generated for the dry run.")
1016
+
1017
+ sample = examples[0]
1018
+ sample_reward = reward_fn(
1019
+ completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
1020
+ scenario_name=[sample["scenario_name"]],
1021
+ history_actions=[sample["history_actions"]],
1022
+ )[0]
1023
+
1024
+ print(f"Built {len(examples)} prompt states.")
1025
+ print(f"Output directory: {Path(output_dir)}")
1026
+ print(f"Sample scenario: {sample['scenario_name']}")
1027
+ print(f"Sample reward for reference action: {sample_reward:+.3f}")
1028
+ print("\nSample prompt:\n")
1029
+ print(sample["prompt"])
1030
+
1031
+
1032
+ def resolve_torch_runtime() -> Dict[str, Any]:
1033
+ import torch
1034
+
1035
+ use_cuda = torch.cuda.is_available()
1036
+ bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
1037
+ dtype = torch.bfloat16 if bf16 else (
1038
+ torch.float16 if use_cuda else torch.float32
1039
+ )
1040
+ return {
1041
+ "use_cuda": use_cuda,
1042
+ "device": "cuda:0" if use_cuda else "cpu",
1043
+ "dtype": dtype,
1044
+ "bf16": bf16,
1045
+ "fp16": use_cuda and not bf16,
1046
+ "device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
1047
+ }
1048
+
1049
+
1050
+ def load_model_artifacts(
1051
+ model_id: str,
1052
+ *,
1053
+ trust_remote_code: bool,
1054
+ ):
1055
+ from transformers import AutoModelForCausalLM, AutoTokenizer
1056
+
1057
+ runtime = resolve_torch_runtime()
1058
+ print(f"Loading tokenizer for {model_id} ...")
1059
+ tokenizer = AutoTokenizer.from_pretrained(
1060
+ model_id,
1061
+ trust_remote_code=trust_remote_code,
1062
+ )
1063
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1064
+ tokenizer.pad_token = tokenizer.eos_token
1065
+
1066
+ print(f"Loading model for {model_id} ...")
1067
+ model = AutoModelForCausalLM.from_pretrained(
1068
+ model_id,
1069
+ trust_remote_code=trust_remote_code,
1070
+ torch_dtype=runtime["dtype"],
1071
+ )
1072
+ if runtime["use_cuda"]:
1073
+ model = model.to(runtime["device"])
1074
+ else:
1075
+ model = model.to("cpu")
1076
+ return tokenizer, model
1077
+
1078
+
1079
+ def generate_action_with_model(
1080
+ model: Any,
1081
+ tokenizer: Any,
1082
+ prompt_or_observation: str | ExperimentObservation,
1083
+ *,
1084
+ max_new_tokens: int = 220,
1085
+ temperature: float = 0.2,
1086
+ top_p: float = 0.9,
1087
+ do_sample: bool = True,
1088
+ ) -> Dict[str, Any]:
1089
+ import torch
1090
+
1091
+ if isinstance(prompt_or_observation, ExperimentObservation):
1092
+ prompt = build_training_prompt(prompt_or_observation)
1093
+ else:
1094
+ prompt = str(prompt_or_observation)
1095
+
1096
+ model_device = getattr(model, "device", None)
1097
+ if model_device is None:
1098
+ model_device = resolve_torch_runtime()["device"]
1099
+
1100
+ inputs = tokenizer(prompt, return_tensors="pt")
1101
+ inputs = {key: value.to(model_device) for key, value in inputs.items()}
1102
+ prompt_tokens = inputs["input_ids"].shape[1]
1103
+
1104
+ generation_kwargs = {
1105
+ "max_new_tokens": max_new_tokens,
1106
+ "do_sample": do_sample,
1107
+ "temperature": temperature,
1108
+ "top_p": top_p,
1109
+ "pad_token_id": tokenizer.pad_token_id,
1110
+ }
1111
+ with torch.no_grad():
1112
+ output_ids = model.generate(**inputs, **generation_kwargs)
1113
+
1114
+ new_tokens = output_ids[0][prompt_tokens:]
1115
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
1116
+ action = parse_action_completion(response_text)
1117
+ return {
1118
+ "prompt": prompt,
1119
+ "response_text": response_text,
1120
+ "action": action,
1121
+ }
1122
+
1123
+
1124
+ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
1125
+ random.seed(args.seed)
1126
+ runtime = resolve_torch_runtime()
1127
+
1128
+ if args.load_model_only:
1129
+ tokenizer, model = load_model_artifacts(
1130
+ args.model_id,
1131
+ trust_remote_code=args.trust_remote_code,
1132
+ )
1133
+ device = getattr(model, "device", "unknown")
1134
+ print(f"Model ready: {args.model_id}")
1135
+ print(f"Tokenizer vocab size: {len(tokenizer)}")
1136
+ print(f"Model device: {device}")
1137
+ print(f"Runtime device name: {runtime['device_name']}")
1138
+ return {
1139
+ "args": args,
1140
+ "runtime": runtime,
1141
+ "tokenizer": tokenizer,
1142
+ "model": model,
1143
+ }
1144
+
1145
+ scenario_names = selected_scenarios(args.scenario_name)
1146
+ examples = build_prompt_examples(
1147
+ dataset_episodes=args.dataset_episodes,
1148
+ rollout_steps=args.rollout_steps,
1149
+ collection_policy=args.collection_policy,
1150
+ scenario_names=scenario_names,
1151
+ seed=args.seed,
1152
+ domain_randomise=args.domain_randomise,
1153
+ )
1154
+ reward_fn = OpenEnvReward(
1155
+ reward_backend=args.reward_backend,
1156
+ base_url=args.base_url,
1157
+ domain_randomise=args.domain_randomise,
1158
+ )
1159
+
1160
+ if args.dry_run:
1161
+ run_dry_run_preview(examples, reward_fn, args.output_dir)
1162
+ return {
1163
+ "args": args,
1164
+ "runtime": runtime,
1165
+ "scenario_names": scenario_names,
1166
+ "examples": examples,
1167
+ "reward_fn": reward_fn,
1168
+ }
1169
+
1170
+ from datasets import Dataset
1171
+ from trl import GRPOConfig, GRPOTrainer
1172
+
1173
+ train_dataset = Dataset.from_list(examples)
1174
+ tokenizer, model = load_model_artifacts(
1175
+ args.model_id,
1176
+ trust_remote_code=args.trust_remote_code,
1177
+ )
1178
+ config = GRPOConfig(
1179
+ output_dir=args.output_dir,
1180
+ learning_rate=args.learning_rate,
1181
+ per_device_train_batch_size=args.per_device_train_batch_size,
1182
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
1183
+ num_generations=args.num_generations,
1184
+ max_completion_length=args.max_completion_length,
1185
+ num_train_epochs=args.num_train_epochs,
1186
+ logging_steps=args.logging_steps,
1187
+ save_steps=args.save_steps,
1188
+ bf16=runtime["bf16"],
1189
+ fp16=runtime["fp16"],
1190
+ report_to="none",
1191
+ remove_unused_columns=False,
1192
+ )
1193
+
1194
+ print(
1195
+ f"Training runtime: device={runtime['device']} "
1196
+ f"name={runtime['device_name']} "
1197
+ f"dtype={runtime['dtype']}"
1198
+ )
1199
+
1200
+ trainer = GRPOTrainer(
1201
+ model=model,
1202
+ reward_funcs=reward_fn,
1203
+ args=config,
1204
+ train_dataset=train_dataset,
1205
+ processing_class=tokenizer,
1206
+ )
1207
+ trainer.train()
1208
+ trainer.save_model(args.output_dir)
1209
+ tokenizer.save_pretrained(args.output_dir)
1210
+ if args.push_to_hub:
1211
+ from huggingface_hub import HfApi
1212
+ api = HfApi()
1213
+ api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
1214
+ print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
1215
+ api.upload_folder(
1216
+ folder_path=args.output_dir,
1217
+ repo_id=args.push_to_hub,
1218
+ repo_type="model",
1219
+ create_pr=False,
1220
+ )
1221
+ print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
1222
+ plot_paths = save_training_plots(
1223
+ trainer.state.log_history,
1224
+ args.output_dir,
1225
+ metric_key=args.plot_metric_key,
1226
+ )
1227
+ print("Saved training plots:")
1228
+ for plot_name, plot_path in plot_paths.items():
1229
+ print(f" - {plot_name}: {plot_path}")
1230
+
1231
+ return {
1232
+ "args": args,
1233
+ "runtime": runtime,
1234
+ "scenario_names": scenario_names,
1235
+ "examples": examples,
1236
+ "reward_fn": reward_fn,
1237
+ "train_dataset": train_dataset,
1238
+ "tokenizer": tokenizer,
1239
+ "model": model,
1240
+ "trainer": trainer,
1241
+ "plot_paths": plot_paths,
1242
+ }
1243
+
1244
+
1245
+ def main() -> None:
1246
+ run_training(parse_args())
1247
+
1248
+
1249
+ if __name__ == "__main__":
1250
+ main()
uv.lock CHANGED
The diff for this file is too large to render. See raw diff