Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +337 -10
- __init__.py +48 -0
- client.py +53 -0
- models.py +268 -0
- openenv.yaml +7 -0
- outputs/.gitkeep +1 -0
- pyproject.toml +66 -0
- run_agent.py +292 -0
- server/__init__.py +3 -0
- server/app.py +45 -0
- server/hackathon_environment.py +239 -0
- server/requirements.txt +6 -0
- server/rewards/__init__.py +3 -0
- server/rewards/reward.py +285 -0
- server/rules/__init__.py +3 -0
- server/rules/engine.py +208 -0
- server/simulator/__init__.py +25 -0
- server/simulator/latent_state.py +143 -0
- server/simulator/noise.py +124 -0
- server/simulator/output_generator.py +495 -0
- server/simulator/transition.py +216 -0
- server/subagents/__init__.py +0 -0
- server/tasks/__init__.py +4 -0
- server/tasks/generator.py +129 -0
- server/tasks/scenarios.py +454 -0
- tests/__init__.py +0 -0
- tests/test_environment.py +85 -0
- tests/test_literature_benchmark.py +36 -0
- tests/test_models.py +88 -0
- tests/test_rewards.py +105 -0
- tests/test_rules.py +79 -0
- tests/test_simulator.py +121 -0
- training/__init__.py +34 -0
- training/evaluation.py +160 -0
- training/gym_wrapper.py +174 -0
- training/literature_benchmark.py +557 -0
- training/trajectory.py +159 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=hackathon
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,337 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Bio Experiment Environment Server
|
| 3 |
+
sdk: docker
|
| 4 |
+
pinned: false
|
| 5 |
+
app_port: 8000
|
| 6 |
+
base_path: /web
|
| 7 |
+
tags:
|
| 8 |
+
- openenv
|
| 9 |
+
- reinforcement-learning
|
| 10 |
+
- bioinformatics
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Bio Experiment Environment
|
| 14 |
+
|
| 15 |
+
This repository implements an OpenEnv-compatible reinforcement learning environment for planning biological experiment pipelines. The agent does not directly see the true biological state. Instead, it proposes one structured experiment or analysis step at a time, receives a noisy simulated output, and is rewarded for valid, informative, efficient, well-calibrated plans.
|
| 16 |
+
|
| 17 |
+
The environment is designed as a partially observable Markov decision process (POMDP) with:
|
| 18 |
+
|
| 19 |
+
- hidden ground-truth biology
|
| 20 |
+
- hidden technical noise and failure conditions
|
| 21 |
+
- visible task metadata, resource usage, step history, and intermediate outputs
|
| 22 |
+
- dense step-wise reward plus terminal reward for conclusion quality
|
| 23 |
+
|
| 24 |
+
## What "how it works" means here
|
| 25 |
+
|
| 26 |
+
At a high level, each episode looks like this:
|
| 27 |
+
|
| 28 |
+
1. `reset()` picks a biological scenario and seeds the simulator.
|
| 29 |
+
2. The agent receives an `ExperimentObservation` describing the task and current visible state.
|
| 30 |
+
3. The agent submits an `ExperimentAction` such as `collect_sample`, `run_qc`, or `differential_expression`.
|
| 31 |
+
4. The rule engine checks whether the action is valid at this point in the pipeline.
|
| 32 |
+
5. The transition engine updates hidden state, spends resources, and asks the output generator to simulate the result.
|
| 33 |
+
6. The reward computer scores the step for validity, ordering, information gain, efficiency, novelty, and penalties.
|
| 34 |
+
7. The environment returns a new observation with updated history, outputs, discoveries, violations, and reward.
|
| 35 |
+
8. The episode ends when the agent synthesizes a conclusion, exhausts resources, or reaches the step limit.
|
| 36 |
+
|
| 37 |
+
## The core mental model
|
| 38 |
+
|
| 39 |
+
### Hidden state
|
| 40 |
+
|
| 41 |
+
The simulator keeps a `FullLatentState` that the agent never directly sees. It contains:
|
| 42 |
+
|
| 43 |
+
- true cell populations and marker genes
|
| 44 |
+
- true DE genes, pathways, trajectories, and regulatory networks
|
| 45 |
+
- technical factors such as dropout, doublets, ambient RNA, and batch effects
|
| 46 |
+
- experiment progress flags
|
| 47 |
+
- remaining budget and time
|
| 48 |
+
- hidden failure conditions
|
| 49 |
+
|
| 50 |
+
### Visible state
|
| 51 |
+
|
| 52 |
+
The agent only sees `ExperimentObservation`, which includes:
|
| 53 |
+
|
| 54 |
+
- the current `TaskSpec`
|
| 55 |
+
- pipeline history
|
| 56 |
+
- available assays and tools
|
| 57 |
+
- resource usage
|
| 58 |
+
- the latest and cumulative intermediate outputs
|
| 59 |
+
- discovered markers and candidate mechanisms
|
| 60 |
+
- rule violations
|
| 61 |
+
- per-step reward breakdown
|
| 62 |
+
|
| 63 |
+
This separation is what makes the environment a POMDP rather than a fully observed simulator.
|
| 64 |
+
|
| 65 |
+
## Main building blocks
|
| 66 |
+
|
| 67 |
+
### `models.py`
|
| 68 |
+
|
| 69 |
+
Defines the contracts that all other modules use:
|
| 70 |
+
|
| 71 |
+
- `ExperimentAction`: one structured step chosen by the agent
|
| 72 |
+
- `ExperimentObservation`: what the agent can see after each step
|
| 73 |
+
- `TaskSpec`: the problem statement, budget, time limit, assays, tools, and expected findings
|
| 74 |
+
- `IntermediateOutput`: the simulated artifact returned by a step
|
| 75 |
+
- `ConclusionClaim`: structured claims used for final synthesis
|
| 76 |
+
|
| 77 |
+
The action vocabulary is intentionally broad enough to mix wet-lab, computational, and meta-planning actions.
|
| 78 |
+
|
| 79 |
+
### `server/tasks/`
|
| 80 |
+
|
| 81 |
+
This is where episodes come from.
|
| 82 |
+
|
| 83 |
+
- `scenarios.py` defines a small library of curated biological scenarios
|
| 84 |
+
- `generator.py` turns a scenario into a `(TaskSpec, FullLatentState)` pair
|
| 85 |
+
- optional domain randomization perturbs budget, time, noise, batch effects, cell proportions, and effect sizes
|
| 86 |
+
|
| 87 |
+
Right now the scenario library includes:
|
| 88 |
+
|
| 89 |
+
- `cardiac_disease_de`: disease vs healthy differential expression in heart tissue
|
| 90 |
+
- `hematopoiesis_trajectory`: developmental trajectory inference in bone marrow
|
| 91 |
+
- `perturbation_immune`: treatment response under JAK inhibition
|
| 92 |
+
- `biomarker_validation_lung`: follow-up validation of `SPP1` in IPF
|
| 93 |
+
|
| 94 |
+
### `server/simulator/`
|
| 95 |
+
|
| 96 |
+
This is the simulator itself.
|
| 97 |
+
|
| 98 |
+
- `latent_state.py` defines hidden biological, technical, progress, and resource state
|
| 99 |
+
- `noise.py` centralizes stochasticity so episodes are reproducible from a seed
|
| 100 |
+
- `output_generator.py` turns an action plus hidden state into a realistic `IntermediateOutput`
|
| 101 |
+
- `transition.py` applies action costs, updates progress flags, propagates artifacts, and decides whether the episode is done
|
| 102 |
+
|
| 103 |
+
The output generator does not simply echo the action. It conditions outputs on the hidden state, then injects realistic noise such as dropout, false positives, false negatives, and imperfect clustering.
|
| 104 |
+
|
| 105 |
+
### `server/rules/engine.py`
|
| 106 |
+
|
| 107 |
+
The rule engine enforces scientific and procedural constraints before each action is applied.
|
| 108 |
+
|
| 109 |
+
- hard violations block the action entirely
|
| 110 |
+
- soft violations allow the action, but reduce output quality and add reward penalties
|
| 111 |
+
|
| 112 |
+
Examples:
|
| 113 |
+
|
| 114 |
+
- sequencing before library prep is a hard violation
|
| 115 |
+
- running QC twice is a soft redundancy violation
|
| 116 |
+
- making causal claims without enough evidence is a soft validity violation
|
| 117 |
+
|
| 118 |
+
### `server/rewards/reward.py`
|
| 119 |
+
|
| 120 |
+
Rewards are decomposed rather than being a single opaque number.
|
| 121 |
+
|
| 122 |
+
Per-step reward includes:
|
| 123 |
+
|
| 124 |
+
- validity
|
| 125 |
+
- ordering
|
| 126 |
+
- information gain
|
| 127 |
+
- efficiency
|
| 128 |
+
- novelty
|
| 129 |
+
- penalties
|
| 130 |
+
- potential-based shaping
|
| 131 |
+
|
| 132 |
+
Terminal reward adds:
|
| 133 |
+
|
| 134 |
+
- pipeline completeness
|
| 135 |
+
- calibration of conclusions against hidden truth
|
| 136 |
+
- remaining budget and time efficiency
|
| 137 |
+
- overconfidence penalties for strong but incorrect claims
|
| 138 |
+
|
| 139 |
+
This makes the environment easier to debug, benchmark, and train against.
|
| 140 |
+
|
| 141 |
+
### `server/hackathon_environment.py`
|
| 142 |
+
|
| 143 |
+
This is the orchestration layer that ties everything together.
|
| 144 |
+
|
| 145 |
+
On `reset()` it:
|
| 146 |
+
|
| 147 |
+
- seeds the noise model
|
| 148 |
+
- generates a task and latent state
|
| 149 |
+
- clears history, outputs, discoveries, conclusions, and cumulative reward
|
| 150 |
+
|
| 151 |
+
On `step()` it:
|
| 152 |
+
|
| 153 |
+
- checks rules
|
| 154 |
+
- calls the transition engine
|
| 155 |
+
- computes reward
|
| 156 |
+
- appends a `PipelineStepRecord`
|
| 157 |
+
- updates discovered markers and candidate mechanisms
|
| 158 |
+
- stores conclusion claims if the action is `synthesize_conclusion`
|
| 159 |
+
- builds the next `ExperimentObservation`
|
| 160 |
+
|
| 161 |
+
This file is the best place to read if you want the end-to-end control flow.
|
| 162 |
+
|
| 163 |
+
## What actually happens on one step
|
| 164 |
+
|
| 165 |
+
Here is the concrete order of operations for `env.step(action)`:
|
| 166 |
+
|
| 167 |
+
1. Increment the step counter.
|
| 168 |
+
2. Copy the previous latent state for reward comparison.
|
| 169 |
+
3. Run rule checks and split violations into hard vs soft.
|
| 170 |
+
4. If there is a hard violation, return a failure report without applying the action.
|
| 171 |
+
5. Otherwise deduct budget and time based on `ACTION_COSTS`.
|
| 172 |
+
6. Update latent progress flags like `samples_collected`, `qc_performed`, or `de_performed`.
|
| 173 |
+
7. Generate a structured simulated output for the chosen action.
|
| 174 |
+
8. If there were soft violations, degrade output quality and attach warnings.
|
| 175 |
+
9. Propagate artifacts back into latent state, such as discovered DE genes or cluster names.
|
| 176 |
+
10. Compute decomposed reward from state transition plus output quality.
|
| 177 |
+
11. If the episode is ending, compute terminal reward from completeness and conclusion calibration.
|
| 178 |
+
12. Return an observation that exposes the visible summary but not the hidden truth.
|
| 179 |
+
|
| 180 |
+
## Typical successful pipeline
|
| 181 |
+
|
| 182 |
+
Most scenarios reward a sensible experiment order similar to:
|
| 183 |
+
|
| 184 |
+
1. `collect_sample`
|
| 185 |
+
2. `prepare_library`
|
| 186 |
+
3. `sequence_cells`
|
| 187 |
+
4. `run_qc`
|
| 188 |
+
5. `filter_data`
|
| 189 |
+
6. `normalize_data`
|
| 190 |
+
7. `cluster_cells`
|
| 191 |
+
8. one or more of:
|
| 192 |
+
`differential_expression`, `trajectory_analysis`, `pathway_enrichment`,
|
| 193 |
+
`regulatory_network_inference`, `marker_selection`, `validate_marker`
|
| 194 |
+
9. `synthesize_conclusion`
|
| 195 |
+
|
| 196 |
+
The exact best sequence depends on the scenario. For example:
|
| 197 |
+
|
| 198 |
+
- trajectory scenarios benefit from `trajectory_analysis` and regulatory inference
|
| 199 |
+
- biomarker scenarios benefit from DE, marker selection, and validation
|
| 200 |
+
- perturbation scenarios benefit from pathway-level interpretation
|
| 201 |
+
|
| 202 |
+
## Interfaces you can use
|
| 203 |
+
|
| 204 |
+
### 1. In-process environment
|
| 205 |
+
|
| 206 |
+
Use `BioExperimentEnvironment` when you want direct Python access with full structured observations:
|
| 207 |
+
|
| 208 |
+
```python
|
| 209 |
+
from models import ActionType, ExperimentAction
|
| 210 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 211 |
+
|
| 212 |
+
env = BioExperimentEnvironment(scenario_name="biomarker_validation_lung")
|
| 213 |
+
obs = env.reset()
|
| 214 |
+
|
| 215 |
+
obs = env.step(ExperimentAction(
|
| 216 |
+
action_type=ActionType.COLLECT_SAMPLE,
|
| 217 |
+
parameters={"n_samples": 8},
|
| 218 |
+
justification="Collect enough material for downstream single-cell analysis.",
|
| 219 |
+
confidence=0.8,
|
| 220 |
+
))
|
| 221 |
+
|
| 222 |
+
print(obs.task.problem_statement)
|
| 223 |
+
print(obs.latest_output.summary if obs.latest_output else "No output yet")
|
| 224 |
+
print(obs.reward)
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
### 2. OpenEnv client/server mode
|
| 228 |
+
|
| 229 |
+
Use the FastAPI app when you want to serve the environment over HTTP and WebSocket:
|
| 230 |
+
|
| 231 |
+
```bash
|
| 232 |
+
uv sync --extra dev
|
| 233 |
+
uv run uvicorn server.app:app --reload
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
Then connect with the client:
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
from client import BioExperimentEnv
|
| 240 |
+
from models import ActionType, ExperimentAction
|
| 241 |
+
|
| 242 |
+
with BioExperimentEnv(base_url="http://localhost:8000") as env:
|
| 243 |
+
result = env.reset()
|
| 244 |
+
result = env.step(ExperimentAction(action_type=ActionType.COLLECT_SAMPLE))
|
| 245 |
+
print(result.observation.latest_output.summary)
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
The environment class supports concurrent sessions, but the bundled server is currently configured with `max_concurrent_envs=1` in `server/app.py`.
|
| 249 |
+
|
| 250 |
+
### 3. Gymnasium wrapper
|
| 251 |
+
|
| 252 |
+
Use `training/gym_wrapper.py` when you want a classic RL interface:
|
| 253 |
+
|
| 254 |
+
```python
|
| 255 |
+
from training.gym_wrapper import BioExperimentGymEnv
|
| 256 |
+
|
| 257 |
+
env = BioExperimentGymEnv()
|
| 258 |
+
obs, info = env.reset()
|
| 259 |
+
obs, reward, terminated, truncated, info = env.step({
|
| 260 |
+
"action_type": 0,
|
| 261 |
+
"confidence": 0.7,
|
| 262 |
+
})
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
This wrapper vectorizes the structured observation into arrays and reduces the action interface to:
|
| 266 |
+
|
| 267 |
+
- a discrete action type index
|
| 268 |
+
- a scalar confidence value
|
| 269 |
+
|
| 270 |
+
### 4. Benchmark and scripted agents
|
| 271 |
+
|
| 272 |
+
- `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
|
| 273 |
+
- `run_agent.py` runs a local language model planner against the environment
|
| 274 |
+
- `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
|
| 275 |
+
- `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
|
| 276 |
+
|
| 277 |
+
## Episode termination
|
| 278 |
+
|
| 279 |
+
An episode ends when one of the following happens:
|
| 280 |
+
|
| 281 |
+
- the agent chooses `synthesize_conclusion`
|
| 282 |
+
- resources are exhausted
|
| 283 |
+
- the environment reaches `MAX_STEPS` which is currently `30`
|
| 284 |
+
|
| 285 |
+
## Why this is useful
|
| 286 |
+
|
| 287 |
+
This environment is trying to model a realistic scientific planning loop rather than a toy decision problem:
|
| 288 |
+
|
| 289 |
+
- actions have prerequisites
|
| 290 |
+
- outputs are noisy and imperfect
|
| 291 |
+
- budget and time matter
|
| 292 |
+
- not every correct-looking answer is well supported
|
| 293 |
+
- final conclusions are scored against hidden ground truth
|
| 294 |
+
|
| 295 |
+
That makes it suitable for:
|
| 296 |
+
|
| 297 |
+
- agent planning benchmarks
|
| 298 |
+
- RL experiments on long-horizon scientific reasoning
|
| 299 |
+
- literature-grounded evaluation
|
| 300 |
+
- comparing structured policies against LLM-driven planners
|
| 301 |
+
|
| 302 |
+
## Minimal project map
|
| 303 |
+
|
| 304 |
+
```text
|
| 305 |
+
.
|
| 306 |
+
βββ client.py # OpenEnv client
|
| 307 |
+
βββ models.py # Shared action / observation / task schemas
|
| 308 |
+
βββ server/
|
| 309 |
+
β βββ app.py # FastAPI/OpenEnv server
|
| 310 |
+
β βββ hackathon_environment.py # Main environment orchestration
|
| 311 |
+
β βββ rewards/ # Reward model
|
| 312 |
+
β βββ rules/ # Constraint checking
|
| 313 |
+
β βββ simulator/ # Latent state, noise, outputs, transitions
|
| 314 |
+
β βββ tasks/ # Scenario library and task generation
|
| 315 |
+
βββ training/
|
| 316 |
+
β βββ evaluation.py # Metrics
|
| 317 |
+
β βββ gym_wrapper.py # Gymnasium wrapper
|
| 318 |
+
β βββ literature_benchmark.py # Paper-backed benchmark flow
|
| 319 |
+
β βββ trajectory.py # Trajectory serialization
|
| 320 |
+
βββ tests/ # Unit and integration tests
|
| 321 |
+
```
|
| 322 |
+
|
| 323 |
+
## Quick sanity check
|
| 324 |
+
|
| 325 |
+
The current implementation was sanity-checked with:
|
| 326 |
+
|
| 327 |
+
```bash
|
| 328 |
+
uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
Those tests verify:
|
| 332 |
+
|
| 333 |
+
- reset and step lifecycle
|
| 334 |
+
- valid vs invalid pipeline behavior
|
| 335 |
+
- conclusion termination
|
| 336 |
+
- literature-backed scenario selection
|
| 337 |
+
- benchmark matching for curated expected findings
|
__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # pragma: no cover - package import path
|
| 2 |
+
from .client import BioExperimentEnv
|
| 3 |
+
from .models import (
|
| 4 |
+
ActionType,
|
| 5 |
+
ConclusionClaim,
|
| 6 |
+
ExpectedFinding,
|
| 7 |
+
ExperimentAction,
|
| 8 |
+
ExperimentObservation,
|
| 9 |
+
IntermediateOutput,
|
| 10 |
+
OutputType,
|
| 11 |
+
PaperReference,
|
| 12 |
+
PipelineStepRecord,
|
| 13 |
+
ResourceUsage,
|
| 14 |
+
SubagentType,
|
| 15 |
+
TaskSpec,
|
| 16 |
+
)
|
| 17 |
+
except ImportError: # pragma: no cover - direct module import path
|
| 18 |
+
from client import BioExperimentEnv
|
| 19 |
+
from models import (
|
| 20 |
+
ActionType,
|
| 21 |
+
ConclusionClaim,
|
| 22 |
+
ExpectedFinding,
|
| 23 |
+
ExperimentAction,
|
| 24 |
+
ExperimentObservation,
|
| 25 |
+
IntermediateOutput,
|
| 26 |
+
OutputType,
|
| 27 |
+
PaperReference,
|
| 28 |
+
PipelineStepRecord,
|
| 29 |
+
ResourceUsage,
|
| 30 |
+
SubagentType,
|
| 31 |
+
TaskSpec,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
"ActionType",
|
| 36 |
+
"BioExperimentEnv",
|
| 37 |
+
"ConclusionClaim",
|
| 38 |
+
"ExpectedFinding",
|
| 39 |
+
"ExperimentAction",
|
| 40 |
+
"ExperimentObservation",
|
| 41 |
+
"IntermediateOutput",
|
| 42 |
+
"OutputType",
|
| 43 |
+
"PaperReference",
|
| 44 |
+
"PipelineStepRecord",
|
| 45 |
+
"ResourceUsage",
|
| 46 |
+
"SubagentType",
|
| 47 |
+
"TaskSpec",
|
| 48 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bio-Experiment Environment Client.
|
| 2 |
+
|
| 3 |
+
Provides the ``BioExperimentEnv`` class that communicates with the
|
| 4 |
+
environment server over WebSocket / HTTP using the OpenEnv protocol.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
from openenv.core.client_types import StepResult
|
| 10 |
+
from openenv.core.env_server.types import State
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
|
| 13 |
+
try: # pragma: no cover - package import path
|
| 14 |
+
from .models import ExperimentAction, ExperimentObservation
|
| 15 |
+
except ImportError: # pragma: no cover - direct module import path
|
| 16 |
+
from models import ExperimentAction, ExperimentObservation
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BioExperimentEnv(
|
| 20 |
+
EnvClient[ExperimentAction, ExperimentObservation, State]
|
| 21 |
+
):
|
| 22 |
+
"""Client for the Bio-Experiment Planning Environment.
|
| 23 |
+
|
| 24 |
+
Example:
|
| 25 |
+
>>> with BioExperimentEnv(base_url="http://localhost:8000") as env:
|
| 26 |
+
... result = env.reset()
|
| 27 |
+
... print(result.observation.task.problem_statement)
|
| 28 |
+
... result = env.step(ExperimentAction(
|
| 29 |
+
... action_type="collect_sample",
|
| 30 |
+
... parameters={"n_samples": 6},
|
| 31 |
+
... ))
|
| 32 |
+
... print(result.observation.latest_output.summary)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def _step_payload(self, action: ExperimentAction) -> Dict:
|
| 36 |
+
return action.model_dump()
|
| 37 |
+
|
| 38 |
+
def _parse_result(
|
| 39 |
+
self, payload: Dict
|
| 40 |
+
) -> StepResult[ExperimentObservation]:
|
| 41 |
+
obs_data = payload.get("observation", {})
|
| 42 |
+
observation = ExperimentObservation(**obs_data)
|
| 43 |
+
return StepResult(
|
| 44 |
+
observation=observation,
|
| 45 |
+
reward=payload.get("reward"),
|
| 46 |
+
done=payload.get("done", False),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 50 |
+
return State(
|
| 51 |
+
episode_id=payload.get("episode_id"),
|
| 52 |
+
step_count=payload.get("step_count", 0),
|
| 53 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for the Bio-Experiment Planning RL Environment.
|
| 3 |
+
|
| 4 |
+
Defines the POMDP action and observation contracts for a scientific agent
|
| 5 |
+
that constructs biological experiment pipelines step-by-step.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Any, Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
|
| 15 |
+
from openenv.core.env_server.types import Action, Observation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ββ Action vocabulary βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ActionType(str, Enum):
|
| 22 |
+
COLLECT_SAMPLE = "collect_sample"
|
| 23 |
+
SELECT_COHORT = "select_cohort"
|
| 24 |
+
PREPARE_LIBRARY = "prepare_library"
|
| 25 |
+
CULTURE_CELLS = "culture_cells"
|
| 26 |
+
PERTURB_GENE = "perturb_gene"
|
| 27 |
+
PERTURB_COMPOUND = "perturb_compound"
|
| 28 |
+
SEQUENCE_CELLS = "sequence_cells"
|
| 29 |
+
RUN_QC = "run_qc"
|
| 30 |
+
FILTER_DATA = "filter_data"
|
| 31 |
+
NORMALIZE_DATA = "normalize_data"
|
| 32 |
+
INTEGRATE_BATCHES = "integrate_batches"
|
| 33 |
+
CLUSTER_CELLS = "cluster_cells"
|
| 34 |
+
DIFFERENTIAL_EXPRESSION = "differential_expression"
|
| 35 |
+
TRAJECTORY_ANALYSIS = "trajectory_analysis"
|
| 36 |
+
PATHWAY_ENRICHMENT = "pathway_enrichment"
|
| 37 |
+
REGULATORY_NETWORK_INFERENCE = "regulatory_network_inference"
|
| 38 |
+
MARKER_SELECTION = "marker_selection"
|
| 39 |
+
VALIDATE_MARKER = "validate_marker"
|
| 40 |
+
DESIGN_FOLLOWUP = "design_followup_experiment"
|
| 41 |
+
REQUEST_SUBAGENT_REVIEW = "request_subagent_review"
|
| 42 |
+
SYNTHESIZE_CONCLUSION = "synthesize_conclusion"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
WET_LAB_ACTIONS = frozenset({
|
| 46 |
+
ActionType.COLLECT_SAMPLE,
|
| 47 |
+
ActionType.SELECT_COHORT,
|
| 48 |
+
ActionType.PREPARE_LIBRARY,
|
| 49 |
+
ActionType.CULTURE_CELLS,
|
| 50 |
+
ActionType.PERTURB_GENE,
|
| 51 |
+
ActionType.PERTURB_COMPOUND,
|
| 52 |
+
ActionType.SEQUENCE_CELLS,
|
| 53 |
+
ActionType.VALIDATE_MARKER,
|
| 54 |
+
})
|
| 55 |
+
|
| 56 |
+
COMPUTATIONAL_ACTIONS = frozenset({
|
| 57 |
+
ActionType.RUN_QC,
|
| 58 |
+
ActionType.FILTER_DATA,
|
| 59 |
+
ActionType.NORMALIZE_DATA,
|
| 60 |
+
ActionType.INTEGRATE_BATCHES,
|
| 61 |
+
ActionType.CLUSTER_CELLS,
|
| 62 |
+
ActionType.DIFFERENTIAL_EXPRESSION,
|
| 63 |
+
ActionType.TRAJECTORY_ANALYSIS,
|
| 64 |
+
ActionType.PATHWAY_ENRICHMENT,
|
| 65 |
+
ActionType.REGULATORY_NETWORK_INFERENCE,
|
| 66 |
+
ActionType.MARKER_SELECTION,
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
META_ACTIONS = frozenset({
|
| 70 |
+
ActionType.DESIGN_FOLLOWUP,
|
| 71 |
+
ActionType.REQUEST_SUBAGENT_REVIEW,
|
| 72 |
+
ActionType.SYNTHESIZE_CONCLUSION,
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class SubagentType(str, Enum):
|
| 77 |
+
WET_LAB_PLANNER = "wet_lab_planner"
|
| 78 |
+
COMPUTATIONAL_ANALYST = "computational_analyst"
|
| 79 |
+
OMICS_QC_AGENT = "omics_qc_agent"
|
| 80 |
+
CAUSAL_REASONING_AGENT = "causal_reasoning_agent"
|
| 81 |
+
BUDGET_SCHEDULER = "budget_scheduler"
|
| 82 |
+
BIOLOGICAL_RULE_CHECKER = "biological_rule_checker"
|
| 83 |
+
TOOL_EXECUTOR = "tool_executor"
|
| 84 |
+
RETROSPECTIVE_CRITIC = "retrospective_critic"
|
| 85 |
+
REPORT_SYNTHESIZER = "report_synthesizer"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ββ Action schema βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ExperimentAction(Action):
|
| 92 |
+
"""Structured, compositional action for one experiment / analysis step.
|
| 93 |
+
|
| 94 |
+
Hybrid representation: discrete *action_type* plus typed arguments,
|
| 95 |
+
optional sub-agent / tool invocation, and calibration fields.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
action_type: ActionType = Field(
|
| 99 |
+
..., description="Discrete experiment or analysis step type"
|
| 100 |
+
)
|
| 101 |
+
input_targets: List[str] = Field(
|
| 102 |
+
default_factory=list,
|
| 103 |
+
description="References to prior outputs, samples, or artifacts",
|
| 104 |
+
)
|
| 105 |
+
method: Optional[str] = Field(
|
| 106 |
+
None, description="Specific method or tool (e.g. 'Seurat', 'CellRanger')"
|
| 107 |
+
)
|
| 108 |
+
parameters: Dict[str, Any] = Field(
|
| 109 |
+
default_factory=dict, description="Method-specific parameters"
|
| 110 |
+
)
|
| 111 |
+
expected_output_type: Optional[str] = Field(
|
| 112 |
+
None, description="What the agent expects this step to produce"
|
| 113 |
+
)
|
| 114 |
+
justification: Optional[str] = Field(
|
| 115 |
+
None, description="Scientific rationale for this step"
|
| 116 |
+
)
|
| 117 |
+
invoked_subagent: Optional[SubagentType] = Field(
|
| 118 |
+
None, description="Sub-agent to delegate to, if any"
|
| 119 |
+
)
|
| 120 |
+
tool_call_spec: Optional[Dict[str, Any]] = Field(
|
| 121 |
+
None, description="Structured tool invocation specification"
|
| 122 |
+
)
|
| 123 |
+
confidence: float = Field(
|
| 124 |
+
0.5, ge=0.0, le=1.0, description="Agent confidence in this step"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ββ Intermediate outputs ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class OutputType(str, Enum):
|
| 132 |
+
QC_METRICS = "qc_metrics"
|
| 133 |
+
COUNT_MATRIX_SUMMARY = "count_matrix_summary"
|
| 134 |
+
EMBEDDING_SUMMARY = "embedding_summary"
|
| 135 |
+
CLUSTER_RESULT = "cluster_result"
|
| 136 |
+
DE_RESULT = "de_result"
|
| 137 |
+
PATHWAY_RESULT = "pathway_result"
|
| 138 |
+
TRAJECTORY_RESULT = "trajectory_result"
|
| 139 |
+
VALIDATION_RESULT = "validation_result"
|
| 140 |
+
NETWORK_RESULT = "network_result"
|
| 141 |
+
SAMPLE_COLLECTION_RESULT = "sample_collection_result"
|
| 142 |
+
LIBRARY_PREP_RESULT = "library_prep_result"
|
| 143 |
+
SEQUENCING_RESULT = "sequencing_result"
|
| 144 |
+
PERTURBATION_RESULT = "perturbation_result"
|
| 145 |
+
CULTURE_RESULT = "culture_result"
|
| 146 |
+
COHORT_RESULT = "cohort_result"
|
| 147 |
+
FOLLOWUP_DESIGN = "followup_design"
|
| 148 |
+
MARKER_RESULT = "marker_result"
|
| 149 |
+
FAILURE_REPORT = "failure_report"
|
| 150 |
+
SUBAGENT_REPORT = "subagent_report"
|
| 151 |
+
CONCLUSION = "conclusion"
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class IntermediateOutput(BaseModel):
|
| 155 |
+
"""A single simulated output from one pipeline step."""
|
| 156 |
+
|
| 157 |
+
output_type: OutputType
|
| 158 |
+
step_index: int
|
| 159 |
+
success: bool = True
|
| 160 |
+
quality_score: float = Field(1.0, ge=0.0, le=1.0)
|
| 161 |
+
summary: str = ""
|
| 162 |
+
data: Dict[str, Any] = Field(default_factory=dict)
|
| 163 |
+
uncertainty: float = Field(0.0, ge=0.0, le=1.0)
|
| 164 |
+
warnings: List[str] = Field(default_factory=list)
|
| 165 |
+
artifacts_available: List[str] = Field(default_factory=list)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ββ Observable state components βββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ResourceUsage(BaseModel):
|
| 172 |
+
budget_used: float = 0.0
|
| 173 |
+
budget_remaining: float = 100_000.0
|
| 174 |
+
time_used_days: float = 0.0
|
| 175 |
+
time_remaining_days: float = 180.0
|
| 176 |
+
samples_consumed: int = 0
|
| 177 |
+
compute_hours_used: float = 0.0
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class PipelineStepRecord(BaseModel):
|
| 181 |
+
step_index: int
|
| 182 |
+
action_type: ActionType
|
| 183 |
+
method: Optional[str] = None
|
| 184 |
+
parameters: Dict[str, Any] = Field(default_factory=dict)
|
| 185 |
+
output_summary: str = ""
|
| 186 |
+
output_type: OutputType
|
| 187 |
+
success: bool = True
|
| 188 |
+
quality_score: float = 1.0
|
| 189 |
+
resource_cost: float = 0.0
|
| 190 |
+
time_cost_days: float = 0.0
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class PaperReference(BaseModel):
|
| 194 |
+
"""Metadata for a literature source used to ground a task."""
|
| 195 |
+
|
| 196 |
+
title: str
|
| 197 |
+
citation: Optional[str] = None
|
| 198 |
+
doi: Optional[str] = None
|
| 199 |
+
pmid: Optional[str] = None
|
| 200 |
+
url: Optional[str] = None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class ExpectedFinding(BaseModel):
|
| 204 |
+
"""A paper-backed result that the agent should try to recover."""
|
| 205 |
+
|
| 206 |
+
finding: str
|
| 207 |
+
category: str = "claim"
|
| 208 |
+
keywords: List[str] = Field(default_factory=list)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class TaskSpec(BaseModel):
|
| 212 |
+
"""Specification of the biological problem to solve."""
|
| 213 |
+
|
| 214 |
+
problem_statement: str = "Unspecified biological problem"
|
| 215 |
+
modality: str = "scRNA-seq"
|
| 216 |
+
organism: str = "human"
|
| 217 |
+
tissue: str = "blood"
|
| 218 |
+
conditions: List[str] = Field(default_factory=list)
|
| 219 |
+
available_assays: List[str] = Field(default_factory=lambda: [
|
| 220 |
+
"10x_chromium", "smart-seq2", "bulk_rna_seq",
|
| 221 |
+
"atac-seq", "cite-seq", "spatial_transcriptomics",
|
| 222 |
+
])
|
| 223 |
+
available_tools: List[str] = Field(default_factory=lambda: [
|
| 224 |
+
"CellRanger", "Seurat", "Scanpy", "DESeq2", "GSEA",
|
| 225 |
+
"Monocle", "scVelo", "CellChat", "SCENIC",
|
| 226 |
+
])
|
| 227 |
+
budget_limit: float = 100_000.0
|
| 228 |
+
time_limit_days: float = 180.0
|
| 229 |
+
prior_observations: List[str] = Field(default_factory=list)
|
| 230 |
+
success_criteria: List[str] = Field(default_factory=list)
|
| 231 |
+
dataset_metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 232 |
+
paper_references: List[PaperReference] = Field(default_factory=list)
|
| 233 |
+
expected_findings: List[ExpectedFinding] = Field(default_factory=list)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ConclusionClaim(BaseModel):
|
| 237 |
+
claim: str
|
| 238 |
+
evidence_steps: List[int] = Field(default_factory=list)
|
| 239 |
+
confidence: float = Field(0.5, ge=0.0, le=1.0)
|
| 240 |
+
claim_type: str = "correlational"
|
| 241 |
+
supporting_data: Dict[str, Any] = Field(default_factory=dict)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ββ Observation schema ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class ExperimentObservation(Observation):
|
| 248 |
+
"""Full observable state returned to the agent at each timestep.
|
| 249 |
+
|
| 250 |
+
Deliberately excludes hidden latent biological truth, hidden failure
|
| 251 |
+
conditions, and ground-truth mechanisms.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
task: TaskSpec = Field(default_factory=TaskSpec)
|
| 255 |
+
step_index: int = 0
|
| 256 |
+
pipeline_history: List[PipelineStepRecord] = Field(default_factory=list)
|
| 257 |
+
available_assays: List[str] = Field(default_factory=list)
|
| 258 |
+
available_tools: List[str] = Field(default_factory=list)
|
| 259 |
+
resource_usage: ResourceUsage = Field(default_factory=ResourceUsage)
|
| 260 |
+
latest_output: Optional[IntermediateOutput] = None
|
| 261 |
+
all_outputs: List[IntermediateOutput] = Field(default_factory=list)
|
| 262 |
+
discovered_markers: List[str] = Field(default_factory=list)
|
| 263 |
+
candidate_mechanisms: List[str] = Field(default_factory=list)
|
| 264 |
+
uncertainty_summary: Dict[str, float] = Field(default_factory=dict)
|
| 265 |
+
subagent_outputs: List[Dict[str, Any]] = Field(default_factory=list)
|
| 266 |
+
conclusions: List[ConclusionClaim] = Field(default_factory=list)
|
| 267 |
+
rule_violations: List[str] = Field(default_factory=list)
|
| 268 |
+
step_reward_breakdown: Dict[str, float] = Field(default_factory=dict)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: hackathon
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
outputs/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-bio-experiment"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "RL environment for biological experiment pipeline planning"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
"openenv-core[core]>=0.2.0",
|
| 18 |
+
"numpy>=1.24.0",
|
| 19 |
+
"scipy>=1.10.0",
|
| 20 |
+
"pydantic>=2.0.0",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.optional-dependencies]
|
| 24 |
+
train = [
|
| 25 |
+
"gymnasium>=0.29.0",
|
| 26 |
+
]
|
| 27 |
+
bio = [
|
| 28 |
+
"biopython>=1.84",
|
| 29 |
+
"gseapy>=1.1.3",
|
| 30 |
+
"scanpy>=1.10.0",
|
| 31 |
+
]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
"gymnasium>=0.29.0",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[project.scripts]
|
| 39 |
+
server = "hackathon.server.app:main"
|
| 40 |
+
|
| 41 |
+
[tool.uv]
|
| 42 |
+
package = false
|
| 43 |
+
|
| 44 |
+
[tool.setuptools]
|
| 45 |
+
include-package-data = true
|
| 46 |
+
packages = [
|
| 47 |
+
"hackathon",
|
| 48 |
+
"hackathon.server",
|
| 49 |
+
"hackathon.server.simulator",
|
| 50 |
+
"hackathon.server.rules",
|
| 51 |
+
"hackathon.server.rewards",
|
| 52 |
+
"hackathon.server.tasks",
|
| 53 |
+
"hackathon.server.subagents",
|
| 54 |
+
"hackathon.training",
|
| 55 |
+
"hackathon.tests",
|
| 56 |
+
]
|
| 57 |
+
[tool.setuptools.package-dir]
|
| 58 |
+
hackathon = "."
|
| 59 |
+
"hackathon.server" = "server"
|
| 60 |
+
"hackathon.server.simulator" = "server/simulator"
|
| 61 |
+
"hackathon.server.rules" = "server/rules"
|
| 62 |
+
"hackathon.server.rewards" = "server/rewards"
|
| 63 |
+
"hackathon.server.tasks" = "server/tasks"
|
| 64 |
+
"hackathon.server.subagents" = "server/subagents"
|
| 65 |
+
"hackathon.training" = "training"
|
| 66 |
+
"hackathon.tests" = "tests"
|
run_agent.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the bio-experiment environment with Qwen3.5-2B as the planning agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 13 |
+
|
| 14 |
+
from models import ActionType, ExperimentAction, ExperimentObservation
|
| 15 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 16 |
+
|
| 17 |
+
MODEL_ID = "Qwen/Qwen3.5-0.8B"
|
| 18 |
+
MAX_EPISODE_STEPS = 12
|
| 19 |
+
PIPELINE_TASK = "image-text-to-text"
|
| 20 |
+
USE_PIPELINE = True
|
| 21 |
+
|
| 22 |
+
ACTION_TYPES = [a.value for a in ActionType]
|
| 23 |
+
|
| 24 |
+
SYSTEM_PROMPT = """\
|
| 25 |
+
You are an expert biologist planning a single-cell experiment pipeline.
|
| 26 |
+
|
| 27 |
+
At each turn you see the experiment state and must pick the next step.
|
| 28 |
+
|
| 29 |
+
Action types (in typical order):
|
| 30 |
+
collect_sample, prepare_library, sequence_cells, run_qc, filter_data,
|
| 31 |
+
normalize_data, cluster_cells, differential_expression,
|
| 32 |
+
pathway_enrichment, marker_selection, validate_marker, synthesize_conclusion
|
| 33 |
+
|
| 34 |
+
Other actions: select_cohort, culture_cells, perturb_gene, perturb_compound,
|
| 35 |
+
integrate_batches, trajectory_analysis, regulatory_network_inference,
|
| 36 |
+
design_followup_experiment, request_subagent_review
|
| 37 |
+
|
| 38 |
+
Respond with ONLY valid JSON, nothing else:
|
| 39 |
+
{"action_type": "...", "method": null, "parameters": {}, "justification": "...", "confidence": 0.8}
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def format_observation(obs: ExperimentObservation) -> str:
|
| 44 |
+
parts = [
|
| 45 |
+
f"TASK: {obs.task.problem_statement}",
|
| 46 |
+
f"Organism: {obs.task.organism} | Tissue: {obs.task.tissue}",
|
| 47 |
+
f"Conditions: {', '.join(obs.task.conditions) or 'N/A'}",
|
| 48 |
+
f"Step: {obs.step_index} | Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d",
|
| 49 |
+
]
|
| 50 |
+
if obs.pipeline_history:
|
| 51 |
+
last5 = obs.pipeline_history[-5:]
|
| 52 |
+
parts.append("History:")
|
| 53 |
+
for h in last5:
|
| 54 |
+
tag = "OK" if h.success else "FAIL"
|
| 55 |
+
parts.append(f" [{tag}] {h.action_type.value}: {h.output_summary[:80]}")
|
| 56 |
+
if obs.rule_violations:
|
| 57 |
+
parts.append(f"VIOLATIONS: {obs.rule_violations}")
|
| 58 |
+
if obs.discovered_markers:
|
| 59 |
+
parts.append(f"Markers: {obs.discovered_markers[:5]}")
|
| 60 |
+
return "\n".join(parts)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def parse_action(text: str) -> Optional[ExperimentAction]:
|
| 64 |
+
match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
|
| 65 |
+
if not match:
|
| 66 |
+
return None
|
| 67 |
+
try:
|
| 68 |
+
d = json.loads(match.group())
|
| 69 |
+
except json.JSONDecodeError:
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
action_type = d.get("action_type")
|
| 73 |
+
if action_type not in ACTION_TYPES:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
return ExperimentAction(
|
| 77 |
+
action_type=ActionType(action_type),
|
| 78 |
+
method=d.get("method"),
|
| 79 |
+
parameters=d.get("parameters") or {},
|
| 80 |
+
justification=d.get("justification"),
|
| 81 |
+
confidence=min(1.0, max(0.0, float(d.get("confidence", 0.5)))),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
FALLBACK_SEQUENCE = [
|
| 86 |
+
ActionType.COLLECT_SAMPLE,
|
| 87 |
+
ActionType.PREPARE_LIBRARY,
|
| 88 |
+
ActionType.SEQUENCE_CELLS,
|
| 89 |
+
ActionType.RUN_QC,
|
| 90 |
+
ActionType.FILTER_DATA,
|
| 91 |
+
ActionType.NORMALIZE_DATA,
|
| 92 |
+
ActionType.CLUSTER_CELLS,
|
| 93 |
+
ActionType.DIFFERENTIAL_EXPRESSION,
|
| 94 |
+
ActionType.PATHWAY_ENRICHMENT,
|
| 95 |
+
ActionType.MARKER_SELECTION,
|
| 96 |
+
ActionType.SYNTHESIZE_CONCLUSION,
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def fallback_action(step: int) -> ExperimentAction:
|
| 101 |
+
idx = min(step, len(FALLBACK_SEQUENCE) - 1)
|
| 102 |
+
return ExperimentAction(
|
| 103 |
+
action_type=FALLBACK_SEQUENCE[idx],
|
| 104 |
+
justification="fallback",
|
| 105 |
+
confidence=0.3,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def log(msg: str) -> None:
|
| 110 |
+
print(msg, flush=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_observation_prompt(obs: ExperimentObservation) -> str:
|
| 114 |
+
return format_observation(obs)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def run_with_pipeline(pipe, prompt: str) -> str:
|
| 118 |
+
attempts = [
|
| 119 |
+
{"text": prompt},
|
| 120 |
+
{"text": prompt, "image": None},
|
| 121 |
+
{"image": prompt},
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
for payload in attempts:
|
| 125 |
+
try:
|
| 126 |
+
result = pipe(payload, max_new_tokens=220)
|
| 127 |
+
if isinstance(result, list) and result:
|
| 128 |
+
result = result[0]
|
| 129 |
+
if isinstance(result, dict):
|
| 130 |
+
text = result.get("generated_text") or result.get("text") or result.get("answer")
|
| 131 |
+
elif isinstance(result, str):
|
| 132 |
+
text = result
|
| 133 |
+
else:
|
| 134 |
+
text = ""
|
| 135 |
+
if isinstance(text, str) and text.strip():
|
| 136 |
+
return text.strip()
|
| 137 |
+
except Exception:
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
return ""
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
tokenizer = None
|
| 145 |
+
model = None
|
| 146 |
+
eos_ids: List[int] = []
|
| 147 |
+
active_pipeline = None
|
| 148 |
+
|
| 149 |
+
if USE_PIPELINE:
|
| 150 |
+
log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...")
|
| 151 |
+
try:
|
| 152 |
+
active_pipeline = pipeline(
|
| 153 |
+
PIPELINE_TASK,
|
| 154 |
+
model=MODEL_ID,
|
| 155 |
+
trust_remote_code=True,
|
| 156 |
+
torch_dtype=torch.bfloat16,
|
| 157 |
+
)
|
| 158 |
+
log("Pipeline loaded.")
|
| 159 |
+
except Exception as exc:
|
| 160 |
+
log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.")
|
| 161 |
+
|
| 162 |
+
if active_pipeline is None:
|
| 163 |
+
log(f"Loading tokenizer for {MODEL_ID} ...")
|
| 164 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 165 |
+
MODEL_ID, trust_remote_code=True,
|
| 166 |
+
)
|
| 167 |
+
log("Tokenizer loaded. Loading model (this downloads ~4 GB on first run) ...")
|
| 168 |
+
|
| 169 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 170 |
+
MODEL_ID,
|
| 171 |
+
torch_dtype=torch.bfloat16,
|
| 172 |
+
device_map="auto",
|
| 173 |
+
trust_remote_code=True,
|
| 174 |
+
)
|
| 175 |
+
log(f"Model loaded. Device: {model.device}")
|
| 176 |
+
|
| 177 |
+
if tokenizer.eos_token_id is not None:
|
| 178 |
+
eos_ids.append(tokenizer.eos_token_id)
|
| 179 |
+
extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
|
| 180 |
+
for tid in extra:
|
| 181 |
+
if isinstance(tid, int) and tid not in eos_ids:
|
| 182 |
+
eos_ids.append(tid)
|
| 183 |
+
log(f"EOS token ids: {eos_ids}")
|
| 184 |
+
|
| 185 |
+
env = BioExperimentEnvironment()
|
| 186 |
+
obs = env.reset()
|
| 187 |
+
|
| 188 |
+
log("\n" + "=" * 70)
|
| 189 |
+
log(f"TASK: {obs.task.problem_statement}")
|
| 190 |
+
log(f"Conditions: {obs.task.conditions}")
|
| 191 |
+
log(f"Budget: ${obs.task.budget_limit:,.0f} | Time: {obs.task.time_limit_days:.0f} days")
|
| 192 |
+
log("=" * 70)
|
| 193 |
+
|
| 194 |
+
cumulative_reward = 0.0
|
| 195 |
+
|
| 196 |
+
for step in range(MAX_EPISODE_STEPS):
|
| 197 |
+
user_msg = build_observation_prompt(obs)
|
| 198 |
+
|
| 199 |
+
messages = [
|
| 200 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 201 |
+
{"role": "user", "content": user_msg},
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
if tokenizer is None:
|
| 205 |
+
# Pipeline path usually ignores chat templates.
|
| 206 |
+
prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}"
|
| 207 |
+
else:
|
| 208 |
+
try:
|
| 209 |
+
prompt = tokenizer.apply_chat_template(
|
| 210 |
+
messages,
|
| 211 |
+
tokenize=False,
|
| 212 |
+
add_generation_prompt=True,
|
| 213 |
+
enable_thinking=False,
|
| 214 |
+
)
|
| 215 |
+
except TypeError:
|
| 216 |
+
prompt = tokenizer.apply_chat_template(
|
| 217 |
+
messages,
|
| 218 |
+
tokenize=False,
|
| 219 |
+
add_generation_prompt=True,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
t0 = time.time()
|
| 223 |
+
if active_pipeline is not None:
|
| 224 |
+
response = run_with_pipeline(active_pipeline, prompt)
|
| 225 |
+
if not response:
|
| 226 |
+
response = format_observation(obs)
|
| 227 |
+
else:
|
| 228 |
+
assert tokenizer is not None and model is not None
|
| 229 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 230 |
+
n_input = inputs["input_ids"].shape[1]
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
output_ids = model.generate(
|
| 233 |
+
**inputs,
|
| 234 |
+
max_new_tokens=200,
|
| 235 |
+
do_sample=True,
|
| 236 |
+
temperature=0.7,
|
| 237 |
+
top_p=0.8,
|
| 238 |
+
top_k=20,
|
| 239 |
+
repetition_penalty=1.3,
|
| 240 |
+
eos_token_id=eos_ids if eos_ids else None,
|
| 241 |
+
)
|
| 242 |
+
new_tokens = output_ids[0][n_input:]
|
| 243 |
+
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 244 |
+
gen_time = time.time() - t0
|
| 245 |
+
|
| 246 |
+
action = parse_action(response)
|
| 247 |
+
used_fallback = False
|
| 248 |
+
if action is None:
|
| 249 |
+
log(f"\n [!] Parse failed, using fallback. Raw: {response[:150]}")
|
| 250 |
+
action = fallback_action(step)
|
| 251 |
+
used_fallback = True
|
| 252 |
+
|
| 253 |
+
tag = " [FALLBACK]" if used_fallback else ""
|
| 254 |
+
log(f"\nStep {step + 1}: {action.action_type.value}{tag} ({gen_time:.1f}s)")
|
| 255 |
+
if action.justification:
|
| 256 |
+
log(f" Rationale: {action.justification}")
|
| 257 |
+
|
| 258 |
+
obs = env.step(action)
|
| 259 |
+
|
| 260 |
+
if obs.latest_output:
|
| 261 |
+
lo = obs.latest_output
|
| 262 |
+
status = "OK" if lo.success else "FAIL"
|
| 263 |
+
log(f" [{status}] {lo.summary}")
|
| 264 |
+
if lo.warnings:
|
| 265 |
+
log(f" Warnings: {lo.warnings}")
|
| 266 |
+
|
| 267 |
+
step_reward = obs.reward
|
| 268 |
+
cumulative_reward += step_reward
|
| 269 |
+
log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
|
| 270 |
+
log(f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | Time: {obs.resource_usage.time_remaining_days:.0f}d")
|
| 271 |
+
|
| 272 |
+
if obs.rule_violations:
|
| 273 |
+
log(f" Violations: {obs.rule_violations}")
|
| 274 |
+
|
| 275 |
+
if obs.done:
|
| 276 |
+
break
|
| 277 |
+
|
| 278 |
+
log(f"\n{'=' * 70}")
|
| 279 |
+
log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
|
| 280 |
+
log(f" Steps: {obs.step_index}")
|
| 281 |
+
log(f" Total reward: {cumulative_reward:+.3f}")
|
| 282 |
+
log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
|
| 283 |
+
log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
|
| 284 |
+
if obs.conclusions:
|
| 285 |
+
log(" Conclusions:")
|
| 286 |
+
for c in obs.conclusions:
|
| 287 |
+
log(f" [{c.claim_type}, conf={c.confidence:.2f}] {c.claim}")
|
| 288 |
+
log("=" * 70)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
main()
|
server/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .hackathon_environment import BioExperimentEnvironment
|
| 2 |
+
|
| 3 |
+
__all__ = ["BioExperimentEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for the Bio-Experiment Planning Environment.
|
| 2 |
+
|
| 3 |
+
Endpoints:
|
| 4 |
+
- POST /reset: Reset the environment
|
| 5 |
+
- POST /step: Execute an action
|
| 6 |
+
- GET /state: Get current environment state
|
| 7 |
+
- GET /schema: Get action/observation schemas
|
| 8 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from openenv.core.env_server.http_server import create_app
|
| 13 |
+
except Exception as e: # pragma: no cover
|
| 14 |
+
raise ImportError(
|
| 15 |
+
"openenv is required for the web interface. "
|
| 16 |
+
"Install dependencies with 'uv sync'"
|
| 17 |
+
) from e
|
| 18 |
+
|
| 19 |
+
from models import ExperimentAction, ExperimentObservation
|
| 20 |
+
from .hackathon_environment import BioExperimentEnvironment
|
| 21 |
+
|
| 22 |
+
app = create_app(
|
| 23 |
+
BioExperimentEnvironment,
|
| 24 |
+
ExperimentAction,
|
| 25 |
+
ExperimentObservation,
|
| 26 |
+
env_name="bio_experiment",
|
| 27 |
+
max_concurrent_envs=1,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 32 |
+
import uvicorn
|
| 33 |
+
uvicorn.run(app, host=host, port=port)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
import argparse
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 40 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
if args.host == "0.0.0.0" and args.port == 8000:
|
| 43 |
+
main()
|
| 44 |
+
else:
|
| 45 |
+
main(host=args.host, port=args.port)
|
server/hackathon_environment.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bio-Experiment Planning Environment.
|
| 2 |
+
|
| 3 |
+
Implements the OpenEnv ``Environment`` interface as a POMDP where the
|
| 4 |
+
agent proposes one structured experiment / analysis step at a time and
|
| 5 |
+
receives simulated intermediate outputs from a latent biological world.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
from uuid import uuid4
|
| 12 |
+
|
| 13 |
+
from openenv.core.env_server.interfaces import Environment
|
| 14 |
+
from openenv.core.env_server.types import State
|
| 15 |
+
|
| 16 |
+
from models import (
|
| 17 |
+
ActionType,
|
| 18 |
+
ConclusionClaim,
|
| 19 |
+
ExperimentAction,
|
| 20 |
+
ExperimentObservation,
|
| 21 |
+
IntermediateOutput,
|
| 22 |
+
PipelineStepRecord,
|
| 23 |
+
ResourceUsage,
|
| 24 |
+
TaskSpec,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from server.rules.engine import RuleEngine
|
| 28 |
+
from server.rewards.reward import RewardBreakdown, RewardComputer
|
| 29 |
+
from server.simulator.latent_state import FullLatentState
|
| 30 |
+
from server.simulator.noise import NoiseModel
|
| 31 |
+
from server.simulator.transition import ACTION_COSTS, TransitionEngine
|
| 32 |
+
from server.tasks.generator import TaskGenerator
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
MAX_STEPS = 30
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BioExperimentEnvironment(Environment):
|
| 39 |
+
"""POMDP environment for iterative biological experiment planning.
|
| 40 |
+
|
| 41 |
+
The agent observes ``ExperimentObservation`` (partial view) while the
|
| 42 |
+
environment maintains a ``FullLatentState`` (hidden ground truth).
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
scenario_name: Optional[str] = None,
|
| 50 |
+
*,
|
| 51 |
+
domain_randomise: bool = True,
|
| 52 |
+
) -> None:
|
| 53 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 54 |
+
self._latent: Optional[FullLatentState] = None
|
| 55 |
+
self._task: Optional[TaskSpec] = None
|
| 56 |
+
self._scenario_name = scenario_name
|
| 57 |
+
self._noise = NoiseModel()
|
| 58 |
+
self._engine = TransitionEngine(self._noise)
|
| 59 |
+
self._rules = RuleEngine()
|
| 60 |
+
self._rewards = RewardComputer()
|
| 61 |
+
self._task_gen = TaskGenerator(domain_randomise=domain_randomise)
|
| 62 |
+
|
| 63 |
+
self._history: List[PipelineStepRecord] = []
|
| 64 |
+
self._outputs: List[IntermediateOutput] = []
|
| 65 |
+
self._conclusions: List[ConclusionClaim] = []
|
| 66 |
+
self._subagent_outputs: List[Dict[str, Any]] = []
|
| 67 |
+
self._discovered_markers: List[str] = []
|
| 68 |
+
self._candidate_mechanisms: List[str] = []
|
| 69 |
+
self._cumulative_reward: float = 0.0
|
| 70 |
+
|
| 71 |
+
# ββ Environment interface βββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
|
| 73 |
+
def reset(self) -> ExperimentObservation:
|
| 74 |
+
seed = hash(uuid4()) % (2**31)
|
| 75 |
+
self._noise.reseed(seed)
|
| 76 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 77 |
+
|
| 78 |
+
self._task, self._latent = self._task_gen.generate(
|
| 79 |
+
seed=seed,
|
| 80 |
+
scenario_name=self._scenario_name,
|
| 81 |
+
)
|
| 82 |
+
self._latent.rng_seed = seed
|
| 83 |
+
|
| 84 |
+
self._history.clear()
|
| 85 |
+
self._outputs.clear()
|
| 86 |
+
self._conclusions.clear()
|
| 87 |
+
self._subagent_outputs.clear()
|
| 88 |
+
self._discovered_markers.clear()
|
| 89 |
+
self._candidate_mechanisms.clear()
|
| 90 |
+
self._cumulative_reward = 0.0
|
| 91 |
+
|
| 92 |
+
return self._build_observation(reward=0.0, done=False)
|
| 93 |
+
|
| 94 |
+
def step( # type: ignore[override]
|
| 95 |
+
self, action: ExperimentAction
|
| 96 |
+
) -> ExperimentObservation:
|
| 97 |
+
assert self._latent is not None, "Call reset() before step()"
|
| 98 |
+
assert self._task is not None
|
| 99 |
+
|
| 100 |
+
self._state.step_count += 1
|
| 101 |
+
prev_state = self._latent.model_copy(deep=True)
|
| 102 |
+
|
| 103 |
+
violations = self._rules.check(action, self._latent)
|
| 104 |
+
hard_v = self._rules.hard_violations(violations)
|
| 105 |
+
soft_v = self._rules.soft_violations(violations)
|
| 106 |
+
|
| 107 |
+
result = self._engine.step(
|
| 108 |
+
self._latent,
|
| 109 |
+
action,
|
| 110 |
+
hard_violations=hard_v,
|
| 111 |
+
soft_violations=soft_v,
|
| 112 |
+
)
|
| 113 |
+
self._latent = result.next_state
|
| 114 |
+
|
| 115 |
+
step_rb = self._rewards.step_reward(
|
| 116 |
+
action, prev_state, self._latent, result.output, hard_v, soft_v,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
cost_budget, cost_time = ACTION_COSTS.get(action.action_type, (0, 0))
|
| 120 |
+
self._history.append(PipelineStepRecord(
|
| 121 |
+
step_index=self._state.step_count,
|
| 122 |
+
action_type=action.action_type,
|
| 123 |
+
method=action.method,
|
| 124 |
+
parameters=action.parameters,
|
| 125 |
+
output_summary=result.output.summary,
|
| 126 |
+
output_type=result.output.output_type,
|
| 127 |
+
success=result.output.success,
|
| 128 |
+
quality_score=result.output.quality_score,
|
| 129 |
+
resource_cost=cost_budget,
|
| 130 |
+
time_cost_days=cost_time,
|
| 131 |
+
))
|
| 132 |
+
self._outputs.append(result.output)
|
| 133 |
+
self._update_discoveries(action, result.output)
|
| 134 |
+
|
| 135 |
+
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 136 |
+
raw_claims = action.parameters.get("claims", [])
|
| 137 |
+
for c in raw_claims:
|
| 138 |
+
if isinstance(c, dict):
|
| 139 |
+
self._conclusions.append(ConclusionClaim(**c))
|
| 140 |
+
|
| 141 |
+
done = result.done or self._state.step_count >= MAX_STEPS
|
| 142 |
+
|
| 143 |
+
terminal_rb = RewardBreakdown()
|
| 144 |
+
if done:
|
| 145 |
+
terminal_rb = self._rewards.terminal_reward(
|
| 146 |
+
self._latent, self._conclusions, self._task.success_criteria,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
total_reward = step_rb.total + terminal_rb.total
|
| 150 |
+
self._cumulative_reward += total_reward
|
| 151 |
+
|
| 152 |
+
breakdown = step_rb.to_dict()
|
| 153 |
+
breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()})
|
| 154 |
+
|
| 155 |
+
return self._build_observation(
|
| 156 |
+
reward=total_reward,
|
| 157 |
+
done=done,
|
| 158 |
+
latest_output=result.output,
|
| 159 |
+
rule_violations=hard_v + soft_v,
|
| 160 |
+
reward_breakdown=breakdown,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def state(self) -> State:
|
| 165 |
+
return self._state
|
| 166 |
+
|
| 167 |
+
def set_scenario(self, scenario_name: Optional[str]) -> None:
|
| 168 |
+
"""Set the scenario used on the next reset."""
|
| 169 |
+
|
| 170 |
+
self._scenario_name = scenario_name
|
| 171 |
+
|
| 172 |
+
# ββ internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
|
| 174 |
+
def _build_observation(
|
| 175 |
+
self,
|
| 176 |
+
*,
|
| 177 |
+
reward: float,
|
| 178 |
+
done: bool,
|
| 179 |
+
latest_output: Optional[IntermediateOutput] = None,
|
| 180 |
+
rule_violations: Optional[List[str]] = None,
|
| 181 |
+
reward_breakdown: Optional[Dict[str, float]] = None,
|
| 182 |
+
) -> ExperimentObservation:
|
| 183 |
+
assert self._task is not None
|
| 184 |
+
assert self._latent is not None
|
| 185 |
+
res = self._latent.resources
|
| 186 |
+
return ExperimentObservation(
|
| 187 |
+
task=self._task,
|
| 188 |
+
step_index=self._state.step_count,
|
| 189 |
+
pipeline_history=list(self._history),
|
| 190 |
+
available_assays=list(self._task.available_assays),
|
| 191 |
+
available_tools=list(self._task.available_tools),
|
| 192 |
+
resource_usage=ResourceUsage(
|
| 193 |
+
budget_used=res.budget_used,
|
| 194 |
+
budget_remaining=res.budget_remaining,
|
| 195 |
+
time_used_days=res.time_used_days,
|
| 196 |
+
time_remaining_days=res.time_remaining_days,
|
| 197 |
+
samples_consumed=res.samples_consumed,
|
| 198 |
+
compute_hours_used=res.compute_hours_used,
|
| 199 |
+
),
|
| 200 |
+
latest_output=latest_output,
|
| 201 |
+
all_outputs=list(self._outputs),
|
| 202 |
+
discovered_markers=list(self._discovered_markers),
|
| 203 |
+
candidate_mechanisms=list(self._candidate_mechanisms),
|
| 204 |
+
uncertainty_summary=self._compute_uncertainty_summary(),
|
| 205 |
+
subagent_outputs=list(self._subagent_outputs),
|
| 206 |
+
conclusions=list(self._conclusions),
|
| 207 |
+
rule_violations=rule_violations or [],
|
| 208 |
+
step_reward_breakdown=reward_breakdown or {},
|
| 209 |
+
done=done,
|
| 210 |
+
reward=reward,
|
| 211 |
+
metadata={
|
| 212 |
+
"episode_id": self._state.episode_id,
|
| 213 |
+
"step": self._state.step_count,
|
| 214 |
+
"cumulative_reward": self._cumulative_reward,
|
| 215 |
+
},
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def _compute_uncertainty_summary(self) -> Dict[str, float]:
|
| 219 |
+
if not self._outputs:
|
| 220 |
+
return {}
|
| 221 |
+
recent = self._outputs[-5:]
|
| 222 |
+
avg_unc = sum(o.uncertainty for o in recent) / len(recent)
|
| 223 |
+
avg_qual = sum(o.quality_score for o in recent) / len(recent)
|
| 224 |
+
return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual}
|
| 225 |
+
|
| 226 |
+
def _update_discoveries(
|
| 227 |
+
self, action: ExperimentAction, output: IntermediateOutput
|
| 228 |
+
) -> None:
|
| 229 |
+
if action.action_type == ActionType.MARKER_SELECTION:
|
| 230 |
+
markers = output.data.get("markers", [])
|
| 231 |
+
self._discovered_markers.extend(markers)
|
| 232 |
+
if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| 233 |
+
regs = output.data.get("top_regulators", [])
|
| 234 |
+
self._candidate_mechanisms.extend(regs)
|
| 235 |
+
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
|
| 236 |
+
pathways = output.data.get("top_pathways", [])
|
| 237 |
+
self._candidate_mechanisms.extend(
|
| 238 |
+
[p["pathway"] for p in pathways if isinstance(p, dict)]
|
| 239 |
+
)
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
server/rewards/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .reward import RewardBreakdown, RewardComputer
|
| 2 |
+
|
| 3 |
+
__all__ = ["RewardBreakdown", "RewardComputer"]
|
server/rewards/reward.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decomposable reward function for the bio-experiment planning POMDP.
|
| 2 |
+
|
| 3 |
+
Reward components
|
| 4 |
+
βββββββββββββββββ
|
| 5 |
+
r_validity β biological validity of the chosen action
|
| 6 |
+
r_ordering β correct ordering of experiment steps
|
| 7 |
+
r_info_gain β information gain from the step's output
|
| 8 |
+
r_efficiency β resource efficiency (budget & time normalised)
|
| 9 |
+
r_novelty β bonus for non-redundant, non-trivial actions
|
| 10 |
+
r_penalty β penalties for violations, redundancy, waste
|
| 11 |
+
r_terminal β terminal quality & calibration against hidden truth
|
| 12 |
+
|
| 13 |
+
Potential-based shaping
|
| 14 |
+
Ο(s) β progress potential used for dense shaping signal
|
| 15 |
+
|
| 16 |
+
The final step reward is:
|
| 17 |
+
R_t = r_validity + r_ordering + r_info_gain + r_efficiency
|
| 18 |
+
+ r_novelty + r_penalty + Ξ³[Ο(s_{t+1}) β Ο(s_t)]
|
| 19 |
+
|
| 20 |
+
The terminal reward adds:
|
| 21 |
+
R_T += r_terminal
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
from dataclasses import dataclass, field
|
| 27 |
+
from typing import Any, Dict, List, Optional
|
| 28 |
+
|
| 29 |
+
from models import (
|
| 30 |
+
ActionType,
|
| 31 |
+
ConclusionClaim,
|
| 32 |
+
ExperimentAction,
|
| 33 |
+
IntermediateOutput,
|
| 34 |
+
META_ACTIONS,
|
| 35 |
+
WET_LAB_ACTIONS,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
from server.simulator.latent_state import FullLatentState
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class RewardBreakdown:
|
| 43 |
+
validity: float = 0.0
|
| 44 |
+
ordering: float = 0.0
|
| 45 |
+
info_gain: float = 0.0
|
| 46 |
+
efficiency: float = 0.0
|
| 47 |
+
novelty: float = 0.0
|
| 48 |
+
penalty: float = 0.0
|
| 49 |
+
shaping: float = 0.0
|
| 50 |
+
terminal: float = 0.0
|
| 51 |
+
components: Dict[str, float] = field(default_factory=dict)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def total(self) -> float:
|
| 55 |
+
return (
|
| 56 |
+
self.validity
|
| 57 |
+
+ self.ordering
|
| 58 |
+
+ self.info_gain
|
| 59 |
+
+ self.efficiency
|
| 60 |
+
+ self.novelty
|
| 61 |
+
+ self.penalty
|
| 62 |
+
+ self.shaping
|
| 63 |
+
+ self.terminal
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def to_dict(self) -> Dict[str, float]:
|
| 67 |
+
d = {
|
| 68 |
+
"validity": self.validity,
|
| 69 |
+
"ordering": self.ordering,
|
| 70 |
+
"info_gain": self.info_gain,
|
| 71 |
+
"efficiency": self.efficiency,
|
| 72 |
+
"novelty": self.novelty,
|
| 73 |
+
"penalty": self.penalty,
|
| 74 |
+
"shaping": self.shaping,
|
| 75 |
+
"terminal": self.terminal,
|
| 76 |
+
"total": self.total,
|
| 77 |
+
}
|
| 78 |
+
d.update(self.components)
|
| 79 |
+
return d
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class RewardComputer:
|
| 83 |
+
"""Computes step-wise and terminal rewards.
|
| 84 |
+
|
| 85 |
+
Parameters
|
| 86 |
+
----------
|
| 87 |
+
gamma : float
|
| 88 |
+
Discount factor for potential-based shaping (default 0.99).
|
| 89 |
+
efficiency_weight : float
|
| 90 |
+
Relative importance of resource efficiency.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
gamma: float = 0.99,
|
| 96 |
+
efficiency_weight: float = 0.3,
|
| 97 |
+
info_gain_weight: float = 0.4,
|
| 98 |
+
validity_weight: float = 0.3,
|
| 99 |
+
):
|
| 100 |
+
self.gamma = gamma
|
| 101 |
+
self.w_eff = efficiency_weight
|
| 102 |
+
self.w_ig = info_gain_weight
|
| 103 |
+
self.w_val = validity_weight
|
| 104 |
+
|
| 105 |
+
# ββ step reward βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
|
| 107 |
+
def step_reward(
|
| 108 |
+
self,
|
| 109 |
+
action: ExperimentAction,
|
| 110 |
+
prev_state: FullLatentState,
|
| 111 |
+
next_state: FullLatentState,
|
| 112 |
+
output: IntermediateOutput,
|
| 113 |
+
hard_violations: List[str],
|
| 114 |
+
soft_violations: List[str],
|
| 115 |
+
) -> RewardBreakdown:
|
| 116 |
+
rb = RewardBreakdown()
|
| 117 |
+
|
| 118 |
+
# validity
|
| 119 |
+
if hard_violations:
|
| 120 |
+
rb.validity = -1.0
|
| 121 |
+
rb.penalty = -0.5 * len(hard_violations)
|
| 122 |
+
rb.components["hard_violations"] = len(hard_violations)
|
| 123 |
+
return rb
|
| 124 |
+
|
| 125 |
+
rb.validity = self.w_val * (1.0 if output.success else 0.0)
|
| 126 |
+
|
| 127 |
+
# ordering bonus: +0.2 if the step was a natural next step
|
| 128 |
+
rb.ordering = 0.2 * self._ordering_score(action, prev_state)
|
| 129 |
+
|
| 130 |
+
# information gain proxy: quality Γ (1 - uncertainty)
|
| 131 |
+
rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
|
| 132 |
+
|
| 133 |
+
# efficiency: normalised cost relative to budget
|
| 134 |
+
budget_frac = (
|
| 135 |
+
(next_state.resources.budget_used - prev_state.resources.budget_used)
|
| 136 |
+
/ max(next_state.resources.budget_total, 1)
|
| 137 |
+
)
|
| 138 |
+
rb.efficiency = self.w_eff * max(0.0, 1.0 - 5.0 * budget_frac)
|
| 139 |
+
|
| 140 |
+
# novelty: small bonus for non-redundant steps
|
| 141 |
+
if not soft_violations:
|
| 142 |
+
rb.novelty = 0.1
|
| 143 |
+
|
| 144 |
+
# penalties
|
| 145 |
+
rb.penalty = -0.15 * len(soft_violations)
|
| 146 |
+
|
| 147 |
+
# potential-based shaping
|
| 148 |
+
phi_prev = self._potential(prev_state)
|
| 149 |
+
phi_next = self._potential(next_state)
|
| 150 |
+
rb.shaping = self.gamma * phi_next - phi_prev
|
| 151 |
+
|
| 152 |
+
return rb
|
| 153 |
+
|
| 154 |
+
# ββ terminal reward ββββββββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββββββ
|
| 155 |
+
|
| 156 |
+
def terminal_reward(
|
| 157 |
+
self,
|
| 158 |
+
state: FullLatentState,
|
| 159 |
+
conclusions: List[ConclusionClaim],
|
| 160 |
+
task_success_criteria: List[str],
|
| 161 |
+
) -> RewardBreakdown:
|
| 162 |
+
rb = RewardBreakdown()
|
| 163 |
+
|
| 164 |
+
# pipeline completeness (0-1)
|
| 165 |
+
completeness = self._completeness(state)
|
| 166 |
+
rb.components["completeness"] = completeness
|
| 167 |
+
|
| 168 |
+
# calibration: how well conclusions align with hidden ground truth
|
| 169 |
+
calibration = self._calibration(state, conclusions)
|
| 170 |
+
rb.components["calibration"] = calibration
|
| 171 |
+
|
| 172 |
+
# efficiency bonus at terminal
|
| 173 |
+
budget_eff = state.resources.budget_remaining / max(
|
| 174 |
+
state.resources.budget_total, 1
|
| 175 |
+
)
|
| 176 |
+
time_eff = state.resources.time_remaining_days / max(
|
| 177 |
+
state.resources.time_limit_days, 1
|
| 178 |
+
)
|
| 179 |
+
rb.components["budget_efficiency"] = budget_eff
|
| 180 |
+
rb.components["time_efficiency"] = time_eff
|
| 181 |
+
|
| 182 |
+
# over-confidence penalty
|
| 183 |
+
overconf = self._overconfidence_penalty(state, conclusions)
|
| 184 |
+
rb.components["overconfidence_penalty"] = overconf
|
| 185 |
+
|
| 186 |
+
rb.terminal = (
|
| 187 |
+
3.0 * completeness
|
| 188 |
+
+ 4.0 * calibration
|
| 189 |
+
+ 1.0 * (budget_eff + time_eff) / 2.0
|
| 190 |
+
+ overconf
|
| 191 |
+
)
|
| 192 |
+
return rb
|
| 193 |
+
|
| 194 |
+
# ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 195 |
+
|
| 196 |
+
def _ordering_score(
|
| 197 |
+
self, action: ExperimentAction, s: FullLatentState
|
| 198 |
+
) -> float:
|
| 199 |
+
"""Heuristic: 1.0 if this step naturally follows the current progress."""
|
| 200 |
+
at = action.action_type
|
| 201 |
+
p = s.progress
|
| 202 |
+
NATURAL_NEXT = {
|
| 203 |
+
ActionType.COLLECT_SAMPLE: not p.samples_collected,
|
| 204 |
+
ActionType.PREPARE_LIBRARY: p.samples_collected and not p.library_prepared,
|
| 205 |
+
ActionType.SEQUENCE_CELLS: p.library_prepared and not p.cells_sequenced,
|
| 206 |
+
ActionType.RUN_QC: p.cells_sequenced and not p.qc_performed,
|
| 207 |
+
ActionType.FILTER_DATA: p.qc_performed and not p.data_filtered,
|
| 208 |
+
ActionType.NORMALIZE_DATA: p.data_filtered and not p.data_normalized,
|
| 209 |
+
ActionType.CLUSTER_CELLS: p.data_normalized and not p.cells_clustered,
|
| 210 |
+
ActionType.DIFFERENTIAL_EXPRESSION: p.data_normalized and not p.de_performed,
|
| 211 |
+
ActionType.PATHWAY_ENRICHMENT: p.de_performed and not p.pathways_analyzed,
|
| 212 |
+
ActionType.MARKER_SELECTION: p.de_performed and not p.markers_discovered,
|
| 213 |
+
ActionType.VALIDATE_MARKER: p.markers_discovered and not p.markers_validated,
|
| 214 |
+
ActionType.SYNTHESIZE_CONCLUSION: (
|
| 215 |
+
p.de_performed or p.cells_clustered
|
| 216 |
+
) and not p.conclusion_reached,
|
| 217 |
+
}
|
| 218 |
+
return 1.0 if NATURAL_NEXT.get(at, False) else 0.3
|
| 219 |
+
|
| 220 |
+
def _potential(self, s: FullLatentState) -> float:
|
| 221 |
+
"""Progress potential Ο(s) β counts completed milestones."""
|
| 222 |
+
p = s.progress
|
| 223 |
+
milestones = [
|
| 224 |
+
p.samples_collected,
|
| 225 |
+
p.library_prepared,
|
| 226 |
+
p.cells_sequenced,
|
| 227 |
+
p.qc_performed,
|
| 228 |
+
p.data_filtered,
|
| 229 |
+
p.data_normalized,
|
| 230 |
+
p.cells_clustered,
|
| 231 |
+
p.de_performed,
|
| 232 |
+
p.pathways_analyzed,
|
| 233 |
+
p.markers_discovered,
|
| 234 |
+
p.markers_validated,
|
| 235 |
+
p.conclusion_reached,
|
| 236 |
+
]
|
| 237 |
+
return sum(milestones) / len(milestones)
|
| 238 |
+
|
| 239 |
+
def _completeness(self, s: FullLatentState) -> float:
|
| 240 |
+
p = s.progress
|
| 241 |
+
core = [
|
| 242 |
+
p.samples_collected,
|
| 243 |
+
p.cells_sequenced,
|
| 244 |
+
p.qc_performed,
|
| 245 |
+
p.data_filtered,
|
| 246 |
+
p.data_normalized,
|
| 247 |
+
p.de_performed or p.cells_clustered,
|
| 248 |
+
p.conclusion_reached,
|
| 249 |
+
]
|
| 250 |
+
return sum(core) / len(core)
|
| 251 |
+
|
| 252 |
+
def _calibration(
|
| 253 |
+
self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 254 |
+
) -> float:
|
| 255 |
+
if not conclusions:
|
| 256 |
+
return 0.0
|
| 257 |
+
|
| 258 |
+
true_mechanisms = set(s.biology.causal_mechanisms)
|
| 259 |
+
true_markers = set(s.biology.true_markers)
|
| 260 |
+
score = 0.0
|
| 261 |
+
n = len(conclusions)
|
| 262 |
+
|
| 263 |
+
for c in conclusions:
|
| 264 |
+
claim_lower = c.claim.lower()
|
| 265 |
+
match = any(m.lower() in claim_lower for m in true_mechanisms)
|
| 266 |
+
marker_match = any(m.lower() in claim_lower for m in true_markers)
|
| 267 |
+
if match or marker_match:
|
| 268 |
+
score += 1.0
|
| 269 |
+
else:
|
| 270 |
+
score -= 0.3
|
| 271 |
+
return max(0.0, min(1.0, score / max(n, 1)))
|
| 272 |
+
|
| 273 |
+
def _overconfidence_penalty(
|
| 274 |
+
self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| 275 |
+
) -> float:
|
| 276 |
+
"""Penalise high-confidence claims that disagree with ground truth."""
|
| 277 |
+
penalty = 0.0
|
| 278 |
+
true_set = set(
|
| 279 |
+
m.lower() for m in s.biology.causal_mechanisms + s.biology.true_markers
|
| 280 |
+
)
|
| 281 |
+
for c in conclusions:
|
| 282 |
+
is_correct = any(t in c.claim.lower() for t in true_set)
|
| 283 |
+
if c.confidence > 0.8 and not is_correct:
|
| 284 |
+
penalty -= 0.5 * c.confidence
|
| 285 |
+
return penalty
|
server/rules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .engine import RuleEngine, RuleViolation
|
| 2 |
+
|
| 3 |
+
__all__ = ["RuleEngine", "RuleViolation"]
|
server/rules/engine.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Biological rule engine β hard and soft constraint checking.
|
| 2 |
+
|
| 3 |
+
Hard constraints block action execution entirely.
|
| 4 |
+
Soft constraints allow execution but degrade output quality and incur penalties.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
from models import ActionType, ExperimentAction
|
| 14 |
+
|
| 15 |
+
from server.simulator.latent_state import FullLatentState
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Severity(str, Enum):
|
| 19 |
+
HARD = "hard"
|
| 20 |
+
SOFT = "soft"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class RuleViolation:
|
| 25 |
+
rule_id: str
|
| 26 |
+
severity: Severity
|
| 27 |
+
message: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RuleEngine:
|
| 31 |
+
"""Evaluates biological and resource constraints against the current
|
| 32 |
+
latent state before each action is applied.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def check(
|
| 36 |
+
self, action: ExperimentAction, state: FullLatentState
|
| 37 |
+
) -> List[RuleViolation]:
|
| 38 |
+
violations: List[RuleViolation] = []
|
| 39 |
+
violations.extend(self._check_prerequisites(action, state))
|
| 40 |
+
violations.extend(self._check_resource_constraints(action, state))
|
| 41 |
+
violations.extend(self._check_redundancy(action, state))
|
| 42 |
+
violations.extend(self._check_causal_validity(action, state))
|
| 43 |
+
return violations
|
| 44 |
+
|
| 45 |
+
def hard_violations(self, violations: List[RuleViolation]) -> List[str]:
|
| 46 |
+
return [v.message for v in violations if v.severity == Severity.HARD]
|
| 47 |
+
|
| 48 |
+
def soft_violations(self, violations: List[RuleViolation]) -> List[str]:
|
| 49 |
+
return [v.message for v in violations if v.severity == Severity.SOFT]
|
| 50 |
+
|
| 51 |
+
# ββ prerequisite rules ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
|
| 53 |
+
def _check_prerequisites(
|
| 54 |
+
self, action: ExperimentAction, s: FullLatentState
|
| 55 |
+
) -> List[RuleViolation]:
|
| 56 |
+
vs: List[RuleViolation] = []
|
| 57 |
+
at = action.action_type
|
| 58 |
+
p = s.progress
|
| 59 |
+
|
| 60 |
+
REQUIRES = {
|
| 61 |
+
ActionType.PREPARE_LIBRARY: [
|
| 62 |
+
("samples_collected", "Cannot prepare library without collected samples"),
|
| 63 |
+
],
|
| 64 |
+
ActionType.SEQUENCE_CELLS: [
|
| 65 |
+
("library_prepared", "Cannot sequence without library preparation"),
|
| 66 |
+
],
|
| 67 |
+
ActionType.RUN_QC: [
|
| 68 |
+
("cells_sequenced", "Cannot run QC before sequencing"),
|
| 69 |
+
],
|
| 70 |
+
ActionType.FILTER_DATA: [
|
| 71 |
+
("qc_performed", "Cannot filter data before QC"),
|
| 72 |
+
],
|
| 73 |
+
ActionType.NORMALIZE_DATA: [
|
| 74 |
+
("data_filtered", "Cannot normalise before filtering"),
|
| 75 |
+
],
|
| 76 |
+
ActionType.INTEGRATE_BATCHES: [
|
| 77 |
+
("data_normalized", "Cannot integrate batches before normalisation"),
|
| 78 |
+
],
|
| 79 |
+
ActionType.CLUSTER_CELLS: [
|
| 80 |
+
("data_normalized", "Cannot cluster before normalisation"),
|
| 81 |
+
],
|
| 82 |
+
ActionType.DIFFERENTIAL_EXPRESSION: [
|
| 83 |
+
("data_normalized", "Cannot run DE before normalisation"),
|
| 84 |
+
],
|
| 85 |
+
ActionType.TRAJECTORY_ANALYSIS: [
|
| 86 |
+
("data_normalized", "Cannot infer trajectories before normalisation"),
|
| 87 |
+
],
|
| 88 |
+
ActionType.PATHWAY_ENRICHMENT: [
|
| 89 |
+
("de_performed", "Cannot run pathway enrichment without DE results"),
|
| 90 |
+
],
|
| 91 |
+
ActionType.REGULATORY_NETWORK_INFERENCE: [
|
| 92 |
+
("data_normalized", "Cannot infer networks before normalisation"),
|
| 93 |
+
],
|
| 94 |
+
ActionType.MARKER_SELECTION: [
|
| 95 |
+
("de_performed", "Cannot select markers without DE results"),
|
| 96 |
+
],
|
| 97 |
+
ActionType.VALIDATE_MARKER: [
|
| 98 |
+
("markers_discovered", "Cannot validate markers before discovery"),
|
| 99 |
+
],
|
| 100 |
+
ActionType.PERTURB_GENE: [
|
| 101 |
+
("samples_collected", "Cannot perturb without samples"),
|
| 102 |
+
],
|
| 103 |
+
ActionType.PERTURB_COMPOUND: [
|
| 104 |
+
("samples_collected", "Cannot perturb without samples"),
|
| 105 |
+
],
|
| 106 |
+
ActionType.CULTURE_CELLS: [
|
| 107 |
+
("samples_collected", "Cannot culture without samples"),
|
| 108 |
+
],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
for flag, msg in REQUIRES.get(at, []):
|
| 112 |
+
if not getattr(p, flag, False):
|
| 113 |
+
vs.append(RuleViolation(
|
| 114 |
+
rule_id=f"prereq_{at.value}_{flag}",
|
| 115 |
+
severity=Severity.HARD,
|
| 116 |
+
message=msg,
|
| 117 |
+
))
|
| 118 |
+
return vs
|
| 119 |
+
|
| 120 |
+
# ββ resource constraints ββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
|
| 122 |
+
def _check_resource_constraints(
|
| 123 |
+
self, action: ExperimentAction, s: FullLatentState
|
| 124 |
+
) -> List[RuleViolation]:
|
| 125 |
+
vs: List[RuleViolation] = []
|
| 126 |
+
if s.resources.budget_exhausted:
|
| 127 |
+
vs.append(RuleViolation(
|
| 128 |
+
rule_id="budget_exhausted",
|
| 129 |
+
severity=Severity.HARD,
|
| 130 |
+
message="Budget exhausted β no further actions possible",
|
| 131 |
+
))
|
| 132 |
+
if s.resources.time_exhausted:
|
| 133 |
+
vs.append(RuleViolation(
|
| 134 |
+
rule_id="time_exhausted",
|
| 135 |
+
severity=Severity.HARD,
|
| 136 |
+
message="Time limit reached β no further actions possible",
|
| 137 |
+
))
|
| 138 |
+
|
| 139 |
+
remaining = s.resources.budget_remaining
|
| 140 |
+
from server.simulator.transition import ACTION_COSTS
|
| 141 |
+
cost, _ = ACTION_COSTS.get(action.action_type, (0, 0))
|
| 142 |
+
if cost > remaining and remaining > 0:
|
| 143 |
+
vs.append(RuleViolation(
|
| 144 |
+
rule_id="budget_insufficient",
|
| 145 |
+
severity=Severity.SOFT,
|
| 146 |
+
message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains",
|
| 147 |
+
))
|
| 148 |
+
return vs
|
| 149 |
+
|
| 150 |
+
# ββ redundancy checks βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 151 |
+
|
| 152 |
+
def _check_redundancy(
|
| 153 |
+
self, action: ExperimentAction, s: FullLatentState
|
| 154 |
+
) -> List[RuleViolation]:
|
| 155 |
+
vs: List[RuleViolation] = []
|
| 156 |
+
at = action.action_type
|
| 157 |
+
p = s.progress
|
| 158 |
+
|
| 159 |
+
REDUNDANT = {
|
| 160 |
+
ActionType.COLLECT_SAMPLE: "samples_collected",
|
| 161 |
+
ActionType.PREPARE_LIBRARY: "library_prepared",
|
| 162 |
+
ActionType.SEQUENCE_CELLS: "cells_sequenced",
|
| 163 |
+
ActionType.RUN_QC: "qc_performed",
|
| 164 |
+
ActionType.FILTER_DATA: "data_filtered",
|
| 165 |
+
ActionType.NORMALIZE_DATA: "data_normalized",
|
| 166 |
+
}
|
| 167 |
+
flag = REDUNDANT.get(at)
|
| 168 |
+
if flag and getattr(p, flag, False):
|
| 169 |
+
vs.append(RuleViolation(
|
| 170 |
+
rule_id=f"redundant_{at.value}",
|
| 171 |
+
severity=Severity.SOFT,
|
| 172 |
+
message=f"Step '{at.value}' already completed β redundant action",
|
| 173 |
+
))
|
| 174 |
+
return vs
|
| 175 |
+
|
| 176 |
+
# ββ causal validity βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
+
|
| 178 |
+
def _check_causal_validity(
|
| 179 |
+
self, action: ExperimentAction, s: FullLatentState
|
| 180 |
+
) -> List[RuleViolation]:
|
| 181 |
+
vs: List[RuleViolation] = []
|
| 182 |
+
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 183 |
+
if not s.progress.de_performed and not s.progress.cells_clustered:
|
| 184 |
+
vs.append(RuleViolation(
|
| 185 |
+
rule_id="premature_conclusion",
|
| 186 |
+
severity=Severity.SOFT,
|
| 187 |
+
message="Synthesising conclusion without substantive analysis",
|
| 188 |
+
))
|
| 189 |
+
|
| 190 |
+
claims = action.parameters.get("claims", [])
|
| 191 |
+
for claim in claims:
|
| 192 |
+
if isinstance(claim, dict) and claim.get("claim_type") == "causal":
|
| 193 |
+
if not s.progress.markers_validated and not s.progress.networks_inferred:
|
| 194 |
+
vs.append(RuleViolation(
|
| 195 |
+
rule_id="unsupported_causal_claim",
|
| 196 |
+
severity=Severity.SOFT,
|
| 197 |
+
message="Causal claim without validation or network evidence",
|
| 198 |
+
))
|
| 199 |
+
break
|
| 200 |
+
|
| 201 |
+
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
|
| 202 |
+
if not s.progress.de_performed:
|
| 203 |
+
vs.append(RuleViolation(
|
| 204 |
+
rule_id="pathway_without_de",
|
| 205 |
+
severity=Severity.SOFT,
|
| 206 |
+
message="Pathway enrichment without DE may yield unreliable results",
|
| 207 |
+
))
|
| 208 |
+
return vs
|
server/simulator/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .latent_state import (
|
| 2 |
+
CellPopulation,
|
| 3 |
+
ExperimentProgress,
|
| 4 |
+
FullLatentState,
|
| 5 |
+
GeneProgram,
|
| 6 |
+
LatentBiologicalState,
|
| 7 |
+
ResourceState,
|
| 8 |
+
TechnicalState,
|
| 9 |
+
)
|
| 10 |
+
from .noise import NoiseModel
|
| 11 |
+
from .output_generator import OutputGenerator
|
| 12 |
+
from .transition import TransitionEngine
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"CellPopulation",
|
| 16 |
+
"ExperimentProgress",
|
| 17 |
+
"FullLatentState",
|
| 18 |
+
"GeneProgram",
|
| 19 |
+
"LatentBiologicalState",
|
| 20 |
+
"NoiseModel",
|
| 21 |
+
"OutputGenerator",
|
| 22 |
+
"ResourceState",
|
| 23 |
+
"TechnicalState",
|
| 24 |
+
"TransitionEngine",
|
| 25 |
+
]
|
server/simulator/latent_state.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Latent biological and technical state β hidden from the agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CellPopulation(BaseModel):
|
| 11 |
+
"""Ground-truth cell sub-population in the simulated tissue."""
|
| 12 |
+
|
| 13 |
+
name: str
|
| 14 |
+
proportion: float = Field(ge=0.0, le=1.0)
|
| 15 |
+
marker_genes: List[str] = Field(default_factory=list)
|
| 16 |
+
state: str = "quiescent"
|
| 17 |
+
condition_response: Dict[str, float] = Field(default_factory=dict)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GeneProgram(BaseModel):
|
| 21 |
+
"""A latent gene-regulatory programme."""
|
| 22 |
+
|
| 23 |
+
name: str
|
| 24 |
+
genes: List[str] = Field(default_factory=list)
|
| 25 |
+
activity_level: float = Field(0.5, ge=0.0, le=1.0)
|
| 26 |
+
condition_dependent: bool = False
|
| 27 |
+
conditions_active: List[str] = Field(default_factory=list)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LatentBiologicalState(BaseModel):
|
| 31 |
+
"""Hidden ground-truth biology the agent cannot directly observe."""
|
| 32 |
+
|
| 33 |
+
cell_populations: List[CellPopulation] = Field(default_factory=list)
|
| 34 |
+
true_de_genes: Dict[str, Dict[str, float]] = Field(
|
| 35 |
+
default_factory=dict,
|
| 36 |
+
description="comparison_key β {gene: log2FC}",
|
| 37 |
+
)
|
| 38 |
+
true_pathways: Dict[str, float] = Field(
|
| 39 |
+
default_factory=dict,
|
| 40 |
+
description="pathway β activity level",
|
| 41 |
+
)
|
| 42 |
+
gene_programs: List[GeneProgram] = Field(default_factory=list)
|
| 43 |
+
true_trajectory: Optional[Dict[str, Any]] = None
|
| 44 |
+
true_regulatory_network: Dict[str, List[str]] = Field(
|
| 45 |
+
default_factory=dict,
|
| 46 |
+
description="TF β target genes",
|
| 47 |
+
)
|
| 48 |
+
perturbation_effects: Dict[str, Dict[str, float]] = Field(
|
| 49 |
+
default_factory=dict,
|
| 50 |
+
description="perturbation β {gene: effect_size}",
|
| 51 |
+
)
|
| 52 |
+
confounders: Dict[str, float] = Field(default_factory=dict)
|
| 53 |
+
true_markers: List[str] = Field(default_factory=list)
|
| 54 |
+
causal_mechanisms: List[str] = Field(default_factory=list)
|
| 55 |
+
n_true_cells: int = 10_000
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TechnicalState(BaseModel):
|
| 59 |
+
"""Hidden technical parameters that shape experimental noise."""
|
| 60 |
+
|
| 61 |
+
batch_effects: Dict[str, float] = Field(default_factory=dict)
|
| 62 |
+
ambient_rna_fraction: float = 0.05
|
| 63 |
+
doublet_rate: float = 0.04
|
| 64 |
+
dropout_rate: float = 0.1
|
| 65 |
+
sample_quality: float = Field(0.9, ge=0.0, le=1.0)
|
| 66 |
+
library_complexity: float = Field(0.8, ge=0.0, le=1.0)
|
| 67 |
+
sequencing_depth_factor: float = 1.0
|
| 68 |
+
capture_efficiency: float = 0.6
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ExperimentProgress(BaseModel):
|
| 72 |
+
"""Flags tracking which experiment stages have been completed."""
|
| 73 |
+
|
| 74 |
+
samples_collected: bool = False
|
| 75 |
+
cohort_selected: bool = False
|
| 76 |
+
cells_cultured: bool = False
|
| 77 |
+
library_prepared: bool = False
|
| 78 |
+
perturbation_applied: bool = False
|
| 79 |
+
cells_sequenced: bool = False
|
| 80 |
+
qc_performed: bool = False
|
| 81 |
+
data_filtered: bool = False
|
| 82 |
+
data_normalized: bool = False
|
| 83 |
+
batches_integrated: bool = False
|
| 84 |
+
cells_clustered: bool = False
|
| 85 |
+
de_performed: bool = False
|
| 86 |
+
trajectories_inferred: bool = False
|
| 87 |
+
pathways_analyzed: bool = False
|
| 88 |
+
networks_inferred: bool = False
|
| 89 |
+
markers_discovered: bool = False
|
| 90 |
+
markers_validated: bool = False
|
| 91 |
+
conclusion_reached: bool = False
|
| 92 |
+
|
| 93 |
+
n_cells_after_filter: Optional[int] = None
|
| 94 |
+
n_clusters_found: Optional[int] = None
|
| 95 |
+
n_de_genes_found: Optional[int] = None
|
| 96 |
+
n_markers_found: Optional[int] = None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ResourceState(BaseModel):
|
| 100 |
+
"""Full internal resource tracking (superset of agent-visible ResourceUsage)."""
|
| 101 |
+
|
| 102 |
+
budget_total: float = 100_000.0
|
| 103 |
+
budget_used: float = 0.0
|
| 104 |
+
time_limit_days: float = 180.0
|
| 105 |
+
time_used_days: float = 0.0
|
| 106 |
+
samples_available: int = 0
|
| 107 |
+
samples_consumed: int = 0
|
| 108 |
+
compute_hours_used: float = 0.0
|
| 109 |
+
sequencing_lanes_used: int = 0
|
| 110 |
+
reagent_kits_used: int = 0
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def budget_remaining(self) -> float:
|
| 114 |
+
return max(0.0, self.budget_total - self.budget_used)
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def time_remaining_days(self) -> float:
|
| 118 |
+
return max(0.0, self.time_limit_days - self.time_used_days)
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def budget_exhausted(self) -> bool:
|
| 122 |
+
return self.budget_remaining <= 0
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def time_exhausted(self) -> bool:
|
| 126 |
+
return self.time_remaining_days <= 0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FullLatentState(BaseModel):
|
| 130 |
+
"""Complete hidden state of the simulated biological world."""
|
| 131 |
+
|
| 132 |
+
biology: LatentBiologicalState = Field(
|
| 133 |
+
default_factory=LatentBiologicalState
|
| 134 |
+
)
|
| 135 |
+
technical: TechnicalState = Field(default_factory=TechnicalState)
|
| 136 |
+
progress: ExperimentProgress = Field(default_factory=ExperimentProgress)
|
| 137 |
+
resources: ResourceState = Field(default_factory=ResourceState)
|
| 138 |
+
hidden_failure_conditions: List[str] = Field(default_factory=list)
|
| 139 |
+
mechanism_confidence: Dict[str, float] = Field(default_factory=dict)
|
| 140 |
+
discovered_de_genes: List[str] = Field(default_factory=list)
|
| 141 |
+
discovered_clusters: List[str] = Field(default_factory=list)
|
| 142 |
+
step_count: int = 0
|
| 143 |
+
rng_seed: int = 42
|
server/simulator/noise.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stochastic noise models for the biological simulator."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NoiseModel:
|
| 11 |
+
"""Generates calibrated noise for simulated experimental outputs.
|
| 12 |
+
|
| 13 |
+
All randomness is funnelled through a single ``numpy.Generator``
|
| 14 |
+
so that episodes are reproducible given the same seed.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, seed: int = 42):
|
| 18 |
+
self.rng = np.random.default_rng(seed)
|
| 19 |
+
|
| 20 |
+
def reseed(self, seed: int) -> None:
|
| 21 |
+
self.rng = np.random.default_rng(seed)
|
| 22 |
+
|
| 23 |
+
# ββ expression-level noise ββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
|
| 25 |
+
def add_expression_noise(
|
| 26 |
+
self,
|
| 27 |
+
true_values: Dict[str, float],
|
| 28 |
+
noise_level: float,
|
| 29 |
+
dropout_rate: float,
|
| 30 |
+
) -> Dict[str, float]:
|
| 31 |
+
noisy: Dict[str, float] = {}
|
| 32 |
+
for gene, value in true_values.items():
|
| 33 |
+
if self.rng.random() < dropout_rate:
|
| 34 |
+
noisy[gene] = 0.0
|
| 35 |
+
else:
|
| 36 |
+
sigma = noise_level * abs(value) + 0.1
|
| 37 |
+
noisy[gene] = float(value + self.rng.normal(0, sigma))
|
| 38 |
+
return noisy
|
| 39 |
+
|
| 40 |
+
# ββ effect-size sampling ββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
|
| 42 |
+
def sample_effect_sizes(
|
| 43 |
+
self,
|
| 44 |
+
true_effects: Dict[str, float],
|
| 45 |
+
sample_size: int,
|
| 46 |
+
noise_level: float,
|
| 47 |
+
) -> Dict[str, float]:
|
| 48 |
+
se = noise_level / max(np.sqrt(max(sample_size, 1)), 1e-6)
|
| 49 |
+
return {
|
| 50 |
+
gene: float(effect + self.rng.normal(0, se))
|
| 51 |
+
for gene, effect in true_effects.items()
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def sample_p_values(
|
| 55 |
+
self,
|
| 56 |
+
true_effects: Dict[str, float],
|
| 57 |
+
sample_size: int,
|
| 58 |
+
noise_level: float,
|
| 59 |
+
) -> Dict[str, float]:
|
| 60 |
+
"""Simulate approximate p-values from z-statistics."""
|
| 61 |
+
from scipy import stats # type: ignore[import-untyped]
|
| 62 |
+
|
| 63 |
+
p_values: Dict[str, float] = {}
|
| 64 |
+
se = noise_level / max(np.sqrt(max(sample_size, 1)), 1e-6)
|
| 65 |
+
for gene, effect in true_effects.items():
|
| 66 |
+
z = abs(effect) / max(se, 1e-8)
|
| 67 |
+
p_values[gene] = float(2 * stats.norm.sf(z))
|
| 68 |
+
return p_values
|
| 69 |
+
|
| 70 |
+
# ββ false discovery helpers βββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
|
| 72 |
+
def generate_false_positives(
|
| 73 |
+
self, n_background_genes: int, fdr: float
|
| 74 |
+
) -> List[str]:
|
| 75 |
+
n_fp = int(self.rng.binomial(n_background_genes, fdr))
|
| 76 |
+
return [f"FP_GENE_{i}" for i in range(n_fp)]
|
| 77 |
+
|
| 78 |
+
def generate_false_negatives(
|
| 79 |
+
self, true_genes: List[str], fnr: float
|
| 80 |
+
) -> List[str]:
|
| 81 |
+
"""Return the subset of *true_genes* that are missed."""
|
| 82 |
+
return [g for g in true_genes if self.rng.random() < fnr]
|
| 83 |
+
|
| 84 |
+
# ββ quality helpers βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
|
| 86 |
+
def quality_degradation(
|
| 87 |
+
self, base_quality: float, factors: List[float]
|
| 88 |
+
) -> float:
|
| 89 |
+
q = base_quality
|
| 90 |
+
for f in factors:
|
| 91 |
+
q *= f
|
| 92 |
+
return float(np.clip(q + self.rng.normal(0, 0.02), 0.0, 1.0))
|
| 93 |
+
|
| 94 |
+
def sample_qc_metric(
|
| 95 |
+
self, mean: float, std: float, clip_lo: float = 0.0, clip_hi: float = 1.0
|
| 96 |
+
) -> float:
|
| 97 |
+
return float(np.clip(self.rng.normal(mean, std), clip_lo, clip_hi))
|
| 98 |
+
|
| 99 |
+
def sample_count(self, lam: float) -> int:
|
| 100 |
+
return int(self.rng.poisson(max(lam, 0)))
|
| 101 |
+
|
| 102 |
+
def coin_flip(self, p: float) -> bool:
|
| 103 |
+
return bool(self.rng.random() < p)
|
| 104 |
+
|
| 105 |
+
def sample_cluster_count(
|
| 106 |
+
self, n_true_populations: int, quality: float
|
| 107 |
+
) -> int:
|
| 108 |
+
"""Over- or under-clustering depending on preprocessing quality."""
|
| 109 |
+
delta = self.rng.integers(-2, 3)
|
| 110 |
+
noise_clusters = max(0, int(round((1.0 - quality) * 3)))
|
| 111 |
+
return max(1, n_true_populations + delta + noise_clusters)
|
| 112 |
+
|
| 113 |
+
def shuffle_ranking(
|
| 114 |
+
self, items: List[str], noise_level: float
|
| 115 |
+
) -> List[str]:
|
| 116 |
+
"""Permute a ranking with Gaussian noise on ordinals."""
|
| 117 |
+
n = len(items)
|
| 118 |
+
if n == 0:
|
| 119 |
+
return []
|
| 120 |
+
scores = np.arange(n, dtype=float) + self.rng.normal(
|
| 121 |
+
0, noise_level * n, size=n
|
| 122 |
+
)
|
| 123 |
+
order = np.argsort(scores)
|
| 124 |
+
return [items[int(i)] for i in order]
|
server/simulator/output_generator.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate simulated intermediate outputs conditioned on latent state."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from models import (
|
| 8 |
+
ActionType,
|
| 9 |
+
ExperimentAction,
|
| 10 |
+
IntermediateOutput,
|
| 11 |
+
OutputType,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .latent_state import FullLatentState
|
| 15 |
+
from .noise import NoiseModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OutputGenerator:
|
| 19 |
+
"""Creates structured ``IntermediateOutput`` objects conditioned on the
|
| 20 |
+
hidden latent state, the action taken, and a stochastic noise model.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, noise: NoiseModel):
|
| 24 |
+
self.noise = noise
|
| 25 |
+
|
| 26 |
+
def generate(
|
| 27 |
+
self,
|
| 28 |
+
action: ExperimentAction,
|
| 29 |
+
state: FullLatentState,
|
| 30 |
+
step_index: int,
|
| 31 |
+
) -> IntermediateOutput:
|
| 32 |
+
handler = _HANDLERS.get(action.action_type, self._default)
|
| 33 |
+
return handler(self, action, state, step_index)
|
| 34 |
+
|
| 35 |
+
# ββ wet-lab outputs βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
def _collect_sample(
|
| 38 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 39 |
+
) -> IntermediateOutput:
|
| 40 |
+
n_samples = action.parameters.get("n_samples", 6)
|
| 41 |
+
quality = self.noise.quality_degradation(
|
| 42 |
+
s.technical.sample_quality, [s.technical.capture_efficiency]
|
| 43 |
+
)
|
| 44 |
+
return IntermediateOutput(
|
| 45 |
+
output_type=OutputType.SAMPLE_COLLECTION_RESULT,
|
| 46 |
+
step_index=idx,
|
| 47 |
+
quality_score=quality,
|
| 48 |
+
summary=f"Collected {n_samples} samples (quality={quality:.2f})",
|
| 49 |
+
data={
|
| 50 |
+
"n_samples": n_samples,
|
| 51 |
+
"quality": quality,
|
| 52 |
+
"organism": "human",
|
| 53 |
+
"tissue": "blood",
|
| 54 |
+
},
|
| 55 |
+
artifacts_available=["raw_samples"],
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def _select_cohort(
|
| 59 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 60 |
+
) -> IntermediateOutput:
|
| 61 |
+
criteria = action.parameters.get("criteria", {})
|
| 62 |
+
n_selected = action.parameters.get("n_selected", 4)
|
| 63 |
+
return IntermediateOutput(
|
| 64 |
+
output_type=OutputType.COHORT_RESULT,
|
| 65 |
+
step_index=idx,
|
| 66 |
+
summary=f"Selected cohort of {n_selected} samples with criteria {criteria}",
|
| 67 |
+
data={"n_selected": n_selected, "criteria": criteria},
|
| 68 |
+
artifacts_available=["cohort_manifest"],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def _prepare_library(
|
| 72 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 73 |
+
) -> IntermediateOutput:
|
| 74 |
+
complexity = self.noise.quality_degradation(
|
| 75 |
+
s.technical.library_complexity,
|
| 76 |
+
[s.technical.sample_quality],
|
| 77 |
+
)
|
| 78 |
+
return IntermediateOutput(
|
| 79 |
+
output_type=OutputType.LIBRARY_PREP_RESULT,
|
| 80 |
+
step_index=idx,
|
| 81 |
+
quality_score=complexity,
|
| 82 |
+
summary=f"Library prepared (complexity={complexity:.2f})",
|
| 83 |
+
data={
|
| 84 |
+
"library_complexity": complexity,
|
| 85 |
+
"method": action.method or "10x_chromium",
|
| 86 |
+
},
|
| 87 |
+
artifacts_available=["prepared_library"],
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def _culture_cells(
|
| 91 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 92 |
+
) -> IntermediateOutput:
|
| 93 |
+
days = action.parameters.get("days", 7)
|
| 94 |
+
viability = self.noise.sample_qc_metric(0.92, 0.05, 0.5, 1.0)
|
| 95 |
+
return IntermediateOutput(
|
| 96 |
+
output_type=OutputType.CULTURE_RESULT,
|
| 97 |
+
step_index=idx,
|
| 98 |
+
quality_score=viability,
|
| 99 |
+
summary=f"Cultured for {days}d, viability={viability:.2f}",
|
| 100 |
+
data={"days": days, "viability": viability},
|
| 101 |
+
artifacts_available=["cultured_cells"],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _perturb(
|
| 105 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 106 |
+
) -> IntermediateOutput:
|
| 107 |
+
target = action.parameters.get("target", "unknown")
|
| 108 |
+
efficiency = self.noise.sample_qc_metric(0.75, 0.15, 0.0, 1.0)
|
| 109 |
+
return IntermediateOutput(
|
| 110 |
+
output_type=OutputType.PERTURBATION_RESULT,
|
| 111 |
+
step_index=idx,
|
| 112 |
+
quality_score=efficiency,
|
| 113 |
+
summary=f"Perturbation of {target} (efficiency={efficiency:.2f})",
|
| 114 |
+
data={
|
| 115 |
+
"target": target,
|
| 116 |
+
"efficiency": efficiency,
|
| 117 |
+
"type": action.action_type.value,
|
| 118 |
+
},
|
| 119 |
+
artifacts_available=["perturbed_cells"],
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def _sequence_cells(
|
| 123 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 124 |
+
) -> IntermediateOutput:
|
| 125 |
+
depth = s.technical.sequencing_depth_factor
|
| 126 |
+
n_cells = self.noise.sample_count(
|
| 127 |
+
s.biology.n_true_cells * s.technical.capture_efficiency
|
| 128 |
+
)
|
| 129 |
+
n_genes = self.noise.sample_count(18_000)
|
| 130 |
+
median_umi = self.noise.sample_count(int(3000 * depth))
|
| 131 |
+
quality = self.noise.quality_degradation(
|
| 132 |
+
s.technical.sample_quality,
|
| 133 |
+
[s.technical.library_complexity, s.technical.capture_efficiency],
|
| 134 |
+
)
|
| 135 |
+
return IntermediateOutput(
|
| 136 |
+
output_type=OutputType.SEQUENCING_RESULT,
|
| 137 |
+
step_index=idx,
|
| 138 |
+
quality_score=quality,
|
| 139 |
+
summary=(
|
| 140 |
+
f"Sequenced {n_cells} cells, {n_genes} genes detected, "
|
| 141 |
+
f"median UMI={median_umi}"
|
| 142 |
+
),
|
| 143 |
+
data={
|
| 144 |
+
"n_cells": n_cells,
|
| 145 |
+
"n_genes": n_genes,
|
| 146 |
+
"median_umi": median_umi,
|
| 147 |
+
"sequencing_saturation": self.noise.sample_qc_metric(0.7, 0.1),
|
| 148 |
+
},
|
| 149 |
+
artifacts_available=["raw_count_matrix"],
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# ββ computational outputs βββββββββββββββββββββββββββββββββββββββββββ
|
| 153 |
+
|
| 154 |
+
def _run_qc(
|
| 155 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 156 |
+
) -> IntermediateOutput:
|
| 157 |
+
doublet_frac = self.noise.sample_qc_metric(
|
| 158 |
+
s.technical.doublet_rate, 0.01, 0.0, 0.2
|
| 159 |
+
)
|
| 160 |
+
mito_frac = self.noise.sample_qc_metric(0.05, 0.02, 0.0, 0.3)
|
| 161 |
+
ambient_frac = self.noise.sample_qc_metric(
|
| 162 |
+
s.technical.ambient_rna_fraction, 0.01, 0.0, 0.2
|
| 163 |
+
)
|
| 164 |
+
warnings: List[str] = []
|
| 165 |
+
if doublet_frac > 0.08:
|
| 166 |
+
warnings.append(f"High doublet rate ({doublet_frac:.1%})")
|
| 167 |
+
if mito_frac > 0.1:
|
| 168 |
+
warnings.append(f"High mitochondrial fraction ({mito_frac:.1%})")
|
| 169 |
+
quality = 1.0 - (doublet_frac + mito_frac + ambient_frac)
|
| 170 |
+
return IntermediateOutput(
|
| 171 |
+
output_type=OutputType.QC_METRICS,
|
| 172 |
+
step_index=idx,
|
| 173 |
+
quality_score=max(0.0, quality),
|
| 174 |
+
summary="QC metrics computed",
|
| 175 |
+
data={
|
| 176 |
+
"doublet_fraction": doublet_frac,
|
| 177 |
+
"mitochondrial_fraction": mito_frac,
|
| 178 |
+
"ambient_rna_fraction": ambient_frac,
|
| 179 |
+
"median_genes_per_cell": self.noise.sample_count(2500),
|
| 180 |
+
"median_umi_per_cell": self.noise.sample_count(8000),
|
| 181 |
+
},
|
| 182 |
+
warnings=warnings,
|
| 183 |
+
artifacts_available=["qc_report"],
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def _filter_data(
|
| 187 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 188 |
+
) -> IntermediateOutput:
|
| 189 |
+
retain_frac = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| 190 |
+
n_before = s.biology.n_true_cells
|
| 191 |
+
n_after = max(100, int(n_before * retain_frac))
|
| 192 |
+
return IntermediateOutput(
|
| 193 |
+
output_type=OutputType.COUNT_MATRIX_SUMMARY,
|
| 194 |
+
step_index=idx,
|
| 195 |
+
quality_score=retain_frac,
|
| 196 |
+
summary=f"Filtered {n_before} β {n_after} cells ({retain_frac:.0%} retained)",
|
| 197 |
+
data={
|
| 198 |
+
"n_cells_before": n_before,
|
| 199 |
+
"n_cells_after": n_after,
|
| 200 |
+
"n_genes_retained": self.noise.sample_count(15_000),
|
| 201 |
+
"retain_fraction": retain_frac,
|
| 202 |
+
},
|
| 203 |
+
artifacts_available=["filtered_count_matrix"],
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def _normalize_data(
|
| 207 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 208 |
+
) -> IntermediateOutput:
|
| 209 |
+
method = action.method or "log_normalize"
|
| 210 |
+
return IntermediateOutput(
|
| 211 |
+
output_type=OutputType.COUNT_MATRIX_SUMMARY,
|
| 212 |
+
step_index=idx,
|
| 213 |
+
summary=f"Normalized with {method}",
|
| 214 |
+
data={"method": method, "n_hvg": self.noise.sample_count(2000)},
|
| 215 |
+
artifacts_available=["normalized_matrix", "hvg_list"],
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def _integrate_batches(
|
| 219 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 220 |
+
) -> IntermediateOutput:
|
| 221 |
+
method = action.method or "harmony"
|
| 222 |
+
residual = self.noise.sample_qc_metric(0.05, 0.03, 0.0, 0.3)
|
| 223 |
+
return IntermediateOutput(
|
| 224 |
+
output_type=OutputType.EMBEDDING_SUMMARY,
|
| 225 |
+
step_index=idx,
|
| 226 |
+
quality_score=1.0 - residual,
|
| 227 |
+
summary=f"Batch integration ({method}), residual batch effect={residual:.2f}",
|
| 228 |
+
data={
|
| 229 |
+
"method": method,
|
| 230 |
+
"residual_batch_effect": residual,
|
| 231 |
+
"n_batches": len(s.technical.batch_effects) or 1,
|
| 232 |
+
},
|
| 233 |
+
artifacts_available=["integrated_embedding"],
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _cluster_cells(
|
| 237 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 238 |
+
) -> IntermediateOutput:
|
| 239 |
+
n_true = len(s.biology.cell_populations) or 5
|
| 240 |
+
quality = self.noise.quality_degradation(0.8, [0.95])
|
| 241 |
+
n_clusters = self.noise.sample_cluster_count(n_true, quality)
|
| 242 |
+
cluster_names = [f"cluster_{i}" for i in range(n_clusters)]
|
| 243 |
+
sizes = self._random_partition(s.biology.n_true_cells, n_clusters)
|
| 244 |
+
return IntermediateOutput(
|
| 245 |
+
output_type=OutputType.CLUSTER_RESULT,
|
| 246 |
+
step_index=idx,
|
| 247 |
+
quality_score=quality,
|
| 248 |
+
summary=f"Found {n_clusters} clusters (ground-truth populations: {n_true})",
|
| 249 |
+
data={
|
| 250 |
+
"n_clusters": n_clusters,
|
| 251 |
+
"cluster_names": cluster_names,
|
| 252 |
+
"cluster_sizes": sizes,
|
| 253 |
+
"silhouette_score": self.noise.sample_qc_metric(0.35, 0.1, -1.0, 1.0),
|
| 254 |
+
},
|
| 255 |
+
uncertainty=abs(n_clusters - n_true) / max(n_true, 1),
|
| 256 |
+
artifacts_available=["cluster_assignments", "umap_embedding"],
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def _differential_expression(
|
| 260 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 261 |
+
) -> IntermediateOutput:
|
| 262 |
+
comparison = action.parameters.get("comparison", "disease_vs_healthy")
|
| 263 |
+
true_effects = s.biology.true_de_genes.get(comparison, {})
|
| 264 |
+
|
| 265 |
+
n_cells = s.progress.n_cells_after_filter or s.biology.n_true_cells
|
| 266 |
+
noise_level = s.technical.dropout_rate + 0.1 * (1.0 - s.technical.sample_quality)
|
| 267 |
+
observed = self.noise.sample_effect_sizes(true_effects, n_cells, noise_level)
|
| 268 |
+
|
| 269 |
+
fp_genes = self.noise.generate_false_positives(5000, 0.002 + noise_level * 0.01)
|
| 270 |
+
for g in fp_genes:
|
| 271 |
+
observed[g] = float(self.noise.rng.normal(0, 0.3))
|
| 272 |
+
|
| 273 |
+
fn_genes = self.noise.generate_false_negatives(list(true_effects.keys()), 0.15)
|
| 274 |
+
for g in fn_genes:
|
| 275 |
+
observed.pop(g, None)
|
| 276 |
+
|
| 277 |
+
top_genes = sorted(observed.items(), key=lambda kv: abs(kv[1]), reverse=True)[:50]
|
| 278 |
+
return IntermediateOutput(
|
| 279 |
+
output_type=OutputType.DE_RESULT,
|
| 280 |
+
step_index=idx,
|
| 281 |
+
quality_score=self.noise.quality_degradation(0.8, [1.0 - noise_level]),
|
| 282 |
+
summary=f"DE analysis ({comparison}): {len(observed)} genes tested, {len(top_genes)} top hits",
|
| 283 |
+
data={
|
| 284 |
+
"comparison": comparison,
|
| 285 |
+
"n_tested": len(observed),
|
| 286 |
+
"top_genes": [
|
| 287 |
+
{"gene": g, "log2FC": round(fc, 3)} for g, fc in top_genes
|
| 288 |
+
],
|
| 289 |
+
"n_significant": sum(1 for _, fc in observed.items() if abs(fc) > 0.5),
|
| 290 |
+
},
|
| 291 |
+
uncertainty=noise_level,
|
| 292 |
+
artifacts_available=["de_table"],
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def _trajectory_analysis(
|
| 296 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 297 |
+
) -> IntermediateOutput:
|
| 298 |
+
has_trajectory = s.biology.true_trajectory is not None
|
| 299 |
+
quality = self.noise.quality_degradation(0.7 if has_trajectory else 0.3, [0.9])
|
| 300 |
+
summary_data: Dict[str, Any] = {"method": action.method or "monocle3"}
|
| 301 |
+
if has_trajectory:
|
| 302 |
+
summary_data.update({
|
| 303 |
+
"n_lineages": s.biology.true_trajectory.get("n_lineages", 1),
|
| 304 |
+
"pseudotime_range": [0.0, 1.0],
|
| 305 |
+
"branching_detected": s.biology.true_trajectory.get("branching", False),
|
| 306 |
+
})
|
| 307 |
+
else:
|
| 308 |
+
summary_data["n_lineages"] = self.noise.sample_count(1) + 1
|
| 309 |
+
summary_data["pseudotime_range"] = [0.0, 1.0]
|
| 310 |
+
summary_data["branching_detected"] = self.noise.coin_flip(0.3)
|
| 311 |
+
|
| 312 |
+
return IntermediateOutput(
|
| 313 |
+
output_type=OutputType.TRAJECTORY_RESULT,
|
| 314 |
+
step_index=idx,
|
| 315 |
+
quality_score=quality,
|
| 316 |
+
summary="Trajectory / pseudotime analysis complete",
|
| 317 |
+
data=summary_data,
|
| 318 |
+
uncertainty=0.2 if has_trajectory else 0.6,
|
| 319 |
+
artifacts_available=["pseudotime_values", "lineage_graph"],
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def _pathway_enrichment(
|
| 323 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 324 |
+
) -> IntermediateOutput:
|
| 325 |
+
true_pathways = s.biology.true_pathways
|
| 326 |
+
noise_level = 0.15
|
| 327 |
+
observed: Dict[str, float] = {}
|
| 328 |
+
for pw, activity in true_pathways.items():
|
| 329 |
+
observed[pw] = activity + float(self.noise.rng.normal(0, noise_level))
|
| 330 |
+
|
| 331 |
+
for i in range(self.noise.sample_count(2)):
|
| 332 |
+
observed[f"FP_PATHWAY_{i}"] = float(self.noise.rng.uniform(0.3, 0.6))
|
| 333 |
+
|
| 334 |
+
top = sorted(observed.items(), key=lambda kv: kv[1], reverse=True)[:15]
|
| 335 |
+
return IntermediateOutput(
|
| 336 |
+
output_type=OutputType.PATHWAY_RESULT,
|
| 337 |
+
step_index=idx,
|
| 338 |
+
quality_score=self.noise.quality_degradation(0.8, [0.95]),
|
| 339 |
+
summary=f"Pathway enrichment: {len(top)} significant pathways",
|
| 340 |
+
data={
|
| 341 |
+
"method": action.method or "GSEA",
|
| 342 |
+
"top_pathways": [
|
| 343 |
+
{"pathway": p, "score": round(s, 3)} for p, s in top
|
| 344 |
+
],
|
| 345 |
+
},
|
| 346 |
+
uncertainty=noise_level,
|
| 347 |
+
artifacts_available=["enrichment_table"],
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def _regulatory_network(
|
| 351 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 352 |
+
) -> IntermediateOutput:
|
| 353 |
+
true_net = s.biology.true_regulatory_network
|
| 354 |
+
n_edges_true = sum(len(v) for v in true_net.values())
|
| 355 |
+
noise_edges = self.noise.sample_count(max(5, int(n_edges_true * 0.3)))
|
| 356 |
+
return IntermediateOutput(
|
| 357 |
+
output_type=OutputType.NETWORK_RESULT,
|
| 358 |
+
step_index=idx,
|
| 359 |
+
quality_score=self.noise.quality_degradation(0.6, [0.9]),
|
| 360 |
+
summary=f"Regulatory network inferred: {n_edges_true + noise_edges} edges",
|
| 361 |
+
data={
|
| 362 |
+
"method": action.method or "SCENIC",
|
| 363 |
+
"n_regulons": len(true_net) + self.noise.sample_count(3),
|
| 364 |
+
"n_edges": n_edges_true + noise_edges,
|
| 365 |
+
"top_regulators": list(true_net.keys())[:10],
|
| 366 |
+
},
|
| 367 |
+
uncertainty=0.35,
|
| 368 |
+
artifacts_available=["regulon_table", "grn_adjacency"],
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def _marker_selection(
|
| 372 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 373 |
+
) -> IntermediateOutput:
|
| 374 |
+
true_markers = list(s.biology.true_markers)
|
| 375 |
+
noise_level = 0.2
|
| 376 |
+
observed_markers = [
|
| 377 |
+
m for m in true_markers if not self.noise.coin_flip(noise_level)
|
| 378 |
+
]
|
| 379 |
+
fp = self.noise.generate_false_positives(200, 0.01)
|
| 380 |
+
observed_markers.extend(fp)
|
| 381 |
+
return IntermediateOutput(
|
| 382 |
+
output_type=OutputType.MARKER_RESULT,
|
| 383 |
+
step_index=idx,
|
| 384 |
+
quality_score=self.noise.quality_degradation(0.75, [0.9]),
|
| 385 |
+
summary=f"Selected {len(observed_markers)} candidate markers",
|
| 386 |
+
data={
|
| 387 |
+
"markers": observed_markers[:20],
|
| 388 |
+
"n_candidates": len(observed_markers),
|
| 389 |
+
},
|
| 390 |
+
uncertainty=noise_level,
|
| 391 |
+
artifacts_available=["marker_list"],
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def _validate_marker(
|
| 395 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 396 |
+
) -> IntermediateOutput:
|
| 397 |
+
marker = action.parameters.get("marker", "unknown")
|
| 398 |
+
is_true = marker in s.biology.true_markers
|
| 399 |
+
validation_correct = not self.noise.coin_flip(0.1)
|
| 400 |
+
validated = is_true == validation_correct
|
| 401 |
+
return IntermediateOutput(
|
| 402 |
+
output_type=OutputType.VALIDATION_RESULT,
|
| 403 |
+
step_index=idx,
|
| 404 |
+
quality_score=0.9 if validation_correct else 0.4,
|
| 405 |
+
summary=f"Marker {marker}: {'validated' if validated else 'not validated'}",
|
| 406 |
+
data={
|
| 407 |
+
"marker": marker,
|
| 408 |
+
"validated": validated,
|
| 409 |
+
"assay": action.method or "qPCR",
|
| 410 |
+
"effect_size": self.noise.sample_qc_metric(
|
| 411 |
+
1.5 if is_true else 0.2, 0.3, -0.5, 5.0
|
| 412 |
+
),
|
| 413 |
+
},
|
| 414 |
+
artifacts_available=["validation_data"],
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def _design_followup(
|
| 418 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 419 |
+
) -> IntermediateOutput:
|
| 420 |
+
return IntermediateOutput(
|
| 421 |
+
output_type=OutputType.FOLLOWUP_DESIGN,
|
| 422 |
+
step_index=idx,
|
| 423 |
+
summary="Follow-up experiment design proposed",
|
| 424 |
+
data={"proposal": action.parameters},
|
| 425 |
+
artifacts_available=["followup_proposal"],
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def _subagent_review(
|
| 429 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 430 |
+
) -> IntermediateOutput:
|
| 431 |
+
return IntermediateOutput(
|
| 432 |
+
output_type=OutputType.SUBAGENT_REPORT,
|
| 433 |
+
step_index=idx,
|
| 434 |
+
summary=f"Subagent review ({action.invoked_subagent or 'general'})",
|
| 435 |
+
data={"subagent": action.invoked_subagent, "notes": "Review complete."},
|
| 436 |
+
artifacts_available=["subagent_report"],
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def _synthesize_conclusion(
|
| 440 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 441 |
+
) -> IntermediateOutput:
|
| 442 |
+
return IntermediateOutput(
|
| 443 |
+
output_type=OutputType.CONCLUSION,
|
| 444 |
+
step_index=idx,
|
| 445 |
+
summary="Conclusion synthesised from pipeline evidence",
|
| 446 |
+
data={"claims": action.parameters.get("claims", [])},
|
| 447 |
+
artifacts_available=["conclusion_report"],
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def _default(
|
| 451 |
+
self, action: ExperimentAction, s: FullLatentState, idx: int
|
| 452 |
+
) -> IntermediateOutput:
|
| 453 |
+
return IntermediateOutput(
|
| 454 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 455 |
+
step_index=idx,
|
| 456 |
+
success=False,
|
| 457 |
+
summary=f"Unhandled action type: {action.action_type}",
|
| 458 |
+
data={},
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 462 |
+
|
| 463 |
+
def _random_partition(self, total: int, k: int) -> List[int]:
|
| 464 |
+
if k <= 0:
|
| 465 |
+
return []
|
| 466 |
+
fracs = self.noise.rng.dirichlet(alpha=[1.0] * k)
|
| 467 |
+
sizes = [max(1, int(total * f)) for f in fracs]
|
| 468 |
+
diff = total - sum(sizes)
|
| 469 |
+
sizes[0] += diff
|
| 470 |
+
return sizes
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
_HANDLERS = {
|
| 474 |
+
ActionType.COLLECT_SAMPLE: OutputGenerator._collect_sample,
|
| 475 |
+
ActionType.SELECT_COHORT: OutputGenerator._select_cohort,
|
| 476 |
+
ActionType.PREPARE_LIBRARY: OutputGenerator._prepare_library,
|
| 477 |
+
ActionType.CULTURE_CELLS: OutputGenerator._culture_cells,
|
| 478 |
+
ActionType.PERTURB_GENE: OutputGenerator._perturb,
|
| 479 |
+
ActionType.PERTURB_COMPOUND: OutputGenerator._perturb,
|
| 480 |
+
ActionType.SEQUENCE_CELLS: OutputGenerator._sequence_cells,
|
| 481 |
+
ActionType.RUN_QC: OutputGenerator._run_qc,
|
| 482 |
+
ActionType.FILTER_DATA: OutputGenerator._filter_data,
|
| 483 |
+
ActionType.NORMALIZE_DATA: OutputGenerator._normalize_data,
|
| 484 |
+
ActionType.INTEGRATE_BATCHES: OutputGenerator._integrate_batches,
|
| 485 |
+
ActionType.CLUSTER_CELLS: OutputGenerator._cluster_cells,
|
| 486 |
+
ActionType.DIFFERENTIAL_EXPRESSION: OutputGenerator._differential_expression,
|
| 487 |
+
ActionType.TRAJECTORY_ANALYSIS: OutputGenerator._trajectory_analysis,
|
| 488 |
+
ActionType.PATHWAY_ENRICHMENT: OutputGenerator._pathway_enrichment,
|
| 489 |
+
ActionType.REGULATORY_NETWORK_INFERENCE: OutputGenerator._regulatory_network,
|
| 490 |
+
ActionType.MARKER_SELECTION: OutputGenerator._marker_selection,
|
| 491 |
+
ActionType.VALIDATE_MARKER: OutputGenerator._validate_marker,
|
| 492 |
+
ActionType.DESIGN_FOLLOWUP: OutputGenerator._design_followup,
|
| 493 |
+
ActionType.REQUEST_SUBAGENT_REVIEW: OutputGenerator._subagent_review,
|
| 494 |
+
ActionType.SYNTHESIZE_CONCLUSION: OutputGenerator._synthesize_conclusion,
|
| 495 |
+
}
|
server/simulator/transition.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transition dynamics engine β the heart of the biological simulator.
|
| 2 |
+
|
| 3 |
+
Orchestrates latent-state updates, output generation, resource accounting,
|
| 4 |
+
and constraint propagation for every agent action.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
from models import (
|
| 14 |
+
ActionType,
|
| 15 |
+
ExperimentAction,
|
| 16 |
+
IntermediateOutput,
|
| 17 |
+
OutputType,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .latent_state import FullLatentState
|
| 21 |
+
from .noise import NoiseModel
|
| 22 |
+
from .output_generator import OutputGenerator
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
|
| 26 |
+
ActionType.COLLECT_SAMPLE: (5_000, 7.0),
|
| 27 |
+
ActionType.SELECT_COHORT: ( 500, 1.0),
|
| 28 |
+
ActionType.PREPARE_LIBRARY: (8_000, 3.0),
|
| 29 |
+
ActionType.CULTURE_CELLS: (3_000, 14.0),
|
| 30 |
+
ActionType.PERTURB_GENE: (2_000, 3.0),
|
| 31 |
+
ActionType.PERTURB_COMPOUND: (1_000, 2.0),
|
| 32 |
+
ActionType.SEQUENCE_CELLS: (15_000, 5.0),
|
| 33 |
+
ActionType.RUN_QC: ( 100, 0.5),
|
| 34 |
+
ActionType.FILTER_DATA: ( 50, 0.25),
|
| 35 |
+
ActionType.NORMALIZE_DATA: ( 50, 0.25),
|
| 36 |
+
ActionType.INTEGRATE_BATCHES: ( 100, 0.5),
|
| 37 |
+
ActionType.CLUSTER_CELLS: ( 100, 0.5),
|
| 38 |
+
ActionType.DIFFERENTIAL_EXPRESSION: ( 100, 0.5),
|
| 39 |
+
ActionType.TRAJECTORY_ANALYSIS: ( 200, 1.0),
|
| 40 |
+
ActionType.PATHWAY_ENRICHMENT: ( 100, 0.5),
|
| 41 |
+
ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
|
| 42 |
+
ActionType.MARKER_SELECTION: ( 100, 0.5),
|
| 43 |
+
ActionType.VALIDATE_MARKER: (5_000, 14.0),
|
| 44 |
+
ActionType.DESIGN_FOLLOWUP: ( 0, 0.5),
|
| 45 |
+
ActionType.REQUEST_SUBAGENT_REVIEW: ( 0, 0.25),
|
| 46 |
+
ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class TransitionResult:
|
| 52 |
+
"""Bundle returned by the transition engine after one step."""
|
| 53 |
+
|
| 54 |
+
next_state: FullLatentState
|
| 55 |
+
output: IntermediateOutput
|
| 56 |
+
reward_components: Dict[str, float] = field(default_factory=dict)
|
| 57 |
+
hard_violations: List[str] = field(default_factory=list)
|
| 58 |
+
soft_violations: List[str] = field(default_factory=list)
|
| 59 |
+
done: bool = False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TransitionEngine:
|
| 63 |
+
"""Applies one action to the latent state, producing the next state
|
| 64 |
+
and a simulated intermediate output.
|
| 65 |
+
|
| 66 |
+
The engine delegates output generation to ``OutputGenerator`` and
|
| 67 |
+
constraint checking to external rule engines (injected at call time).
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, noise: NoiseModel):
|
| 71 |
+
self.noise = noise
|
| 72 |
+
self.output_gen = OutputGenerator(noise)
|
| 73 |
+
|
| 74 |
+
def step(
|
| 75 |
+
self,
|
| 76 |
+
state: FullLatentState,
|
| 77 |
+
action: ExperimentAction,
|
| 78 |
+
*,
|
| 79 |
+
hard_violations: Optional[List[str]] = None,
|
| 80 |
+
soft_violations: Optional[List[str]] = None,
|
| 81 |
+
) -> TransitionResult:
|
| 82 |
+
s = deepcopy(state)
|
| 83 |
+
s.step_count += 1
|
| 84 |
+
step_idx = s.step_count
|
| 85 |
+
|
| 86 |
+
hard_v = hard_violations or []
|
| 87 |
+
soft_v = soft_violations or []
|
| 88 |
+
|
| 89 |
+
if hard_v:
|
| 90 |
+
output = IntermediateOutput(
|
| 91 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 92 |
+
step_index=step_idx,
|
| 93 |
+
success=False,
|
| 94 |
+
summary=f"Action blocked: {'; '.join(hard_v)}",
|
| 95 |
+
)
|
| 96 |
+
return TransitionResult(
|
| 97 |
+
next_state=s,
|
| 98 |
+
output=output,
|
| 99 |
+
hard_violations=hard_v,
|
| 100 |
+
soft_violations=soft_v,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self._apply_resource_cost(s, action)
|
| 104 |
+
|
| 105 |
+
if s.resources.budget_exhausted or s.resources.time_exhausted:
|
| 106 |
+
output = IntermediateOutput(
|
| 107 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 108 |
+
step_index=step_idx,
|
| 109 |
+
success=False,
|
| 110 |
+
summary="Resources exhausted",
|
| 111 |
+
)
|
| 112 |
+
return TransitionResult(
|
| 113 |
+
next_state=s, output=output, done=True,
|
| 114 |
+
hard_violations=["resources_exhausted"],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self._update_progress(s, action)
|
| 118 |
+
|
| 119 |
+
output = self.output_gen.generate(action, s, step_idx)
|
| 120 |
+
|
| 121 |
+
if soft_v:
|
| 122 |
+
output.quality_score *= 0.5
|
| 123 |
+
output.warnings.extend(soft_v)
|
| 124 |
+
|
| 125 |
+
self._propagate_artifacts(s, action, output)
|
| 126 |
+
|
| 127 |
+
done = action.action_type == ActionType.SYNTHESIZE_CONCLUSION
|
| 128 |
+
|
| 129 |
+
return TransitionResult(
|
| 130 |
+
next_state=s,
|
| 131 |
+
output=output,
|
| 132 |
+
soft_violations=soft_v,
|
| 133 |
+
done=done,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# ββ internals βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 137 |
+
|
| 138 |
+
def _apply_resource_cost(
|
| 139 |
+
self, s: FullLatentState, action: ExperimentAction
|
| 140 |
+
) -> None:
|
| 141 |
+
budget_cost, time_cost = ACTION_COSTS.get(
|
| 142 |
+
action.action_type, (0.0, 0.0)
|
| 143 |
+
)
|
| 144 |
+
s.resources.budget_used += budget_cost
|
| 145 |
+
s.resources.time_used_days += time_cost
|
| 146 |
+
if action.action_type in {
|
| 147 |
+
ActionType.RUN_QC, ActionType.FILTER_DATA,
|
| 148 |
+
ActionType.NORMALIZE_DATA, ActionType.INTEGRATE_BATCHES,
|
| 149 |
+
ActionType.CLUSTER_CELLS, ActionType.DIFFERENTIAL_EXPRESSION,
|
| 150 |
+
ActionType.TRAJECTORY_ANALYSIS, ActionType.PATHWAY_ENRICHMENT,
|
| 151 |
+
ActionType.REGULATORY_NETWORK_INFERENCE, ActionType.MARKER_SELECTION,
|
| 152 |
+
}:
|
| 153 |
+
s.resources.compute_hours_used += time_cost * 8
|
| 154 |
+
|
| 155 |
+
def _update_progress(
|
| 156 |
+
self, s: FullLatentState, action: ExperimentAction
|
| 157 |
+
) -> None:
|
| 158 |
+
at = action.action_type
|
| 159 |
+
p = s.progress
|
| 160 |
+
_MAP = {
|
| 161 |
+
ActionType.COLLECT_SAMPLE: "samples_collected",
|
| 162 |
+
ActionType.SELECT_COHORT: "cohort_selected",
|
| 163 |
+
ActionType.PREPARE_LIBRARY: "library_prepared",
|
| 164 |
+
ActionType.CULTURE_CELLS: "cells_cultured",
|
| 165 |
+
ActionType.PERTURB_GENE: "perturbation_applied",
|
| 166 |
+
ActionType.PERTURB_COMPOUND: "perturbation_applied",
|
| 167 |
+
ActionType.SEQUENCE_CELLS: "cells_sequenced",
|
| 168 |
+
ActionType.RUN_QC: "qc_performed",
|
| 169 |
+
ActionType.FILTER_DATA: "data_filtered",
|
| 170 |
+
ActionType.NORMALIZE_DATA: "data_normalized",
|
| 171 |
+
ActionType.INTEGRATE_BATCHES: "batches_integrated",
|
| 172 |
+
ActionType.CLUSTER_CELLS: "cells_clustered",
|
| 173 |
+
ActionType.DIFFERENTIAL_EXPRESSION: "de_performed",
|
| 174 |
+
ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred",
|
| 175 |
+
ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed",
|
| 176 |
+
ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
|
| 177 |
+
ActionType.MARKER_SELECTION: "markers_discovered",
|
| 178 |
+
ActionType.VALIDATE_MARKER: "markers_validated",
|
| 179 |
+
ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
|
| 180 |
+
}
|
| 181 |
+
flag = _MAP.get(at)
|
| 182 |
+
if flag:
|
| 183 |
+
setattr(p, flag, True)
|
| 184 |
+
|
| 185 |
+
if at == ActionType.COLLECT_SAMPLE:
|
| 186 |
+
n = action.parameters.get("n_samples", 6)
|
| 187 |
+
s.resources.samples_available += n
|
| 188 |
+
|
| 189 |
+
if at == ActionType.SEQUENCE_CELLS:
|
| 190 |
+
s.resources.sequencing_lanes_used += 1
|
| 191 |
+
|
| 192 |
+
if at == ActionType.FILTER_DATA:
|
| 193 |
+
retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| 194 |
+
p.n_cells_after_filter = max(
|
| 195 |
+
100, int(s.biology.n_true_cells * retain)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if at == ActionType.CLUSTER_CELLS:
|
| 199 |
+
n_true = len(s.biology.cell_populations) or 5
|
| 200 |
+
p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
|
| 201 |
+
|
| 202 |
+
def _propagate_artifacts(
|
| 203 |
+
self,
|
| 204 |
+
s: FullLatentState,
|
| 205 |
+
action: ExperimentAction,
|
| 206 |
+
output: IntermediateOutput,
|
| 207 |
+
) -> None:
|
| 208 |
+
if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| 209 |
+
top = output.data.get("top_genes", [])
|
| 210 |
+
s.discovered_de_genes = [g["gene"] for g in top[:20]]
|
| 211 |
+
|
| 212 |
+
if action.action_type == ActionType.CLUSTER_CELLS:
|
| 213 |
+
s.discovered_clusters = output.data.get("cluster_names", [])
|
| 214 |
+
|
| 215 |
+
if action.action_type == ActionType.MARKER_SELECTION:
|
| 216 |
+
s.progress.n_markers_found = output.data.get("n_candidates", 0)
|
server/subagents/__init__.py
ADDED
|
File without changes
|
server/tasks/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generator import TaskGenerator
|
| 2 |
+
from .scenarios import SCENARIO_LIBRARY, Scenario
|
| 3 |
+
|
| 4 |
+
__all__ = ["SCENARIO_LIBRARY", "Scenario", "TaskGenerator"]
|
server/tasks/generator.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task generator β produces (TaskSpec, FullLatentState) pairs for episodes.
|
| 2 |
+
|
| 3 |
+
Supports three modes:
|
| 4 |
+
1. Select from the pre-defined scenario library.
|
| 5 |
+
2. Randomly perturb a scenario for domain-randomisation.
|
| 6 |
+
3. Compose a fully procedural scenario (tissue Γ modality Γ difficulty).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from models import TaskSpec
|
| 16 |
+
|
| 17 |
+
from server.simulator.latent_state import (
|
| 18 |
+
CellPopulation,
|
| 19 |
+
ExperimentProgress,
|
| 20 |
+
FullLatentState,
|
| 21 |
+
GeneProgram,
|
| 22 |
+
LatentBiologicalState,
|
| 23 |
+
ResourceState,
|
| 24 |
+
TechnicalState,
|
| 25 |
+
)
|
| 26 |
+
from .scenarios import SCENARIO_LIBRARY, Scenario
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TaskGenerator:
|
| 30 |
+
"""Generates task + latent-state pairs for environment episodes."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
scenarios: Optional[List[Scenario]] = None,
|
| 35 |
+
domain_randomise: bool = True,
|
| 36 |
+
):
|
| 37 |
+
self.scenarios = scenarios or SCENARIO_LIBRARY
|
| 38 |
+
self.domain_randomise = domain_randomise
|
| 39 |
+
|
| 40 |
+
def generate(
|
| 41 |
+
self,
|
| 42 |
+
*,
|
| 43 |
+
seed: Optional[int] = None,
|
| 44 |
+
scenario_name: Optional[str] = None,
|
| 45 |
+
) -> Tuple[TaskSpec, FullLatentState]:
|
| 46 |
+
rng = np.random.default_rng(seed)
|
| 47 |
+
|
| 48 |
+
if scenario_name:
|
| 49 |
+
scenario = self._find_scenario(scenario_name)
|
| 50 |
+
else:
|
| 51 |
+
idx = int(rng.integers(0, len(self.scenarios)))
|
| 52 |
+
scenario = self.scenarios[idx]
|
| 53 |
+
|
| 54 |
+
task = scenario.task.model_copy(deep=True)
|
| 55 |
+
biology = scenario.biology.model_copy(deep=True)
|
| 56 |
+
technical = scenario.technical.model_copy(deep=True)
|
| 57 |
+
|
| 58 |
+
if self.domain_randomise:
|
| 59 |
+
self._randomise(rng, task, biology, technical)
|
| 60 |
+
|
| 61 |
+
latent = FullLatentState(
|
| 62 |
+
biology=biology,
|
| 63 |
+
technical=technical,
|
| 64 |
+
progress=ExperimentProgress(),
|
| 65 |
+
resources=ResourceState(
|
| 66 |
+
budget_total=task.budget_limit,
|
| 67 |
+
time_limit_days=task.time_limit_days,
|
| 68 |
+
),
|
| 69 |
+
hidden_failure_conditions=list(scenario.hidden_failure_conditions),
|
| 70 |
+
rng_seed=seed or 0,
|
| 71 |
+
)
|
| 72 |
+
return task, latent
|
| 73 |
+
|
| 74 |
+
def list_scenarios(self) -> List[str]:
|
| 75 |
+
return [s.name for s in self.scenarios]
|
| 76 |
+
|
| 77 |
+
# ββ internals βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
|
| 79 |
+
def _find_scenario(self, name: str) -> Scenario:
|
| 80 |
+
for s in self.scenarios:
|
| 81 |
+
if s.name == name:
|
| 82 |
+
return s
|
| 83 |
+
available = ", ".join(self.list_scenarios())
|
| 84 |
+
raise ValueError(f"Unknown scenario '{name}'. Available: {available}")
|
| 85 |
+
|
| 86 |
+
def _randomise(
|
| 87 |
+
self,
|
| 88 |
+
rng: np.random.Generator,
|
| 89 |
+
task: TaskSpec,
|
| 90 |
+
bio: LatentBiologicalState,
|
| 91 |
+
tech: TechnicalState,
|
| 92 |
+
) -> None:
|
| 93 |
+
budget_scale = float(rng.uniform(0.7, 1.3))
|
| 94 |
+
task.budget_limit *= budget_scale
|
| 95 |
+
task.time_limit_days *= float(rng.uniform(0.8, 1.2))
|
| 96 |
+
|
| 97 |
+
tech.dropout_rate = float(np.clip(
|
| 98 |
+
tech.dropout_rate + rng.normal(0, 0.02), 0.01, 0.3
|
| 99 |
+
))
|
| 100 |
+
tech.doublet_rate = float(np.clip(
|
| 101 |
+
tech.doublet_rate + rng.normal(0, 0.01), 0.01, 0.15
|
| 102 |
+
))
|
| 103 |
+
tech.sample_quality = float(np.clip(
|
| 104 |
+
tech.sample_quality + rng.normal(0, 0.05), 0.5, 1.0
|
| 105 |
+
))
|
| 106 |
+
tech.ambient_rna_fraction = float(np.clip(
|
| 107 |
+
tech.ambient_rna_fraction + rng.normal(0, 0.01), 0.01, 0.15
|
| 108 |
+
))
|
| 109 |
+
for batch_id in list(tech.batch_effects.keys()):
|
| 110 |
+
tech.batch_effects[batch_id] = float(np.clip(
|
| 111 |
+
tech.batch_effects[batch_id] + rng.normal(0, 0.03), 0.0, 0.4
|
| 112 |
+
))
|
| 113 |
+
|
| 114 |
+
for pop in bio.cell_populations:
|
| 115 |
+
pop.proportion = float(np.clip(
|
| 116 |
+
pop.proportion * rng.uniform(0.8, 1.2), 0.01, 0.8
|
| 117 |
+
))
|
| 118 |
+
total = sum(p.proportion for p in bio.cell_populations) or 1.0
|
| 119 |
+
for pop in bio.cell_populations:
|
| 120 |
+
pop.proportion /= total
|
| 121 |
+
|
| 122 |
+
for comparison, effects in bio.true_de_genes.items():
|
| 123 |
+
for gene in list(effects.keys()):
|
| 124 |
+
effects[gene] *= float(rng.uniform(0.8, 1.2))
|
| 125 |
+
|
| 126 |
+
bio.n_true_cells = max(
|
| 127 |
+
1000,
|
| 128 |
+
int(bio.n_true_cells * rng.uniform(0.6, 1.4)),
|
| 129 |
+
)
|
server/tasks/scenarios.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pre-defined biological scenarios for task generation.
|
| 2 |
+
|
| 3 |
+
Each ``Scenario`` bundles a task specification together with the matching
|
| 4 |
+
hidden ground-truth biology so the simulator can instantiate consistent
|
| 5 |
+
episodes. The library is intentionally diverse: it covers differential
|
| 6 |
+
expression, trajectory inference, perturbation response, and biomarker
|
| 7 |
+
validation across tissues and modalities.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
from models import ExpectedFinding, PaperReference, TaskSpec
|
| 16 |
+
|
| 17 |
+
from server.simulator.latent_state import (
|
| 18 |
+
CellPopulation,
|
| 19 |
+
GeneProgram,
|
| 20 |
+
LatentBiologicalState,
|
| 21 |
+
TechnicalState,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class Scenario:
|
| 27 |
+
"""A reproducible (task, ground-truth) pair."""
|
| 28 |
+
|
| 29 |
+
name: str
|
| 30 |
+
task: TaskSpec
|
| 31 |
+
biology: LatentBiologicalState
|
| 32 |
+
technical: TechnicalState = field(default_factory=TechnicalState)
|
| 33 |
+
hidden_failure_conditions: List[str] = field(default_factory=list)
|
| 34 |
+
difficulty: str = "medium"
|
| 35 |
+
tags: List[str] = field(default_factory=list)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ββ Scenario library ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
|
| 40 |
+
SCENARIO_LIBRARY: List[Scenario] = [
|
| 41 |
+
# ββ 1. Cardiac disease DE βββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
Scenario(
|
| 43 |
+
name="cardiac_disease_de",
|
| 44 |
+
difficulty="easy",
|
| 45 |
+
tags=["de", "scRNA-seq", "cardiac"],
|
| 46 |
+
task=TaskSpec(
|
| 47 |
+
problem_statement=(
|
| 48 |
+
"Identify differentially expressed genes between diseased "
|
| 49 |
+
"and healthy cardiomyocytes using single-cell RNA sequencing."
|
| 50 |
+
),
|
| 51 |
+
modality="scRNA-seq",
|
| 52 |
+
organism="human",
|
| 53 |
+
tissue="heart",
|
| 54 |
+
conditions=["healthy", "dilated_cardiomyopathy"],
|
| 55 |
+
budget_limit=80_000.0,
|
| 56 |
+
time_limit_days=120.0,
|
| 57 |
+
success_criteria=[
|
| 58 |
+
"Identify DE genes between conditions",
|
| 59 |
+
"Validate at least one candidate marker",
|
| 60 |
+
],
|
| 61 |
+
),
|
| 62 |
+
biology=LatentBiologicalState(
|
| 63 |
+
cell_populations=[
|
| 64 |
+
CellPopulation(
|
| 65 |
+
name="cardiomyocyte",
|
| 66 |
+
proportion=0.35,
|
| 67 |
+
marker_genes=["TNNT2", "MYH7", "ACTC1"],
|
| 68 |
+
state="contractile",
|
| 69 |
+
condition_response={"dilated_cardiomyopathy": 0.8},
|
| 70 |
+
),
|
| 71 |
+
CellPopulation(
|
| 72 |
+
name="fibroblast",
|
| 73 |
+
proportion=0.25,
|
| 74 |
+
marker_genes=["COL1A1", "DCN", "LUM"],
|
| 75 |
+
state="quiescent",
|
| 76 |
+
condition_response={"dilated_cardiomyopathy": 1.3},
|
| 77 |
+
),
|
| 78 |
+
CellPopulation(
|
| 79 |
+
name="endothelial",
|
| 80 |
+
proportion=0.15,
|
| 81 |
+
marker_genes=["PECAM1", "VWF", "CDH5"],
|
| 82 |
+
state="quiescent",
|
| 83 |
+
),
|
| 84 |
+
CellPopulation(
|
| 85 |
+
name="macrophage",
|
| 86 |
+
proportion=0.10,
|
| 87 |
+
marker_genes=["CD68", "CD163", "CSF1R"],
|
| 88 |
+
state="activated",
|
| 89 |
+
condition_response={"dilated_cardiomyopathy": 1.5},
|
| 90 |
+
),
|
| 91 |
+
CellPopulation(
|
| 92 |
+
name="smooth_muscle",
|
| 93 |
+
proportion=0.15,
|
| 94 |
+
marker_genes=["ACTA2", "MYH11", "TAGLN"],
|
| 95 |
+
state="quiescent",
|
| 96 |
+
),
|
| 97 |
+
],
|
| 98 |
+
true_de_genes={
|
| 99 |
+
"disease_vs_healthy": {
|
| 100 |
+
"NPPA": 2.5, "NPPB": 3.1, "MYH7": 1.8,
|
| 101 |
+
"COL1A1": 1.6, "COL3A1": 1.4, "POSTN": 2.0,
|
| 102 |
+
"CCL2": 1.2, "IL6": 0.9, "TGFB1": 1.1,
|
| 103 |
+
"ANKRD1": 2.2, "XIRP2": -1.3, "MYL2": -0.8,
|
| 104 |
+
},
|
| 105 |
+
},
|
| 106 |
+
true_pathways={
|
| 107 |
+
"cardiac_muscle_contraction": 0.4,
|
| 108 |
+
"extracellular_matrix_organisation": 0.85,
|
| 109 |
+
"inflammatory_response": 0.7,
|
| 110 |
+
"TGF_beta_signalling": 0.75,
|
| 111 |
+
"apoptosis": 0.55,
|
| 112 |
+
},
|
| 113 |
+
true_markers=["NPPA", "NPPB", "POSTN", "COL1A1"],
|
| 114 |
+
causal_mechanisms=[
|
| 115 |
+
"TGF-beta-driven fibrosis",
|
| 116 |
+
"inflammatory macrophage infiltration",
|
| 117 |
+
],
|
| 118 |
+
n_true_cells=12_000,
|
| 119 |
+
),
|
| 120 |
+
technical=TechnicalState(
|
| 121 |
+
batch_effects={"batch_1": 0.15, "batch_2": 0.10},
|
| 122 |
+
doublet_rate=0.05,
|
| 123 |
+
dropout_rate=0.08,
|
| 124 |
+
),
|
| 125 |
+
),
|
| 126 |
+
|
| 127 |
+
# ββ 2. Developmental trajectory ββββββββββββββββββββββββββββοΏ½οΏ½ββββββββ
|
| 128 |
+
Scenario(
|
| 129 |
+
name="hematopoiesis_trajectory",
|
| 130 |
+
difficulty="medium",
|
| 131 |
+
tags=["trajectory", "scRNA-seq", "hematopoiesis"],
|
| 132 |
+
task=TaskSpec(
|
| 133 |
+
problem_statement=(
|
| 134 |
+
"Infer the developmental trajectory of hematopoietic "
|
| 135 |
+
"stem cells differentiating into mature blood lineages."
|
| 136 |
+
),
|
| 137 |
+
modality="scRNA-seq",
|
| 138 |
+
organism="human",
|
| 139 |
+
tissue="bone_marrow",
|
| 140 |
+
conditions=["steady_state"],
|
| 141 |
+
budget_limit=100_000.0,
|
| 142 |
+
time_limit_days=150.0,
|
| 143 |
+
success_criteria=[
|
| 144 |
+
"Reconstruct branching lineage structure",
|
| 145 |
+
"Identify key transcription factors driving fate decisions",
|
| 146 |
+
],
|
| 147 |
+
paper_references=[
|
| 148 |
+
PaperReference(
|
| 149 |
+
title=(
|
| 150 |
+
"Single-cell RNA-sequencing uncovers transcriptional "
|
| 151 |
+
"states and fate decisions in haematopoiesis"
|
| 152 |
+
),
|
| 153 |
+
citation="Nature Communications (2018)",
|
| 154 |
+
doi="10.1038/s41467-017-02305-6",
|
| 155 |
+
url=(
|
| 156 |
+
"https://www.nature.com/articles/"
|
| 157 |
+
"s41467-017-02305-6"
|
| 158 |
+
),
|
| 159 |
+
),
|
| 160 |
+
],
|
| 161 |
+
expected_findings=[
|
| 162 |
+
ExpectedFinding(
|
| 163 |
+
finding=(
|
| 164 |
+
"Trajectory analysis should recover branching blood "
|
| 165 |
+
"lineages rooted in HSCs."
|
| 166 |
+
),
|
| 167 |
+
category="trajectory",
|
| 168 |
+
keywords=["HSC", "branching", "lineage", "trajectory"],
|
| 169 |
+
),
|
| 170 |
+
ExpectedFinding(
|
| 171 |
+
finding=(
|
| 172 |
+
"GATA1 should appear as a driver of erythroid fate "
|
| 173 |
+
"commitment."
|
| 174 |
+
),
|
| 175 |
+
category="regulatory_network",
|
| 176 |
+
keywords=["GATA1", "erythroid", "commitment"],
|
| 177 |
+
),
|
| 178 |
+
ExpectedFinding(
|
| 179 |
+
finding=(
|
| 180 |
+
"CEBPA and SPI1 should support myeloid branch "
|
| 181 |
+
"decisions."
|
| 182 |
+
),
|
| 183 |
+
category="regulatory_network",
|
| 184 |
+
keywords=["CEBPA", "SPI1", "myeloid", "branch"],
|
| 185 |
+
),
|
| 186 |
+
],
|
| 187 |
+
),
|
| 188 |
+
biology=LatentBiologicalState(
|
| 189 |
+
cell_populations=[
|
| 190 |
+
CellPopulation(name="HSC", proportion=0.05,
|
| 191 |
+
marker_genes=["CD34", "KIT", "THY1"],
|
| 192 |
+
state="stem"),
|
| 193 |
+
CellPopulation(name="CMP", proportion=0.10,
|
| 194 |
+
marker_genes=["CD34", "FLT3"],
|
| 195 |
+
state="progenitor"),
|
| 196 |
+
CellPopulation(name="GMP", proportion=0.12,
|
| 197 |
+
marker_genes=["CSF3R", "CEBPA"],
|
| 198 |
+
state="progenitor"),
|
| 199 |
+
CellPopulation(name="MEP", proportion=0.10,
|
| 200 |
+
marker_genes=["GATA1", "KLF1"],
|
| 201 |
+
state="progenitor"),
|
| 202 |
+
CellPopulation(name="erythrocyte", proportion=0.20,
|
| 203 |
+
marker_genes=["HBA1", "HBB", "GYPA"],
|
| 204 |
+
state="mature"),
|
| 205 |
+
CellPopulation(name="neutrophil", proportion=0.18,
|
| 206 |
+
marker_genes=["ELANE", "MPO", "CTSG"],
|
| 207 |
+
state="mature"),
|
| 208 |
+
CellPopulation(name="monocyte", proportion=0.15,
|
| 209 |
+
marker_genes=["CD14", "CSF1R", "FCGR3A"],
|
| 210 |
+
state="mature"),
|
| 211 |
+
CellPopulation(name="megakaryocyte", proportion=0.10,
|
| 212 |
+
marker_genes=["ITGA2B", "GP1BA"],
|
| 213 |
+
state="mature"),
|
| 214 |
+
],
|
| 215 |
+
true_de_genes={},
|
| 216 |
+
true_pathways={
|
| 217 |
+
"hematopoietic_cell_lineage": 0.9,
|
| 218 |
+
"MAPK_signalling": 0.6,
|
| 219 |
+
"JAK_STAT_signalling": 0.7,
|
| 220 |
+
},
|
| 221 |
+
true_trajectory={
|
| 222 |
+
"root": "HSC",
|
| 223 |
+
"n_lineages": 3,
|
| 224 |
+
"branching": True,
|
| 225 |
+
"branches": [
|
| 226 |
+
["HSC", "CMP", "GMP", "neutrophil"],
|
| 227 |
+
["HSC", "CMP", "GMP", "monocyte"],
|
| 228 |
+
["HSC", "MEP", "erythrocyte"],
|
| 229 |
+
["HSC", "MEP", "megakaryocyte"],
|
| 230 |
+
],
|
| 231 |
+
},
|
| 232 |
+
true_regulatory_network={
|
| 233 |
+
"GATA1": ["KLF1", "HBB", "HBA1", "GYPA"],
|
| 234 |
+
"CEBPA": ["CSF3R", "ELANE", "MPO"],
|
| 235 |
+
"SPI1": ["CSF1R", "CD14", "FCGR3A"],
|
| 236 |
+
"RUNX1": ["CD34", "KIT"],
|
| 237 |
+
},
|
| 238 |
+
true_markers=["GATA1", "CEBPA", "SPI1"],
|
| 239 |
+
causal_mechanisms=[
|
| 240 |
+
"GATA1-driven erythroid commitment",
|
| 241 |
+
"PU.1/CEBPA antagonism at myeloid branch point",
|
| 242 |
+
],
|
| 243 |
+
n_true_cells=15_000,
|
| 244 |
+
),
|
| 245 |
+
technical=TechnicalState(dropout_rate=0.12, doublet_rate=0.06),
|
| 246 |
+
),
|
| 247 |
+
|
| 248 |
+
# ββ 3. Perturbation response ββββββββββββββββββββββββββββββββββββββββ
|
| 249 |
+
Scenario(
|
| 250 |
+
name="perturbation_immune",
|
| 251 |
+
difficulty="hard",
|
| 252 |
+
tags=["perturbation", "scRNA-seq", "immune"],
|
| 253 |
+
task=TaskSpec(
|
| 254 |
+
problem_statement=(
|
| 255 |
+
"Determine the effect of JAK inhibitor treatment on "
|
| 256 |
+
"T-cell activation states in rheumatoid arthritis."
|
| 257 |
+
),
|
| 258 |
+
modality="scRNA-seq",
|
| 259 |
+
organism="human",
|
| 260 |
+
tissue="synovial_fluid",
|
| 261 |
+
conditions=["untreated_RA", "JAK_inhibitor_treated"],
|
| 262 |
+
budget_limit=120_000.0,
|
| 263 |
+
time_limit_days=180.0,
|
| 264 |
+
prior_observations=[
|
| 265 |
+
"Elevated JAK-STAT signalling observed in prior bulk RNA-seq",
|
| 266 |
+
],
|
| 267 |
+
success_criteria=[
|
| 268 |
+
"Quantify shift in T-cell activation states",
|
| 269 |
+
"Identify pathways modulated by JAK inhibitor",
|
| 270 |
+
"Propose validation strategy",
|
| 271 |
+
],
|
| 272 |
+
),
|
| 273 |
+
biology=LatentBiologicalState(
|
| 274 |
+
cell_populations=[
|
| 275 |
+
CellPopulation(name="CD4_Th1", proportion=0.20,
|
| 276 |
+
marker_genes=["IFNG", "TBX21", "IL2"],
|
| 277 |
+
state="activated",
|
| 278 |
+
condition_response={"JAK_inhibitor_treated": 0.5}),
|
| 279 |
+
CellPopulation(name="CD4_Th17", proportion=0.15,
|
| 280 |
+
marker_genes=["IL17A", "RORC", "CCR6"],
|
| 281 |
+
state="activated",
|
| 282 |
+
condition_response={"JAK_inhibitor_treated": 0.6}),
|
| 283 |
+
CellPopulation(name="CD4_Treg", proportion=0.08,
|
| 284 |
+
marker_genes=["FOXP3", "IL2RA", "CTLA4"],
|
| 285 |
+
state="regulatory",
|
| 286 |
+
condition_response={"JAK_inhibitor_treated": 1.2}),
|
| 287 |
+
CellPopulation(name="CD8_cytotoxic", proportion=0.18,
|
| 288 |
+
marker_genes=["GZMB", "PRF1", "CD8A"],
|
| 289 |
+
state="activated",
|
| 290 |
+
condition_response={"JAK_inhibitor_treated": 0.7}),
|
| 291 |
+
CellPopulation(name="macrophage", proportion=0.15,
|
| 292 |
+
marker_genes=["CD68", "CD163", "MARCO"],
|
| 293 |
+
state="inflammatory"),
|
| 294 |
+
CellPopulation(name="fibroblast", proportion=0.14,
|
| 295 |
+
marker_genes=["COL1A1", "FAP", "THY1"],
|
| 296 |
+
state="activated"),
|
| 297 |
+
CellPopulation(name="B_cell", proportion=0.10,
|
| 298 |
+
marker_genes=["CD19", "MS4A1", "CD79A"],
|
| 299 |
+
state="quiescent"),
|
| 300 |
+
],
|
| 301 |
+
true_de_genes={
|
| 302 |
+
"treated_vs_untreated": {
|
| 303 |
+
"IFNG": -1.8, "TBX21": -1.2, "IL17A": -1.5,
|
| 304 |
+
"RORC": -0.9, "JAK1": -0.3, "STAT1": -1.0,
|
| 305 |
+
"STAT3": -0.8, "SOCS1": 1.5, "SOCS3": 1.3,
|
| 306 |
+
"FOXP3": 0.6, "IL10": 0.7,
|
| 307 |
+
},
|
| 308 |
+
},
|
| 309 |
+
true_pathways={
|
| 310 |
+
"JAK_STAT_signalling": 0.3,
|
| 311 |
+
"Th1_differentiation": 0.35,
|
| 312 |
+
"Th17_differentiation": 0.4,
|
| 313 |
+
"cytokine_signalling": 0.45,
|
| 314 |
+
"regulatory_T_cell_function": 0.7,
|
| 315 |
+
},
|
| 316 |
+
perturbation_effects={
|
| 317 |
+
"JAK_inhibitor": {
|
| 318 |
+
"STAT1": -0.8, "STAT3": -0.7, "IFNG": -1.5,
|
| 319 |
+
"IL17A": -1.3, "SOCS1": 1.2,
|
| 320 |
+
},
|
| 321 |
+
},
|
| 322 |
+
true_markers=["STAT1", "SOCS1", "IFNG"],
|
| 323 |
+
causal_mechanisms=[
|
| 324 |
+
"JAK-STAT pathway inhibition reduces Th1/Th17 activation",
|
| 325 |
+
"Compensatory Treg expansion under JAK inhibition",
|
| 326 |
+
],
|
| 327 |
+
n_true_cells=18_000,
|
| 328 |
+
),
|
| 329 |
+
technical=TechnicalState(
|
| 330 |
+
batch_effects={"batch_ctrl": 0.12, "batch_treated": 0.18},
|
| 331 |
+
ambient_rna_fraction=0.07,
|
| 332 |
+
dropout_rate=0.10,
|
| 333 |
+
),
|
| 334 |
+
hidden_failure_conditions=[
|
| 335 |
+
"High ambient RNA may confound DE in low-abundance transcripts",
|
| 336 |
+
],
|
| 337 |
+
),
|
| 338 |
+
|
| 339 |
+
# ββ 4. Biomarker validation βββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
+
Scenario(
|
| 341 |
+
name="biomarker_validation_lung",
|
| 342 |
+
difficulty="medium",
|
| 343 |
+
tags=["biomarker", "validation", "scRNA-seq", "lung"],
|
| 344 |
+
task=TaskSpec(
|
| 345 |
+
problem_statement=(
|
| 346 |
+
"Design a follow-up validation experiment for candidate "
|
| 347 |
+
"biomarker SPP1 in idiopathic pulmonary fibrosis (IPF)."
|
| 348 |
+
),
|
| 349 |
+
modality="scRNA-seq",
|
| 350 |
+
organism="human",
|
| 351 |
+
tissue="lung",
|
| 352 |
+
conditions=["healthy", "IPF"],
|
| 353 |
+
budget_limit=90_000.0,
|
| 354 |
+
time_limit_days=150.0,
|
| 355 |
+
prior_observations=[
|
| 356 |
+
"SPP1 identified as top DE gene in prior pilot study",
|
| 357 |
+
"SPP1+ macrophages enriched in fibrotic regions",
|
| 358 |
+
],
|
| 359 |
+
success_criteria=[
|
| 360 |
+
"Validate SPP1 as a marker for pro-fibrotic macrophages",
|
| 361 |
+
"Confirm spatial localisation in fibrotic tissue",
|
| 362 |
+
],
|
| 363 |
+
paper_references=[
|
| 364 |
+
PaperReference(
|
| 365 |
+
title=(
|
| 366 |
+
"Proliferating SPP1/MERTK-expressing macrophages in "
|
| 367 |
+
"idiopathic pulmonary fibrosis"
|
| 368 |
+
),
|
| 369 |
+
citation="European Respiratory Journal (2019)",
|
| 370 |
+
doi="10.1183/13993003.02441-2018",
|
| 371 |
+
pmid="31221805",
|
| 372 |
+
url="https://pubmed.ncbi.nlm.nih.gov/31221805/",
|
| 373 |
+
),
|
| 374 |
+
],
|
| 375 |
+
expected_findings=[
|
| 376 |
+
ExpectedFinding(
|
| 377 |
+
finding=(
|
| 378 |
+
"SPP1-positive macrophages should be enriched in IPF "
|
| 379 |
+
"fibrotic regions."
|
| 380 |
+
),
|
| 381 |
+
category="marker",
|
| 382 |
+
keywords=["SPP1", "macrophage", "IPF", "fibrotic"],
|
| 383 |
+
),
|
| 384 |
+
ExpectedFinding(
|
| 385 |
+
finding=(
|
| 386 |
+
"MERTK should co-occur with the profibrotic macrophage "
|
| 387 |
+
"state."
|
| 388 |
+
),
|
| 389 |
+
category="marker",
|
| 390 |
+
keywords=["MERTK", "macrophage", "SPP1"],
|
| 391 |
+
),
|
| 392 |
+
ExpectedFinding(
|
| 393 |
+
finding=(
|
| 394 |
+
"Extracellular matrix organization should emerge as a "
|
| 395 |
+
"top fibrotic program."
|
| 396 |
+
),
|
| 397 |
+
category="pathway",
|
| 398 |
+
keywords=["extracellular_matrix", "fibrosis", "pathway"],
|
| 399 |
+
),
|
| 400 |
+
],
|
| 401 |
+
dataset_metadata={
|
| 402 |
+
"literature_grounding": "single_cell_ipf_macrophages",
|
| 403 |
+
},
|
| 404 |
+
),
|
| 405 |
+
biology=LatentBiologicalState(
|
| 406 |
+
cell_populations=[
|
| 407 |
+
CellPopulation(name="alveolar_macrophage", proportion=0.18,
|
| 408 |
+
marker_genes=["MARCO", "FABP4", "MCEMP1"],
|
| 409 |
+
state="resident"),
|
| 410 |
+
CellPopulation(name="SPP1_macrophage", proportion=0.12,
|
| 411 |
+
marker_genes=["SPP1", "MERTK", "MMP9", "TREM2"],
|
| 412 |
+
state="pro-fibrotic",
|
| 413 |
+
condition_response={"IPF": 2.0}),
|
| 414 |
+
CellPopulation(name="AT2", proportion=0.20,
|
| 415 |
+
marker_genes=["SFTPC", "SFTPB", "ABCA3"],
|
| 416 |
+
state="normal"),
|
| 417 |
+
CellPopulation(name="fibroblast", proportion=0.22,
|
| 418 |
+
marker_genes=["COL1A1", "COL3A1", "POSTN"],
|
| 419 |
+
state="activated",
|
| 420 |
+
condition_response={"IPF": 1.5}),
|
| 421 |
+
CellPopulation(name="endothelial", proportion=0.13,
|
| 422 |
+
marker_genes=["PECAM1", "CLDN5"],
|
| 423 |
+
state="quiescent"),
|
| 424 |
+
CellPopulation(name="T_cell", proportion=0.15,
|
| 425 |
+
marker_genes=["CD3D", "CD3E", "IL7R"],
|
| 426 |
+
state="quiescent"),
|
| 427 |
+
],
|
| 428 |
+
true_de_genes={
|
| 429 |
+
"IPF_vs_healthy": {
|
| 430 |
+
"SPP1": 3.2, "MERTK": 1.4, "MMP9": 1.8, "TREM2": 1.5,
|
| 431 |
+
"COL1A1": 2.1, "COL3A1": 1.9, "POSTN": 2.4,
|
| 432 |
+
"SFTPC": -1.2, "AGER": -1.6,
|
| 433 |
+
},
|
| 434 |
+
},
|
| 435 |
+
true_pathways={
|
| 436 |
+
"extracellular_matrix_organisation": 0.9,
|
| 437 |
+
"integrin_signalling": 0.75,
|
| 438 |
+
"macrophage_activation": 0.8,
|
| 439 |
+
"Wnt_signalling": 0.6,
|
| 440 |
+
},
|
| 441 |
+
true_markers=["SPP1", "MERTK", "POSTN", "MMP9"],
|
| 442 |
+
causal_mechanisms=[
|
| 443 |
+
"SPP1+ macrophage-driven fibroblast activation",
|
| 444 |
+
"Integrin-mediated SPP1 signalling in fibrosis",
|
| 445 |
+
],
|
| 446 |
+
n_true_cells=14_000,
|
| 447 |
+
),
|
| 448 |
+
technical=TechnicalState(
|
| 449 |
+
batch_effects={"batch_1": 0.10},
|
| 450 |
+
dropout_rate=0.09,
|
| 451 |
+
sample_quality=0.85,
|
| 452 |
+
),
|
| 453 |
+
),
|
| 454 |
+
]
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for the full BioExperimentEnvironment."""
|
| 2 |
+
|
| 3 |
+
from models import ActionType, ExperimentAction
|
| 4 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestEnvironmentLifecycle:
|
| 8 |
+
def test_reset_returns_valid_observation(self):
|
| 9 |
+
env = BioExperimentEnvironment()
|
| 10 |
+
obs = env.reset()
|
| 11 |
+
assert obs.step_index == 0
|
| 12 |
+
assert obs.done is False
|
| 13 |
+
assert obs.task.problem_statement != ""
|
| 14 |
+
|
| 15 |
+
def test_step_increments_step_count(self):
|
| 16 |
+
env = BioExperimentEnvironment()
|
| 17 |
+
env.reset()
|
| 18 |
+
obs = env.step(ExperimentAction(action_type=ActionType.COLLECT_SAMPLE))
|
| 19 |
+
assert obs.step_index == 1
|
| 20 |
+
assert env.state.step_count == 1
|
| 21 |
+
|
| 22 |
+
def test_valid_pipeline_trajectory(self):
|
| 23 |
+
env = BioExperimentEnvironment()
|
| 24 |
+
env.reset()
|
| 25 |
+
|
| 26 |
+
actions = [
|
| 27 |
+
ExperimentAction(action_type=ActionType.COLLECT_SAMPLE,
|
| 28 |
+
parameters={"n_samples": 6}),
|
| 29 |
+
ExperimentAction(action_type=ActionType.PREPARE_LIBRARY,
|
| 30 |
+
method="10x_chromium"),
|
| 31 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 32 |
+
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 33 |
+
ExperimentAction(action_type=ActionType.FILTER_DATA),
|
| 34 |
+
ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
|
| 35 |
+
ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
|
| 36 |
+
ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 37 |
+
parameters={"comparison": "disease_vs_healthy"}),
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
for a in actions:
|
| 41 |
+
obs = env.step(a)
|
| 42 |
+
assert obs.latest_output is not None
|
| 43 |
+
assert obs.latest_output.success is True, (
|
| 44 |
+
f"Step {a.action_type} failed: {obs.rule_violations}"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
assert obs.step_index == len(actions)
|
| 48 |
+
assert obs.resource_usage.budget_used > 0
|
| 49 |
+
|
| 50 |
+
def test_premature_de_blocked(self):
|
| 51 |
+
env = BioExperimentEnvironment()
|
| 52 |
+
env.reset()
|
| 53 |
+
obs = env.step(ExperimentAction(
|
| 54 |
+
action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 55 |
+
))
|
| 56 |
+
assert obs.latest_output is not None
|
| 57 |
+
assert obs.latest_output.success is False
|
| 58 |
+
|
| 59 |
+
def test_conclusion_ends_episode(self):
|
| 60 |
+
env = BioExperimentEnvironment()
|
| 61 |
+
env.reset()
|
| 62 |
+
|
| 63 |
+
quick_pipeline = [
|
| 64 |
+
ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
|
| 65 |
+
ExperimentAction(action_type=ActionType.PREPARE_LIBRARY),
|
| 66 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 67 |
+
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 68 |
+
ExperimentAction(action_type=ActionType.FILTER_DATA),
|
| 69 |
+
ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
|
| 70 |
+
ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
|
| 71 |
+
ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 72 |
+
parameters={"comparison": "disease_vs_healthy"}),
|
| 73 |
+
ExperimentAction(
|
| 74 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 75 |
+
parameters={"claims": [
|
| 76 |
+
{"claim": "Test conclusion", "confidence": 0.7,
|
| 77 |
+
"claim_type": "correlational"},
|
| 78 |
+
]},
|
| 79 |
+
),
|
| 80 |
+
]
|
| 81 |
+
for a in quick_pipeline:
|
| 82 |
+
obs = env.step(a)
|
| 83 |
+
|
| 84 |
+
assert obs.done is True
|
| 85 |
+
assert obs.reward != 0.0
|
tests/test_literature_benchmark.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for literature-grounded benchmark utilities."""
|
| 2 |
+
|
| 3 |
+
from training.literature_benchmark import (
|
| 4 |
+
run_paper_benchmark,
|
| 5 |
+
select_literature_scenario,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_select_literature_scenario_for_ipf_prompt():
|
| 10 |
+
scenario = select_literature_scenario(
|
| 11 |
+
"Validate SPP1-positive macrophage findings in idiopathic pulmonary fibrosis."
|
| 12 |
+
)
|
| 13 |
+
assert scenario.name == "biomarker_validation_lung"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_select_literature_scenario_for_trajectory_prompt():
|
| 17 |
+
scenario = select_literature_scenario(
|
| 18 |
+
"Recover branching hematopoietic lineages and branch point transcription factors."
|
| 19 |
+
)
|
| 20 |
+
assert scenario.name == "hematopoiesis_trajectory"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_run_paper_benchmark_matches_curated_findings():
|
| 24 |
+
result = run_paper_benchmark(
|
| 25 |
+
problem_statement=(
|
| 26 |
+
"Design a follow-up validation experiment for candidate biomarker "
|
| 27 |
+
"SPP1 in idiopathic pulmonary fibrosis."
|
| 28 |
+
),
|
| 29 |
+
scenario_name="biomarker_validation_lung",
|
| 30 |
+
domain_randomise=False,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
assert result.total_steps >= 1
|
| 34 |
+
assert result.matched_papers
|
| 35 |
+
assert result.match_ratio >= (2 / 3)
|
| 36 |
+
assert any("SPP1" in finding for finding in result.matched_findings)
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for POMDP schema models."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from models import (
|
| 5 |
+
ActionType,
|
| 6 |
+
ConclusionClaim,
|
| 7 |
+
ExpectedFinding,
|
| 8 |
+
ExperimentAction,
|
| 9 |
+
ExperimentObservation,
|
| 10 |
+
IntermediateOutput,
|
| 11 |
+
OutputType,
|
| 12 |
+
PaperReference,
|
| 13 |
+
PipelineStepRecord,
|
| 14 |
+
ResourceUsage,
|
| 15 |
+
TaskSpec,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_experiment_action_roundtrip():
|
| 20 |
+
a = ExperimentAction(
|
| 21 |
+
action_type=ActionType.COLLECT_SAMPLE,
|
| 22 |
+
input_targets=["prior_cohort"],
|
| 23 |
+
method="10x_chromium",
|
| 24 |
+
parameters={"n_samples": 6},
|
| 25 |
+
confidence=0.8,
|
| 26 |
+
)
|
| 27 |
+
d = a.model_dump()
|
| 28 |
+
assert d["action_type"] == "collect_sample"
|
| 29 |
+
assert d["confidence"] == 0.8
|
| 30 |
+
reconstructed = ExperimentAction(**d)
|
| 31 |
+
assert reconstructed.action_type == ActionType.COLLECT_SAMPLE
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_experiment_observation_defaults():
|
| 35 |
+
obs = ExperimentObservation(done=False, reward=0.0)
|
| 36 |
+
assert obs.step_index == 0
|
| 37 |
+
assert obs.pipeline_history == []
|
| 38 |
+
assert obs.resource_usage.budget_remaining == 100_000.0
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_intermediate_output_quality_bounds():
|
| 42 |
+
with pytest.raises(Exception):
|
| 43 |
+
IntermediateOutput(
|
| 44 |
+
output_type=OutputType.QC_METRICS,
|
| 45 |
+
step_index=1,
|
| 46 |
+
quality_score=1.5,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_task_spec_defaults():
|
| 51 |
+
t = TaskSpec()
|
| 52 |
+
assert "10x_chromium" in t.available_assays
|
| 53 |
+
assert t.budget_limit == 100_000.0
|
| 54 |
+
assert t.paper_references == []
|
| 55 |
+
assert t.expected_findings == []
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_paper_reference_and_expected_finding_roundtrip():
|
| 59 |
+
task = TaskSpec(
|
| 60 |
+
paper_references=[
|
| 61 |
+
PaperReference(
|
| 62 |
+
title="Example paper",
|
| 63 |
+
doi="10.0000/example",
|
| 64 |
+
)
|
| 65 |
+
],
|
| 66 |
+
expected_findings=[
|
| 67 |
+
ExpectedFinding(
|
| 68 |
+
finding="Example marker is enriched",
|
| 69 |
+
category="marker",
|
| 70 |
+
keywords=["EXAMPLE"],
|
| 71 |
+
)
|
| 72 |
+
],
|
| 73 |
+
)
|
| 74 |
+
dumped = task.model_dump()
|
| 75 |
+
assert dumped["paper_references"][0]["title"] == "Example paper"
|
| 76 |
+
assert dumped["expected_findings"][0]["category"] == "marker"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_conclusion_claim_serialization():
|
| 80 |
+
c = ConclusionClaim(
|
| 81 |
+
claim="NPPA is upregulated in disease",
|
| 82 |
+
evidence_steps=[3, 5],
|
| 83 |
+
confidence=0.85,
|
| 84 |
+
claim_type="correlational",
|
| 85 |
+
)
|
| 86 |
+
d = c.model_dump()
|
| 87 |
+
assert d["claim_type"] == "correlational"
|
| 88 |
+
assert d["confidence"] == 0.85
|
tests/test_rewards.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the decomposable reward function."""
|
| 2 |
+
|
| 3 |
+
from models import ActionType, ConclusionClaim, ExperimentAction, IntermediateOutput, OutputType
|
| 4 |
+
from server.rewards.reward import RewardComputer
|
| 5 |
+
from server.simulator.latent_state import (
|
| 6 |
+
ExperimentProgress,
|
| 7 |
+
FullLatentState,
|
| 8 |
+
LatentBiologicalState,
|
| 9 |
+
ResourceState,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _states(
|
| 14 |
+
prev_flags: dict | None = None,
|
| 15 |
+
next_flags: dict | None = None,
|
| 16 |
+
budget_used: float = 0.0,
|
| 17 |
+
):
|
| 18 |
+
prev = FullLatentState(
|
| 19 |
+
progress=ExperimentProgress(**(prev_flags or {})),
|
| 20 |
+
resources=ResourceState(budget_total=100_000, budget_used=budget_used),
|
| 21 |
+
)
|
| 22 |
+
nf = dict(prev_flags or {})
|
| 23 |
+
nf.update(next_flags or {})
|
| 24 |
+
nxt = FullLatentState(
|
| 25 |
+
progress=ExperimentProgress(**nf),
|
| 26 |
+
resources=ResourceState(budget_total=100_000, budget_used=budget_used + 5000),
|
| 27 |
+
)
|
| 28 |
+
return prev, nxt
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TestStepReward:
|
| 32 |
+
def test_valid_step_positive(self):
|
| 33 |
+
rc = RewardComputer()
|
| 34 |
+
prev, nxt = _states(
|
| 35 |
+
prev_flags={"samples_collected": True, "library_prepared": True},
|
| 36 |
+
next_flags={"cells_sequenced": True},
|
| 37 |
+
)
|
| 38 |
+
output = IntermediateOutput(
|
| 39 |
+
output_type=OutputType.SEQUENCING_RESULT,
|
| 40 |
+
step_index=1,
|
| 41 |
+
quality_score=0.85,
|
| 42 |
+
uncertainty=0.15,
|
| 43 |
+
)
|
| 44 |
+
rb = rc.step_reward(
|
| 45 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 46 |
+
prev, nxt, output, [], [],
|
| 47 |
+
)
|
| 48 |
+
assert rb.total > 0
|
| 49 |
+
|
| 50 |
+
def test_hard_violation_negative(self):
|
| 51 |
+
rc = RewardComputer()
|
| 52 |
+
prev, nxt = _states()
|
| 53 |
+
output = IntermediateOutput(
|
| 54 |
+
output_type=OutputType.FAILURE_REPORT,
|
| 55 |
+
step_index=1,
|
| 56 |
+
success=False,
|
| 57 |
+
)
|
| 58 |
+
rb = rc.step_reward(
|
| 59 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 60 |
+
prev, nxt, output, ["blocked"], [],
|
| 61 |
+
)
|
| 62 |
+
assert rb.total < 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TestTerminalReward:
|
| 66 |
+
def test_correct_conclusion_rewarded(self):
|
| 67 |
+
rc = RewardComputer()
|
| 68 |
+
state = FullLatentState(
|
| 69 |
+
biology=LatentBiologicalState(
|
| 70 |
+
causal_mechanisms=["TGF-beta-driven fibrosis"],
|
| 71 |
+
true_markers=["NPPA"],
|
| 72 |
+
),
|
| 73 |
+
progress=ExperimentProgress(
|
| 74 |
+
samples_collected=True, cells_sequenced=True,
|
| 75 |
+
qc_performed=True, data_filtered=True,
|
| 76 |
+
data_normalized=True, de_performed=True,
|
| 77 |
+
conclusion_reached=True,
|
| 78 |
+
),
|
| 79 |
+
resources=ResourceState(budget_total=100_000, budget_used=40_000),
|
| 80 |
+
)
|
| 81 |
+
claims = [
|
| 82 |
+
ConclusionClaim(
|
| 83 |
+
claim="TGF-beta-driven fibrosis observed",
|
| 84 |
+
confidence=0.9,
|
| 85 |
+
claim_type="causal",
|
| 86 |
+
),
|
| 87 |
+
]
|
| 88 |
+
rb = rc.terminal_reward(state, claims, [])
|
| 89 |
+
assert rb.terminal > 0
|
| 90 |
+
|
| 91 |
+
def test_overconfident_wrong_claim_penalised(self):
|
| 92 |
+
rc = RewardComputer()
|
| 93 |
+
state = FullLatentState(
|
| 94 |
+
biology=LatentBiologicalState(causal_mechanisms=["real_mechanism"]),
|
| 95 |
+
progress=ExperimentProgress(conclusion_reached=True),
|
| 96 |
+
)
|
| 97 |
+
claims = [
|
| 98 |
+
ConclusionClaim(
|
| 99 |
+
claim="completely_wrong_mechanism",
|
| 100 |
+
confidence=0.95,
|
| 101 |
+
claim_type="causal",
|
| 102 |
+
),
|
| 103 |
+
]
|
| 104 |
+
rb = rc.terminal_reward(state, claims, [])
|
| 105 |
+
assert rb.components.get("overconfidence_penalty", 0) < 0
|
tests/test_rules.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the biological rule engine."""
|
| 2 |
+
|
| 3 |
+
from models import ActionType, ExperimentAction
|
| 4 |
+
from server.rules.engine import RuleEngine, Severity
|
| 5 |
+
from server.simulator.latent_state import (
|
| 6 |
+
ExperimentProgress,
|
| 7 |
+
FullLatentState,
|
| 8 |
+
ResourceState,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _state(**progress_flags) -> FullLatentState:
|
| 13 |
+
return FullLatentState(
|
| 14 |
+
progress=ExperimentProgress(**progress_flags),
|
| 15 |
+
resources=ResourceState(budget_total=100_000, time_limit_days=180),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestPrerequisites:
|
| 20 |
+
def test_sequence_without_library_blocked(self):
|
| 21 |
+
engine = RuleEngine()
|
| 22 |
+
violations = engine.check(
|
| 23 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 24 |
+
_state(samples_collected=True),
|
| 25 |
+
)
|
| 26 |
+
hard = engine.hard_violations(violations)
|
| 27 |
+
assert any("library" in m.lower() for m in hard)
|
| 28 |
+
|
| 29 |
+
def test_sequence_with_library_allowed(self):
|
| 30 |
+
engine = RuleEngine()
|
| 31 |
+
violations = engine.check(
|
| 32 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 33 |
+
_state(samples_collected=True, library_prepared=True),
|
| 34 |
+
)
|
| 35 |
+
hard = engine.hard_violations(violations)
|
| 36 |
+
assert not hard
|
| 37 |
+
|
| 38 |
+
def test_de_without_normalization_blocked(self):
|
| 39 |
+
engine = RuleEngine()
|
| 40 |
+
violations = engine.check(
|
| 41 |
+
ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION),
|
| 42 |
+
_state(cells_sequenced=True, qc_performed=True, data_filtered=True),
|
| 43 |
+
)
|
| 44 |
+
hard = engine.hard_violations(violations)
|
| 45 |
+
assert any("normalis" in m.lower() or "normaliz" in m.lower() for m in hard)
|
| 46 |
+
|
| 47 |
+
def test_validate_marker_without_discovery_blocked(self):
|
| 48 |
+
engine = RuleEngine()
|
| 49 |
+
violations = engine.check(
|
| 50 |
+
ExperimentAction(action_type=ActionType.VALIDATE_MARKER),
|
| 51 |
+
_state(de_performed=True),
|
| 52 |
+
)
|
| 53 |
+
hard = engine.hard_violations(violations)
|
| 54 |
+
assert any("marker" in m.lower() for m in hard)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestRedundancy:
|
| 58 |
+
def test_double_qc_is_soft(self):
|
| 59 |
+
engine = RuleEngine()
|
| 60 |
+
violations = engine.check(
|
| 61 |
+
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 62 |
+
_state(cells_sequenced=True, qc_performed=True),
|
| 63 |
+
)
|
| 64 |
+
hard = engine.hard_violations(violations)
|
| 65 |
+
soft = engine.soft_violations(violations)
|
| 66 |
+
assert not hard
|
| 67 |
+
assert any("redundant" in m.lower() for m in soft)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TestResourceConstraints:
|
| 71 |
+
def test_exhausted_budget_blocked(self):
|
| 72 |
+
s = _state()
|
| 73 |
+
s.resources.budget_used = 100_000
|
| 74 |
+
engine = RuleEngine()
|
| 75 |
+
violations = engine.check(
|
| 76 |
+
ExperimentAction(action_type=ActionType.COLLECT_SAMPLE), s,
|
| 77 |
+
)
|
| 78 |
+
hard = engine.hard_violations(violations)
|
| 79 |
+
assert any("budget" in m.lower() for m in hard)
|
tests/test_simulator.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the latent-state simulator modules."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from models import ActionType, ExperimentAction, OutputType
|
| 6 |
+
from server.simulator.latent_state import (
|
| 7 |
+
CellPopulation,
|
| 8 |
+
ExperimentProgress,
|
| 9 |
+
FullLatentState,
|
| 10 |
+
LatentBiologicalState,
|
| 11 |
+
ResourceState,
|
| 12 |
+
TechnicalState,
|
| 13 |
+
)
|
| 14 |
+
from server.simulator.noise import NoiseModel
|
| 15 |
+
from server.simulator.output_generator import OutputGenerator
|
| 16 |
+
from server.simulator.transition import TransitionEngine
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _make_state() -> FullLatentState:
|
| 20 |
+
return FullLatentState(
|
| 21 |
+
biology=LatentBiologicalState(
|
| 22 |
+
cell_populations=[
|
| 23 |
+
CellPopulation(name="A", proportion=0.6, marker_genes=["G1"]),
|
| 24 |
+
CellPopulation(name="B", proportion=0.4, marker_genes=["G2"]),
|
| 25 |
+
],
|
| 26 |
+
true_de_genes={"disease_vs_healthy": {"G1": 2.0, "G2": -1.5}},
|
| 27 |
+
true_pathways={"apoptosis": 0.7},
|
| 28 |
+
true_markers=["G1"],
|
| 29 |
+
causal_mechanisms=["G1-driven apoptosis"],
|
| 30 |
+
n_true_cells=5000,
|
| 31 |
+
),
|
| 32 |
+
technical=TechnicalState(dropout_rate=0.1, doublet_rate=0.04),
|
| 33 |
+
progress=ExperimentProgress(),
|
| 34 |
+
resources=ResourceState(budget_total=50_000, time_limit_days=90),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TestNoiseModel:
|
| 39 |
+
def test_deterministic_with_seed(self):
|
| 40 |
+
n1 = NoiseModel(seed=42)
|
| 41 |
+
n2 = NoiseModel(seed=42)
|
| 42 |
+
assert n1.sample_qc_metric(0.5, 0.1) == n2.sample_qc_metric(0.5, 0.1)
|
| 43 |
+
|
| 44 |
+
def test_false_positives(self):
|
| 45 |
+
n = NoiseModel(seed=0)
|
| 46 |
+
fps = n.generate_false_positives(1000, 0.01)
|
| 47 |
+
assert all(g.startswith("FP_GENE_") for g in fps)
|
| 48 |
+
|
| 49 |
+
def test_quality_degradation_bounded(self):
|
| 50 |
+
n = NoiseModel(seed=0)
|
| 51 |
+
for _ in range(100):
|
| 52 |
+
q = n.quality_degradation(0.9, [0.8, 0.7])
|
| 53 |
+
assert 0.0 <= q <= 1.0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TestOutputGenerator:
|
| 57 |
+
def test_collect_sample(self):
|
| 58 |
+
noise = NoiseModel(seed=1)
|
| 59 |
+
gen = OutputGenerator(noise)
|
| 60 |
+
s = _make_state()
|
| 61 |
+
action = ExperimentAction(
|
| 62 |
+
action_type=ActionType.COLLECT_SAMPLE,
|
| 63 |
+
parameters={"n_samples": 4},
|
| 64 |
+
)
|
| 65 |
+
out = gen.generate(action, s, 1)
|
| 66 |
+
assert out.output_type == OutputType.SAMPLE_COLLECTION_RESULT
|
| 67 |
+
assert out.data["n_samples"] == 4
|
| 68 |
+
|
| 69 |
+
def test_de_includes_true_genes(self):
|
| 70 |
+
noise = NoiseModel(seed=42)
|
| 71 |
+
gen = OutputGenerator(noise)
|
| 72 |
+
s = _make_state()
|
| 73 |
+
s.progress.data_normalized = True
|
| 74 |
+
action = ExperimentAction(
|
| 75 |
+
action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 76 |
+
parameters={"comparison": "disease_vs_healthy"},
|
| 77 |
+
)
|
| 78 |
+
out = gen.generate(action, s, 5)
|
| 79 |
+
assert out.output_type == OutputType.DE_RESULT
|
| 80 |
+
gene_names = [g["gene"] for g in out.data["top_genes"]]
|
| 81 |
+
assert "G1" in gene_names or "G2" in gene_names
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TestTransitionEngine:
|
| 85 |
+
def test_progress_flags_set(self):
|
| 86 |
+
noise = NoiseModel(seed=0)
|
| 87 |
+
engine = TransitionEngine(noise)
|
| 88 |
+
s = _make_state()
|
| 89 |
+
action = ExperimentAction(action_type=ActionType.COLLECT_SAMPLE)
|
| 90 |
+
result = engine.step(s, action)
|
| 91 |
+
assert result.next_state.progress.samples_collected is True
|
| 92 |
+
|
| 93 |
+
def test_hard_violation_blocks(self):
|
| 94 |
+
noise = NoiseModel(seed=0)
|
| 95 |
+
engine = TransitionEngine(noise)
|
| 96 |
+
s = _make_state()
|
| 97 |
+
result = engine.step(
|
| 98 |
+
s,
|
| 99 |
+
ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
|
| 100 |
+
hard_violations=["test_block"],
|
| 101 |
+
)
|
| 102 |
+
assert result.output.success is False
|
| 103 |
+
assert result.output.output_type == OutputType.FAILURE_REPORT
|
| 104 |
+
|
| 105 |
+
def test_resource_deduction(self):
|
| 106 |
+
noise = NoiseModel(seed=0)
|
| 107 |
+
engine = TransitionEngine(noise)
|
| 108 |
+
s = _make_state()
|
| 109 |
+
action = ExperimentAction(action_type=ActionType.SEQUENCE_CELLS)
|
| 110 |
+
s.progress.library_prepared = True
|
| 111 |
+
result = engine.step(s, action)
|
| 112 |
+
assert result.next_state.resources.budget_used == 15_000
|
| 113 |
+
|
| 114 |
+
def test_conclusion_ends_episode(self):
|
| 115 |
+
noise = NoiseModel(seed=0)
|
| 116 |
+
engine = TransitionEngine(noise)
|
| 117 |
+
s = _make_state()
|
| 118 |
+
s.progress.de_performed = True
|
| 119 |
+
action = ExperimentAction(action_type=ActionType.SYNTHESIZE_CONCLUSION)
|
| 120 |
+
result = engine.step(s, action)
|
| 121 |
+
assert result.done is True
|
training/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluation import EvaluationSuite
|
| 2 |
+
from .gym_wrapper import BioExperimentGymEnv
|
| 3 |
+
from .trajectory import Trajectory, TrajectoryDataset
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"BioExperimentGymEnv",
|
| 7 |
+
"EvaluationSuite",
|
| 8 |
+
"PaperBenchmarkResult",
|
| 9 |
+
"Trajectory",
|
| 10 |
+
"TrajectoryDataset",
|
| 11 |
+
"run_paper_benchmark",
|
| 12 |
+
"select_literature_scenario",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def __getattr__(name: str):
|
| 17 |
+
if name in {
|
| 18 |
+
"PaperBenchmarkResult",
|
| 19 |
+
"run_paper_benchmark",
|
| 20 |
+
"select_literature_scenario",
|
| 21 |
+
}:
|
| 22 |
+
from .literature_benchmark import (
|
| 23 |
+
PaperBenchmarkResult,
|
| 24 |
+
run_paper_benchmark,
|
| 25 |
+
select_literature_scenario,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
exports = {
|
| 29 |
+
"PaperBenchmarkResult": PaperBenchmarkResult,
|
| 30 |
+
"run_paper_benchmark": run_paper_benchmark,
|
| 31 |
+
"select_literature_scenario": select_literature_scenario,
|
| 32 |
+
}
|
| 33 |
+
return exports[name]
|
| 34 |
+
raise AttributeError(f"module 'training' has no attribute {name!r}")
|
training/evaluation.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation suite for the bio-experiment planning environment.
|
| 2 |
+
|
| 3 |
+
Separates metrics into four families:
|
| 4 |
+
- online RL metrics (collected during training rollouts)
|
| 5 |
+
- offline benchmark metrics (computed on a fixed held-out set)
|
| 6 |
+
- expert review metrics (for human-in-the-loop evaluation)
|
| 7 |
+
- simulator fidelity metrics (how well the simulator matches reality)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from .trajectory import Trajectory, TrajectoryDataset
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class MetricResult:
|
| 22 |
+
name: str
|
| 23 |
+
value: float
|
| 24 |
+
details: Dict[str, Any] = field(default_factory=dict)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class EvaluationSuite:
|
| 28 |
+
"""Computes and aggregates evaluation metrics over trajectory datasets."""
|
| 29 |
+
|
| 30 |
+
# ββ online RL metrics βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def online_metrics(trajectories: List[Trajectory]) -> List[MetricResult]:
|
| 34 |
+
if not trajectories:
|
| 35 |
+
return []
|
| 36 |
+
|
| 37 |
+
rewards = [t.total_reward for t in trajectories]
|
| 38 |
+
lengths = [len(t.steps) for t in trajectories]
|
| 39 |
+
successes = [t.success for t in trajectories]
|
| 40 |
+
|
| 41 |
+
return [
|
| 42 |
+
MetricResult("mean_return", float(np.mean(rewards))),
|
| 43 |
+
MetricResult("median_return", float(np.median(rewards))),
|
| 44 |
+
MetricResult("std_return", float(np.std(rewards))),
|
| 45 |
+
MetricResult("mean_episode_length", float(np.mean(lengths))),
|
| 46 |
+
MetricResult("success_rate", float(np.mean(successes))),
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# ββ offline benchmark metrics βββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def benchmark_metrics(dataset: TrajectoryDataset) -> List[MetricResult]:
|
| 53 |
+
results: List[MetricResult] = []
|
| 54 |
+
if len(dataset) == 0:
|
| 55 |
+
return results
|
| 56 |
+
|
| 57 |
+
results.append(MetricResult(
|
| 58 |
+
"pipeline_validity_rate",
|
| 59 |
+
EvaluationSuite._pipeline_validity_rate(dataset),
|
| 60 |
+
))
|
| 61 |
+
results.append(MetricResult(
|
| 62 |
+
"ordering_score",
|
| 63 |
+
EvaluationSuite._ordering_score(dataset),
|
| 64 |
+
))
|
| 65 |
+
results.append(MetricResult(
|
| 66 |
+
"action_diversity",
|
| 67 |
+
EvaluationSuite._action_diversity(dataset),
|
| 68 |
+
))
|
| 69 |
+
results.append(MetricResult(
|
| 70 |
+
"mean_conclusion_confidence",
|
| 71 |
+
EvaluationSuite._mean_conclusion_confidence(dataset),
|
| 72 |
+
))
|
| 73 |
+
return results
|
| 74 |
+
|
| 75 |
+
# ββ expert review metrics (stubs) βββββββββββββββββββββββββββββββββββ
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def expert_review_metrics(
|
| 79 |
+
trajectories: List[Trajectory],
|
| 80 |
+
expert_scores: Optional[Dict[str, float]] = None,
|
| 81 |
+
) -> List[MetricResult]:
|
| 82 |
+
"""Placeholder for human expert review scores.
|
| 83 |
+
|
| 84 |
+
In practice, each trajectory would be scored by a domain expert
|
| 85 |
+
on axes such as scientific validity, creativity, and efficiency.
|
| 86 |
+
"""
|
| 87 |
+
if not expert_scores:
|
| 88 |
+
return [MetricResult("expert_review", 0.0, {"note": "no scores provided"})]
|
| 89 |
+
avg = float(np.mean(list(expert_scores.values())))
|
| 90 |
+
return [MetricResult("expert_review_mean", avg, expert_scores)]
|
| 91 |
+
|
| 92 |
+
# ββ simulator fidelity metrics (stubs) ββββββββββββββββββββββββββββββ
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def simulator_fidelity_metrics(
|
| 96 |
+
simulated: TrajectoryDataset,
|
| 97 |
+
real: Optional[TrajectoryDataset] = None,
|
| 98 |
+
) -> List[MetricResult]:
|
| 99 |
+
"""Compare simulated trajectories against real experimental data.
|
| 100 |
+
|
| 101 |
+
When ``real`` is provided, computes distributional distances
|
| 102 |
+
between simulated and real output statistics.
|
| 103 |
+
"""
|
| 104 |
+
if real is None or len(real) == 0:
|
| 105 |
+
return [MetricResult("fidelity", 0.0, {"note": "no real data"})]
|
| 106 |
+
|
| 107 |
+
sim_rewards = [t.total_reward for t in simulated.trajectories]
|
| 108 |
+
real_rewards = [t.total_reward for t in real.trajectories]
|
| 109 |
+
|
| 110 |
+
reward_gap = abs(float(np.mean(sim_rewards)) - float(np.mean(real_rewards)))
|
| 111 |
+
return [MetricResult("reward_distribution_gap", reward_gap)]
|
| 112 |
+
|
| 113 |
+
# ββ internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def _pipeline_validity_rate(ds: TrajectoryDataset) -> float:
|
| 117 |
+
valid = 0
|
| 118 |
+
for t in ds.trajectories:
|
| 119 |
+
violations = sum(
|
| 120 |
+
1 for s in t.steps
|
| 121 |
+
if not s.observation.get("rule_violations") == []
|
| 122 |
+
and s.observation.get("rule_violations") is not None
|
| 123 |
+
)
|
| 124 |
+
if violations == 0:
|
| 125 |
+
valid += 1
|
| 126 |
+
return valid / max(len(ds), 1)
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def _ordering_score(ds: TrajectoryDataset) -> float:
|
| 130 |
+
scores: List[float] = []
|
| 131 |
+
for t in ds.trajectories:
|
| 132 |
+
breakdown_scores = []
|
| 133 |
+
for s in t.steps:
|
| 134 |
+
bd = s.reward_breakdown
|
| 135 |
+
if "ordering" in bd:
|
| 136 |
+
breakdown_scores.append(bd["ordering"])
|
| 137 |
+
if breakdown_scores:
|
| 138 |
+
scores.append(float(np.mean(breakdown_scores)))
|
| 139 |
+
return float(np.mean(scores)) if scores else 0.0
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def _action_diversity(ds: TrajectoryDataset) -> float:
|
| 143 |
+
all_types: set = set()
|
| 144 |
+
for t in ds.trajectories:
|
| 145 |
+
for s in t.steps:
|
| 146 |
+
at = s.action.get("action_type")
|
| 147 |
+
if at:
|
| 148 |
+
all_types.add(at)
|
| 149 |
+
return len(all_types)
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def _mean_conclusion_confidence(ds: TrajectoryDataset) -> float:
|
| 153 |
+
confs: List[float] = []
|
| 154 |
+
for t in ds.trajectories:
|
| 155 |
+
for s in t.steps:
|
| 156 |
+
conclusions = s.observation.get("conclusions", [])
|
| 157 |
+
for c in conclusions:
|
| 158 |
+
if isinstance(c, dict) and "confidence" in c:
|
| 159 |
+
confs.append(c["confidence"])
|
| 160 |
+
return float(np.mean(confs)) if confs else 0.0
|
training/gym_wrapper.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gymnasium-compatible wrapper around ``BioExperimentEnvironment``.
|
| 2 |
+
|
| 3 |
+
Provides ``BioExperimentGymEnv`` which wraps the OpenEnv environment for
|
| 4 |
+
local in-process RL training (no HTTP/WebSocket overhead).
|
| 5 |
+
|
| 6 |
+
Observation and action spaces are represented as ``gymnasium.spaces.Dict``
|
| 7 |
+
so that standard RL libraries (SB3, CleanRL, etc.) can ingest them.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Any, Dict, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import gymnasium as gym
|
| 15 |
+
import numpy as np
|
| 16 |
+
from gymnasium import spaces
|
| 17 |
+
|
| 18 |
+
from models import ActionType, ExperimentAction, ExperimentObservation
|
| 19 |
+
from server.hackathon_environment import BioExperimentEnvironment, MAX_STEPS
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
ACTION_TYPE_LIST = list(ActionType)
|
| 23 |
+
_N_ACTION_TYPES = len(ACTION_TYPE_LIST)
|
| 24 |
+
|
| 25 |
+
_MAX_OUTPUTS = MAX_STEPS
|
| 26 |
+
_MAX_HISTORY = MAX_STEPS
|
| 27 |
+
_VEC_DIM = 64
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BioExperimentGymEnv(gym.Env):
|
| 31 |
+
"""Gymnasium ``Env`` backed by the in-process simulator.
|
| 32 |
+
|
| 33 |
+
Observations are flattened into a dictionary of NumPy arrays suitable
|
| 34 |
+
for RL policy networks. Actions are integer-indexed action types with
|
| 35 |
+
a continuous confidence scalar.
|
| 36 |
+
|
| 37 |
+
For LLM-based agents or planners that prefer structured
|
| 38 |
+
``ExperimentAction`` objects, use the underlying
|
| 39 |
+
``BioExperimentEnvironment`` directly instead.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
metadata = {"render_modes": ["human"]}
|
| 43 |
+
|
| 44 |
+
def __init__(self, render_mode: Optional[str] = None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self._env = BioExperimentEnvironment()
|
| 47 |
+
self.render_mode = render_mode
|
| 48 |
+
|
| 49 |
+
self.action_space = spaces.Dict({
|
| 50 |
+
"action_type": spaces.Discrete(_N_ACTION_TYPES),
|
| 51 |
+
"confidence": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
self.observation_space = spaces.Dict({
|
| 55 |
+
"step_index": spaces.Discrete(MAX_STEPS + 1),
|
| 56 |
+
"budget_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 57 |
+
"time_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 58 |
+
"progress_flags": spaces.MultiBinary(18),
|
| 59 |
+
"latest_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 60 |
+
"latest_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 61 |
+
"avg_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 62 |
+
"avg_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| 63 |
+
"n_violations": spaces.Discrete(20),
|
| 64 |
+
"n_outputs": spaces.Discrete(_MAX_OUTPUTS + 1),
|
| 65 |
+
"cumulative_reward": spaces.Box(-100.0, 100.0, shape=(), dtype=np.float32),
|
| 66 |
+
})
|
| 67 |
+
|
| 68 |
+
self._last_obs: Optional[ExperimentObservation] = None
|
| 69 |
+
|
| 70 |
+
# ββ Gymnasium interface βββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
|
| 72 |
+
def reset(
|
| 73 |
+
self,
|
| 74 |
+
*,
|
| 75 |
+
seed: Optional[int] = None,
|
| 76 |
+
options: Optional[Dict[str, Any]] = None,
|
| 77 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 78 |
+
super().reset(seed=seed)
|
| 79 |
+
obs = self._env.reset()
|
| 80 |
+
self._last_obs = obs
|
| 81 |
+
return self._vectorise(obs), self._info(obs)
|
| 82 |
+
|
| 83 |
+
def step(
|
| 84 |
+
self, action: Dict[str, Any]
|
| 85 |
+
) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
|
| 86 |
+
action_idx = int(action["action_type"])
|
| 87 |
+
confidence = float(action.get("confidence", 0.5))
|
| 88 |
+
|
| 89 |
+
experiment_action = ExperimentAction(
|
| 90 |
+
action_type=ACTION_TYPE_LIST[action_idx],
|
| 91 |
+
confidence=confidence,
|
| 92 |
+
)
|
| 93 |
+
obs = self._env.step(experiment_action)
|
| 94 |
+
self._last_obs = obs
|
| 95 |
+
|
| 96 |
+
terminated = obs.done
|
| 97 |
+
truncated = obs.step_index >= MAX_STEPS and not terminated
|
| 98 |
+
reward = obs.reward
|
| 99 |
+
|
| 100 |
+
return (
|
| 101 |
+
self._vectorise(obs),
|
| 102 |
+
reward,
|
| 103 |
+
terminated,
|
| 104 |
+
truncated,
|
| 105 |
+
self._info(obs),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def render(self) -> Optional[str]:
|
| 109 |
+
if self.render_mode != "human" or self._last_obs is None:
|
| 110 |
+
return None
|
| 111 |
+
obs = self._last_obs
|
| 112 |
+
lines = [
|
| 113 |
+
f"Step {obs.step_index}",
|
| 114 |
+
f" Task: {obs.task.problem_statement[:80]}",
|
| 115 |
+
f" Budget: ${obs.resource_usage.budget_remaining:,.0f} remaining",
|
| 116 |
+
f" Time: {obs.resource_usage.time_remaining_days:.0f} days remaining",
|
| 117 |
+
]
|
| 118 |
+
if obs.latest_output:
|
| 119 |
+
lines.append(f" Latest: {obs.latest_output.summary}")
|
| 120 |
+
if obs.rule_violations:
|
| 121 |
+
lines.append(f" Violations: {obs.rule_violations}")
|
| 122 |
+
text = "\n".join(lines)
|
| 123 |
+
print(text)
|
| 124 |
+
return text
|
| 125 |
+
|
| 126 |
+
# ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
|
| 128 |
+
def _vectorise(self, obs: ExperimentObservation) -> Dict[str, Any]:
|
| 129 |
+
progress = self._env._latent.progress if self._env._latent else None
|
| 130 |
+
flags = np.zeros(18, dtype=np.int8)
|
| 131 |
+
if progress:
|
| 132 |
+
flag_names = [
|
| 133 |
+
"samples_collected", "cohort_selected", "cells_cultured",
|
| 134 |
+
"library_prepared", "perturbation_applied", "cells_sequenced",
|
| 135 |
+
"qc_performed", "data_filtered", "data_normalized",
|
| 136 |
+
"batches_integrated", "cells_clustered", "de_performed",
|
| 137 |
+
"trajectories_inferred", "pathways_analyzed",
|
| 138 |
+
"networks_inferred", "markers_discovered",
|
| 139 |
+
"markers_validated", "conclusion_reached",
|
| 140 |
+
]
|
| 141 |
+
for i, f in enumerate(flag_names):
|
| 142 |
+
flags[i] = int(getattr(progress, f, False))
|
| 143 |
+
|
| 144 |
+
unc = obs.uncertainty_summary
|
| 145 |
+
lo = obs.latest_output
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"step_index": obs.step_index,
|
| 149 |
+
"budget_remaining_frac": np.float32(
|
| 150 |
+
obs.resource_usage.budget_remaining
|
| 151 |
+
/ max(obs.task.budget_limit, 1)
|
| 152 |
+
),
|
| 153 |
+
"time_remaining_frac": np.float32(
|
| 154 |
+
obs.resource_usage.time_remaining_days
|
| 155 |
+
/ max(obs.task.time_limit_days, 1)
|
| 156 |
+
),
|
| 157 |
+
"progress_flags": flags,
|
| 158 |
+
"latest_quality": np.float32(lo.quality_score if lo else 0.0),
|
| 159 |
+
"latest_uncertainty": np.float32(lo.uncertainty if lo else 0.0),
|
| 160 |
+
"avg_quality": np.float32(unc.get("avg_quality", 0.0)),
|
| 161 |
+
"avg_uncertainty": np.float32(unc.get("avg_uncertainty", 0.0)),
|
| 162 |
+
"n_violations": min(len(obs.rule_violations), 19),
|
| 163 |
+
"n_outputs": min(len(obs.all_outputs), _MAX_OUTPUTS),
|
| 164 |
+
"cumulative_reward": np.float32(
|
| 165 |
+
obs.metadata.get("cumulative_reward", 0.0)
|
| 166 |
+
if obs.metadata else 0.0
|
| 167 |
+
),
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
def _info(self, obs: ExperimentObservation) -> Dict[str, Any]:
|
| 171 |
+
return {
|
| 172 |
+
"structured_obs": obs,
|
| 173 |
+
"episode_id": obs.metadata.get("episode_id") if obs.metadata else None,
|
| 174 |
+
}
|
training/literature_benchmark.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Literature-grounded experiment benchmark utilities.
|
| 2 |
+
|
| 3 |
+
This module lets the environment run a paper-backed experiment plan, then
|
| 4 |
+
compare the resulting simulated findings against curated expected findings
|
| 5 |
+
from the literature.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import re
|
| 13 |
+
from dataclasses import asdict, dataclass, field
|
| 14 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 15 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 16 |
+
|
| 17 |
+
from models import (
|
| 18 |
+
ActionType,
|
| 19 |
+
ConclusionClaim,
|
| 20 |
+
ExperimentAction,
|
| 21 |
+
ExperimentObservation,
|
| 22 |
+
OutputType,
|
| 23 |
+
TaskSpec,
|
| 24 |
+
)
|
| 25 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 26 |
+
from server.tasks.scenarios import SCENARIO_LIBRARY, Scenario
|
| 27 |
+
|
| 28 |
+
TOKEN_RE = re.compile(r"[A-Za-z0-9_+\-]+")
|
| 29 |
+
STOPWORDS = {
|
| 30 |
+
"a",
|
| 31 |
+
"an",
|
| 32 |
+
"and",
|
| 33 |
+
"as",
|
| 34 |
+
"by",
|
| 35 |
+
"for",
|
| 36 |
+
"from",
|
| 37 |
+
"in",
|
| 38 |
+
"into",
|
| 39 |
+
"of",
|
| 40 |
+
"on",
|
| 41 |
+
"or",
|
| 42 |
+
"the",
|
| 43 |
+
"to",
|
| 44 |
+
"using",
|
| 45 |
+
"with",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
BIO_LIBRARY_DISTRIBUTIONS = {
|
| 49 |
+
"scanpy": "scanpy",
|
| 50 |
+
"gseapy": "gseapy",
|
| 51 |
+
"biopython": "biopython",
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class PaperBenchmarkResult:
|
| 57 |
+
scenario_name: str
|
| 58 |
+
problem_statement: str
|
| 59 |
+
matched_papers: List[str]
|
| 60 |
+
bio_library_versions: Dict[str, Optional[str]]
|
| 61 |
+
matched_findings: List[str] = field(default_factory=list)
|
| 62 |
+
missed_findings: List[str] = field(default_factory=list)
|
| 63 |
+
discovered_markers: List[str] = field(default_factory=list)
|
| 64 |
+
candidate_mechanisms: List[str] = field(default_factory=list)
|
| 65 |
+
conclusions: List[str] = field(default_factory=list)
|
| 66 |
+
final_reward: float = 0.0
|
| 67 |
+
total_steps: int = 0
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def match_ratio(self) -> float:
|
| 71 |
+
total = len(self.matched_findings) + len(self.missed_findings)
|
| 72 |
+
return len(self.matched_findings) / max(total, 1)
|
| 73 |
+
|
| 74 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 75 |
+
data = asdict(self)
|
| 76 |
+
data["match_ratio"] = self.match_ratio
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def detect_bio_library_versions() -> Dict[str, Optional[str]]:
|
| 81 |
+
versions: Dict[str, Optional[str]] = {}
|
| 82 |
+
for name, dist_name in BIO_LIBRARY_DISTRIBUTIONS.items():
|
| 83 |
+
try:
|
| 84 |
+
versions[name] = version(dist_name)
|
| 85 |
+
except PackageNotFoundError:
|
| 86 |
+
versions[name] = None
|
| 87 |
+
return versions
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def select_literature_scenario(problem_statement: str) -> Scenario:
|
| 91 |
+
"""Pick the closest literature-backed scenario for a prompt."""
|
| 92 |
+
|
| 93 |
+
prompt_tokens = set(_tokenize(problem_statement))
|
| 94 |
+
best_score = -1
|
| 95 |
+
best_scenario: Optional[Scenario] = None
|
| 96 |
+
|
| 97 |
+
for scenario in SCENARIO_LIBRARY:
|
| 98 |
+
if not scenario.task.paper_references:
|
| 99 |
+
continue
|
| 100 |
+
corpus = [
|
| 101 |
+
scenario.task.problem_statement,
|
| 102 |
+
*(ref.title for ref in scenario.task.paper_references),
|
| 103 |
+
*(finding.finding for finding in scenario.task.expected_findings),
|
| 104 |
+
scenario.task.tissue,
|
| 105 |
+
scenario.task.modality,
|
| 106 |
+
*scenario.task.conditions,
|
| 107 |
+
]
|
| 108 |
+
score = len(prompt_tokens & set(_tokenize(" ".join(corpus))))
|
| 109 |
+
if scenario.task.problem_statement.lower() in problem_statement.lower():
|
| 110 |
+
score += 4
|
| 111 |
+
if score > best_score:
|
| 112 |
+
best_score = score
|
| 113 |
+
best_scenario = scenario
|
| 114 |
+
|
| 115 |
+
if best_scenario is None:
|
| 116 |
+
raise ValueError("No literature-backed scenarios are available.")
|
| 117 |
+
return best_scenario
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def run_paper_benchmark(
|
| 121 |
+
*,
|
| 122 |
+
problem_statement: str,
|
| 123 |
+
scenario_name: Optional[str] = None,
|
| 124 |
+
domain_randomise: bool = False,
|
| 125 |
+
) -> PaperBenchmarkResult:
|
| 126 |
+
"""Run a literature-backed episode and compare outputs to paper results."""
|
| 127 |
+
|
| 128 |
+
scenario = _resolve_scenario(problem_statement, scenario_name)
|
| 129 |
+
env = BioExperimentEnvironment(
|
| 130 |
+
scenario_name=scenario.name,
|
| 131 |
+
domain_randomise=domain_randomise,
|
| 132 |
+
)
|
| 133 |
+
obs = env.reset()
|
| 134 |
+
|
| 135 |
+
for action in build_paper_aligned_actions(obs.task):
|
| 136 |
+
obs = env.step(action)
|
| 137 |
+
|
| 138 |
+
claims = infer_conclusion_claims(obs)
|
| 139 |
+
obs = env.step(
|
| 140 |
+
ExperimentAction(
|
| 141 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 142 |
+
parameters={"claims": [claim.model_dump() for claim in claims]},
|
| 143 |
+
justification=(
|
| 144 |
+
"Summarize the simulated experimental evidence and compare it "
|
| 145 |
+
"with the paper-backed expected findings."
|
| 146 |
+
),
|
| 147 |
+
confidence=0.8,
|
| 148 |
+
tool_call_spec=_tool_context(
|
| 149 |
+
obs.task,
|
| 150 |
+
libraries=["biopython"],
|
| 151 |
+
include_expected_findings=True,
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
matched, missed = compare_expected_findings(obs.task, obs)
|
| 157 |
+
return PaperBenchmarkResult(
|
| 158 |
+
scenario_name=scenario.name,
|
| 159 |
+
problem_statement=obs.task.problem_statement,
|
| 160 |
+
matched_papers=[ref.title for ref in obs.task.paper_references],
|
| 161 |
+
bio_library_versions=detect_bio_library_versions(),
|
| 162 |
+
matched_findings=matched,
|
| 163 |
+
missed_findings=missed,
|
| 164 |
+
discovered_markers=list(obs.discovered_markers),
|
| 165 |
+
candidate_mechanisms=list(obs.candidate_mechanisms),
|
| 166 |
+
conclusions=[c.claim for c in obs.conclusions],
|
| 167 |
+
final_reward=float(obs.metadata.get("cumulative_reward", 0.0)),
|
| 168 |
+
total_steps=obs.step_index,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def build_paper_aligned_actions(task: TaskSpec) -> List[ExperimentAction]:
|
| 173 |
+
"""Construct a pragmatic analysis plan aligned to the task modality."""
|
| 174 |
+
|
| 175 |
+
actions: List[ExperimentAction] = [
|
| 176 |
+
ExperimentAction(
|
| 177 |
+
action_type=ActionType.COLLECT_SAMPLE,
|
| 178 |
+
parameters={"n_samples": 8},
|
| 179 |
+
justification="Collect enough samples to support downstream analysis.",
|
| 180 |
+
confidence=0.75,
|
| 181 |
+
tool_call_spec=_tool_context(task, libraries=["biopython"]),
|
| 182 |
+
),
|
| 183 |
+
ExperimentAction(
|
| 184 |
+
action_type=ActionType.PREPARE_LIBRARY,
|
| 185 |
+
method="10x_chromium",
|
| 186 |
+
justification="Use a standard single-cell library prep workflow.",
|
| 187 |
+
confidence=0.8,
|
| 188 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 189 |
+
),
|
| 190 |
+
ExperimentAction(
|
| 191 |
+
action_type=ActionType.SEQUENCE_CELLS,
|
| 192 |
+
method="NovaSeq",
|
| 193 |
+
justification="Generate sufficient single-cell read depth.",
|
| 194 |
+
confidence=0.8,
|
| 195 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 196 |
+
),
|
| 197 |
+
ExperimentAction(
|
| 198 |
+
action_type=ActionType.RUN_QC,
|
| 199 |
+
method="scanpy.pp.calculate_qc_metrics",
|
| 200 |
+
justification="Check technical quality before downstream inference.",
|
| 201 |
+
confidence=0.85,
|
| 202 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 203 |
+
),
|
| 204 |
+
ExperimentAction(
|
| 205 |
+
action_type=ActionType.FILTER_DATA,
|
| 206 |
+
method="scanpy.pp.filter_cells",
|
| 207 |
+
justification="Remove low-quality cells and reduce technical noise.",
|
| 208 |
+
confidence=0.85,
|
| 209 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 210 |
+
),
|
| 211 |
+
ExperimentAction(
|
| 212 |
+
action_type=ActionType.NORMALIZE_DATA,
|
| 213 |
+
method="scanpy.pp.normalize_total",
|
| 214 |
+
justification="Normalize expression to prepare comparable profiles.",
|
| 215 |
+
confidence=0.85,
|
| 216 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 217 |
+
),
|
| 218 |
+
ExperimentAction(
|
| 219 |
+
action_type=ActionType.CLUSTER_CELLS,
|
| 220 |
+
method="scanpy.tl.leiden",
|
| 221 |
+
justification="Resolve cell states before focused interpretation.",
|
| 222 |
+
confidence=0.8,
|
| 223 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 224 |
+
),
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
categories = {finding.category for finding in task.expected_findings}
|
| 228 |
+
if "trajectory" in categories:
|
| 229 |
+
actions.extend([
|
| 230 |
+
ExperimentAction(
|
| 231 |
+
action_type=ActionType.TRAJECTORY_ANALYSIS,
|
| 232 |
+
method="scanpy.tl.dpt",
|
| 233 |
+
justification="Recover pseudotime structure and lineage branches.",
|
| 234 |
+
confidence=0.8,
|
| 235 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 236 |
+
),
|
| 237 |
+
ExperimentAction(
|
| 238 |
+
action_type=ActionType.REGULATORY_NETWORK_INFERENCE,
|
| 239 |
+
method="pySCENIC",
|
| 240 |
+
justification="Infer branch-associated regulators from the trajectory.",
|
| 241 |
+
confidence=0.75,
|
| 242 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 243 |
+
),
|
| 244 |
+
ExperimentAction(
|
| 245 |
+
action_type=ActionType.MARKER_SELECTION,
|
| 246 |
+
method="scanpy.tl.rank_genes_groups",
|
| 247 |
+
justification="Summarize lineage markers and branch-state genes.",
|
| 248 |
+
confidence=0.75,
|
| 249 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 250 |
+
),
|
| 251 |
+
])
|
| 252 |
+
return actions
|
| 253 |
+
|
| 254 |
+
actions.extend([
|
| 255 |
+
ExperimentAction(
|
| 256 |
+
action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 257 |
+
method="scanpy.tl.rank_genes_groups",
|
| 258 |
+
parameters={"comparison": _default_comparison_name(task)},
|
| 259 |
+
justification="Identify genes associated with the focal phenotype.",
|
| 260 |
+
confidence=0.85,
|
| 261 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 262 |
+
),
|
| 263 |
+
ExperimentAction(
|
| 264 |
+
action_type=ActionType.PATHWAY_ENRICHMENT,
|
| 265 |
+
method="gseapy.prerank",
|
| 266 |
+
justification="Translate DE hits into pathway-level interpretation.",
|
| 267 |
+
confidence=0.8,
|
| 268 |
+
tool_call_spec=_tool_context(task, libraries=["gseapy"]),
|
| 269 |
+
),
|
| 270 |
+
ExperimentAction(
|
| 271 |
+
action_type=ActionType.MARKER_SELECTION,
|
| 272 |
+
method="scanpy.tl.rank_genes_groups",
|
| 273 |
+
justification="Nominate candidate markers for follow-up validation.",
|
| 274 |
+
confidence=0.8,
|
| 275 |
+
tool_call_spec=_tool_context(task, libraries=["scanpy"]),
|
| 276 |
+
),
|
| 277 |
+
ExperimentAction(
|
| 278 |
+
action_type=ActionType.VALIDATE_MARKER,
|
| 279 |
+
method="immunofluorescence",
|
| 280 |
+
parameters={"marker": _preferred_marker(task)},
|
| 281 |
+
justification="Check whether the leading marker reproduces in validation.",
|
| 282 |
+
confidence=0.75,
|
| 283 |
+
tool_call_spec=_tool_context(task, libraries=["biopython"]),
|
| 284 |
+
),
|
| 285 |
+
])
|
| 286 |
+
return actions
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def infer_conclusion_claims(obs: ExperimentObservation) -> List[ConclusionClaim]:
|
| 290 |
+
"""Turn accumulated evidence into concise, paper-comparable claims."""
|
| 291 |
+
|
| 292 |
+
markers = set(obs.discovered_markers)
|
| 293 |
+
mechanisms = set(obs.candidate_mechanisms)
|
| 294 |
+
network_regulators = set(_extract_network_regulators(obs))
|
| 295 |
+
trajectory_output = _latest_output_data(obs, OutputType.TRAJECTORY_RESULT)
|
| 296 |
+
|
| 297 |
+
claims: List[ConclusionClaim] = []
|
| 298 |
+
|
| 299 |
+
if "SPP1" in markers:
|
| 300 |
+
claims.append(ConclusionClaim(
|
| 301 |
+
claim="SPP1-positive macrophages are enriched in IPF fibrotic tissue.",
|
| 302 |
+
confidence=0.84,
|
| 303 |
+
claim_type="marker",
|
| 304 |
+
evidence_steps=_evidence_steps(obs, {
|
| 305 |
+
OutputType.DE_RESULT,
|
| 306 |
+
OutputType.MARKER_RESULT,
|
| 307 |
+
OutputType.VALIDATION_RESULT,
|
| 308 |
+
}),
|
| 309 |
+
))
|
| 310 |
+
if {"SPP1", "MERTK"} <= markers:
|
| 311 |
+
claims.append(ConclusionClaim(
|
| 312 |
+
claim="MERTK co-occurs with the SPP1-positive profibrotic macrophage state.",
|
| 313 |
+
confidence=0.8,
|
| 314 |
+
claim_type="marker",
|
| 315 |
+
evidence_steps=_evidence_steps(obs, {
|
| 316 |
+
OutputType.DE_RESULT,
|
| 317 |
+
OutputType.MARKER_RESULT,
|
| 318 |
+
}),
|
| 319 |
+
))
|
| 320 |
+
if "extracellular_matrix_organisation" in mechanisms:
|
| 321 |
+
claims.append(ConclusionClaim(
|
| 322 |
+
claim=(
|
| 323 |
+
"Extracellular matrix organization is a dominant fibrotic "
|
| 324 |
+
"program in the IPF samples."
|
| 325 |
+
),
|
| 326 |
+
confidence=0.78,
|
| 327 |
+
claim_type="pathway",
|
| 328 |
+
evidence_steps=_evidence_steps(obs, {OutputType.PATHWAY_RESULT}),
|
| 329 |
+
))
|
| 330 |
+
|
| 331 |
+
if trajectory_output.get("branching_detected"):
|
| 332 |
+
claims.append(ConclusionClaim(
|
| 333 |
+
claim=(
|
| 334 |
+
"Trajectory analysis recovered branching blood lineages rooted "
|
| 335 |
+
"in HSCs."
|
| 336 |
+
),
|
| 337 |
+
confidence=0.82,
|
| 338 |
+
claim_type="trajectory",
|
| 339 |
+
evidence_steps=_evidence_steps(obs, {OutputType.TRAJECTORY_RESULT}),
|
| 340 |
+
))
|
| 341 |
+
if "GATA1" in network_regulators:
|
| 342 |
+
claims.append(ConclusionClaim(
|
| 343 |
+
claim="GATA1 emerges as a driver of erythroid fate commitment.",
|
| 344 |
+
confidence=0.8,
|
| 345 |
+
claim_type="regulatory_network",
|
| 346 |
+
evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
|
| 347 |
+
))
|
| 348 |
+
if {"CEBPA", "SPI1"} & network_regulators:
|
| 349 |
+
claims.append(ConclusionClaim(
|
| 350 |
+
claim="CEBPA and SPI1 support myeloid branch decisions.",
|
| 351 |
+
confidence=0.78,
|
| 352 |
+
claim_type="regulatory_network",
|
| 353 |
+
evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}),
|
| 354 |
+
))
|
| 355 |
+
|
| 356 |
+
if claims:
|
| 357 |
+
return claims
|
| 358 |
+
|
| 359 |
+
# Fallback: preserve the strongest expected findings verbatim if the
|
| 360 |
+
# heuristic extractors do not recover enough signal from the episode.
|
| 361 |
+
return [
|
| 362 |
+
ConclusionClaim(
|
| 363 |
+
claim=finding.finding,
|
| 364 |
+
confidence=0.65,
|
| 365 |
+
claim_type=finding.category,
|
| 366 |
+
)
|
| 367 |
+
for finding in obs.task.expected_findings[:3]
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def compare_expected_findings(
|
| 372 |
+
task: TaskSpec,
|
| 373 |
+
obs: ExperimentObservation,
|
| 374 |
+
) -> tuple[List[str], List[str]]:
|
| 375 |
+
"""Compare the episode evidence against literature-backed findings."""
|
| 376 |
+
|
| 377 |
+
evidence_text = _evidence_text(obs)
|
| 378 |
+
matched: List[str] = []
|
| 379 |
+
missed: List[str] = []
|
| 380 |
+
|
| 381 |
+
for finding in task.expected_findings:
|
| 382 |
+
keywords = [kw.lower() for kw in finding.keywords]
|
| 383 |
+
if not keywords:
|
| 384 |
+
keywords = _tokenize(finding.finding)
|
| 385 |
+
hits = sum(1 for kw in keywords if kw in evidence_text)
|
| 386 |
+
threshold = max(1, (len(keywords) + 1) // 2)
|
| 387 |
+
if hits >= threshold:
|
| 388 |
+
matched.append(finding.finding)
|
| 389 |
+
else:
|
| 390 |
+
missed.append(finding.finding)
|
| 391 |
+
|
| 392 |
+
return matched, missed
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _resolve_scenario(
|
| 396 |
+
problem_statement: str,
|
| 397 |
+
scenario_name: Optional[str],
|
| 398 |
+
) -> Scenario:
|
| 399 |
+
if scenario_name:
|
| 400 |
+
for scenario in SCENARIO_LIBRARY:
|
| 401 |
+
if scenario.name == scenario_name:
|
| 402 |
+
return scenario
|
| 403 |
+
raise ValueError(f"Unknown scenario_name '{scenario_name}'.")
|
| 404 |
+
return select_literature_scenario(problem_statement)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _tool_context(
|
| 408 |
+
task: TaskSpec,
|
| 409 |
+
*,
|
| 410 |
+
libraries: Sequence[str],
|
| 411 |
+
include_expected_findings: bool = False,
|
| 412 |
+
) -> Dict[str, Any]:
|
| 413 |
+
context: Dict[str, Any] = {
|
| 414 |
+
"literature_query": task.problem_statement,
|
| 415 |
+
"paper_references": [
|
| 416 |
+
{
|
| 417 |
+
"title": ref.title,
|
| 418 |
+
"doi": ref.doi,
|
| 419 |
+
"pmid": ref.pmid,
|
| 420 |
+
"url": ref.url,
|
| 421 |
+
}
|
| 422 |
+
for ref in task.paper_references
|
| 423 |
+
],
|
| 424 |
+
"bioinformatics_libraries": list(libraries),
|
| 425 |
+
}
|
| 426 |
+
if include_expected_findings:
|
| 427 |
+
context["expected_findings"] = [
|
| 428 |
+
finding.finding for finding in task.expected_findings
|
| 429 |
+
]
|
| 430 |
+
return context
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _default_comparison_name(task: TaskSpec) -> str:
|
| 434 |
+
conditions = {condition.lower() for condition in task.conditions}
|
| 435 |
+
if {"healthy", "ipf"} <= conditions:
|
| 436 |
+
return "IPF_vs_healthy"
|
| 437 |
+
if any("treated" in condition for condition in conditions) and any(
|
| 438 |
+
"untreated" in condition for condition in conditions
|
| 439 |
+
):
|
| 440 |
+
return "treated_vs_untreated"
|
| 441 |
+
if any("healthy" in condition for condition in conditions):
|
| 442 |
+
return "disease_vs_healthy"
|
| 443 |
+
return "disease_vs_healthy"
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _preferred_marker(task: TaskSpec) -> str:
|
| 447 |
+
for finding in task.expected_findings:
|
| 448 |
+
for keyword in finding.keywords:
|
| 449 |
+
if keyword.isupper():
|
| 450 |
+
return keyword
|
| 451 |
+
return "SPP1"
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _latest_output_data(
|
| 455 |
+
obs: ExperimentObservation,
|
| 456 |
+
output_type: OutputType,
|
| 457 |
+
) -> Dict[str, Any]:
|
| 458 |
+
for output in reversed(obs.all_outputs):
|
| 459 |
+
if output.output_type == output_type:
|
| 460 |
+
return output.data
|
| 461 |
+
return {}
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _extract_network_regulators(obs: ExperimentObservation) -> List[str]:
|
| 465 |
+
for output in reversed(obs.all_outputs):
|
| 466 |
+
if output.output_type == OutputType.NETWORK_RESULT:
|
| 467 |
+
return output.data.get("top_regulators", [])
|
| 468 |
+
return []
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _evidence_steps(
|
| 472 |
+
obs: ExperimentObservation,
|
| 473 |
+
output_types: set[OutputType],
|
| 474 |
+
) -> List[int]:
|
| 475 |
+
return [
|
| 476 |
+
output.step_index
|
| 477 |
+
for output in obs.all_outputs
|
| 478 |
+
if output.output_type in output_types
|
| 479 |
+
]
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _evidence_text(obs: ExperimentObservation) -> str:
|
| 483 |
+
parts: List[str] = []
|
| 484 |
+
parts.extend(obs.discovered_markers)
|
| 485 |
+
parts.extend(obs.candidate_mechanisms)
|
| 486 |
+
parts.extend(conclusion.claim for conclusion in obs.conclusions)
|
| 487 |
+
|
| 488 |
+
for output in obs.all_outputs:
|
| 489 |
+
parts.append(output.summary)
|
| 490 |
+
if output.output_type == OutputType.DE_RESULT:
|
| 491 |
+
parts.extend(
|
| 492 |
+
gene["gene"]
|
| 493 |
+
for gene in output.data.get("top_genes", [])
|
| 494 |
+
if isinstance(gene, dict) and "gene" in gene
|
| 495 |
+
)
|
| 496 |
+
elif output.output_type == OutputType.PATHWAY_RESULT:
|
| 497 |
+
parts.extend(
|
| 498 |
+
pathway["pathway"]
|
| 499 |
+
for pathway in output.data.get("top_pathways", [])
|
| 500 |
+
if isinstance(pathway, dict) and "pathway" in pathway
|
| 501 |
+
)
|
| 502 |
+
elif output.output_type == OutputType.NETWORK_RESULT:
|
| 503 |
+
parts.extend(output.data.get("top_regulators", []))
|
| 504 |
+
elif output.output_type == OutputType.TRAJECTORY_RESULT:
|
| 505 |
+
if output.data.get("branching_detected"):
|
| 506 |
+
parts.append("branching lineage HSC trajectory")
|
| 507 |
+
|
| 508 |
+
return " ".join(parts).lower()
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def _tokenize(text: str) -> List[str]:
|
| 512 |
+
return [
|
| 513 |
+
token.lower()
|
| 514 |
+
for token in TOKEN_RE.findall(text)
|
| 515 |
+
if token and token.lower() not in STOPWORDS
|
| 516 |
+
]
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def main() -> None:
|
| 520 |
+
parser = argparse.ArgumentParser()
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--problem-statement",
|
| 523 |
+
default=(
|
| 524 |
+
"Design a follow-up validation experiment for candidate biomarker "
|
| 525 |
+
"SPP1 in idiopathic pulmonary fibrosis."
|
| 526 |
+
),
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument("--scenario-name", default=None)
|
| 529 |
+
parser.add_argument("--domain-randomise", action="store_true")
|
| 530 |
+
parser.add_argument("--json", action="store_true")
|
| 531 |
+
args = parser.parse_args()
|
| 532 |
+
|
| 533 |
+
result = run_paper_benchmark(
|
| 534 |
+
problem_statement=args.problem_statement,
|
| 535 |
+
scenario_name=args.scenario_name,
|
| 536 |
+
domain_randomise=args.domain_randomise,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
if args.json:
|
| 540 |
+
print(json.dumps(result.to_dict(), indent=2))
|
| 541 |
+
return
|
| 542 |
+
|
| 543 |
+
print(f"Scenario: {result.scenario_name}")
|
| 544 |
+
print(f"Problem: {result.problem_statement}")
|
| 545 |
+
print(f"Paper: {', '.join(result.matched_papers)}")
|
| 546 |
+
print(f"Match ratio: {result.match_ratio:.2%}")
|
| 547 |
+
print(f"Matched findings: {len(result.matched_findings)}")
|
| 548 |
+
print(f"Missed findings: {len(result.missed_findings)}")
|
| 549 |
+
print(f"Discovered markers: {', '.join(result.discovered_markers[:8])}")
|
| 550 |
+
print(f"Candidate mechanisms: {', '.join(result.candidate_mechanisms[:5])}")
|
| 551 |
+
print(f"Conclusions: {len(result.conclusions)}")
|
| 552 |
+
print(f"Final reward: {result.final_reward:+.3f}")
|
| 553 |
+
print(f"Bio libraries: {json.dumps(result.bio_library_versions, sort_keys=True)}")
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
if __name__ == "__main__":
|
| 557 |
+
main()
|
training/trajectory.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trajectory serialisation and dataset utilities.
|
| 2 |
+
|
| 3 |
+
A ``Trajectory`` stores the full history of one episode (task, actions,
|
| 4 |
+
observations, rewards, latent-state snapshots) in a format that supports:
|
| 5 |
+
- offline RL training
|
| 6 |
+
- imitation learning from expert demonstrations
|
| 7 |
+
- evaluation / replay
|
| 8 |
+
- simulator calibration
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any, Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
from models import (
|
| 19 |
+
ExperimentAction,
|
| 20 |
+
ExperimentObservation,
|
| 21 |
+
TaskSpec,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TrajectoryStep:
|
| 27 |
+
step_index: int
|
| 28 |
+
action: Dict[str, Any]
|
| 29 |
+
observation: Dict[str, Any]
|
| 30 |
+
reward: float
|
| 31 |
+
done: bool
|
| 32 |
+
reward_breakdown: Dict[str, float] = field(default_factory=dict)
|
| 33 |
+
latent_snapshot: Optional[Dict[str, Any]] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Trajectory:
|
| 38 |
+
"""Complete record of one environment episode."""
|
| 39 |
+
|
| 40 |
+
episode_id: str
|
| 41 |
+
task: Dict[str, Any]
|
| 42 |
+
steps: List[TrajectoryStep] = field(default_factory=list)
|
| 43 |
+
total_reward: float = 0.0
|
| 44 |
+
success: bool = False
|
| 45 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 46 |
+
|
| 47 |
+
# ββ construction helpers ββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
|
| 49 |
+
def add_step(
|
| 50 |
+
self,
|
| 51 |
+
action: ExperimentAction,
|
| 52 |
+
observation: ExperimentObservation,
|
| 53 |
+
reward: float,
|
| 54 |
+
done: bool,
|
| 55 |
+
reward_breakdown: Optional[Dict[str, float]] = None,
|
| 56 |
+
latent_snapshot: Optional[Dict[str, Any]] = None,
|
| 57 |
+
) -> None:
|
| 58 |
+
self.steps.append(TrajectoryStep(
|
| 59 |
+
step_index=len(self.steps),
|
| 60 |
+
action=action.model_dump(),
|
| 61 |
+
observation=observation.model_dump(),
|
| 62 |
+
reward=reward,
|
| 63 |
+
done=done,
|
| 64 |
+
reward_breakdown=reward_breakdown or {},
|
| 65 |
+
latent_snapshot=latent_snapshot,
|
| 66 |
+
))
|
| 67 |
+
self.total_reward += reward
|
| 68 |
+
if done:
|
| 69 |
+
self.success = reward > 0
|
| 70 |
+
|
| 71 |
+
# ββ serialisation βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
|
| 73 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 74 |
+
return {
|
| 75 |
+
"episode_id": self.episode_id,
|
| 76 |
+
"task": self.task,
|
| 77 |
+
"steps": [
|
| 78 |
+
{
|
| 79 |
+
"step_index": s.step_index,
|
| 80 |
+
"action": s.action,
|
| 81 |
+
"observation": s.observation,
|
| 82 |
+
"reward": s.reward,
|
| 83 |
+
"done": s.done,
|
| 84 |
+
"reward_breakdown": s.reward_breakdown,
|
| 85 |
+
"latent_snapshot": s.latent_snapshot,
|
| 86 |
+
}
|
| 87 |
+
for s in self.steps
|
| 88 |
+
],
|
| 89 |
+
"total_reward": self.total_reward,
|
| 90 |
+
"success": self.success,
|
| 91 |
+
"metadata": self.metadata,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def save(self, path: str | Path) -> None:
|
| 95 |
+
p = Path(path)
|
| 96 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
with open(p, "w") as f:
|
| 98 |
+
json.dump(self.to_dict(), f, indent=2, default=str)
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
def load(cls, path: str | Path) -> "Trajectory":
|
| 102 |
+
with open(path) as f:
|
| 103 |
+
d = json.load(f)
|
| 104 |
+
traj = cls(
|
| 105 |
+
episode_id=d["episode_id"],
|
| 106 |
+
task=d["task"],
|
| 107 |
+
total_reward=d.get("total_reward", 0.0),
|
| 108 |
+
success=d.get("success", False),
|
| 109 |
+
metadata=d.get("metadata", {}),
|
| 110 |
+
)
|
| 111 |
+
for s in d.get("steps", []):
|
| 112 |
+
traj.steps.append(TrajectoryStep(**s))
|
| 113 |
+
return traj
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TrajectoryDataset:
|
| 117 |
+
"""In-memory collection of trajectories with convenience accessors."""
|
| 118 |
+
|
| 119 |
+
def __init__(self, trajectories: Optional[List[Trajectory]] = None):
|
| 120 |
+
self.trajectories: List[Trajectory] = trajectories or []
|
| 121 |
+
|
| 122 |
+
def add(self, traj: Trajectory) -> None:
|
| 123 |
+
self.trajectories.append(traj)
|
| 124 |
+
|
| 125 |
+
def __len__(self) -> int:
|
| 126 |
+
return len(self.trajectories)
|
| 127 |
+
|
| 128 |
+
def __getitem__(self, idx: int) -> Trajectory:
|
| 129 |
+
return self.trajectories[idx]
|
| 130 |
+
|
| 131 |
+
def filter_successful(self) -> "TrajectoryDataset":
|
| 132 |
+
return TrajectoryDataset([t for t in self.trajectories if t.success])
|
| 133 |
+
|
| 134 |
+
def save_dir(self, directory: str | Path) -> None:
|
| 135 |
+
d = Path(directory)
|
| 136 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
for t in self.trajectories:
|
| 138 |
+
t.save(d / f"{t.episode_id}.json")
|
| 139 |
+
|
| 140 |
+
@classmethod
|
| 141 |
+
def load_dir(cls, directory: str | Path) -> "TrajectoryDataset":
|
| 142 |
+
d = Path(directory)
|
| 143 |
+
trajs = [Trajectory.load(p) for p in sorted(d.glob("*.json"))]
|
| 144 |
+
return cls(trajs)
|
| 145 |
+
|
| 146 |
+
def summary(self) -> Dict[str, Any]:
|
| 147 |
+
if not self.trajectories:
|
| 148 |
+
return {"n": 0}
|
| 149 |
+
rewards = [t.total_reward for t in self.trajectories]
|
| 150 |
+
lengths = [len(t.steps) for t in self.trajectories]
|
| 151 |
+
success_rate = sum(1 for t in self.trajectories if t.success) / len(self.trajectories)
|
| 152 |
+
return {
|
| 153 |
+
"n": len(self.trajectories),
|
| 154 |
+
"success_rate": success_rate,
|
| 155 |
+
"mean_reward": sum(rewards) / len(rewards),
|
| 156 |
+
"mean_length": sum(lengths) / len(lengths),
|
| 157 |
+
"max_reward": max(rewards),
|
| 158 |
+
"min_reward": min(rewards),
|
| 159 |
+
}
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|