Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Upload folder using huggingface_hub
Browse files- H100_JUPYTER_SETUP.md +199 -0
- README.md +24 -1
- _dashboard_state.json +234 -91
- dashboard.html +26 -5
- debug-904eee.log +10 -0
- inference.ipynb +166 -0
- pyproject.toml +1 -0
- run_agent.py +138 -26
- run_agent_unsloth.py +294 -0
- server/hackathon_environment.py +2 -2
- server/rewards/reward.py +51 -2
- server/rules/engine.py +24 -0
- tests/test_environment.py +33 -1
- tests/test_rewards.py +50 -1
- tests/test_rules.py +35 -14
- tests/test_run_agent.py +59 -4
- tests/test_training_script.py +10 -0
- train.ipynb +141 -0
- training_script.py +243 -22
- training_unsloth.py +439 -0
- uv.lock +0 -0
H100_JUPYTER_SETUP.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# H100 Jupyter Notebook Setup
|
| 2 |
+
|
| 3 |
+
This guide walks you through setting up the OpenEnv Bio Experiment environment on an **NVIDIA H100** Jupyter notebook instance (e.g., Jupiter Labs, Lambda Labs, RunPod, or similar).
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- **Python** ≥ 3.10 (3.10, 3.11, or 3.12 recommended)
|
| 8 |
+
- **uv** – fast Python package manager ([install instructions](#installing-uv))
|
| 9 |
+
- **NVIDIA driver** ≥ 535.104.05 (usually pre-installed on H100 instances)
|
| 10 |
+
- **CUDA** – H100 uses CUDA 12.x; PyTorch wheels bundle the runtime, so a separate CUDA Toolkit is not required
|
| 11 |
+
|
| 12 |
+
## Installing uv
|
| 13 |
+
|
| 14 |
+
If `uv` is not already installed:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
# Unix/Linux (including Jupiter notebook terminals)
|
| 18 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 19 |
+
|
| 20 |
+
# Or with pip
|
| 21 |
+
pip install uv
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Verify:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
uv --version
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Quick Setup (Recommended)
|
| 31 |
+
|
| 32 |
+
### 1. Clone and enter the project
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
git clone <repository-url> OpenENV-Hackathon
|
| 36 |
+
cd OpenENV-Hackathon
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### 2. Use uv's auto PyTorch backend
|
| 40 |
+
|
| 41 |
+
uv can detect your GPU and pick the right PyTorch build. For H100 (CUDA 12.x):
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# Install everything: core + training (TRL, transformers, torch, unsloth) + Jupyter
|
| 45 |
+
UV_TORCH_BACKEND=cu128 uv sync --extra train
|
| 46 |
+
|
| 47 |
+
# Add Jupyter kernel support
|
| 48 |
+
uv add ipykernel jupyter --extra train
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
If `UV_TORCH_BACKEND=cu128` fails (e.g., cu128 wheels not available yet), try:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
UV_TORCH_BACKEND=cu126 uv sync --extra train
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### 3. Register the environment as a Jupyter kernel
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
uv run python -m ipykernel install --user --name openenv-bio --display-name "OpenEnv Bio (H100)"
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### 4. Verify CUDA
|
| 64 |
+
|
| 65 |
+
In a new Jupyter notebook, select the **"OpenEnv Bio (H100)"** kernel and run:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
import torch
|
| 69 |
+
print(f"PyTorch: {torch.__version__}")
|
| 70 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 71 |
+
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Expected output (or similar):
|
| 75 |
+
|
| 76 |
+
```
|
| 77 |
+
PyTorch: 2.x.x+cu128
|
| 78 |
+
CUDA available: True
|
| 79 |
+
GPU: NVIDIA H100 ...
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### 5. Sanity check the environment
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## Manual PyTorch CUDA Configuration
|
| 89 |
+
|
| 90 |
+
If you need explicit control over the PyTorch index (e.g., for reproducibility), add the following to `pyproject.toml`:
|
| 91 |
+
|
| 92 |
+
### Add to `pyproject.toml`
|
| 93 |
+
|
| 94 |
+
```toml
|
| 95 |
+
# After [tool.uv], add:
|
| 96 |
+
|
| 97 |
+
[[tool.uv.index]]
|
| 98 |
+
name = "pytorch-cu128"
|
| 99 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 100 |
+
explicit = true
|
| 101 |
+
|
| 102 |
+
[tool.uv.sources]
|
| 103 |
+
torch = [{ index = "pytorch-cu128" }]
|
| 104 |
+
torchvision = [{ index = "pytorch-cu128" }]
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
Then run:
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
uv sync --extra train
|
| 111 |
+
uv add ipykernel jupyter --extra train
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
|
| 115 |
+
|
| 116 |
+
## Dependency Groups
|
| 117 |
+
|
| 118 |
+
| uv sync flag | Contents |
|
| 119 |
+
|-------------------|--------------------------------------------------------------------------|
|
| 120 |
+
| *(default)* | Core: `openenv-core`, `numpy`, `scipy`, `pydantic` |
|
| 121 |
+
| `--extra dev` | Testing: `pytest`, `pytest-cov` |
|
| 122 |
+
| `--extra train` | Training: `torch`, `transformers`, `trl`, `accelerate`, `peft`, `unsloth`, etc. |
|
| 123 |
+
| `--extra bio` | Bioinformatics: `scanpy`, `biopython`, `gseapy` |
|
| 124 |
+
| `--extra train --extra dev` | Combined for development + training |
|
| 125 |
+
|
| 126 |
+
## Preferred H100 Workflow
|
| 127 |
+
|
| 128 |
+
On H100, use the quantized Unsloth entrypoints:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
uv run python training_unsloth.py --dry-run
|
| 132 |
+
uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B --output-dir training/grpo-unsloth-output
|
| 133 |
+
uv run python run_agent_unsloth.py
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
The checked-in `inference.ipynb` notebook now uses `training_unsloth.py` helpers with 4-bit loading and fast inference enabled by default.
|
| 137 |
+
|
| 138 |
+
## Running Training in a Jupyter Notebook
|
| 139 |
+
|
| 140 |
+
Example cell:
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
# In a notebook with the OpenEnv Bio (H100) kernel
|
| 144 |
+
!uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B --dry-run
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Or run interactively from Python:
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
import subprocess
|
| 151 |
+
subprocess.run([
|
| 152 |
+
"uv", "run", "python", "training_unsloth.py",
|
| 153 |
+
"--model-id", "Qwen/Qwen3.5-4B",
|
| 154 |
+
"--output-dir", "training/grpo-unsloth-output",
|
| 155 |
+
], check=True)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Requirements Summary
|
| 159 |
+
|
| 160 |
+
| Component | Version / Notes |
|
| 161 |
+
|----------------|------------------------------------------------------|
|
| 162 |
+
| Python | ≥ 3.10 |
|
| 163 |
+
| uv | ≥ 0.5.3 (for PyTorch index support) |
|
| 164 |
+
| torch | ≥ 2.10.0 (cu128 or cu126 for H100) |
|
| 165 |
+
| transformers | ≥ 5.3.0 |
|
| 166 |
+
| trl | ≥ 0.29.0 |
|
| 167 |
+
| accelerate | ≥ 1.13.0 |
|
| 168 |
+
| Jupyter | Optional, for notebook workflows |
|
| 169 |
+
|
| 170 |
+
## Troubleshooting
|
| 171 |
+
|
| 172 |
+
### `torch.cuda.is_available()` is False
|
| 173 |
+
|
| 174 |
+
- Confirm the Jupyter kernel is the one where you ran `uv sync` (the one with `ipykernel`).
|
| 175 |
+
- Ensure no CPU-only PyTorch is overriding the CUDA build (e.g., from a different conda/pip env).
|
| 176 |
+
- Run `uv run python -c "import torch; print(torch.__file__)"` to verify PyTorch comes from your project venv.
|
| 177 |
+
|
| 178 |
+
### Flash Attention / causal-conv fallback warnings
|
| 179 |
+
|
| 180 |
+
These are common and usually harmless; execution continues with a slower path. For best H100 performance, ensure `transformers` and `torch` are recent versions that support Flash Attention 2.
|
| 181 |
+
|
| 182 |
+
### HuggingFace symlink warnings
|
| 183 |
+
|
| 184 |
+
Set:
|
| 185 |
+
|
| 186 |
+
```bash
|
| 187 |
+
export HF_HUB_DISABLE_SYMLINKS_WARNING=1
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
### Out-of-memory during training
|
| 191 |
+
|
| 192 |
+
- Reduce `--num-generations` or `--rollout-steps`.
|
| 193 |
+
- Use a smaller model (e.g., `Qwen/Qwen3.5-0.8B`) for experiments.
|
| 194 |
+
- Keep `--disable-4bit` off unless you explicitly need wider weights.
|
| 195 |
+
|
| 196 |
+
## See Also
|
| 197 |
+
|
| 198 |
+
- Main [README.md](README.md) for project overview, APIs, and usage
|
| 199 |
+
- [uv PyTorch guide](https://docs.astral.sh/uv/guides/integration/pytorch/) for advanced PyTorch configuration
|
README.md
CHANGED
|
@@ -249,6 +249,8 @@ An episode ends when one of the following happens:
|
|
| 249 |
|
| 250 |
Dependencies are managed with `uv`. The package requires Python ≥ 3.10.
|
| 251 |
|
|
|
|
|
|
|
| 252 |
```bash
|
| 253 |
# Core environment only
|
| 254 |
uv sync
|
|
@@ -342,6 +344,13 @@ The environment class supports concurrent sessions, but the bundled server is cu
|
|
| 342 |
uv run python run_agent.py
|
| 343 |
```
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
Configuration is via environment variables:
|
| 346 |
|
| 347 |
| Variable | Default | Description |
|
|
@@ -367,6 +376,14 @@ uv run python training_script.py --dry-run
|
|
| 367 |
uv run python training_script.py --model-id Qwen/Qwen3.5-0.8B
|
| 368 |
```
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
Key arguments:
|
| 371 |
|
| 372 |
| Argument | Default | Description |
|
|
@@ -381,13 +398,15 @@ Key arguments:
|
|
| 381 |
| `--scenario-name` | all | Repeatable; restricts which scenarios are used |
|
| 382 |
| `--domain-randomise` | off | Enable domain randomisation |
|
| 383 |
| `--num-generations` | `4` | GRPO generations per prompt |
|
| 384 |
-
| `--max-completion-length` | `
|
| 385 |
| `--max-prompt-length` | `768` | Max tokens for prompts |
|
| 386 |
| `--learning-rate` | `5e-6` | AdamW learning rate |
|
| 387 |
| `--dry-run` | off | Build data and test reward without training |
|
| 388 |
|
| 389 |
By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
|
| 390 |
|
|
|
|
|
|
|
| 391 |
After training, the script saves plots to the output directory:
|
| 392 |
|
| 393 |
- `training_loss.png`
|
|
@@ -413,7 +432,9 @@ This runs N episodes with a `random` or `heuristic` policy, saves JSON trajector
|
|
| 413 |
- `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
|
| 414 |
- `training/rollout_collection.py` collects direct environment rollouts into trajectory files
|
| 415 |
- `training_script.py` trains a GRPO policy with OpenEnv reward calls
|
|
|
|
| 416 |
- `run_agent.py` runs a local language model planner against the environment
|
|
|
|
| 417 |
- `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
|
| 418 |
- `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
|
| 419 |
|
|
@@ -488,6 +509,7 @@ That makes it suitable for:
|
|
| 488 |
├── openenv.yaml # OpenEnv platform deployment config
|
| 489 |
├── pyproject.toml # Package metadata and dependency groups
|
| 490 |
├── run_agent.py # Single-episode interactive agent runner
|
|
|
|
| 491 |
├── server/
|
| 492 |
│ ├── app.py # FastAPI/OpenEnv server entry point
|
| 493 |
│ ├── Dockerfile # Multi-stage Docker build
|
|
@@ -512,6 +534,7 @@ That makes it suitable for:
|
|
| 512 |
│ ├── rollout_collection.py # Direct rollout collection helper
|
| 513 |
│ └── trajectory.py # Trajectory serialization and dataset utilities
|
| 514 |
├── training_script.py # TRL GRPO training entry point
|
|
|
|
| 515 |
└── tests/
|
| 516 |
├── test_environment.py
|
| 517 |
├── test_literature_benchmark.py
|
|
|
|
| 249 |
|
| 250 |
Dependencies are managed with `uv`. The package requires Python ≥ 3.10.
|
| 251 |
|
| 252 |
+
> **H100 Jupyter notebook setup:** See [H100_JUPYTER_SETUP.md](H100_JUPYTER_SETUP.md) for environment setup on NVIDIA H100 instances with Jupyter.
|
| 253 |
+
|
| 254 |
```bash
|
| 255 |
# Core environment only
|
| 256 |
uv sync
|
|
|
|
| 344 |
uv run python run_agent.py
|
| 345 |
```
|
| 346 |
|
| 347 |
+
For H100 and other large-GPU workflows, prefer the quantized Unsloth path:
|
| 348 |
+
|
| 349 |
+
```bash
|
| 350 |
+
uv sync --extra train
|
| 351 |
+
uv run python run_agent_unsloth.py
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
Configuration is via environment variables:
|
| 355 |
|
| 356 |
| Variable | Default | Description |
|
|
|
|
| 376 |
uv run python training_script.py --model-id Qwen/Qwen3.5-0.8B
|
| 377 |
```
|
| 378 |
|
| 379 |
+
For H100, the preferred entrypoint is `training_unsloth.py`, which uses Unsloth 4-bit loading plus LoRA for faster quantized GRPO training:
|
| 380 |
+
|
| 381 |
+
```bash
|
| 382 |
+
uv sync --extra train
|
| 383 |
+
uv run python training_unsloth.py --dry-run
|
| 384 |
+
uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B
|
| 385 |
+
```
|
| 386 |
+
|
| 387 |
Key arguments:
|
| 388 |
|
| 389 |
| Argument | Default | Description |
|
|
|
|
| 398 |
| `--scenario-name` | all | Repeatable; restricts which scenarios are used |
|
| 399 |
| `--domain-randomise` | off | Enable domain randomisation |
|
| 400 |
| `--num-generations` | `4` | GRPO generations per prompt |
|
| 401 |
+
| `--max-completion-length` | `160` | Max tokens for model completions |
|
| 402 |
| `--max-prompt-length` | `768` | Max tokens for prompts |
|
| 403 |
| `--learning-rate` | `5e-6` | AdamW learning rate |
|
| 404 |
| `--dry-run` | off | Build data and test reward without training |
|
| 405 |
|
| 406 |
By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
|
| 407 |
|
| 408 |
+
`training_unsloth.py` adds H100-oriented options such as `--max-seq-length`, `--disable-4bit`, `--disable-fast-inference`, and LoRA settings (`--lora-r`, `--lora-alpha`, `--lora-dropout`).
|
| 409 |
+
|
| 410 |
After training, the script saves plots to the output directory:
|
| 411 |
|
| 412 |
- `training_loss.png`
|
|
|
|
| 432 |
- `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
|
| 433 |
- `training/rollout_collection.py` collects direct environment rollouts into trajectory files
|
| 434 |
- `training_script.py` trains a GRPO policy with OpenEnv reward calls
|
| 435 |
+
- `training_unsloth.py` trains a quantized GRPO policy with Unsloth on H100-class GPUs
|
| 436 |
- `run_agent.py` runs a local language model planner against the environment
|
| 437 |
+
- `run_agent_unsloth.py` runs the planner with Unsloth 4-bit loading for faster inference
|
| 438 |
- `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
|
| 439 |
- `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
|
| 440 |
|
|
|
|
| 509 |
├── openenv.yaml # OpenEnv platform deployment config
|
| 510 |
├── pyproject.toml # Package metadata and dependency groups
|
| 511 |
├── run_agent.py # Single-episode interactive agent runner
|
| 512 |
+
├── run_agent_unsloth.py # Quantized Unsloth interactive agent runner
|
| 513 |
├── server/
|
| 514 |
│ ├── app.py # FastAPI/OpenEnv server entry point
|
| 515 |
│ ├── Dockerfile # Multi-stage Docker build
|
|
|
|
| 534 |
│ ├── rollout_collection.py # Direct rollout collection helper
|
| 535 |
│ └── trajectory.py # Trajectory serialization and dataset utilities
|
| 536 |
├── training_script.py # TRL GRPO training entry point
|
| 537 |
+
├── training_unsloth.py # Unsloth quantized GRPO training entry point
|
| 538 |
└── tests/
|
| 539 |
├── test_environment.py
|
| 540 |
├── test_literature_benchmark.py
|
_dashboard_state.json
CHANGED
|
@@ -1,40 +1,40 @@
|
|
| 1 |
{
|
| 2 |
-
"timestamp":
|
| 3 |
-
"step":
|
| 4 |
"episode_done": false,
|
| 5 |
-
"cumulative_reward": 1.
|
| 6 |
-
"gen_time_s":
|
| 7 |
-
"model_response_raw": "{\n\"action_type\": \"
|
| 8 |
"model_thinking": "",
|
| 9 |
"thinking_enabled": false,
|
| 10 |
"task": {
|
| 11 |
-
"problem_statement": "
|
| 12 |
"organism": "human",
|
| 13 |
-
"tissue": "
|
| 14 |
"modality": "scRNA-seq",
|
| 15 |
"conditions": [
|
| 16 |
-
"
|
| 17 |
-
"
|
| 18 |
],
|
| 19 |
-
"budget_limit":
|
| 20 |
-
"time_limit_days":
|
| 21 |
},
|
| 22 |
"resources": {
|
| 23 |
-
"budget_used":
|
| 24 |
-
"budget_remaining":
|
| 25 |
-
"time_used_days":
|
| 26 |
-
"time_remaining_days":
|
| 27 |
"samples_consumed": 0,
|
| 28 |
-
"compute_hours_used":
|
| 29 |
},
|
| 30 |
"pipeline_history": [
|
| 31 |
{
|
| 32 |
"step_index": 1,
|
| 33 |
"action_type": "collect_sample",
|
| 34 |
"method": null,
|
| 35 |
-
"output_summary": "Collected 6 samples (quality=0.
|
| 36 |
"success": true,
|
| 37 |
-
"quality_score": 0.
|
| 38 |
"resource_cost": 5000.0,
|
| 39 |
"time_cost_days": 7.0
|
| 40 |
},
|
|
@@ -62,116 +62,259 @@
|
|
| 62 |
"step_index": 4,
|
| 63 |
"action_type": "prepare_library",
|
| 64 |
"method": null,
|
| 65 |
-
"output_summary": "Library prepared (complexity=0.
|
| 66 |
"success": true,
|
| 67 |
-
"quality_score": 0.
|
| 68 |
"resource_cost": 8000.0,
|
| 69 |
"time_cost_days": 3.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
}
|
| 71 |
],
|
| 72 |
"current_action": {
|
| 73 |
-
"action_type": "
|
| 74 |
"method": null,
|
| 75 |
"parameters": {
|
| 76 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
},
|
| 78 |
-
"justification": "
|
| 79 |
-
"confidence": 0.
|
| 80 |
},
|
| 81 |
"latest_output": {
|
| 82 |
-
"summary": "
|
| 83 |
-
"success":
|
| 84 |
-
"quality_score":
|
| 85 |
"uncertainty": 0.0,
|
| 86 |
"warnings": [],
|
| 87 |
-
"data_preview":
|
| 88 |
},
|
| 89 |
"discovered_markers": [],
|
| 90 |
"candidate_mechanisms": [],
|
| 91 |
-
"rule_violations": [
|
|
|
|
|
|
|
|
|
|
| 92 |
"uncertainty_summary": {
|
| 93 |
-
"avg_uncertainty": 0.
|
| 94 |
-
"avg_quality": 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
},
|
| 96 |
-
"reward_breakdown": {},
|
| 97 |
"latent": {
|
| 98 |
"cell_populations": [
|
| 99 |
{
|
| 100 |
-
"name": "
|
| 101 |
-
"proportion": 0.
|
| 102 |
-
"marker_genes": [
|
| 103 |
-
"COL1A1",
|
| 104 |
-
"FAP",
|
| 105 |
-
"THY1"
|
| 106 |
-
],
|
| 107 |
-
"state": "activated"
|
| 108 |
-
},
|
| 109 |
-
{
|
| 110 |
-
"name": "CD4_T_cell",
|
| 111 |
-
"proportion": 0.179,
|
| 112 |
"marker_genes": [
|
| 113 |
-
"
|
| 114 |
-
"
|
| 115 |
-
"
|
| 116 |
],
|
| 117 |
-
"state": "
|
| 118 |
},
|
| 119 |
{
|
| 120 |
-
"name": "
|
| 121 |
-
"proportion": 0.
|
| 122 |
"marker_genes": [
|
| 123 |
-
"
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
],
|
| 127 |
-
"state": "
|
| 128 |
},
|
| 129 |
{
|
| 130 |
-
"name": "
|
| 131 |
-
"proportion": 0.
|
| 132 |
"marker_genes": [
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
| 136 |
],
|
| 137 |
-
"state": "
|
| 138 |
},
|
| 139 |
{
|
| 140 |
"name": "endothelial",
|
| 141 |
-
"proportion": 0.
|
| 142 |
"marker_genes": [
|
| 143 |
-
"
|
| 144 |
-
"
|
|
|
|
| 145 |
],
|
| 146 |
"state": "quiescent"
|
| 147 |
}
|
| 148 |
],
|
| 149 |
"true_markers": [
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
-
"
|
| 153 |
-
"CXCL13"
|
| 154 |
],
|
| 155 |
"causal_mechanisms": [
|
| 156 |
-
"
|
| 157 |
-
"
|
|
|
|
| 158 |
],
|
| 159 |
"true_pathways": {
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
},
|
| 166 |
-
"true_de_genes_count":
|
| 167 |
-
"true_regulatory_network_size":
|
| 168 |
"confounders": {},
|
| 169 |
-
"n_true_cells":
|
| 170 |
"technical": {
|
| 171 |
-
"ambient_rna_fraction": 0.
|
| 172 |
-
"doublet_rate": 0.
|
| 173 |
-
"dropout_rate": 0.
|
| 174 |
-
"sample_quality": 0.
|
| 175 |
"library_complexity": 0.8,
|
| 176 |
"capture_efficiency": 0.6
|
| 177 |
},
|
|
@@ -181,24 +324,24 @@
|
|
| 181 |
"cells_cultured": false,
|
| 182 |
"library_prepared": true,
|
| 183 |
"perturbation_applied": false,
|
| 184 |
-
"cells_sequenced":
|
| 185 |
-
"qc_performed":
|
| 186 |
-
"data_filtered":
|
| 187 |
-
"data_normalized":
|
| 188 |
"batches_integrated": false,
|
| 189 |
-
"cells_clustered":
|
| 190 |
"de_performed": false,
|
| 191 |
"trajectories_inferred": false,
|
| 192 |
"pathways_analyzed": false,
|
| 193 |
"networks_inferred": false,
|
| 194 |
"markers_discovered": false,
|
| 195 |
"markers_validated": false,
|
| 196 |
-
"followup_designed":
|
| 197 |
"subagent_review_requested": false,
|
| 198 |
"conclusion_reached": false,
|
| 199 |
-
"n_cells_sequenced":
|
| 200 |
-
"n_cells_after_filter":
|
| 201 |
-
"n_clusters_found":
|
| 202 |
"n_de_genes_found": null,
|
| 203 |
"n_markers_found": null
|
| 204 |
},
|
|
|
|
| 1 |
{
|
| 2 |
+
"timestamp": 1772965877.7057643,
|
| 3 |
+
"step": 20,
|
| 4 |
"episode_done": false,
|
| 5 |
+
"cumulative_reward": -1.0267290925637838,
|
| 6 |
+
"gen_time_s": 28.83,
|
| 7 |
+
"model_response_raw": "{\n\"action_type\": \"marker_selection\",\n\"method\": \"\",\n\"parameters\": {\"\"diff_genes\"\": [],\"gene_lists_file\":\"DE_results.h5ad\"},\n\"Justification\": \"The 'violation' error indicates we cannot proceed directly from previous failed attempts due to incomplete filtering/normalization logic implied by the blockage condition ('without prior analysis'). However, looking strictly at the provided completed vs failing log chain where `design` was followed by `[OK] design`, but then immediate failures occurred (`synthesis_blocked`) which were resolved upon finding/deepening insights? Wait, re-evaluating strict fl",
|
| 8 |
"model_thinking": "",
|
| 9 |
"thinking_enabled": false,
|
| 10 |
"task": {
|
| 11 |
+
"problem_statement": "Validate candidate biomarker TREM2 for Alzheimer's disease in brain tissue using single-cell RNA sequencing.",
|
| 12 |
"organism": "human",
|
| 13 |
+
"tissue": "brain",
|
| 14 |
"modality": "scRNA-seq",
|
| 15 |
"conditions": [
|
| 16 |
+
"healthy",
|
| 17 |
+
"Alzheimer"
|
| 18 |
],
|
| 19 |
+
"budget_limit": 115264.15272286453,
|
| 20 |
+
"time_limit_days": 111.8768783273548
|
| 21 |
},
|
| 22 |
"resources": {
|
| 23 |
+
"budget_used": 28900.0,
|
| 24 |
+
"budget_remaining": 86364.15,
|
| 25 |
+
"time_used_days": 18.0,
|
| 26 |
+
"time_remaining_days": 93.9,
|
| 27 |
"samples_consumed": 0,
|
| 28 |
+
"compute_hours_used": 12.0
|
| 29 |
},
|
| 30 |
"pipeline_history": [
|
| 31 |
{
|
| 32 |
"step_index": 1,
|
| 33 |
"action_type": "collect_sample",
|
| 34 |
"method": null,
|
| 35 |
+
"output_summary": "Collected 6 samples (quality=0.57)",
|
| 36 |
"success": true,
|
| 37 |
+
"quality_score": 0.575,
|
| 38 |
"resource_cost": 5000.0,
|
| 39 |
"time_cost_days": 7.0
|
| 40 |
},
|
|
|
|
| 62 |
"step_index": 4,
|
| 63 |
"action_type": "prepare_library",
|
| 64 |
"method": null,
|
| 65 |
+
"output_summary": "Library prepared (complexity=0.77)",
|
| 66 |
"success": true,
|
| 67 |
+
"quality_score": 0.77,
|
| 68 |
"resource_cost": 8000.0,
|
| 69 |
"time_cost_days": 3.0
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"step_index": 5,
|
| 73 |
+
"action_type": "sequence_cells",
|
| 74 |
+
"method": null,
|
| 75 |
+
"output_summary": "Sequenced 12178 cells, 9490 genes detected, median UMI=2966",
|
| 76 |
+
"success": true,
|
| 77 |
+
"quality_score": 0.457,
|
| 78 |
+
"resource_cost": 15000.0,
|
| 79 |
+
"time_cost_days": 5.0
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"step_index": 6,
|
| 83 |
+
"action_type": "filter_data",
|
| 84 |
+
"method": null,
|
| 85 |
+
"output_summary": "Action blocked: Cannot filter data before QC",
|
| 86 |
+
"success": false,
|
| 87 |
+
"quality_score": 1.0,
|
| 88 |
+
"resource_cost": 50.0,
|
| 89 |
+
"time_cost_days": 0.2
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"step_index": 7,
|
| 93 |
+
"action_type": "run_qc",
|
| 94 |
+
"method": null,
|
| 95 |
+
"output_summary": "QC metrics computed",
|
| 96 |
+
"success": true,
|
| 97 |
+
"quality_score": 0.842,
|
| 98 |
+
"resource_cost": 100.0,
|
| 99 |
+
"time_cost_days": 0.5
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"step_index": 8,
|
| 103 |
+
"action_type": "filter_data",
|
| 104 |
+
"method": null,
|
| 105 |
+
"output_summary": "Filtered 12178 \u2192 9967 cells (82% retained)",
|
| 106 |
+
"success": true,
|
| 107 |
+
"quality_score": 0.818,
|
| 108 |
+
"resource_cost": 50.0,
|
| 109 |
+
"time_cost_days": 0.2
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"step_index": 9,
|
| 113 |
+
"action_type": "normalize_data",
|
| 114 |
+
"method": null,
|
| 115 |
+
"output_summary": "Normalized with log_normalize",
|
| 116 |
+
"success": true,
|
| 117 |
+
"quality_score": 1.0,
|
| 118 |
+
"resource_cost": 50.0,
|
| 119 |
+
"time_cost_days": 0.2
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"step_index": 10,
|
| 123 |
+
"action_type": "marker_selection",
|
| 124 |
+
"method": null,
|
| 125 |
+
"output_summary": "Action blocked: Cannot select markers without DE results",
|
| 126 |
+
"success": false,
|
| 127 |
+
"quality_score": 1.0,
|
| 128 |
+
"resource_cost": 100.0,
|
| 129 |
+
"time_cost_days": 0.5
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"step_index": 11,
|
| 133 |
+
"action_type": "design_followup_experiment",
|
| 134 |
+
"method": null,
|
| 135 |
+
"output_summary": "Action blocked: Follow-up design without prior analysis is blocked; complete wet-lab and computational steps first",
|
| 136 |
+
"success": false,
|
| 137 |
+
"quality_score": 1.0,
|
| 138 |
+
"resource_cost": 100.0,
|
| 139 |
+
"time_cost_days": 0.5
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"step_index": 12,
|
| 143 |
+
"action_type": "cluster_cells",
|
| 144 |
+
"method": null,
|
| 145 |
+
"output_summary": "Found 3 clusters",
|
| 146 |
+
"success": true,
|
| 147 |
+
"quality_score": 0.765,
|
| 148 |
+
"resource_cost": 100.0,
|
| 149 |
+
"time_cost_days": 0.5
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"step_index": 13,
|
| 153 |
+
"action_type": "design_followup_experiment",
|
| 154 |
+
"method": null,
|
| 155 |
+
"output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
|
| 156 |
+
"success": true,
|
| 157 |
+
"quality_score": 0.28,
|
| 158 |
+
"resource_cost": 100.0,
|
| 159 |
+
"time_cost_days": 0.5
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"step_index": 14,
|
| 163 |
+
"action_type": "synthesize_conclusion",
|
| 164 |
+
"method": null,
|
| 165 |
+
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
| 166 |
+
"success": false,
|
| 167 |
+
"quality_score": 1.0,
|
| 168 |
+
"resource_cost": 0.0,
|
| 169 |
+
"time_cost_days": 0.5
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"step_index": 15,
|
| 173 |
+
"action_type": "marker_selection",
|
| 174 |
+
"method": null,
|
| 175 |
+
"output_summary": "Action blocked: Cannot select markers without DE results",
|
| 176 |
+
"success": false,
|
| 177 |
+
"quality_score": 1.0,
|
| 178 |
+
"resource_cost": 100.0,
|
| 179 |
+
"time_cost_days": 0.5
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"step_index": 16,
|
| 183 |
+
"action_type": "synthesize_conclusion",
|
| 184 |
+
"method": null,
|
| 185 |
+
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
| 186 |
+
"success": false,
|
| 187 |
+
"quality_score": 1.0,
|
| 188 |
+
"resource_cost": 0.0,
|
| 189 |
+
"time_cost_days": 0.5
|
| 190 |
}
|
| 191 |
],
|
| 192 |
"current_action": {
|
| 193 |
+
"action_type": "synthesize_conclusion",
|
| 194 |
"method": null,
|
| 195 |
"parameters": {
|
| 196 |
+
"claims": [
|
| 197 |
+
{
|
| 198 |
+
"top_markers": [],
|
| 199 |
+
"causal_mechanisms": [],
|
| 200 |
+
"predicted_pathways": {},
|
| 201 |
+
"confidence": 0.5,
|
| 202 |
+
"claim_type": "correlational",
|
| 203 |
+
"claim": "forced terminal conclusion"
|
| 204 |
+
}
|
| 205 |
+
]
|
| 206 |
},
|
| 207 |
+
"justification": "forced terminal conclusion",
|
| 208 |
+
"confidence": 0.5
|
| 209 |
},
|
| 210 |
"latest_output": {
|
| 211 |
+
"summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion before inferring pathways or mechanisms",
|
| 212 |
+
"success": false,
|
| 213 |
+
"quality_score": 1.0,
|
| 214 |
"uncertainty": 0.0,
|
| 215 |
"warnings": [],
|
| 216 |
+
"data_preview": null
|
| 217 |
},
|
| 218 |
"discovered_markers": [],
|
| 219 |
"candidate_mechanisms": [],
|
| 220 |
+
"rule_violations": [
|
| 221 |
+
"Cannot synthesise conclusion before discovering or validating markers",
|
| 222 |
+
"Cannot synthesise conclusion before inferring pathways or mechanisms"
|
| 223 |
+
],
|
| 224 |
"uncertainty_summary": {
|
| 225 |
+
"avg_uncertainty": 0.194,
|
| 226 |
+
"avg_quality": 0.809
|
| 227 |
+
},
|
| 228 |
+
"reward_breakdown": {
|
| 229 |
+
"validity": -1.0,
|
| 230 |
+
"ordering": 0.0,
|
| 231 |
+
"info_gain": 0.0,
|
| 232 |
+
"efficiency": 0.0,
|
| 233 |
+
"novelty": 0.0,
|
| 234 |
+
"penalty": -1.0,
|
| 235 |
+
"shaping": 0.0,
|
| 236 |
+
"terminal": 0.0,
|
| 237 |
+
"total": -2.0,
|
| 238 |
+
"hard_violations": 2.0,
|
| 239 |
+
"term_validity": 0.0,
|
| 240 |
+
"term_ordering": 0.0,
|
| 241 |
+
"term_info_gain": 0.0,
|
| 242 |
+
"term_efficiency": 0.0,
|
| 243 |
+
"term_novelty": 0.0,
|
| 244 |
+
"term_penalty": 0.0,
|
| 245 |
+
"term_shaping": 0.0,
|
| 246 |
+
"term_terminal": 0.0,
|
| 247 |
+
"term_total": 0.0
|
| 248 |
},
|
|
|
|
| 249 |
"latent": {
|
| 250 |
"cell_populations": [
|
| 251 |
{
|
| 252 |
+
"name": "excitatory_neuron",
|
| 253 |
+
"proportion": 0.425,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
"marker_genes": [
|
| 255 |
+
"SLC17A7",
|
| 256 |
+
"CAMK2A",
|
| 257 |
+
"NRGN"
|
| 258 |
],
|
| 259 |
+
"state": "stressed"
|
| 260 |
},
|
| 261 |
{
|
| 262 |
+
"name": "inhibitory_neuron",
|
| 263 |
+
"proportion": 0.346,
|
| 264 |
"marker_genes": [
|
| 265 |
+
"GAD1",
|
| 266 |
+
"GAD2",
|
| 267 |
+
"SLC32A1"
|
| 268 |
],
|
| 269 |
+
"state": "normal"
|
| 270 |
},
|
| 271 |
{
|
| 272 |
+
"name": "OPC",
|
| 273 |
+
"proportion": 0.093,
|
| 274 |
"marker_genes": [
|
| 275 |
+
"PDGFRA",
|
| 276 |
+
"CSPG4",
|
| 277 |
+
"OLIG2"
|
| 278 |
],
|
| 279 |
+
"state": "progenitor"
|
| 280 |
},
|
| 281 |
{
|
| 282 |
"name": "endothelial",
|
| 283 |
+
"proportion": 0.137,
|
| 284 |
"marker_genes": [
|
| 285 |
+
"CLDN5",
|
| 286 |
+
"FLT1",
|
| 287 |
+
"PECAM1"
|
| 288 |
],
|
| 289 |
"state": "quiescent"
|
| 290 |
}
|
| 291 |
],
|
| 292 |
"true_markers": [
|
| 293 |
+
"TREM2",
|
| 294 |
+
"APOE",
|
| 295 |
+
"GFAP"
|
|
|
|
| 296 |
],
|
| 297 |
"causal_mechanisms": [
|
| 298 |
+
"TREM2-mediated microglial activation in amyloid clearance",
|
| 299 |
+
"complement-driven synaptic pruning",
|
| 300 |
+
"reactive astrogliosis amplifying neuroinflammation"
|
| 301 |
],
|
| 302 |
"true_pathways": {
|
| 303 |
+
"complement_cascade": 0.827,
|
| 304 |
+
"neuroinflammation": 0.666,
|
| 305 |
+
"amyloid_processing": 0.733,
|
| 306 |
+
"synaptic_signalling": 0.438,
|
| 307 |
+
"lipid_metabolism": 0.616
|
| 308 |
+
},
|
| 309 |
+
"true_de_genes_count": 10,
|
| 310 |
+
"true_regulatory_network_size": 0,
|
| 311 |
"confounders": {},
|
| 312 |
+
"n_true_cells": 20321,
|
| 313 |
"technical": {
|
| 314 |
+
"ambient_rna_fraction": 0.050723618495539344,
|
| 315 |
+
"doublet_rate": 0.0546771548836933,
|
| 316 |
+
"dropout_rate": 0.05122168063297322,
|
| 317 |
+
"sample_quality": 0.937985596833521,
|
| 318 |
"library_complexity": 0.8,
|
| 319 |
"capture_efficiency": 0.6
|
| 320 |
},
|
|
|
|
| 324 |
"cells_cultured": false,
|
| 325 |
"library_prepared": true,
|
| 326 |
"perturbation_applied": false,
|
| 327 |
+
"cells_sequenced": true,
|
| 328 |
+
"qc_performed": true,
|
| 329 |
+
"data_filtered": true,
|
| 330 |
+
"data_normalized": true,
|
| 331 |
"batches_integrated": false,
|
| 332 |
+
"cells_clustered": true,
|
| 333 |
"de_performed": false,
|
| 334 |
"trajectories_inferred": false,
|
| 335 |
"pathways_analyzed": false,
|
| 336 |
"networks_inferred": false,
|
| 337 |
"markers_discovered": false,
|
| 338 |
"markers_validated": false,
|
| 339 |
+
"followup_designed": true,
|
| 340 |
"subagent_review_requested": false,
|
| 341 |
"conclusion_reached": false,
|
| 342 |
+
"n_cells_sequenced": 12178,
|
| 343 |
+
"n_cells_after_filter": 9967,
|
| 344 |
+
"n_clusters_found": "3",
|
| 345 |
"n_de_genes_found": null,
|
| 346 |
"n_markers_found": null
|
| 347 |
},
|
dashboard.html
CHANGED
|
@@ -304,6 +304,20 @@ function esc(s) { if (s == null) return '—'; const d = document.createElement(
|
|
| 304 |
function pct(used, total) { if (!total) return 0; return Math.min(100, Math.max(0, (used / total) * 100)); }
|
| 305 |
function gaugeColor(p) { return p < 50 ? 'var(--green)' : p < 80 ? 'var(--amber)' : 'var(--red)'; }
|
| 306 |
function fmt(n) { if (n == null) return '0'; return Number(n).toLocaleString('en-US', { maximumFractionDigits: 0 }); }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
function gauge(label, value, pctVal, inv) {
|
| 308 |
let bar = '';
|
| 309 |
if (pctVal != null) { const c = inv ? gaugeColor(100-pctVal) : gaugeColor(pctVal); bar = `<div class="gauge-bar"><div class="gauge-bar-fill" style="width:${pctVal.toFixed(1)}%;background:${c}"></div></div>`; }
|
|
@@ -360,7 +374,10 @@ function showReport() {
|
|
| 360 |
const conc = s.conclusions || [];
|
| 361 |
const trueM = lat.true_markers || [];
|
| 362 |
const trueMech = lat.causal_mechanisms || [];
|
| 363 |
-
const
|
|
|
|
|
|
|
|
|
|
| 364 |
const markerHits = agentM.filter(m => trueM.some(t => t.toUpperCase() === m.toUpperCase()));
|
| 365 |
const r = s.resources || {};
|
| 366 |
|
|
@@ -393,7 +410,7 @@ function showReport() {
|
|
| 393 |
html += `<div class="report-section"><h3>Ground Truth Comparison</h3>
|
| 394 |
<div class="comparison-row"><div class="comparison-col"><h4>Agent's Markers</h4><div class="tag-list">${comparedTags(agentM, trueM, 'green')}</div></div>
|
| 395 |
<div class="comparison-col"><h4>True Markers</h4><div class="tag-list">${tagsHTML(trueM,'green')}</div></div></div>
|
| 396 |
-
<div class="comparison-row"><div class="comparison-col"><h4>Agent's Mechanisms</h4><div class="tag-list">${comparedTags(
|
| 397 |
<div class="comparison-col"><h4>True Mechanisms</h4><div class="tag-list">${tagsHTML(trueMech,'pink')}</div></div></div>
|
| 398 |
</div>`;
|
| 399 |
|
|
@@ -473,12 +490,16 @@ function renderState(s) {
|
|
| 473 |
// Ground truth comparison (visible when done or has conclusions)
|
| 474 |
const lat = s.latent;
|
| 475 |
if ((s.episode_done || conc.length) && lat) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
$('card-gt-comparison').style.display = '';
|
| 477 |
-
setHTML('gt-agent-markers', comparedTags(
|
| 478 |
setHTML('gt-true-markers', tagsHTML(lat.true_markers, 'green'));
|
| 479 |
-
setHTML('gt-agent-mechs', comparedTags(
|
| 480 |
setHTML('gt-true-mechs', tagsHTML(lat.causal_mechanisms, 'pink'));
|
| 481 |
-
const hits =
|
| 482 |
$('gt-score').innerHTML = `Marker accuracy: <span style="color:var(--accent)">${hits.length}</span> / ${(lat.true_markers||[]).length} true markers recovered`;
|
| 483 |
} else { $('card-gt-comparison').style.display = 'none'; }
|
| 484 |
|
|
|
|
| 304 |
function pct(used, total) { if (!total) return 0; return Math.min(100, Math.max(0, (used / total) * 100)); }
|
| 305 |
function gaugeColor(p) { return p < 50 ? 'var(--green)' : p < 80 ? 'var(--amber)' : 'var(--red)'; }
|
| 306 |
function fmt(n) { if (n == null) return '0'; return Number(n).toLocaleString('en-US', { maximumFractionDigits: 0 }); }
|
| 307 |
+
function uniqueItems(arr) {
|
| 308 |
+
const out = [];
|
| 309 |
+
const seen = new Set();
|
| 310 |
+
(arr || []).forEach(item => {
|
| 311 |
+
if (item == null) return;
|
| 312 |
+
const text = String(item).trim();
|
| 313 |
+
if (!text) return;
|
| 314 |
+
const key = text.toUpperCase();
|
| 315 |
+
if (seen.has(key)) return;
|
| 316 |
+
seen.add(key);
|
| 317 |
+
out.push(text);
|
| 318 |
+
});
|
| 319 |
+
return out;
|
| 320 |
+
}
|
| 321 |
function gauge(label, value, pctVal, inv) {
|
| 322 |
let bar = '';
|
| 323 |
if (pctVal != null) { const c = inv ? gaugeColor(100-pctVal) : gaugeColor(pctVal); bar = `<div class="gauge-bar"><div class="gauge-bar-fill" style="width:${pctVal.toFixed(1)}%;background:${c}"></div></div>`; }
|
|
|
|
| 374 |
const conc = s.conclusions || [];
|
| 375 |
const trueM = lat.true_markers || [];
|
| 376 |
const trueMech = lat.causal_mechanisms || [];
|
| 377 |
+
const conclusionMarkers = uniqueItems(conc.flatMap(c => c.top_markers || []));
|
| 378 |
+
const conclusionMechanisms = uniqueItems(conc.flatMap(c => c.causal_mechanisms || []));
|
| 379 |
+
const agentM = uniqueItems((s.discovered_markers && s.discovered_markers.length) ? s.discovered_markers : conclusionMarkers);
|
| 380 |
+
const agentMechanisms = uniqueItems((s.candidate_mechanisms && s.candidate_mechanisms.length) ? s.candidate_mechanisms : conclusionMechanisms);
|
| 381 |
const markerHits = agentM.filter(m => trueM.some(t => t.toUpperCase() === m.toUpperCase()));
|
| 382 |
const r = s.resources || {};
|
| 383 |
|
|
|
|
| 410 |
html += `<div class="report-section"><h3>Ground Truth Comparison</h3>
|
| 411 |
<div class="comparison-row"><div class="comparison-col"><h4>Agent's Markers</h4><div class="tag-list">${comparedTags(agentM, trueM, 'green')}</div></div>
|
| 412 |
<div class="comparison-col"><h4>True Markers</h4><div class="tag-list">${tagsHTML(trueM,'green')}</div></div></div>
|
| 413 |
+
<div class="comparison-row"><div class="comparison-col"><h4>Agent's Mechanisms</h4><div class="tag-list">${comparedTags(agentMechanisms, trueMech, 'pink')}</div></div>
|
| 414 |
<div class="comparison-col"><h4>True Mechanisms</h4><div class="tag-list">${tagsHTML(trueMech,'pink')}</div></div></div>
|
| 415 |
</div>`;
|
| 416 |
|
|
|
|
| 490 |
// Ground truth comparison (visible when done or has conclusions)
|
| 491 |
const lat = s.latent;
|
| 492 |
if ((s.episode_done || conc.length) && lat) {
|
| 493 |
+
const conclusionMarkers = uniqueItems(conc.flatMap(c => c.top_markers || []));
|
| 494 |
+
const conclusionMechanisms = uniqueItems(conc.flatMap(c => c.causal_mechanisms || []));
|
| 495 |
+
const comparisonMarkers = uniqueItems((s.discovered_markers && s.discovered_markers.length) ? s.discovered_markers : conclusionMarkers);
|
| 496 |
+
const comparisonMechanisms = uniqueItems((s.candidate_mechanisms && s.candidate_mechanisms.length) ? s.candidate_mechanisms : conclusionMechanisms);
|
| 497 |
$('card-gt-comparison').style.display = '';
|
| 498 |
+
setHTML('gt-agent-markers', comparedTags(comparisonMarkers, lat.true_markers, 'green'));
|
| 499 |
setHTML('gt-true-markers', tagsHTML(lat.true_markers, 'green'));
|
| 500 |
+
setHTML('gt-agent-mechs', comparedTags(comparisonMechanisms, lat.causal_mechanisms, 'pink'));
|
| 501 |
setHTML('gt-true-mechs', tagsHTML(lat.causal_mechanisms, 'pink'));
|
| 502 |
+
const hits = comparisonMarkers.filter(m => (lat.true_markers||[]).some(t => t.toUpperCase()===m.toUpperCase()));
|
| 503 |
$('gt-score').innerHTML = `Marker accuracy: <span style="color:var(--accent)">${hits.length}</span> / ${(lat.true_markers||[]).length} true markers recovered`;
|
| 504 |
} else { $('card-gt-comparison').style.display = 'none'; }
|
| 505 |
|
debug-904eee.log
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "n studies. Sequencing ensures uniform coverage which aids robustness during subsequent filtering.\",\n\"Confidence\": 0.75\n}", "repaired_tail": "n studies. Sequencing ensures uniform coverage which aids robustness during subsequent filtering.\",\n\"Confidence\": 0.75\n}", "hypothesisId": "H1"}, "timestamp": 1772961980164}
|
| 2 |
+
{"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "ssues common in multi-condition studies. Sequencing ensures uniform coverage which aids robustness during subsequent filtering.\",\n\"Confidence\": 0.75\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772961980165}
|
| 3 |
+
{"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "ia unsupervised learning using dimension reduction techniques like PCA/UMAP from normalized data.\",\n\"Confidence\": 0.95\n}", "repaired_tail": "a unsupervised learning using dimension reduction techniques like PCA/UMAP from normalized data.\",\n\"Confidence\": 0.95\n}\"", "hypothesisId": "H1"}, "timestamp": 1772962029639}
|
| 4 |
+
{"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "inct cell types/states first via unsupervised learning using dimension reduction techniques like PCA/UMAP from normalized data.\",\n\"Confidence\": 0.95\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962029640}
|
| 5 |
+
{"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "usters via downsampling/filtering based on read depth and genome coverage thresholds is mandatory.\",\n\"Confidence\": 0.9\n}", "repaired_tail": "usters via downsampling/filtering based on read depth and genome coverage thresholds is mandatory.\",\n\"Confidence\": 0.9\n}", "hypothesisId": "H1"}, "timestamp": 1772962075204}
|
| 6 |
+
{"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "ighly degraded/contaminated clusters via downsampling/filtering based on read depth and genome coverage thresholds is mandatory.\",\n\"Confidence\": 0.9\n}", "repair_returned": false, "last_json_err": "Expecting value: line 2 column 36 (char 37)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962075205}
|
| 7 |
+
{"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "factors rather than dynamic processes like maturation pathways required for trajectory inference.\",\n\"Confidence\": 0.95\n}", "repaired_tail": "factors rather than dynamic processes like maturation pathways required for trajectory inference.\",\n\"Confidence\": 0.95\n}", "hypothesisId": "H1"}, "timestamp": 1772962101951}
|
| 8 |
+
{"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "identify static transcription factors rather than dynamic processes like maturation pathways required for trajectory inference.\",\n\"Confidence\": 0.95\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962101951}
|
| 9 |
+
{"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "s unique to intermediate progenitors versus terminally differentiated cells within those clusters.\",\n\"Confidence\": 0.9\n}", "repaired_tail": "s unique to intermediate progenitors versus terminally differentiated cells within those clusters.\",\n\"Confidence\": 0.9\n}", "hypothesisId": "H1"}, "timestamp": 1772962135915}
|
| 10 |
+
{"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "ndance transcriptional features unique to intermediate progenitors versus terminally differentiated cells within those clusters.\",\n\"Confidence\": 0.9\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962135916}
|
inference.ipynb
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Self-Driving Lab Inference on H100 With Unsloth\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
|
| 10 |
+
],
|
| 11 |
+
"id": "a9d34036"
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"source": [
|
| 17 |
+
"%pip install -q -U torch transformers unsloth"
|
| 18 |
+
],
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"id": "20b36e01"
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": [
|
| 27 |
+
"import json\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"from training_script import format_observation\n",
|
| 32 |
+
"from training_unsloth import generate_action_with_model, load_model_artifacts\n",
|
| 33 |
+
"from server.hackathon_environment import BioExperimentEnvironment\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"print(\"CUDA available:\", torch.cuda.is_available())\n",
|
| 36 |
+
"if torch.cuda.is_available():\n",
|
| 37 |
+
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
|
| 38 |
+
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
|
| 39 |
+
],
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"id": "bcf24a2e"
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"source": [
|
| 48 |
+
"MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n",
|
| 49 |
+
"SCENARIO_NAME = \"cardiac_disease_de\"\n",
|
| 50 |
+
"SEED = 42\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"tokenizer, model = load_model_artifacts(\n",
|
| 53 |
+
" MODEL_PATH,\n",
|
| 54 |
+
" trust_remote_code=True,\n",
|
| 55 |
+
" max_seq_length=2048,\n",
|
| 56 |
+
" load_in_4bit=True,\n",
|
| 57 |
+
" fast_inference=True,\n",
|
| 58 |
+
" prepare_for_inference=True,\n",
|
| 59 |
+
")\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
|
| 62 |
+
"obs = env.reset(seed=SEED)\n",
|
| 63 |
+
"print(format_observation(obs)[:3000])"
|
| 64 |
+
],
|
| 65 |
+
"execution_count": null,
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"id": "c54f2cfd"
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"result = generate_action_with_model(\n",
|
| 74 |
+
" model,\n",
|
| 75 |
+
" tokenizer,\n",
|
| 76 |
+
" obs,\n",
|
| 77 |
+
" max_new_tokens=160,\n",
|
| 78 |
+
" temperature=0.2,\n",
|
| 79 |
+
" top_p=0.9,\n",
|
| 80 |
+
" do_sample=True,\n",
|
| 81 |
+
")\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"print(\"Model response:\\n\")\n",
|
| 84 |
+
"print(result[\"response_text\"])\n",
|
| 85 |
+
"print(\"\\nParsed action:\\n\")\n",
|
| 86 |
+
"result[\"action\"].model_dump() if result[\"action\"] is not None else None"
|
| 87 |
+
],
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"outputs": [],
|
| 90 |
+
"id": "f9b25208"
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"source": [
|
| 96 |
+
"if result[\"action\"] is not None:\n",
|
| 97 |
+
" next_obs = env.step(result[\"action\"])\n",
|
| 98 |
+
" print(\"Reward:\", next_obs.reward)\n",
|
| 99 |
+
" print(\"Done:\", next_obs.done)\n",
|
| 100 |
+
" print(\"Violations:\", next_obs.rule_violations)\n",
|
| 101 |
+
" print(\"Markers:\", next_obs.discovered_markers[:5])\n",
|
| 102 |
+
" print(\"Mechanisms:\", next_obs.candidate_mechanisms[:5])\n",
|
| 103 |
+
" if next_obs.latest_output is not None:\n",
|
| 104 |
+
" print(\"Summary:\", next_obs.latest_output.summary)\n",
|
| 105 |
+
" print(\"Latest data preview:\")\n",
|
| 106 |
+
" print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
|
| 107 |
+
"else:\n",
|
| 108 |
+
" print(\"Model output did not parse into an ExperimentAction.\")"
|
| 109 |
+
],
|
| 110 |
+
"execution_count": null,
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"id": "c2408f52"
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"source": [
|
| 118 |
+
"# Optional short closed-loop rollout.\n",
|
| 119 |
+
"obs = env.reset(seed=7)\n",
|
| 120 |
+
"trajectory = []\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"for step_idx in range(8):\n",
|
| 123 |
+
" result = generate_action_with_model(model, tokenizer, obs, max_new_tokens=160)\n",
|
| 124 |
+
" action = result[\"action\"]\n",
|
| 125 |
+
" record = {\n",
|
| 126 |
+
" \"step\": step_idx + 1,\n",
|
| 127 |
+
" \"response_text\": result[\"response_text\"],\n",
|
| 128 |
+
" \"action\": action.model_dump() if action is not None else None,\n",
|
| 129 |
+
" }\n",
|
| 130 |
+
" trajectory.append(record)\n",
|
| 131 |
+
" if action is None:\n",
|
| 132 |
+
" break\n",
|
| 133 |
+
"\n",
|
| 134 |
+
" next_obs = env.step(action)\n",
|
| 135 |
+
" record.update({\n",
|
| 136 |
+
" \"reward\": next_obs.reward,\n",
|
| 137 |
+
" \"done\": next_obs.done,\n",
|
| 138 |
+
" \"violations\": list(next_obs.rule_violations),\n",
|
| 139 |
+
" \"latest_summary\": next_obs.latest_output.summary if next_obs.latest_output is not None else None,\n",
|
| 140 |
+
" \"discovered_markers\": list(next_obs.discovered_markers[:5]),\n",
|
| 141 |
+
" \"candidate_mechanisms\": list(next_obs.candidate_mechanisms[:5]),\n",
|
| 142 |
+
" })\n",
|
| 143 |
+
" obs = next_obs\n",
|
| 144 |
+
" if obs.done:\n",
|
| 145 |
+
" break\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"trajectory"
|
| 148 |
+
],
|
| 149 |
+
"execution_count": null,
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"id": "8af34f32"
|
| 152 |
+
}
|
| 153 |
+
],
|
| 154 |
+
"metadata": {
|
| 155 |
+
"kernelspec": {
|
| 156 |
+
"display_name": "Python 3",
|
| 157 |
+
"language": "python",
|
| 158 |
+
"name": "python3"
|
| 159 |
+
},
|
| 160 |
+
"language_info": {
|
| 161 |
+
"name": "python"
|
| 162 |
+
}
|
| 163 |
+
},
|
| 164 |
+
"nbformat": 4,
|
| 165 |
+
"nbformat_minor": 5
|
| 166 |
+
}
|
pyproject.toml
CHANGED
|
@@ -39,6 +39,7 @@ train = [
|
|
| 39 |
"torch>=2.10.0",
|
| 40 |
"transformers>=5.3.0",
|
| 41 |
"trl>=0.29.0",
|
|
|
|
| 42 |
]
|
| 43 |
|
| 44 |
[project.scripts]
|
|
|
|
| 39 |
"torch>=2.10.0",
|
| 40 |
"transformers>=5.3.0",
|
| 41 |
"trl>=0.29.0",
|
| 42 |
+
"unsloth",
|
| 43 |
]
|
| 44 |
|
| 45 |
[project.scripts]
|
run_agent.py
CHANGED
|
@@ -16,6 +16,7 @@ from models import (
|
|
| 16 |
ActionType,
|
| 17 |
ExperimentAction,
|
| 18 |
ExperimentObservation,
|
|
|
|
| 19 |
build_agent_observation_context,
|
| 20 |
build_agent_system_prompt,
|
| 21 |
)
|
|
@@ -205,6 +206,13 @@ def _strip_js_comments(text: str) -> str:
|
|
| 205 |
|
| 206 |
def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
|
| 207 |
stripped = _normalize_jsonish_text(text).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
fence_prefix = "```"
|
| 209 |
if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
|
| 210 |
lines = stripped.splitlines()
|
|
@@ -226,6 +234,7 @@ def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
|
|
| 226 |
break
|
| 227 |
start = stripped.find("{", start + 1)
|
| 228 |
|
|
|
|
| 229 |
first_brace = stripped.find("{")
|
| 230 |
if first_brace != -1:
|
| 231 |
repaired = _repair_truncated_json(stripped[first_brace:])
|
|
@@ -365,34 +374,34 @@ def parse_action(text: str) -> Optional[ExperimentAction]:
|
|
| 365 |
if d is not None:
|
| 366 |
action_type = normalize_action_type(get_payload_value(d, "action_type"))
|
| 367 |
if action_type is None:
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
|
| 397 |
action_match = re.search(
|
| 398 |
r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
|
|
@@ -472,6 +481,107 @@ def should_force_terminal_conclusion(
|
|
| 472 |
)
|
| 473 |
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
def write_dashboard_state(
|
| 477 |
env: BioExperimentEnvironment,
|
|
@@ -888,6 +998,8 @@ def main():
|
|
| 888 |
confidence=action.confidence,
|
| 889 |
)
|
| 890 |
|
|
|
|
|
|
|
| 891 |
log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
|
| 892 |
if thinking:
|
| 893 |
log(f" Thinking: {thinking[:200]}")
|
|
|
|
| 16 |
ActionType,
|
| 17 |
ExperimentAction,
|
| 18 |
ExperimentObservation,
|
| 19 |
+
OutputType,
|
| 20 |
build_agent_observation_context,
|
| 21 |
build_agent_system_prompt,
|
| 22 |
)
|
|
|
|
| 206 |
|
| 207 |
def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
|
| 208 |
stripped = _normalize_jsonish_text(text).strip()
|
| 209 |
+
if stripped.startswith('"') and stripped.endswith('"'):
|
| 210 |
+
try:
|
| 211 |
+
unwrapped = json.loads(stripped)
|
| 212 |
+
except json.JSONDecodeError:
|
| 213 |
+
unwrapped = None
|
| 214 |
+
if isinstance(unwrapped, str):
|
| 215 |
+
stripped = _normalize_jsonish_text(unwrapped).strip()
|
| 216 |
fence_prefix = "```"
|
| 217 |
if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
|
| 218 |
lines = stripped.splitlines()
|
|
|
|
| 234 |
break
|
| 235 |
start = stripped.find("{", start + 1)
|
| 236 |
|
| 237 |
+
repaired = None
|
| 238 |
first_brace = stripped.find("{")
|
| 239 |
if first_brace != -1:
|
| 240 |
repaired = _repair_truncated_json(stripped[first_brace:])
|
|
|
|
| 374 |
if d is not None:
|
| 375 |
action_type = normalize_action_type(get_payload_value(d, "action_type"))
|
| 376 |
if action_type is None:
|
| 377 |
+
pass
|
| 378 |
+
else:
|
| 379 |
+
parameters = get_payload_value(d, "parameters", "params") or {}
|
| 380 |
+
if not isinstance(parameters, dict):
|
| 381 |
+
parameters = {}
|
| 382 |
|
| 383 |
+
confidence = get_payload_value(d, "confidence")
|
| 384 |
+
if confidence is None:
|
| 385 |
+
confidence = 0.5
|
| 386 |
+
try:
|
| 387 |
+
confidence = float(confidence)
|
| 388 |
+
except (TypeError, ValueError):
|
| 389 |
+
confidence = 0.5
|
| 390 |
|
| 391 |
+
justification = get_payload_value(
|
| 392 |
+
d, "justification", "justifyement", "reasoning", "rationale", "reason"
|
| 393 |
+
)
|
| 394 |
+
if justification is not None and not isinstance(justification, str):
|
| 395 |
+
justification = compact_preview(justification, 200)
|
| 396 |
+
method = normalize_optional_string(get_payload_value(d, "method"))
|
| 397 |
+
|
| 398 |
+
return ExperimentAction(
|
| 399 |
+
action_type=ActionType(action_type),
|
| 400 |
+
method=method,
|
| 401 |
+
parameters=parameters,
|
| 402 |
+
justification=justification,
|
| 403 |
+
confidence=min(1.0, max(0.0, confidence)),
|
| 404 |
+
)
|
| 405 |
|
| 406 |
action_match = re.search(
|
| 407 |
r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
|
|
|
|
| 481 |
)
|
| 482 |
|
| 483 |
|
| 484 |
+
def _unique_nonempty(items: List[str], limit: int = 5) -> List[str]:
|
| 485 |
+
seen: set[str] = set()
|
| 486 |
+
result: List[str] = []
|
| 487 |
+
for raw in items:
|
| 488 |
+
value = normalize_optional_string(raw)
|
| 489 |
+
if not value:
|
| 490 |
+
continue
|
| 491 |
+
key = value.upper()
|
| 492 |
+
if key in seen:
|
| 493 |
+
continue
|
| 494 |
+
seen.add(key)
|
| 495 |
+
result.append(value)
|
| 496 |
+
if len(result) >= limit:
|
| 497 |
+
break
|
| 498 |
+
return result
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def _infer_conclusion_evidence(
|
| 502 |
+
obs: ExperimentObservation,
|
| 503 |
+
) -> tuple[List[str], List[str], Dict[str, float]]:
|
| 504 |
+
top_markers = _unique_nonempty(list(obs.discovered_markers), limit=5)
|
| 505 |
+
causal_mechanisms = _unique_nonempty(list(obs.candidate_mechanisms), limit=5)
|
| 506 |
+
predicted_pathways: Dict[str, float] = {}
|
| 507 |
+
|
| 508 |
+
for output in reversed(obs.all_outputs):
|
| 509 |
+
if not output.success:
|
| 510 |
+
continue
|
| 511 |
+
|
| 512 |
+
data = output.data or {}
|
| 513 |
+
if not top_markers:
|
| 514 |
+
if output.output_type == OutputType.MARKER_RESULT:
|
| 515 |
+
top_markers = _unique_nonempty(list(data.get("markers", [])), limit=5)
|
| 516 |
+
elif output.output_type == OutputType.DE_RESULT:
|
| 517 |
+
top_markers = _unique_nonempty(
|
| 518 |
+
[item.get("gene") for item in data.get("top_genes", []) if isinstance(item, dict)],
|
| 519 |
+
limit=5,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if output.output_type == OutputType.PATHWAY_RESULT and not predicted_pathways:
|
| 523 |
+
for item in data.get("top_pathways", []):
|
| 524 |
+
if not isinstance(item, dict):
|
| 525 |
+
continue
|
| 526 |
+
pathway = normalize_optional_string(item.get("pathway"))
|
| 527 |
+
score = item.get("score")
|
| 528 |
+
if pathway and isinstance(score, (int, float)):
|
| 529 |
+
predicted_pathways[pathway] = float(score)
|
| 530 |
+
if len(predicted_pathways) >= 5:
|
| 531 |
+
break
|
| 532 |
+
|
| 533 |
+
if not causal_mechanisms:
|
| 534 |
+
if output.output_type == OutputType.PATHWAY_RESULT:
|
| 535 |
+
causal_mechanisms = _unique_nonempty(
|
| 536 |
+
[item.get("pathway") for item in data.get("top_pathways", []) if isinstance(item, dict)],
|
| 537 |
+
limit=5,
|
| 538 |
+
)
|
| 539 |
+
elif output.output_type == OutputType.NETWORK_RESULT:
|
| 540 |
+
causal_mechanisms = _unique_nonempty(
|
| 541 |
+
list(data.get("top_regulators", [])),
|
| 542 |
+
limit=5,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if top_markers and causal_mechanisms and predicted_pathways:
|
| 546 |
+
break
|
| 547 |
+
|
| 548 |
+
return top_markers, causal_mechanisms, predicted_pathways
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def ensure_conclusion_claims(
|
| 552 |
+
obs: ExperimentObservation,
|
| 553 |
+
action: ExperimentAction,
|
| 554 |
+
) -> ExperimentAction:
|
| 555 |
+
if action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
|
| 556 |
+
return action
|
| 557 |
+
|
| 558 |
+
parameters = dict(action.parameters or {})
|
| 559 |
+
raw_claims = parameters.get("claims")
|
| 560 |
+
if isinstance(raw_claims, list) and raw_claims:
|
| 561 |
+
normalized_claims = [claim for claim in raw_claims if isinstance(claim, dict)]
|
| 562 |
+
if normalized_claims:
|
| 563 |
+
parameters["claims"] = normalized_claims
|
| 564 |
+
if parameters != action.parameters:
|
| 565 |
+
return action.model_copy(update={"parameters": parameters})
|
| 566 |
+
return action
|
| 567 |
+
|
| 568 |
+
top_markers, causal_mechanisms, predicted_pathways = _infer_conclusion_evidence(obs)
|
| 569 |
+
claim_type = "causal" if causal_mechanisms else "correlational"
|
| 570 |
+
conditions = " vs ".join(obs.task.conditions[:2]) if obs.task.conditions else "the task conditions"
|
| 571 |
+
claim = action.justification or f"Final synthesis for {conditions}."
|
| 572 |
+
|
| 573 |
+
parameters["claims"] = [{
|
| 574 |
+
"top_markers": top_markers,
|
| 575 |
+
"causal_mechanisms": causal_mechanisms,
|
| 576 |
+
"predicted_pathways": predicted_pathways,
|
| 577 |
+
"confidence": action.confidence,
|
| 578 |
+
"claim_type": claim_type,
|
| 579 |
+
"claim": claim,
|
| 580 |
+
}]
|
| 581 |
+
if not action.justification:
|
| 582 |
+
action = action.model_copy(update={"justification": claim})
|
| 583 |
+
return action.model_copy(update={"parameters": parameters})
|
| 584 |
+
|
| 585 |
|
| 586 |
def write_dashboard_state(
|
| 587 |
env: BioExperimentEnvironment,
|
|
|
|
| 998 |
confidence=action.confidence,
|
| 999 |
)
|
| 1000 |
|
| 1001 |
+
action = ensure_conclusion_claims(obs, action)
|
| 1002 |
+
|
| 1003 |
log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
|
| 1004 |
if thinking:
|
| 1005 |
log(f" Thinking: {thinking[:200]}")
|
run_agent_unsloth.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the bio-experiment environment with a quantized Unsloth model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
|
| 10 |
+
from models import ActionType, ExperimentAction
|
| 11 |
+
from server.hackathon_environment import BioExperimentEnvironment
|
| 12 |
+
from training_unsloth import (
|
| 13 |
+
DEFAULT_MAX_SEQ_LENGTH,
|
| 14 |
+
generate_action_with_model,
|
| 15 |
+
load_model_artifacts,
|
| 16 |
+
)
|
| 17 |
+
from training_script import DEFAULT_COMPLETION_TOKEN_BUDGET
|
| 18 |
+
|
| 19 |
+
import run_agent as base
|
| 20 |
+
|
| 21 |
+
MODEL_ID = os.getenv("RUN_AGENT_UNSLOTH_MODEL_ID", "unsloth/Qwen3.5-2B-GGUF")
|
| 22 |
+
MAX_EPISODE_STEPS = int(
|
| 23 |
+
os.getenv("RUN_AGENT_UNSLOTH_MAX_EPISODE_STEPS", str(base.MAX_EPISODE_STEPS))
|
| 24 |
+
)
|
| 25 |
+
MAX_NEW_TOKENS = int(
|
| 26 |
+
os.getenv(
|
| 27 |
+
"RUN_AGENT_UNSLOTH_MAX_NEW_TOKENS",
|
| 28 |
+
str(DEFAULT_COMPLETION_TOKEN_BUDGET),
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
MAX_SEQ_LENGTH = int(
|
| 32 |
+
os.getenv("RUN_AGENT_UNSLOTH_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH))
|
| 33 |
+
)
|
| 34 |
+
TRUST_REMOTE_CODE = (
|
| 35 |
+
os.getenv("RUN_AGENT_UNSLOTH_TRUST_REMOTE_CODE", "1").strip().lower()
|
| 36 |
+
not in {"0", "false", "off"}
|
| 37 |
+
)
|
| 38 |
+
LOAD_IN_4BIT = (
|
| 39 |
+
os.getenv("RUN_AGENT_UNSLOTH_LOAD_IN_4BIT", "1").strip().lower()
|
| 40 |
+
not in {"0", "false", "off"}
|
| 41 |
+
)
|
| 42 |
+
FAST_INFERENCE = (
|
| 43 |
+
os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "1").strip().lower()
|
| 44 |
+
not in {"0", "false", "off"}
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def check_dashboard_command() -> Optional[Dict[str, Any]]:
|
| 49 |
+
try:
|
| 50 |
+
raw = base.DASHBOARD_CMD_PATH.read_text(encoding="utf-8")
|
| 51 |
+
base.DASHBOARD_CMD_PATH.unlink(missing_ok=True)
|
| 52 |
+
return json.loads(raw)
|
| 53 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_episode(
|
| 58 |
+
model: Any,
|
| 59 |
+
tokenizer: Any,
|
| 60 |
+
*,
|
| 61 |
+
scenario_name: Optional[str] = None,
|
| 62 |
+
custom_ground_truth: Optional[Dict[str, Any]] = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
env = BioExperimentEnvironment(scenario_name=scenario_name)
|
| 65 |
+
obs = env.reset()
|
| 66 |
+
|
| 67 |
+
if custom_ground_truth and env._latent:
|
| 68 |
+
gt = custom_ground_truth
|
| 69 |
+
bio = env._latent.biology
|
| 70 |
+
if gt.get("true_markers"):
|
| 71 |
+
bio.true_markers = gt["true_markers"]
|
| 72 |
+
if gt.get("causal_mechanisms"):
|
| 73 |
+
bio.causal_mechanisms = gt["causal_mechanisms"]
|
| 74 |
+
if gt.get("true_pathways"):
|
| 75 |
+
bio.true_pathways = {
|
| 76 |
+
key: float(value) for key, value in gt["true_pathways"].items()
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
base.log("\n" + "=" * 70)
|
| 80 |
+
base.log(f"TASK: {obs.task.problem_statement}")
|
| 81 |
+
base.log(f"Conditions: {obs.task.conditions}")
|
| 82 |
+
base.log(
|
| 83 |
+
f"Budget: ${obs.task.budget_limit:,.0f} | "
|
| 84 |
+
f"Time: {obs.task.time_limit_days:.0f} days"
|
| 85 |
+
)
|
| 86 |
+
base.log("Runtime: Unsloth quantized generation")
|
| 87 |
+
base.log("=" * 70)
|
| 88 |
+
|
| 89 |
+
cumulative_reward = 0.0
|
| 90 |
+
base.write_dashboard_state(env, obs, step=0, cumulative_reward=0.0)
|
| 91 |
+
|
| 92 |
+
for step in range(MAX_EPISODE_STEPS):
|
| 93 |
+
cmd = check_dashboard_command()
|
| 94 |
+
if cmd and cmd.get("action") == "restart":
|
| 95 |
+
base.log("\n[DASHBOARD] Restart requested - ending episode early.")
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
t0 = time.time()
|
| 99 |
+
result = generate_action_with_model(
|
| 100 |
+
model,
|
| 101 |
+
tokenizer,
|
| 102 |
+
obs,
|
| 103 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 104 |
+
temperature=0.2,
|
| 105 |
+
top_p=0.9,
|
| 106 |
+
do_sample=True,
|
| 107 |
+
)
|
| 108 |
+
response = result["response_text"]
|
| 109 |
+
action = result["action"]
|
| 110 |
+
gen_time = time.time() - t0
|
| 111 |
+
|
| 112 |
+
is_last_step = step == MAX_EPISODE_STEPS - 1
|
| 113 |
+
if action is None:
|
| 114 |
+
if is_last_step:
|
| 115 |
+
base.log("\n [!] Parse failed on final step - forcing synthesize_conclusion.")
|
| 116 |
+
action = ExperimentAction(
|
| 117 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 118 |
+
justification="forced terminal conclusion",
|
| 119 |
+
confidence=0.5,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
base.log(
|
| 123 |
+
f"\n [!] Parse failed, skipping step. Raw: {response[:150]}"
|
| 124 |
+
)
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
completed_types = {
|
| 128 |
+
record.action_type for record in obs.pipeline_history if record.success
|
| 129 |
+
}
|
| 130 |
+
failed_types = {
|
| 131 |
+
record.action_type for record in obs.pipeline_history if not record.success
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
if base.should_force_terminal_conclusion(action, completed_types):
|
| 135 |
+
base.log(
|
| 136 |
+
f"\n [!] repeated completed meta step {action.action_type.value} "
|
| 137 |
+
f"- forcing synthesize_conclusion."
|
| 138 |
+
)
|
| 139 |
+
action = ExperimentAction(
|
| 140 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 141 |
+
justification="repeated completed meta step forced terminal conclusion",
|
| 142 |
+
confidence=action.confidence,
|
| 143 |
+
)
|
| 144 |
+
completed_types = {
|
| 145 |
+
record.action_type for record in obs.pipeline_history if record.success
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
skip_reason = None
|
| 149 |
+
if action.action_type in completed_types:
|
| 150 |
+
skip_reason = f"blocked repeat of completed step {action.action_type.value}"
|
| 151 |
+
elif action.action_type in failed_types:
|
| 152 |
+
if base.should_block_failed_reattempt(obs.pipeline_history, action.action_type):
|
| 153 |
+
skip_reason = (
|
| 154 |
+
f"blocked re-attempt of failed step {action.action_type.value}"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if skip_reason:
|
| 158 |
+
if is_last_step:
|
| 159 |
+
base.log(
|
| 160 |
+
f"\n [!] {skip_reason} on final step - forcing synthesize_conclusion."
|
| 161 |
+
)
|
| 162 |
+
action = ExperimentAction(
|
| 163 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 164 |
+
justification="forced terminal conclusion",
|
| 165 |
+
confidence=0.5,
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
base.log(f"\n [!] {skip_reason}, skipping step.")
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
if is_last_step and action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
|
| 172 |
+
base.log(
|
| 173 |
+
f"\n [!] Final step - overriding {action.action_type.value} "
|
| 174 |
+
"with synthesize_conclusion."
|
| 175 |
+
)
|
| 176 |
+
action = ExperimentAction(
|
| 177 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 178 |
+
justification="forced terminal conclusion",
|
| 179 |
+
confidence=action.confidence,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
action = base.ensure_conclusion_claims(obs, action)
|
| 183 |
+
|
| 184 |
+
base.log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
|
| 185 |
+
if action.justification:
|
| 186 |
+
base.log(f" Rationale: {action.justification}")
|
| 187 |
+
else:
|
| 188 |
+
base.log(" Rationale: [model did not provide one]")
|
| 189 |
+
if action.parameters:
|
| 190 |
+
base.log(f" Parameters: {base.compact_preview(action.parameters, 200)}")
|
| 191 |
+
elif response:
|
| 192 |
+
base.log(
|
| 193 |
+
" Model response: "
|
| 194 |
+
f"{base.compact_preview(response, base.MODEL_RESPONSE_PREVIEW_CHARS)}"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
obs = env.step(action)
|
| 198 |
+
|
| 199 |
+
if obs.latest_output:
|
| 200 |
+
latest_output = obs.latest_output
|
| 201 |
+
status = "OK" if latest_output.success else "FAIL"
|
| 202 |
+
base.log(f" [{status}] {latest_output.summary}")
|
| 203 |
+
if latest_output.warnings:
|
| 204 |
+
base.log(f" Warnings: {latest_output.warnings}")
|
| 205 |
+
|
| 206 |
+
step_reward = obs.reward
|
| 207 |
+
cumulative_reward += step_reward
|
| 208 |
+
base.log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
|
| 209 |
+
base.log(
|
| 210 |
+
f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | "
|
| 211 |
+
f"Time: {obs.resource_usage.time_remaining_days:.0f}d"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
base.write_dashboard_state(
|
| 215 |
+
env,
|
| 216 |
+
obs,
|
| 217 |
+
step=step + 1,
|
| 218 |
+
cumulative_reward=cumulative_reward,
|
| 219 |
+
model_response=response,
|
| 220 |
+
action=action,
|
| 221 |
+
gen_time=gen_time,
|
| 222 |
+
episode_done=obs.done,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if obs.rule_violations:
|
| 226 |
+
base.log(f" Violations: {obs.rule_violations}")
|
| 227 |
+
if obs.done:
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
base.log(f"\n{'=' * 70}")
|
| 231 |
+
base.log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
|
| 232 |
+
base.log(f" Steps: {obs.step_index}")
|
| 233 |
+
base.log(f" Total reward: {cumulative_reward:+.3f}")
|
| 234 |
+
base.log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
|
| 235 |
+
base.log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
|
| 236 |
+
if obs.conclusions:
|
| 237 |
+
base.log(" Conclusions:")
|
| 238 |
+
for conclusion in obs.conclusions:
|
| 239 |
+
base.log(
|
| 240 |
+
f" [{conclusion.claim_type}, conf={conclusion.confidence:.2f}] "
|
| 241 |
+
f"{conclusion.claim}"
|
| 242 |
+
)
|
| 243 |
+
if conclusion.top_markers:
|
| 244 |
+
base.log(f" Markers: {conclusion.top_markers}")
|
| 245 |
+
if conclusion.causal_mechanisms:
|
| 246 |
+
base.log(f" Mechanisms: {conclusion.causal_mechanisms}")
|
| 247 |
+
if conclusion.predicted_pathways:
|
| 248 |
+
base.log(f" Pathways: {conclusion.predicted_pathways}")
|
| 249 |
+
base.log("=" * 70)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def main() -> None:
|
| 253 |
+
runtime = base.resolve_torch_runtime()
|
| 254 |
+
base.log(
|
| 255 |
+
f"Using Unsloth runtime: device={runtime['device']} "
|
| 256 |
+
f"name={runtime['device_name']} dtype={runtime['dtype']}"
|
| 257 |
+
)
|
| 258 |
+
tokenizer, model = load_model_artifacts(
|
| 259 |
+
MODEL_ID,
|
| 260 |
+
trust_remote_code=TRUST_REMOTE_CODE,
|
| 261 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 262 |
+
load_in_4bit=LOAD_IN_4BIT,
|
| 263 |
+
fast_inference=FAST_INFERENCE,
|
| 264 |
+
prepare_for_inference=True,
|
| 265 |
+
)
|
| 266 |
+
base.DASHBOARD_CMD_PATH.unlink(missing_ok=True)
|
| 267 |
+
run_episode(model, tokenizer)
|
| 268 |
+
|
| 269 |
+
while True:
|
| 270 |
+
base.log("\nWaiting for dashboard command (restart / new task) ...")
|
| 271 |
+
while True:
|
| 272 |
+
cmd = check_dashboard_command()
|
| 273 |
+
if cmd:
|
| 274 |
+
break
|
| 275 |
+
time.sleep(1.0)
|
| 276 |
+
|
| 277 |
+
action_type = cmd.get("action", "restart")
|
| 278 |
+
if action_type == "quit":
|
| 279 |
+
base.log("Quit requested.")
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
scenario = cmd.get("scenario_name")
|
| 283 |
+
ground_truth = cmd.get("ground_truth")
|
| 284 |
+
base.log(f"\n[DASHBOARD] {action_type} - scenario={scenario}")
|
| 285 |
+
run_episode(
|
| 286 |
+
model,
|
| 287 |
+
tokenizer,
|
| 288 |
+
scenario_name=scenario,
|
| 289 |
+
custom_ground_truth=ground_truth,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
main()
|
server/hackathon_environment.py
CHANGED
|
@@ -132,7 +132,7 @@ class BioExperimentEnvironment(Environment):
|
|
| 132 |
self._outputs.append(result.output)
|
| 133 |
self._update_discoveries(action, result.output)
|
| 134 |
|
| 135 |
-
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 136 |
raw_claims = action.parameters.get("claims", [])
|
| 137 |
for c in raw_claims:
|
| 138 |
if isinstance(c, dict):
|
|
@@ -218,7 +218,7 @@ class BioExperimentEnvironment(Environment):
|
|
| 218 |
subagent_outputs=list(self._subagent_outputs),
|
| 219 |
conclusions=list(self._conclusions),
|
| 220 |
rule_violations=rule_violations or [],
|
| 221 |
-
step_reward_breakdown={},
|
| 222 |
done=done,
|
| 223 |
reward=reward,
|
| 224 |
metadata=meta,
|
|
|
|
| 132 |
self._outputs.append(result.output)
|
| 133 |
self._update_discoveries(action, result.output)
|
| 134 |
|
| 135 |
+
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION and result.output.success:
|
| 136 |
raw_claims = action.parameters.get("claims", [])
|
| 137 |
for c in raw_claims:
|
| 138 |
if isinstance(c, dict):
|
|
|
|
| 218 |
subagent_outputs=list(self._subagent_outputs),
|
| 219 |
conclusions=list(self._conclusions),
|
| 220 |
rule_violations=rule_violations or [],
|
| 221 |
+
step_reward_breakdown=reward_breakdown or {},
|
| 222 |
done=done,
|
| 223 |
reward=reward,
|
| 224 |
metadata=meta,
|
server/rewards/reward.py
CHANGED
|
@@ -24,7 +24,7 @@ The terminal reward adds:
|
|
| 24 |
from __future__ import annotations
|
| 25 |
|
| 26 |
from dataclasses import dataclass, field
|
| 27 |
-
from typing import
|
| 28 |
|
| 29 |
from models import (
|
| 30 |
ActionType,
|
|
@@ -214,10 +214,19 @@ class RewardComputer:
|
|
| 214 |
discovered_markers,
|
| 215 |
candidate_mechanisms,
|
| 216 |
)
|
| 217 |
-
discovery_error_penalty = -
|
|
|
|
|
|
|
| 218 |
rb.components["discovery_alignment"] = discovery_alignment
|
| 219 |
rb.components["discovery_error_penalty"] = discovery_error_penalty
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
|
| 222 |
rb.terminal = (
|
| 223 |
3.0 * completeness
|
|
@@ -225,6 +234,7 @@ class RewardComputer:
|
|
| 225 |
+ 1.0 * eff_bonus
|
| 226 |
+ overconf
|
| 227 |
+ discovery_error_penalty
|
|
|
|
| 228 |
)
|
| 229 |
return rb
|
| 230 |
|
|
@@ -470,3 +480,42 @@ class RewardComputer:
|
|
| 470 |
if not components:
|
| 471 |
return 1.0
|
| 472 |
return sum(components) / len(components)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from __future__ import annotations
|
| 25 |
|
| 26 |
from dataclasses import dataclass, field
|
| 27 |
+
from typing import Dict, List, Optional
|
| 28 |
|
| 29 |
from models import (
|
| 30 |
ActionType,
|
|
|
|
| 214 |
discovered_markers,
|
| 215 |
candidate_mechanisms,
|
| 216 |
)
|
| 217 |
+
discovery_error_penalty = -6.0 * (1.0 - discovery_alignment)
|
| 218 |
+
if discovery_alignment < 0.25:
|
| 219 |
+
discovery_error_penalty -= 2.0
|
| 220 |
rb.components["discovery_alignment"] = discovery_alignment
|
| 221 |
rb.components["discovery_error_penalty"] = discovery_error_penalty
|
| 222 |
|
| 223 |
+
conclusion_alignment = self._conclusion_alignment(state, conclusions)
|
| 224 |
+
conclusion_error_penalty = -4.0 * (1.0 - conclusion_alignment)
|
| 225 |
+
if conclusions and conclusion_alignment < 0.25:
|
| 226 |
+
conclusion_error_penalty -= 1.5
|
| 227 |
+
rb.components["conclusion_alignment"] = conclusion_alignment
|
| 228 |
+
rb.components["conclusion_error_penalty"] = conclusion_error_penalty
|
| 229 |
+
|
| 230 |
eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
|
| 231 |
rb.terminal = (
|
| 232 |
3.0 * completeness
|
|
|
|
| 234 |
+ 1.0 * eff_bonus
|
| 235 |
+ overconf
|
| 236 |
+ discovery_error_penalty
|
| 237 |
+
+ conclusion_error_penalty
|
| 238 |
)
|
| 239 |
return rb
|
| 240 |
|
|
|
|
| 480 |
if not components:
|
| 481 |
return 1.0
|
| 482 |
return sum(components) / len(components)
|
| 483 |
+
|
| 484 |
+
def _conclusion_alignment(
|
| 485 |
+
self,
|
| 486 |
+
s: FullLatentState,
|
| 487 |
+
conclusions: List[ConclusionClaim],
|
| 488 |
+
) -> float:
|
| 489 |
+
if not conclusions:
|
| 490 |
+
return 0.0
|
| 491 |
+
|
| 492 |
+
pred_markers = [marker for conclusion in conclusions for marker in conclusion.top_markers]
|
| 493 |
+
pred_mechanisms = [
|
| 494 |
+
mechanism
|
| 495 |
+
for conclusion in conclusions
|
| 496 |
+
for mechanism in conclusion.causal_mechanisms
|
| 497 |
+
]
|
| 498 |
+
|
| 499 |
+
if not pred_markers and not pred_mechanisms:
|
| 500 |
+
return self._legacy_calibration(s, conclusions)
|
| 501 |
+
|
| 502 |
+
components: List[float] = []
|
| 503 |
+
if s.biology.true_markers or pred_markers:
|
| 504 |
+
marker_recall = marker_set_score(pred_markers, s.biology.true_markers)
|
| 505 |
+
marker_precision = marker_set_score(s.biology.true_markers, pred_markers)
|
| 506 |
+
components.append((marker_recall + marker_precision) / 2.0)
|
| 507 |
+
|
| 508 |
+
if s.biology.causal_mechanisms or pred_mechanisms:
|
| 509 |
+
mechanism_recall = mechanism_set_score(
|
| 510 |
+
pred_mechanisms,
|
| 511 |
+
s.biology.causal_mechanisms,
|
| 512 |
+
)
|
| 513 |
+
mechanism_precision = mechanism_set_score(
|
| 514 |
+
s.biology.causal_mechanisms,
|
| 515 |
+
pred_mechanisms,
|
| 516 |
+
)
|
| 517 |
+
components.append((mechanism_recall + mechanism_precision) / 2.0)
|
| 518 |
+
|
| 519 |
+
if not components:
|
| 520 |
+
return 1.0
|
| 521 |
+
return sum(components) / len(components)
|
server/rules/engine.py
CHANGED
|
@@ -45,6 +45,16 @@ class RuleEngine:
|
|
| 45 |
p.markers_validated,
|
| 46 |
])
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def check(
|
| 49 |
self, action: ExperimentAction, state: FullLatentState
|
| 50 |
) -> List[RuleViolation]:
|
|
@@ -238,6 +248,20 @@ class RuleEngine:
|
|
| 238 |
message="Cannot synthesise conclusion without substantive analysis",
|
| 239 |
))
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
claims = action.parameters.get("claims", [])
|
| 242 |
for claim in claims:
|
| 243 |
if isinstance(claim, dict) and claim.get("claim_type") == "causal":
|
|
|
|
| 45 |
p.markers_validated,
|
| 46 |
])
|
| 47 |
|
| 48 |
+
@staticmethod
|
| 49 |
+
def _has_marker_evidence(s: FullLatentState) -> bool:
|
| 50 |
+
p = s.progress
|
| 51 |
+
return p.markers_discovered or p.markers_validated
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _has_mechanism_evidence(s: FullLatentState) -> bool:
|
| 55 |
+
p = s.progress
|
| 56 |
+
return p.pathways_analyzed or p.networks_inferred
|
| 57 |
+
|
| 58 |
def check(
|
| 59 |
self, action: ExperimentAction, state: FullLatentState
|
| 60 |
) -> List[RuleViolation]:
|
|
|
|
| 248 |
message="Cannot synthesise conclusion without substantive analysis",
|
| 249 |
))
|
| 250 |
|
| 251 |
+
if not self._has_marker_evidence(s):
|
| 252 |
+
vs.append(RuleViolation(
|
| 253 |
+
rule_id="conclusion_without_marker_evidence",
|
| 254 |
+
severity=Severity.HARD,
|
| 255 |
+
message="Cannot synthesise conclusion before discovering or validating markers",
|
| 256 |
+
))
|
| 257 |
+
|
| 258 |
+
if not self._has_mechanism_evidence(s):
|
| 259 |
+
vs.append(RuleViolation(
|
| 260 |
+
rule_id="conclusion_without_mechanism_evidence",
|
| 261 |
+
severity=Severity.HARD,
|
| 262 |
+
message="Cannot synthesise conclusion before inferring pathways or mechanisms",
|
| 263 |
+
))
|
| 264 |
+
|
| 265 |
claims = action.parameters.get("claims", [])
|
| 266 |
for claim in claims:
|
| 267 |
if isinstance(claim, dict) and claim.get("claim_type") == "causal":
|
tests/test_environment.py
CHANGED
|
@@ -64,7 +64,7 @@ class TestEnvironmentLifecycle:
|
|
| 64 |
parameters={"assay": "qPCR"},
|
| 65 |
))
|
| 66 |
assert obs.latest_output is not None
|
| 67 |
-
assert obs.latest_output.success is
|
| 68 |
assert any("follow-up design" in msg.lower() for msg in obs.rule_violations)
|
| 69 |
|
| 70 |
def test_conclusion_ends_episode(self):
|
|
@@ -81,6 +81,8 @@ class TestEnvironmentLifecycle:
|
|
| 81 |
ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
|
| 82 |
ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 83 |
parameters={"comparison": "disease_vs_healthy"}),
|
|
|
|
|
|
|
| 84 |
ExperimentAction(
|
| 85 |
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 86 |
parameters={"claims": [
|
|
@@ -94,3 +96,33 @@ class TestEnvironmentLifecycle:
|
|
| 94 |
|
| 95 |
assert obs.done is True
|
| 96 |
assert obs.reward != 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
parameters={"assay": "qPCR"},
|
| 65 |
))
|
| 66 |
assert obs.latest_output is not None
|
| 67 |
+
assert obs.latest_output.success is False
|
| 68 |
assert any("follow-up design" in msg.lower() for msg in obs.rule_violations)
|
| 69 |
|
| 70 |
def test_conclusion_ends_episode(self):
|
|
|
|
| 81 |
ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
|
| 82 |
ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 83 |
parameters={"comparison": "disease_vs_healthy"}),
|
| 84 |
+
ExperimentAction(action_type=ActionType.PATHWAY_ENRICHMENT),
|
| 85 |
+
ExperimentAction(action_type=ActionType.MARKER_SELECTION),
|
| 86 |
ExperimentAction(
|
| 87 |
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 88 |
parameters={"claims": [
|
|
|
|
| 96 |
|
| 97 |
assert obs.done is True
|
| 98 |
assert obs.reward != 0.0
|
| 99 |
+
|
| 100 |
+
def test_blocked_conclusion_does_not_persist_claims(self):
|
| 101 |
+
env = BioExperimentEnvironment()
|
| 102 |
+
env.reset()
|
| 103 |
+
|
| 104 |
+
pipeline = [
|
| 105 |
+
ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
|
| 106 |
+
ExperimentAction(action_type=ActionType.PREPARE_LIBRARY),
|
| 107 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 108 |
+
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 109 |
+
ExperimentAction(action_type=ActionType.FILTER_DATA),
|
| 110 |
+
ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
|
| 111 |
+
ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
|
| 112 |
+
]
|
| 113 |
+
for action in pipeline:
|
| 114 |
+
obs = env.step(action)
|
| 115 |
+
assert obs.latest_output is not None
|
| 116 |
+
assert obs.latest_output.success is True
|
| 117 |
+
|
| 118 |
+
obs = env.step(ExperimentAction(
|
| 119 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 120 |
+
parameters={"claims": [
|
| 121 |
+
{"claim": "Premature conclusion", "confidence": 0.9},
|
| 122 |
+
]},
|
| 123 |
+
))
|
| 124 |
+
|
| 125 |
+
assert obs.latest_output is not None
|
| 126 |
+
assert obs.latest_output.success is False
|
| 127 |
+
assert obs.conclusions == []
|
| 128 |
+
assert any("markers" in msg.lower() for msg in obs.rule_violations)
|
tests/test_rewards.py
CHANGED
|
@@ -108,7 +108,13 @@ class TestTerminalReward:
|
|
| 108 |
claim_type="causal",
|
| 109 |
),
|
| 110 |
]
|
| 111 |
-
rb = rc.terminal_reward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
assert rb.terminal > 0
|
| 113 |
|
| 114 |
def test_overconfident_wrong_claim_penalised(self):
|
|
@@ -165,3 +171,46 @@ class TestTerminalReward:
|
|
| 165 |
assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
|
| 166 |
assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
|
| 167 |
assert aligned.terminal > misaligned.terminal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
claim_type="causal",
|
| 109 |
),
|
| 110 |
]
|
| 111 |
+
rb = rc.terminal_reward(
|
| 112 |
+
state,
|
| 113 |
+
claims,
|
| 114 |
+
[],
|
| 115 |
+
discovered_markers=["NPPA"],
|
| 116 |
+
candidate_mechanisms=["TGF-beta-driven fibrosis"],
|
| 117 |
+
)
|
| 118 |
assert rb.terminal > 0
|
| 119 |
|
| 120 |
def test_overconfident_wrong_claim_penalised(self):
|
|
|
|
| 171 |
assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
|
| 172 |
assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
|
| 173 |
assert aligned.terminal > misaligned.terminal
|
| 174 |
+
|
| 175 |
+
def test_conclusion_error_penalizes_wrong_structured_claims(self):
|
| 176 |
+
rc = RewardComputer()
|
| 177 |
+
state = FullLatentState(
|
| 178 |
+
biology=LatentBiologicalState(
|
| 179 |
+
true_markers=["NPPA", "NPPB"],
|
| 180 |
+
causal_mechanisms=["TGF-beta-driven fibrosis"],
|
| 181 |
+
),
|
| 182 |
+
progress=ExperimentProgress(
|
| 183 |
+
data_normalized=True,
|
| 184 |
+
de_performed=True,
|
| 185 |
+
markers_discovered=True,
|
| 186 |
+
pathways_analyzed=True,
|
| 187 |
+
conclusion_reached=True,
|
| 188 |
+
),
|
| 189 |
+
resources=ResourceState(budget_total=100_000, budget_used=40_000),
|
| 190 |
+
)
|
| 191 |
+
aligned = rc.terminal_reward(
|
| 192 |
+
state,
|
| 193 |
+
[
|
| 194 |
+
ConclusionClaim(
|
| 195 |
+
top_markers=["NPPA", "NPPB"],
|
| 196 |
+
causal_mechanisms=["TGF-beta-driven fibrosis"],
|
| 197 |
+
confidence=0.8,
|
| 198 |
+
),
|
| 199 |
+
],
|
| 200 |
+
[],
|
| 201 |
+
)
|
| 202 |
+
misaligned = rc.terminal_reward(
|
| 203 |
+
state,
|
| 204 |
+
[
|
| 205 |
+
ConclusionClaim(
|
| 206 |
+
top_markers=["WRONG1"],
|
| 207 |
+
causal_mechanisms=["unrelated process"],
|
| 208 |
+
confidence=0.8,
|
| 209 |
+
),
|
| 210 |
+
],
|
| 211 |
+
[],
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
assert aligned.components["conclusion_alignment"] > misaligned.components["conclusion_alignment"]
|
| 215 |
+
assert aligned.components["conclusion_error_penalty"] > misaligned.components["conclusion_error_penalty"]
|
| 216 |
+
assert aligned.terminal > misaligned.terminal
|
tests/test_rules.py
CHANGED
|
@@ -55,47 +55,68 @@ class TestPrerequisites:
|
|
| 55 |
|
| 56 |
|
| 57 |
class TestRedundancy:
|
| 58 |
-
def
|
| 59 |
engine = RuleEngine()
|
| 60 |
violations = engine.check(
|
| 61 |
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 62 |
_state(cells_sequenced=True, qc_performed=True),
|
| 63 |
)
|
| 64 |
hard = engine.hard_violations(violations)
|
| 65 |
-
|
| 66 |
-
assert not hard
|
| 67 |
-
assert any("redundant" in m.lower() for m in soft)
|
| 68 |
|
| 69 |
-
def
|
| 70 |
engine = RuleEngine()
|
| 71 |
violations = engine.check(
|
| 72 |
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 73 |
_state(followup_designed=True, de_performed=True),
|
| 74 |
)
|
| 75 |
hard = engine.hard_violations(violations)
|
| 76 |
-
|
| 77 |
-
assert not hard
|
| 78 |
-
assert any("redundant" in m.lower() for m in soft)
|
| 79 |
|
| 80 |
|
| 81 |
class TestMetaActionTiming:
|
| 82 |
-
def
|
| 83 |
engine = RuleEngine()
|
| 84 |
violations = engine.check(
|
| 85 |
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 86 |
_state(),
|
| 87 |
)
|
| 88 |
-
|
| 89 |
-
assert any("follow-up design" in m.lower() for m in
|
| 90 |
|
| 91 |
-
def
|
| 92 |
engine = RuleEngine()
|
| 93 |
violations = engine.check(
|
| 94 |
ExperimentAction(action_type=ActionType.REQUEST_SUBAGENT_REVIEW),
|
| 95 |
_state(),
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
assert any("subagent review" in m.lower() for m in
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
class TestResourceConstraints:
|
| 101 |
def test_exhausted_budget_blocked(self):
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
class TestRedundancy:
|
| 58 |
+
def test_double_qc_is_hard_blocked(self):
|
| 59 |
engine = RuleEngine()
|
| 60 |
violations = engine.check(
|
| 61 |
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 62 |
_state(cells_sequenced=True, qc_performed=True),
|
| 63 |
)
|
| 64 |
hard = engine.hard_violations(violations)
|
| 65 |
+
assert any("redundant" in m.lower() for m in hard)
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
def test_repeated_followup_design_is_hard_blocked(self):
|
| 68 |
engine = RuleEngine()
|
| 69 |
violations = engine.check(
|
| 70 |
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 71 |
_state(followup_designed=True, de_performed=True),
|
| 72 |
)
|
| 73 |
hard = engine.hard_violations(violations)
|
| 74 |
+
assert any("redundant" in m.lower() for m in hard)
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
class TestMetaActionTiming:
|
| 78 |
+
def test_followup_design_without_analysis_is_hard_blocked(self):
|
| 79 |
engine = RuleEngine()
|
| 80 |
violations = engine.check(
|
| 81 |
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
|
| 82 |
_state(),
|
| 83 |
)
|
| 84 |
+
hard = engine.hard_violations(violations)
|
| 85 |
+
assert any("follow-up design" in m.lower() for m in hard)
|
| 86 |
|
| 87 |
+
def test_subagent_review_without_analysis_is_hard_blocked(self):
|
| 88 |
engine = RuleEngine()
|
| 89 |
violations = engine.check(
|
| 90 |
ExperimentAction(action_type=ActionType.REQUEST_SUBAGENT_REVIEW),
|
| 91 |
_state(),
|
| 92 |
)
|
| 93 |
+
hard = engine.hard_violations(violations)
|
| 94 |
+
assert any("subagent review" in m.lower() for m in hard)
|
| 95 |
+
|
| 96 |
+
def test_conclusion_without_marker_or_mechanism_evidence_is_hard_blocked(self):
|
| 97 |
+
engine = RuleEngine()
|
| 98 |
+
violations = engine.check(
|
| 99 |
+
ExperimentAction(action_type=ActionType.SYNTHESIZE_CONCLUSION),
|
| 100 |
+
_state(data_normalized=True, cells_clustered=True),
|
| 101 |
+
)
|
| 102 |
+
hard = engine.hard_violations(violations)
|
| 103 |
+
assert any("markers" in m.lower() for m in hard)
|
| 104 |
+
assert any("pathways or mechanisms" in m.lower() for m in hard)
|
| 105 |
+
|
| 106 |
+
def test_conclusion_with_marker_and_mechanism_evidence_is_allowed(self):
|
| 107 |
+
engine = RuleEngine()
|
| 108 |
+
violations = engine.check(
|
| 109 |
+
ExperimentAction(action_type=ActionType.SYNTHESIZE_CONCLUSION),
|
| 110 |
+
_state(
|
| 111 |
+
data_normalized=True,
|
| 112 |
+
cells_clustered=True,
|
| 113 |
+
markers_discovered=True,
|
| 114 |
+
pathways_analyzed=True,
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
hard = engine.hard_violations(violations)
|
| 118 |
+
assert not hard
|
| 119 |
+
|
| 120 |
|
| 121 |
class TestResourceConstraints:
|
| 122 |
def test_exhausted_budget_blocked(self):
|
tests/test_run_agent.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Tests for run_agent parser and fallback helpers."""
|
| 2 |
|
| 3 |
from models import ActionType, ExperimentAction
|
| 4 |
-
from run_agent import
|
| 5 |
from server.hackathon_environment import BioExperimentEnvironment
|
| 6 |
|
| 7 |
|
|
@@ -23,7 +23,23 @@ def test_parse_action_accepts_justifyement_typo():
|
|
| 23 |
assert action.justification == "typo key"
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
|
| 28 |
obs = env.reset(seed=0)
|
| 29 |
for action_type in (
|
|
@@ -32,5 +48,44 @@ def test_fallback_uses_observation_progress_not_step_index():
|
|
| 32 |
ActionType.SEQUENCE_CELLS,
|
| 33 |
):
|
| 34 |
obs = env.step(ExperimentAction(action_type=action_type))
|
| 35 |
-
|
| 36 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Tests for run_agent parser and fallback helpers."""
|
| 2 |
|
| 3 |
from models import ActionType, ExperimentAction
|
| 4 |
+
from run_agent import ensure_conclusion_claims, extract_json_object, parse_action, should_block_failed_reattempt
|
| 5 |
from server.hackathon_environment import BioExperimentEnvironment
|
| 6 |
|
| 7 |
|
|
|
|
| 23 |
assert action.justification == "typo key"
|
| 24 |
|
| 25 |
|
| 26 |
+
def test_extract_json_object_unwraps_quoted_json_string():
|
| 27 |
+
parsed = extract_json_object(
|
| 28 |
+
'"{\\"action_type\\": \\"run_qc\\", \\"method\\": \\"\\", \\"parameters\\": {}, \\"Justification\\": \\"check quality\\", \\"confidence\\": 0.8}"'
|
| 29 |
+
)
|
| 30 |
+
assert parsed is not None
|
| 31 |
+
assert parsed["action_type"] == "run_qc"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_parse_action_falls_back_when_inner_object_lacks_action_type():
|
| 35 |
+
action = parse_action(
|
| 36 |
+
'"{\\"action_type\\": \\"design_followup_experiment\\", \\"method\\": \\"\\", \\"parameters\\": {\\"criterion_description\\": \\"\\"}, \\"Justification\\": \\"follow-up\\", \\"confidence\\": 0.6, \\"threshold_value\\": {\\"conditions\\": [], \\"gene_filter_criteria\\": \\"x\\", \\"sample_group_size\\": 3}}"' # noqa: E501
|
| 37 |
+
)
|
| 38 |
+
assert action is not None
|
| 39 |
+
assert action.action_type == ActionType.DESIGN_FOLLOWUP
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_should_block_failed_reattempt_until_pipeline_progress():
|
| 43 |
env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
|
| 44 |
obs = env.reset(seed=0)
|
| 45 |
for action_type in (
|
|
|
|
| 48 |
ActionType.SEQUENCE_CELLS,
|
| 49 |
):
|
| 50 |
obs = env.step(ExperimentAction(action_type=action_type))
|
| 51 |
+
assert should_block_failed_reattempt(obs.pipeline_history, ActionType.SEQUENCE_CELLS) is False
|
| 52 |
+
assert should_block_failed_reattempt(obs.pipeline_history, ActionType.RUN_QC) is False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_ensure_conclusion_claims_infers_from_outputs_when_discoveries_empty():
|
| 56 |
+
env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
|
| 57 |
+
obs = env.reset(seed=0)
|
| 58 |
+
pipeline = [
|
| 59 |
+
ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
|
| 60 |
+
ExperimentAction(action_type=ActionType.PREPARE_LIBRARY),
|
| 61 |
+
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
|
| 62 |
+
ExperimentAction(action_type=ActionType.RUN_QC),
|
| 63 |
+
ExperimentAction(action_type=ActionType.FILTER_DATA),
|
| 64 |
+
ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
|
| 65 |
+
ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
|
| 66 |
+
ExperimentAction(
|
| 67 |
+
action_type=ActionType.DIFFERENTIAL_EXPRESSION,
|
| 68 |
+
parameters={"comparison": "disease_vs_healthy"},
|
| 69 |
+
),
|
| 70 |
+
ExperimentAction(action_type=ActionType.PATHWAY_ENRICHMENT),
|
| 71 |
+
]
|
| 72 |
+
for action in pipeline:
|
| 73 |
+
obs = env.step(action)
|
| 74 |
+
|
| 75 |
+
sparse_obs = obs.model_copy(update={
|
| 76 |
+
"discovered_markers": [],
|
| 77 |
+
"candidate_mechanisms": [],
|
| 78 |
+
})
|
| 79 |
+
action = ensure_conclusion_claims(
|
| 80 |
+
sparse_obs,
|
| 81 |
+
ExperimentAction(
|
| 82 |
+
action_type=ActionType.SYNTHESIZE_CONCLUSION,
|
| 83 |
+
confidence=0.9,
|
| 84 |
+
parameters={},
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
claims = action.parameters["claims"]
|
| 89 |
+
assert claims[0]["top_markers"]
|
| 90 |
+
assert claims[0]["causal_mechanisms"]
|
| 91 |
+
assert claims[0]["predicted_pathways"]
|
tests/test_training_script.py
CHANGED
|
@@ -43,6 +43,15 @@ def test_parse_action_completion_accepts_reasoning_alias():
|
|
| 43 |
assert action.justification == "Measure quality before filtering."
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def test_build_prompt_examples_contains_reference_action():
|
| 47 |
examples = build_prompt_examples(
|
| 48 |
dataset_episodes=1,
|
|
@@ -55,6 +64,7 @@ def test_build_prompt_examples_contains_reference_action():
|
|
| 55 |
assert len(examples) == 2
|
| 56 |
assert examples[0]["scenario_name"] == "cardiac_disease_de"
|
| 57 |
assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def test_openenv_reward_penalizes_invalid_completion():
|
|
|
|
| 43 |
assert action.justification == "Measure quality before filtering."
|
| 44 |
|
| 45 |
|
| 46 |
+
def test_parse_action_completion_normalizes_run_agent_aliases():
|
| 47 |
+
action = parse_action_completion(
|
| 48 |
+
'{"action_type":"network_inference","method":"pySCENIC"}'
|
| 49 |
+
)
|
| 50 |
+
assert action is not None
|
| 51 |
+
assert action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE
|
| 52 |
+
assert action.method == "pySCENIC"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
def test_build_prompt_examples_contains_reference_action():
|
| 56 |
examples = build_prompt_examples(
|
| 57 |
dataset_episodes=1,
|
|
|
|
| 64 |
assert len(examples) == 2
|
| 65 |
assert examples[0]["scenario_name"] == "cardiac_disease_de"
|
| 66 |
assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
|
| 67 |
+
assert '"action_type": "select_cohort"' in examples[1]["reference_action"]
|
| 68 |
|
| 69 |
|
| 70 |
def test_openenv_reward_penalizes_invalid_completion():
|
train.ipynb
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "cbde861c",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Train A Self-Driving Lab Policy on H100\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"This notebook is designed for Jupyter GPU nodes such as H100 clusters.\n",
|
| 11 |
+
"It uses the notebook-friendly helpers in `training_script.py` to build prompts from the same self-driving lab environment state used by `run_agent.py`, preview reference actions, and launch GRPO training without shelling out to the CLI."
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"id": "da2e770c",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"%pip install -q -U torch transformers datasets trl accelerate matplotlib huggingface_hub\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"# Optional extras used by some reward-scoring paths.\n",
|
| 24 |
+
"%pip install -q -U sentence-transformers gseapy"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"id": "f4444591",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"from pathlib import Path\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"import torch\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"from training_script import build_prompt_examples, make_training_args, run_training\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"print(\"CUDA available:\", torch.cuda.is_available())\n",
|
| 41 |
+
"if torch.cuda.is_available():\n",
|
| 42 |
+
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
|
| 43 |
+
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"Path(\"artifacts\").mkdir(exist_ok=True)"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"id": "c9c472b3",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"args = make_training_args(\n",
|
| 56 |
+
" model_id=\"Qwen/Qwen3.5-0.8B\",\n",
|
| 57 |
+
" output_dir=\"artifacts/grpo-h100\",\n",
|
| 58 |
+
" dataset_episodes=32,\n",
|
| 59 |
+
" rollout_steps=10,\n",
|
| 60 |
+
" collection_policy=\"heuristic\",\n",
|
| 61 |
+
" reward_backend=\"local\",\n",
|
| 62 |
+
" domain_randomise=True,\n",
|
| 63 |
+
" num_generations=4,\n",
|
| 64 |
+
" max_completion_length=160,\n",
|
| 65 |
+
" max_prompt_length=1280,\n",
|
| 66 |
+
" per_device_train_batch_size=4,\n",
|
| 67 |
+
" gradient_accumulation_steps=4,\n",
|
| 68 |
+
" learning_rate=5e-6,\n",
|
| 69 |
+
" num_train_epochs=1.0,\n",
|
| 70 |
+
" logging_steps=1,\n",
|
| 71 |
+
" save_steps=25,\n",
|
| 72 |
+
" trust_remote_code=True,\n",
|
| 73 |
+
" dry_run=False,\n",
|
| 74 |
+
" seed=42,\n",
|
| 75 |
+
")\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"args"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"id": "d4c3d9c4",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"preview_examples = build_prompt_examples(\n",
|
| 88 |
+
" dataset_episodes=1,\n",
|
| 89 |
+
" rollout_steps=args.rollout_steps,\n",
|
| 90 |
+
" collection_policy=args.collection_policy,\n",
|
| 91 |
+
" scenario_names=[\"cardiac_disease_de\"],\n",
|
| 92 |
+
" seed=args.seed,\n",
|
| 93 |
+
" domain_randomise=args.domain_randomise,\n",
|
| 94 |
+
")\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"print(preview_examples[0][\"prompt\"][:3500])\n",
|
| 97 |
+
"print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"id": "647663dd",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"# Optional smoke test before a full run.\n",
|
| 108 |
+
"dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
|
| 109 |
+
"dry_run_result = run_training(dry_run_args)\n",
|
| 110 |
+
"len(dry_run_result[\"examples\"])"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": null,
|
| 116 |
+
"id": "5f29f456",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"from IPython.display import Image, display\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"train_result = run_training(args)\n",
|
| 123 |
+
"for name, plot_path in train_result[\"plot_paths\"].items():\n",
|
| 124 |
+
" print(name, plot_path)\n",
|
| 125 |
+
" display(Image(filename=plot_path))"
|
| 126 |
+
]
|
| 127 |
+
}
|
| 128 |
+
],
|
| 129 |
+
"metadata": {
|
| 130 |
+
"kernelspec": {
|
| 131 |
+
"display_name": "Python 3",
|
| 132 |
+
"language": "python",
|
| 133 |
+
"name": "python3"
|
| 134 |
+
},
|
| 135 |
+
"language_info": {
|
| 136 |
+
"name": "python"
|
| 137 |
+
}
|
| 138 |
+
},
|
| 139 |
+
"nbformat": 4,
|
| 140 |
+
"nbformat_minor": 5
|
| 141 |
+
}
|
training_script.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Train a planner with TRL GRPO and OpenEnv rewards."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -21,29 +21,53 @@ from models import (
|
|
| 21 |
from server.hackathon_environment import BioExperimentEnvironment
|
| 22 |
from server.tasks.scenarios import SCENARIO_LIBRARY
|
| 23 |
|
| 24 |
-
DEFAULT_MODEL_ID = "Qwen/Qwen3.5-
|
| 25 |
DEFAULT_OUTPUT_DIR = "training/grpo-output"
|
| 26 |
DEFAULT_BASE_URL = "http://localhost:8000"
|
|
|
|
| 27 |
INVALID_ACTION_PENALTY = -2.0
|
| 28 |
ENVIRONMENT_ERROR_PENALTY = -4.0
|
| 29 |
|
| 30 |
SYSTEM_PROMPT = build_agent_system_prompt()
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
HEURISTIC_SEQUENCE = [
|
| 33 |
ActionType.COLLECT_SAMPLE,
|
|
|
|
| 34 |
ActionType.PREPARE_LIBRARY,
|
| 35 |
ActionType.SEQUENCE_CELLS,
|
| 36 |
ActionType.RUN_QC,
|
| 37 |
ActionType.FILTER_DATA,
|
| 38 |
ActionType.NORMALIZE_DATA,
|
|
|
|
| 39 |
ActionType.CLUSTER_CELLS,
|
| 40 |
ActionType.DIFFERENTIAL_EXPRESSION,
|
| 41 |
ActionType.PATHWAY_ENRICHMENT,
|
| 42 |
ActionType.MARKER_SELECTION,
|
|
|
|
|
|
|
| 43 |
ActionType.SYNTHESIZE_CONCLUSION,
|
| 44 |
]
|
| 45 |
|
| 46 |
-
VALID_ACTION_TYPES =
|
| 47 |
|
| 48 |
|
| 49 |
def compact_preview(value: Any, max_chars: int = 160) -> str:
|
|
@@ -129,7 +153,11 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|
| 129 |
help="Enable domain randomisation while building prompts and local rewards.",
|
| 130 |
)
|
| 131 |
parser.add_argument("--num-generations", type=int, default=2)
|
| 132 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
parser.add_argument("--max-prompt-length", type=int, default=768)
|
| 134 |
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| 135 |
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
|
@@ -197,23 +225,42 @@ def format_observation(obs: ExperimentObservation) -> str:
|
|
| 197 |
if context:
|
| 198 |
parts.append(context)
|
| 199 |
if obs.pipeline_history:
|
| 200 |
-
|
| 201 |
-
|
|
|
|
| 202 |
tag = "OK" if step.success else "FAIL"
|
| 203 |
-
line = f" [{tag}] {step.action_type.value}
|
| 204 |
-
if step.
|
| 205 |
-
line += f"
|
|
|
|
| 206 |
parts.append(line)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
if obs.latest_output and obs.latest_output.data:
|
| 208 |
parts.append(
|
| 209 |
-
f"Latest
|
| 210 |
)
|
| 211 |
if obs.rule_violations:
|
| 212 |
-
parts.append(f"
|
| 213 |
if obs.discovered_markers:
|
| 214 |
-
parts.append(f"Markers: {obs.discovered_markers[:5]}")
|
| 215 |
if obs.candidate_mechanisms:
|
| 216 |
-
parts.append(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
return "\n".join(parts)
|
| 218 |
|
| 219 |
|
|
@@ -251,6 +298,7 @@ def default_comparison_name(conditions: Sequence[str]) -> str:
|
|
| 251 |
def build_experiment_action(
|
| 252 |
action_type: ActionType,
|
| 253 |
discovered_markers: Sequence[str],
|
|
|
|
| 254 |
conditions: Sequence[str],
|
| 255 |
) -> ExperimentAction:
|
| 256 |
method = None
|
|
@@ -260,9 +308,27 @@ def build_experiment_action(
|
|
| 260 |
if action_type == ActionType.COLLECT_SAMPLE:
|
| 261 |
parameters = {"n_samples": 6}
|
| 262 |
justification = "Collect enough samples to start the experiment."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
elif action_type == ActionType.PREPARE_LIBRARY:
|
| 264 |
method = "10x_chromium"
|
| 265 |
justification = "Prepare a single-cell library for sequencing."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
elif action_type == ActionType.SEQUENCE_CELLS:
|
| 267 |
method = "NovaSeq"
|
| 268 |
justification = "Generate reads for downstream single-cell analysis."
|
|
@@ -275,6 +341,9 @@ def build_experiment_action(
|
|
| 275 |
elif action_type == ActionType.NORMALIZE_DATA:
|
| 276 |
method = "scanpy.pp.normalize_total"
|
| 277 |
justification = "Normalize counts for comparable expression profiles."
|
|
|
|
|
|
|
|
|
|
| 278 |
elif action_type == ActionType.CLUSTER_CELLS:
|
| 279 |
method = "scanpy.tl.leiden"
|
| 280 |
justification = "Resolve cell states before interpretation."
|
|
@@ -291,20 +360,34 @@ def build_experiment_action(
|
|
| 291 |
elif action_type == ActionType.MARKER_SELECTION:
|
| 292 |
method = "scanpy.tl.rank_genes_groups"
|
| 293 |
justification = "Nominate marker genes for validation."
|
|
|
|
|
|
|
|
|
|
| 294 |
elif action_type == ActionType.VALIDATE_MARKER:
|
| 295 |
method = "qPCR"
|
| 296 |
parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
|
| 297 |
justification = "Validate the strongest discovered marker."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 299 |
top = list(discovered_markers[:5]) if discovered_markers else []
|
| 300 |
parameters = {
|
| 301 |
"claims": [{
|
| 302 |
"top_markers": top,
|
| 303 |
-
"causal_mechanisms": [],
|
| 304 |
-
"predicted_pathways": {
|
|
|
|
|
|
|
|
|
|
| 305 |
"confidence": 0.6,
|
| 306 |
-
"claim_type": "correlational",
|
| 307 |
-
"claim": "",
|
| 308 |
}],
|
| 309 |
}
|
| 310 |
justification = "Summarize the current evidence into a conclusion."
|
|
@@ -375,6 +458,7 @@ def build_prompt_examples(
|
|
| 375 |
[action.action_type for action in history_actions],
|
| 376 |
),
|
| 377 |
discovered_markers=obs.discovered_markers,
|
|
|
|
| 378 |
conditions=obs.task.conditions,
|
| 379 |
)
|
| 380 |
examples.append({
|
|
@@ -543,11 +627,145 @@ def normalize_optional_string(value: Any) -> Optional[str]:
|
|
| 543 |
return compact_preview(value, 80)
|
| 544 |
|
| 545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
def parse_action_completion(text: str) -> Optional[ExperimentAction]:
|
| 547 |
payload = extract_json_object(text)
|
| 548 |
if payload is not None:
|
| 549 |
-
action_type = get_payload_value(payload, "action_type")
|
| 550 |
-
if action_type
|
| 551 |
return None
|
| 552 |
|
| 553 |
parameters = get_payload_value(payload, "parameters", "params") or {}
|
|
@@ -584,8 +802,8 @@ def parse_action_completion(text: str) -> Optional[ExperimentAction]:
|
|
| 584 |
if not action_match:
|
| 585 |
return None
|
| 586 |
|
| 587 |
-
action_type = action_match.group(1)
|
| 588 |
-
if action_type
|
| 589 |
return None
|
| 590 |
|
| 591 |
method_match = re.search(
|
|
@@ -733,6 +951,7 @@ class OpenEnvReward:
|
|
| 733 |
obs = env.step(previous_action)
|
| 734 |
if obs.done:
|
| 735 |
return float(obs.reward)
|
|
|
|
| 736 |
obs = env.step(action)
|
| 737 |
return float(obs.reward)
|
| 738 |
|
|
@@ -1081,7 +1300,7 @@ def generate_action_with_model(
|
|
| 1081 |
tokenizer: Any,
|
| 1082 |
prompt_or_observation: str | ExperimentObservation,
|
| 1083 |
*,
|
| 1084 |
-
max_new_tokens: int =
|
| 1085 |
temperature: float = 0.2,
|
| 1086 |
top_p: float = 0.9,
|
| 1087 |
do_sample: bool = True,
|
|
@@ -1114,6 +1333,8 @@ def generate_action_with_model(
|
|
| 1114 |
new_tokens = output_ids[0][prompt_tokens:]
|
| 1115 |
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 1116 |
action = parse_action_completion(response_text)
|
|
|
|
|
|
|
| 1117 |
return {
|
| 1118 |
"prompt": prompt,
|
| 1119 |
"response_text": response_text,
|
|
|
|
| 1 |
+
"""Train a self-driving lab planner with TRL GRPO and OpenEnv rewards."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 21 |
from server.hackathon_environment import BioExperimentEnvironment
|
| 22 |
from server.tasks.scenarios import SCENARIO_LIBRARY
|
| 23 |
|
| 24 |
+
DEFAULT_MODEL_ID = "Qwen/Qwen3.5-4B"
|
| 25 |
DEFAULT_OUTPUT_DIR = "training/grpo-output"
|
| 26 |
DEFAULT_BASE_URL = "http://localhost:8000"
|
| 27 |
+
DEFAULT_COMPLETION_TOKEN_BUDGET = 160
|
| 28 |
INVALID_ACTION_PENALTY = -2.0
|
| 29 |
ENVIRONMENT_ERROR_PENALTY = -4.0
|
| 30 |
|
| 31 |
SYSTEM_PROMPT = build_agent_system_prompt()
|
| 32 |
|
| 33 |
+
ACTION_TYPES = [action.value for action in ActionType]
|
| 34 |
+
ACTION_TYPE_ALIASES = {
|
| 35 |
+
"collect_samples": ActionType.COLLECT_SAMPLE.value,
|
| 36 |
+
"collect_sample_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
|
| 37 |
+
"collect_samples_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
|
| 38 |
+
"prepare_sc_library": ActionType.PREPARE_LIBRARY.value,
|
| 39 |
+
"sequence_single_cells": ActionType.SEQUENCE_CELLS.value,
|
| 40 |
+
"qc": ActionType.RUN_QC.value,
|
| 41 |
+
"run_quality_control": ActionType.RUN_QC.value,
|
| 42 |
+
"cluster": ActionType.CLUSTER_CELLS.value,
|
| 43 |
+
"de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
|
| 44 |
+
"differential_expression_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
|
| 45 |
+
"trajectory_inference": ActionType.TRAJECTORY_ANALYSIS.value,
|
| 46 |
+
"infer_trajectory": ActionType.TRAJECTORY_ANALYSIS.value,
|
| 47 |
+
"network_inference": ActionType.REGULATORY_NETWORK_INFERENCE.value,
|
| 48 |
+
"select_markers": ActionType.MARKER_SELECTION.value,
|
| 49 |
+
"final_conclusion": ActionType.SYNTHESIZE_CONCLUSION.value,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
HEURISTIC_SEQUENCE = [
|
| 53 |
ActionType.COLLECT_SAMPLE,
|
| 54 |
+
ActionType.SELECT_COHORT,
|
| 55 |
ActionType.PREPARE_LIBRARY,
|
| 56 |
ActionType.SEQUENCE_CELLS,
|
| 57 |
ActionType.RUN_QC,
|
| 58 |
ActionType.FILTER_DATA,
|
| 59 |
ActionType.NORMALIZE_DATA,
|
| 60 |
+
ActionType.INTEGRATE_BATCHES,
|
| 61 |
ActionType.CLUSTER_CELLS,
|
| 62 |
ActionType.DIFFERENTIAL_EXPRESSION,
|
| 63 |
ActionType.PATHWAY_ENRICHMENT,
|
| 64 |
ActionType.MARKER_SELECTION,
|
| 65 |
+
ActionType.TRAJECTORY_ANALYSIS,
|
| 66 |
+
ActionType.REGULATORY_NETWORK_INFERENCE,
|
| 67 |
ActionType.SYNTHESIZE_CONCLUSION,
|
| 68 |
]
|
| 69 |
|
| 70 |
+
VALID_ACTION_TYPES = set(ACTION_TYPES)
|
| 71 |
|
| 72 |
|
| 73 |
def compact_preview(value: Any, max_chars: int = 160) -> str:
|
|
|
|
| 153 |
help="Enable domain randomisation while building prompts and local rewards.",
|
| 154 |
)
|
| 155 |
parser.add_argument("--num-generations", type=int, default=2)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--max-completion-length",
|
| 158 |
+
type=int,
|
| 159 |
+
default=DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| 160 |
+
)
|
| 161 |
parser.add_argument("--max-prompt-length", type=int, default=768)
|
| 162 |
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| 163 |
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
|
|
|
| 225 |
if context:
|
| 226 |
parts.append(context)
|
| 227 |
if obs.pipeline_history:
|
| 228 |
+
last5 = obs.pipeline_history[-5:]
|
| 229 |
+
parts.append("Recent history:")
|
| 230 |
+
for step in last5:
|
| 231 |
tag = "OK" if step.success else "FAIL"
|
| 232 |
+
line = f" [{tag}] {step.action_type.value}"
|
| 233 |
+
if step.method:
|
| 234 |
+
line += f" ({step.method})"
|
| 235 |
+
line += f": {step.output_summary[:80]}"
|
| 236 |
parts.append(line)
|
| 237 |
+
completed = {
|
| 238 |
+
step.action_type for step in obs.pipeline_history if step.success
|
| 239 |
+
}
|
| 240 |
+
if completed:
|
| 241 |
+
parts.append(
|
| 242 |
+
"Completed steps (do NOT repeat): "
|
| 243 |
+
+ ", ".join(sorted(action.value for action in completed))
|
| 244 |
+
)
|
| 245 |
+
remaining = [
|
| 246 |
+
action.value for action in HEURISTIC_SEQUENCE if action not in completed
|
| 247 |
+
]
|
| 248 |
+
if remaining:
|
| 249 |
+
parts.append(f"Remaining steps (choose one): {', '.join(remaining)}")
|
| 250 |
if obs.latest_output and obs.latest_output.data:
|
| 251 |
parts.append(
|
| 252 |
+
f"Latest data: {compact_preview(obs.latest_output.data, 200)}"
|
| 253 |
)
|
| 254 |
if obs.rule_violations:
|
| 255 |
+
parts.append(f"VIOLATIONS: {obs.rule_violations}")
|
| 256 |
if obs.discovered_markers:
|
| 257 |
+
parts.append(f"Markers found so far: {obs.discovered_markers[:5]}")
|
| 258 |
if obs.candidate_mechanisms:
|
| 259 |
+
parts.append(f"Candidate mechanisms: {obs.candidate_mechanisms[:5]}")
|
| 260 |
+
parts.append(
|
| 261 |
+
'Output ONLY a single JSON object with these exact keys, no comments, no extra text:\n'
|
| 262 |
+
'{"action_type": "<one of the remaining steps>", "method": null, "parameters": {}, "justification": "<why>", "confidence": 0.8}'
|
| 263 |
+
)
|
| 264 |
return "\n".join(parts)
|
| 265 |
|
| 266 |
|
|
|
|
| 298 |
def build_experiment_action(
|
| 299 |
action_type: ActionType,
|
| 300 |
discovered_markers: Sequence[str],
|
| 301 |
+
candidate_mechanisms: Sequence[str],
|
| 302 |
conditions: Sequence[str],
|
| 303 |
) -> ExperimentAction:
|
| 304 |
method = None
|
|
|
|
| 308 |
if action_type == ActionType.COLLECT_SAMPLE:
|
| 309 |
parameters = {"n_samples": 6}
|
| 310 |
justification = "Collect enough samples to start the experiment."
|
| 311 |
+
elif action_type == ActionType.SELECT_COHORT:
|
| 312 |
+
parameters = {
|
| 313 |
+
"comparison": default_comparison_name(conditions),
|
| 314 |
+
"conditions": list(conditions[:2]) or ["disease", "healthy"],
|
| 315 |
+
}
|
| 316 |
+
justification = "Define the cohort split before committing to downstream analysis."
|
| 317 |
elif action_type == ActionType.PREPARE_LIBRARY:
|
| 318 |
method = "10x_chromium"
|
| 319 |
justification = "Prepare a single-cell library for sequencing."
|
| 320 |
+
elif action_type == ActionType.CULTURE_CELLS:
|
| 321 |
+
method = "organoid_culture"
|
| 322 |
+
parameters = {"duration_days": 7}
|
| 323 |
+
justification = "Expand viable cells before a perturbation or profiling step."
|
| 324 |
+
elif action_type == ActionType.PERTURB_GENE:
|
| 325 |
+
method = "CRISPRi"
|
| 326 |
+
parameters = {"target_gene": candidate_mechanisms[0] if candidate_mechanisms else "STAT3"}
|
| 327 |
+
justification = "Test whether a candidate regulator causally shifts cell state."
|
| 328 |
+
elif action_type == ActionType.PERTURB_COMPOUND:
|
| 329 |
+
method = "small_molecule_screen"
|
| 330 |
+
parameters = {"compound": candidate_mechanisms[0] if candidate_mechanisms else "TGFb_inhibitor"}
|
| 331 |
+
justification = "Probe the pathway hypothesis with a targeted compound perturbation."
|
| 332 |
elif action_type == ActionType.SEQUENCE_CELLS:
|
| 333 |
method = "NovaSeq"
|
| 334 |
justification = "Generate reads for downstream single-cell analysis."
|
|
|
|
| 341 |
elif action_type == ActionType.NORMALIZE_DATA:
|
| 342 |
method = "scanpy.pp.normalize_total"
|
| 343 |
justification = "Normalize counts for comparable expression profiles."
|
| 344 |
+
elif action_type == ActionType.INTEGRATE_BATCHES:
|
| 345 |
+
method = "scanorama.integrate"
|
| 346 |
+
justification = "Correct batch effects before comparing cellular programs."
|
| 347 |
elif action_type == ActionType.CLUSTER_CELLS:
|
| 348 |
method = "scanpy.tl.leiden"
|
| 349 |
justification = "Resolve cell states before interpretation."
|
|
|
|
| 360 |
elif action_type == ActionType.MARKER_SELECTION:
|
| 361 |
method = "scanpy.tl.rank_genes_groups"
|
| 362 |
justification = "Nominate marker genes for validation."
|
| 363 |
+
elif action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| 364 |
+
method = "pySCENIC"
|
| 365 |
+
justification = "Infer upstream regulators behind the observed state changes."
|
| 366 |
elif action_type == ActionType.VALIDATE_MARKER:
|
| 367 |
method = "qPCR"
|
| 368 |
parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
|
| 369 |
justification = "Validate the strongest discovered marker."
|
| 370 |
+
elif action_type == ActionType.DESIGN_FOLLOWUP:
|
| 371 |
+
method = "followup_plan"
|
| 372 |
+
parameters = {"priority_hypothesis": candidate_mechanisms[0] if candidate_mechanisms else "fibrotic_activation"}
|
| 373 |
+
justification = "Propose the next experiment to disambiguate remaining uncertainty."
|
| 374 |
+
elif action_type == ActionType.REQUEST_SUBAGENT_REVIEW:
|
| 375 |
+
method = "peer_review"
|
| 376 |
+
parameters = {"focus": "experimental_design"}
|
| 377 |
+
justification = "Request a review of the current self-driving lab plan."
|
| 378 |
elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 379 |
top = list(discovered_markers[:5]) if discovered_markers else []
|
| 380 |
parameters = {
|
| 381 |
"claims": [{
|
| 382 |
"top_markers": top,
|
| 383 |
+
"causal_mechanisms": list(candidate_mechanisms[:5]),
|
| 384 |
+
"predicted_pathways": {
|
| 385 |
+
mechanism: 0.6
|
| 386 |
+
for mechanism in list(candidate_mechanisms[:3])
|
| 387 |
+
},
|
| 388 |
"confidence": 0.6,
|
| 389 |
+
"claim_type": "causal" if candidate_mechanisms else "correlational",
|
| 390 |
+
"claim": f"Synthesis for {default_comparison_name(conditions)}.",
|
| 391 |
}],
|
| 392 |
}
|
| 393 |
justification = "Summarize the current evidence into a conclusion."
|
|
|
|
| 458 |
[action.action_type for action in history_actions],
|
| 459 |
),
|
| 460 |
discovered_markers=obs.discovered_markers,
|
| 461 |
+
candidate_mechanisms=obs.candidate_mechanisms,
|
| 462 |
conditions=obs.task.conditions,
|
| 463 |
)
|
| 464 |
examples.append({
|
|
|
|
| 627 |
return compact_preview(value, 80)
|
| 628 |
|
| 629 |
|
| 630 |
+
def normalize_action_type(raw_action_type: Any) -> Optional[str]:
|
| 631 |
+
if not isinstance(raw_action_type, str):
|
| 632 |
+
return None
|
| 633 |
+
|
| 634 |
+
candidate = raw_action_type.strip().lower()
|
| 635 |
+
if candidate in ACTION_TYPES:
|
| 636 |
+
return candidate
|
| 637 |
+
if candidate in ACTION_TYPE_ALIASES:
|
| 638 |
+
return ACTION_TYPE_ALIASES[candidate]
|
| 639 |
+
|
| 640 |
+
candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
|
| 641 |
+
if candidate in ACTION_TYPES:
|
| 642 |
+
return candidate
|
| 643 |
+
if candidate in ACTION_TYPE_ALIASES:
|
| 644 |
+
return ACTION_TYPE_ALIASES[candidate]
|
| 645 |
+
|
| 646 |
+
heuristics = [
|
| 647 |
+
(("collect", "sample"), ActionType.COLLECT_SAMPLE.value),
|
| 648 |
+
(("cohort",), ActionType.SELECT_COHORT.value),
|
| 649 |
+
(("library",), ActionType.PREPARE_LIBRARY.value),
|
| 650 |
+
(("culture",), ActionType.CULTURE_CELLS.value),
|
| 651 |
+
(("perturb", "gene"), ActionType.PERTURB_GENE.value),
|
| 652 |
+
(("perturb", "compound"), ActionType.PERTURB_COMPOUND.value),
|
| 653 |
+
(("sequence",), ActionType.SEQUENCE_CELLS.value),
|
| 654 |
+
(("qc",), ActionType.RUN_QC.value),
|
| 655 |
+
(("quality", "control"), ActionType.RUN_QC.value),
|
| 656 |
+
(("filter",), ActionType.FILTER_DATA.value),
|
| 657 |
+
(("normal",), ActionType.NORMALIZE_DATA.value),
|
| 658 |
+
(("integrat", "batch"), ActionType.INTEGRATE_BATCHES.value),
|
| 659 |
+
(("cluster",), ActionType.CLUSTER_CELLS.value),
|
| 660 |
+
(("differential", "expression"), ActionType.DIFFERENTIAL_EXPRESSION.value),
|
| 661 |
+
(("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
|
| 662 |
+
(("trajectory",), ActionType.TRAJECTORY_ANALYSIS.value),
|
| 663 |
+
(("network",), ActionType.REGULATORY_NETWORK_INFERENCE.value),
|
| 664 |
+
(("marker",), ActionType.MARKER_SELECTION.value),
|
| 665 |
+
(("validat", "marker"), ActionType.VALIDATE_MARKER.value),
|
| 666 |
+
(("followup",), ActionType.DESIGN_FOLLOWUP.value),
|
| 667 |
+
(("review",), ActionType.REQUEST_SUBAGENT_REVIEW.value),
|
| 668 |
+
(("conclusion",), ActionType.SYNTHESIZE_CONCLUSION.value),
|
| 669 |
+
]
|
| 670 |
+
for fragments, normalized in heuristics:
|
| 671 |
+
if all(fragment in candidate for fragment in fragments):
|
| 672 |
+
return normalized
|
| 673 |
+
return None
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def _unique_nonempty(items: Sequence[Any], limit: int = 5) -> List[str]:
|
| 677 |
+
seen: set[str] = set()
|
| 678 |
+
result: List[str] = []
|
| 679 |
+
for raw in items:
|
| 680 |
+
value = normalize_optional_string(raw)
|
| 681 |
+
if not value:
|
| 682 |
+
continue
|
| 683 |
+
key = value.upper()
|
| 684 |
+
if key in seen:
|
| 685 |
+
continue
|
| 686 |
+
seen.add(key)
|
| 687 |
+
result.append(value)
|
| 688 |
+
if len(result) >= limit:
|
| 689 |
+
break
|
| 690 |
+
return result
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def _infer_conclusion_evidence(
|
| 694 |
+
obs: ExperimentObservation,
|
| 695 |
+
) -> Tuple[List[str], List[str], Dict[str, float]]:
|
| 696 |
+
top_markers = _unique_nonempty(list(obs.discovered_markers), limit=5)
|
| 697 |
+
causal_mechanisms = _unique_nonempty(list(obs.candidate_mechanisms), limit=5)
|
| 698 |
+
predicted_pathways: Dict[str, float] = {}
|
| 699 |
+
|
| 700 |
+
for output in reversed(obs.all_outputs):
|
| 701 |
+
if not output.success:
|
| 702 |
+
continue
|
| 703 |
+
data = output.data or {}
|
| 704 |
+
|
| 705 |
+
if not top_markers:
|
| 706 |
+
markers = data.get("markers", [])
|
| 707 |
+
if isinstance(markers, list):
|
| 708 |
+
top_markers = _unique_nonempty(markers, limit=5)
|
| 709 |
+
if not causal_mechanisms:
|
| 710 |
+
regulators = data.get("top_regulators", [])
|
| 711 |
+
if isinstance(regulators, list):
|
| 712 |
+
causal_mechanisms = _unique_nonempty(regulators, limit=5)
|
| 713 |
+
if not predicted_pathways:
|
| 714 |
+
for item in data.get("top_pathways", []):
|
| 715 |
+
if not isinstance(item, dict):
|
| 716 |
+
continue
|
| 717 |
+
pathway = normalize_optional_string(item.get("pathway"))
|
| 718 |
+
score = item.get("score")
|
| 719 |
+
if pathway and isinstance(score, (int, float)):
|
| 720 |
+
predicted_pathways[pathway] = float(score)
|
| 721 |
+
if len(predicted_pathways) >= 5:
|
| 722 |
+
break
|
| 723 |
+
if top_markers and causal_mechanisms and predicted_pathways:
|
| 724 |
+
break
|
| 725 |
+
|
| 726 |
+
return top_markers, causal_mechanisms, predicted_pathways
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def ensure_conclusion_claims(
|
| 730 |
+
obs: ExperimentObservation,
|
| 731 |
+
action: ExperimentAction,
|
| 732 |
+
) -> ExperimentAction:
|
| 733 |
+
if action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
|
| 734 |
+
return action
|
| 735 |
+
|
| 736 |
+
parameters = dict(action.parameters or {})
|
| 737 |
+
raw_claims = parameters.get("claims")
|
| 738 |
+
if isinstance(raw_claims, list):
|
| 739 |
+
normalized_claims = [claim for claim in raw_claims if isinstance(claim, dict)]
|
| 740 |
+
if normalized_claims:
|
| 741 |
+
parameters["claims"] = normalized_claims
|
| 742 |
+
if parameters != action.parameters:
|
| 743 |
+
return action.model_copy(update={"parameters": parameters})
|
| 744 |
+
return action
|
| 745 |
+
|
| 746 |
+
top_markers, causal_mechanisms, predicted_pathways = _infer_conclusion_evidence(obs)
|
| 747 |
+
claim_type = "causal" if causal_mechanisms else "correlational"
|
| 748 |
+
conditions = " vs ".join(obs.task.conditions[:2]) if obs.task.conditions else "the task conditions"
|
| 749 |
+
claim = action.justification or f"Final synthesis for {conditions}."
|
| 750 |
+
|
| 751 |
+
parameters["claims"] = [{
|
| 752 |
+
"top_markers": top_markers,
|
| 753 |
+
"causal_mechanisms": causal_mechanisms,
|
| 754 |
+
"predicted_pathways": predicted_pathways,
|
| 755 |
+
"confidence": action.confidence,
|
| 756 |
+
"claim_type": claim_type,
|
| 757 |
+
"claim": claim,
|
| 758 |
+
}]
|
| 759 |
+
if not action.justification:
|
| 760 |
+
action = action.model_copy(update={"justification": claim})
|
| 761 |
+
return action.model_copy(update={"parameters": parameters})
|
| 762 |
+
|
| 763 |
+
|
| 764 |
def parse_action_completion(text: str) -> Optional[ExperimentAction]:
|
| 765 |
payload = extract_json_object(text)
|
| 766 |
if payload is not None:
|
| 767 |
+
action_type = normalize_action_type(get_payload_value(payload, "action_type"))
|
| 768 |
+
if action_type is None:
|
| 769 |
return None
|
| 770 |
|
| 771 |
parameters = get_payload_value(payload, "parameters", "params") or {}
|
|
|
|
| 802 |
if not action_match:
|
| 803 |
return None
|
| 804 |
|
| 805 |
+
action_type = normalize_action_type(action_match.group(1))
|
| 806 |
+
if action_type is None:
|
| 807 |
return None
|
| 808 |
|
| 809 |
method_match = re.search(
|
|
|
|
| 951 |
obs = env.step(previous_action)
|
| 952 |
if obs.done:
|
| 953 |
return float(obs.reward)
|
| 954 |
+
action = ensure_conclusion_claims(obs, action)
|
| 955 |
obs = env.step(action)
|
| 956 |
return float(obs.reward)
|
| 957 |
|
|
|
|
| 1300 |
tokenizer: Any,
|
| 1301 |
prompt_or_observation: str | ExperimentObservation,
|
| 1302 |
*,
|
| 1303 |
+
max_new_tokens: int = DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| 1304 |
temperature: float = 0.2,
|
| 1305 |
top_p: float = 0.9,
|
| 1306 |
do_sample: bool = True,
|
|
|
|
| 1333 |
new_tokens = output_ids[0][prompt_tokens:]
|
| 1334 |
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 1335 |
action = parse_action_completion(response_text)
|
| 1336 |
+
if action is not None and isinstance(prompt_or_observation, ExperimentObservation):
|
| 1337 |
+
action = ensure_conclusion_claims(prompt_or_observation, action)
|
| 1338 |
return {
|
| 1339 |
"prompt": prompt,
|
| 1340 |
"response_text": response_text,
|
training_unsloth.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train and run quantized self-driving lab models with Unsloth.
|
| 2 |
+
|
| 3 |
+
This keeps the same OpenEnv prompt + reward wiring as `training_script.py`,
|
| 4 |
+
but arranges the Unsloth path in the more typical pattern:
|
| 5 |
+
1. patch GRPO support
|
| 6 |
+
2. load a quantized model
|
| 7 |
+
3. apply LoRA adapters
|
| 8 |
+
4. train with an explicit OpenEnv reward function
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import random
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any, Dict, Optional, Sequence
|
| 17 |
+
|
| 18 |
+
import training_script as base
|
| 19 |
+
|
| 20 |
+
DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
|
| 21 |
+
DEFAULT_MAX_SEQ_LENGTH = 2048
|
| 22 |
+
DEFAULT_LORA_R = 16
|
| 23 |
+
DEFAULT_LORA_ALPHA = 16
|
| 24 |
+
DEFAULT_LORA_DROPOUT = 0.0
|
| 25 |
+
LORA_TARGET_MODULES = [
|
| 26 |
+
"q_proj",
|
| 27 |
+
"k_proj",
|
| 28 |
+
"v_proj",
|
| 29 |
+
"o_proj",
|
| 30 |
+
"gate_proj",
|
| 31 |
+
"up_proj",
|
| 32 |
+
"down_proj",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def require_unsloth():
|
| 37 |
+
try:
|
| 38 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
| 39 |
+
except ImportError as exc: # pragma: no cover - depends on optional extra
|
| 40 |
+
raise RuntimeError(
|
| 41 |
+
"Unsloth is not installed. Run `uv sync --extra train` "
|
| 42 |
+
"to install the H100/quantized training dependencies."
|
| 43 |
+
) from exc
|
| 44 |
+
return FastLanguageModel, PatchFastRL
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any):
|
| 48 |
+
for optional_key in ("fast_inference", "trust_remote_code"):
|
| 49 |
+
try:
|
| 50 |
+
return FastLanguageModel.from_pretrained(**kwargs)
|
| 51 |
+
except TypeError as exc:
|
| 52 |
+
if optional_key in kwargs and optional_key in str(exc):
|
| 53 |
+
kwargs = dict(kwargs)
|
| 54 |
+
kwargs.pop(optional_key, None)
|
| 55 |
+
continue
|
| 56 |
+
raise
|
| 57 |
+
return FastLanguageModel.from_pretrained(**kwargs)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 61 |
+
parser = base.build_argument_parser()
|
| 62 |
+
parser.description = (
|
| 63 |
+
"Train a GRPO policy with Unsloth quantized loading for faster H100 runs."
|
| 64 |
+
)
|
| 65 |
+
parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--max-seq-length",
|
| 68 |
+
type=int,
|
| 69 |
+
default=DEFAULT_MAX_SEQ_LENGTH,
|
| 70 |
+
help="Context length passed to Unsloth model loading.",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--disable-4bit",
|
| 74 |
+
action="store_true",
|
| 75 |
+
help="Disable 4-bit quantized loading and use the wider base weights.",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--disable-fast-inference",
|
| 79 |
+
action="store_true",
|
| 80 |
+
help="Disable Unsloth fast inference kernels where supported.",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--lora-r",
|
| 84 |
+
type=int,
|
| 85 |
+
default=DEFAULT_LORA_R,
|
| 86 |
+
help="LoRA rank used for the quantized GRPO policy.",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--lora-alpha",
|
| 90 |
+
type=int,
|
| 91 |
+
default=DEFAULT_LORA_ALPHA,
|
| 92 |
+
help="LoRA alpha used for the quantized GRPO policy.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--lora-dropout",
|
| 96 |
+
type=float,
|
| 97 |
+
default=DEFAULT_LORA_DROPOUT,
|
| 98 |
+
help="LoRA dropout used for the quantized GRPO policy.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--save-merged-16bit",
|
| 102 |
+
action="store_true",
|
| 103 |
+
help="Also export a merged 16-bit model after training if supported.",
|
| 104 |
+
)
|
| 105 |
+
return parser
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| 109 |
+
return build_argument_parser().parse_args(argv)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| 113 |
+
parser = build_argument_parser()
|
| 114 |
+
defaults = vars(parser.parse_args([]))
|
| 115 |
+
unknown = sorted(set(overrides) - set(defaults))
|
| 116 |
+
if unknown:
|
| 117 |
+
raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| 118 |
+
defaults.update(overrides)
|
| 119 |
+
return argparse.Namespace(**defaults)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def load_model_artifacts(
|
| 123 |
+
model_id: str,
|
| 124 |
+
*,
|
| 125 |
+
trust_remote_code: bool,
|
| 126 |
+
max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
|
| 127 |
+
load_in_4bit: bool = True,
|
| 128 |
+
fast_inference: bool = True,
|
| 129 |
+
prepare_for_inference: bool = False,
|
| 130 |
+
):
|
| 131 |
+
FastLanguageModel, _ = require_unsloth()
|
| 132 |
+
runtime = base.resolve_torch_runtime()
|
| 133 |
+
|
| 134 |
+
print(f"Loading Unsloth tokenizer+model for {model_id} ...")
|
| 135 |
+
model, tokenizer = _call_unsloth_from_pretrained(
|
| 136 |
+
FastLanguageModel,
|
| 137 |
+
model_name=model_id,
|
| 138 |
+
max_seq_length=max_seq_length,
|
| 139 |
+
dtype="auto",
|
| 140 |
+
load_in_4bit=load_in_4bit,
|
| 141 |
+
fast_inference=fast_inference,
|
| 142 |
+
trust_remote_code=trust_remote_code,
|
| 143 |
+
)
|
| 144 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 145 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 146 |
+
|
| 147 |
+
if prepare_for_inference:
|
| 148 |
+
try:
|
| 149 |
+
FastLanguageModel.for_inference(model)
|
| 150 |
+
except AttributeError:
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
device = getattr(model, "device", None)
|
| 154 |
+
if device is None:
|
| 155 |
+
try:
|
| 156 |
+
device = next(model.parameters()).device
|
| 157 |
+
except StopIteration:
|
| 158 |
+
device = runtime["device"]
|
| 159 |
+
print(f"Loaded model on device: {device}")
|
| 160 |
+
return tokenizer, model
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward:
|
| 164 |
+
"""Return the OpenEnv-compatible reward callable used by GRPO."""
|
| 165 |
+
return base.OpenEnvReward(
|
| 166 |
+
reward_backend=args.reward_backend,
|
| 167 |
+
base_url=args.base_url,
|
| 168 |
+
domain_randomise=args.domain_randomise,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
|
| 173 |
+
"""Build the OpenEnv rollout states that seed GRPO prompts."""
|
| 174 |
+
scenario_names = base.selected_scenarios(args.scenario_name)
|
| 175 |
+
examples = base.build_prompt_examples(
|
| 176 |
+
dataset_episodes=args.dataset_episodes,
|
| 177 |
+
rollout_steps=args.rollout_steps,
|
| 178 |
+
collection_policy=args.collection_policy,
|
| 179 |
+
scenario_names=scenario_names,
|
| 180 |
+
seed=args.seed,
|
| 181 |
+
domain_randomise=args.domain_randomise,
|
| 182 |
+
)
|
| 183 |
+
return {
|
| 184 |
+
"scenario_names": scenario_names,
|
| 185 |
+
"examples": examples,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def patch_unsloth_grpo():
|
| 190 |
+
"""Patch TRL GRPO to use Unsloth's optimized kernels."""
|
| 191 |
+
FastLanguageModel, PatchFastRL = require_unsloth()
|
| 192 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
| 193 |
+
return FastLanguageModel
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any:
|
| 197 |
+
"""Apply LoRA adapters in the usual Unsloth configuration style."""
|
| 198 |
+
return FastLanguageModel.get_peft_model(
|
| 199 |
+
model,
|
| 200 |
+
r=args.lora_r,
|
| 201 |
+
target_modules=LORA_TARGET_MODULES,
|
| 202 |
+
lora_alpha=args.lora_alpha,
|
| 203 |
+
lora_dropout=args.lora_dropout,
|
| 204 |
+
bias="none",
|
| 205 |
+
use_gradient_checkpointing=True,
|
| 206 |
+
random_state=args.seed,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def build_grpo_config(
|
| 211 |
+
args: argparse.Namespace,
|
| 212 |
+
runtime: Dict[str, Any],
|
| 213 |
+
):
|
| 214 |
+
from trl import GRPOConfig
|
| 215 |
+
|
| 216 |
+
return GRPOConfig(
|
| 217 |
+
output_dir=args.output_dir,
|
| 218 |
+
learning_rate=args.learning_rate,
|
| 219 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 220 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 221 |
+
num_generations=args.num_generations,
|
| 222 |
+
max_completion_length=args.max_completion_length,
|
| 223 |
+
num_train_epochs=args.num_train_epochs,
|
| 224 |
+
logging_steps=args.logging_steps,
|
| 225 |
+
save_steps=args.save_steps,
|
| 226 |
+
bf16=runtime["bf16"],
|
| 227 |
+
fp16=runtime["fp16"],
|
| 228 |
+
report_to="none",
|
| 229 |
+
remove_unused_columns=False,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def build_unsloth_grpo_trainer(
|
| 234 |
+
*,
|
| 235 |
+
model: Any,
|
| 236 |
+
tokenizer: Any,
|
| 237 |
+
reward_func: Any,
|
| 238 |
+
train_dataset: Any,
|
| 239 |
+
args: argparse.Namespace,
|
| 240 |
+
runtime: Dict[str, Any],
|
| 241 |
+
):
|
| 242 |
+
from trl import GRPOTrainer
|
| 243 |
+
|
| 244 |
+
config = build_grpo_config(args, runtime)
|
| 245 |
+
return GRPOTrainer(
|
| 246 |
+
model=model,
|
| 247 |
+
reward_funcs=reward_func,
|
| 248 |
+
args=config,
|
| 249 |
+
train_dataset=train_dataset,
|
| 250 |
+
processing_class=tokenizer,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def generate_action_with_model(
|
| 255 |
+
model: Any,
|
| 256 |
+
tokenizer: Any,
|
| 257 |
+
prompt_or_observation: str | base.ExperimentObservation,
|
| 258 |
+
*,
|
| 259 |
+
max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| 260 |
+
temperature: float = 0.2,
|
| 261 |
+
top_p: float = 0.9,
|
| 262 |
+
do_sample: bool = True,
|
| 263 |
+
) -> Dict[str, Any]:
|
| 264 |
+
import torch
|
| 265 |
+
|
| 266 |
+
if isinstance(prompt_or_observation, base.ExperimentObservation):
|
| 267 |
+
prompt = base.build_training_prompt(prompt_or_observation)
|
| 268 |
+
else:
|
| 269 |
+
prompt = str(prompt_or_observation)
|
| 270 |
+
|
| 271 |
+
model_device = getattr(model, "device", None)
|
| 272 |
+
if model_device is None:
|
| 273 |
+
try:
|
| 274 |
+
model_device = next(model.parameters()).device
|
| 275 |
+
except StopIteration:
|
| 276 |
+
model_device = base.resolve_torch_runtime()["device"]
|
| 277 |
+
|
| 278 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 279 |
+
inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| 280 |
+
prompt_tokens = inputs["input_ids"].shape[1]
|
| 281 |
+
|
| 282 |
+
generation_kwargs = {
|
| 283 |
+
"max_new_tokens": max_new_tokens,
|
| 284 |
+
"do_sample": do_sample,
|
| 285 |
+
"temperature": temperature,
|
| 286 |
+
"top_p": top_p,
|
| 287 |
+
"pad_token_id": tokenizer.pad_token_id,
|
| 288 |
+
}
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
output_ids = model.generate(**inputs, **generation_kwargs)
|
| 291 |
+
|
| 292 |
+
new_tokens = output_ids[0][prompt_tokens:]
|
| 293 |
+
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 294 |
+
action = base.parse_action_completion(response_text)
|
| 295 |
+
if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation):
|
| 296 |
+
action = base.ensure_conclusion_claims(prompt_or_observation, action)
|
| 297 |
+
return {
|
| 298 |
+
"prompt": prompt,
|
| 299 |
+
"response_text": response_text,
|
| 300 |
+
"action": action,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| 305 |
+
random.seed(args.seed)
|
| 306 |
+
runtime = base.resolve_torch_runtime()
|
| 307 |
+
|
| 308 |
+
if args.load_model_only:
|
| 309 |
+
tokenizer, model = load_model_artifacts(
|
| 310 |
+
args.model_id,
|
| 311 |
+
trust_remote_code=args.trust_remote_code,
|
| 312 |
+
max_seq_length=args.max_seq_length,
|
| 313 |
+
load_in_4bit=not args.disable_4bit,
|
| 314 |
+
fast_inference=not args.disable_fast_inference,
|
| 315 |
+
prepare_for_inference=True,
|
| 316 |
+
)
|
| 317 |
+
device = getattr(model, "device", "unknown")
|
| 318 |
+
print(f"Unsloth model ready: {args.model_id}")
|
| 319 |
+
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
| 320 |
+
print(f"Model device: {device}")
|
| 321 |
+
print(f"Runtime device name: {runtime['device_name']}")
|
| 322 |
+
return {
|
| 323 |
+
"args": args,
|
| 324 |
+
"runtime": runtime,
|
| 325 |
+
"tokenizer": tokenizer,
|
| 326 |
+
"model": model,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
prompt_data = prepare_prompt_examples(args)
|
| 330 |
+
scenario_names = prompt_data["scenario_names"]
|
| 331 |
+
examples = prompt_data["examples"]
|
| 332 |
+
env_reward = build_openenv_reward(args)
|
| 333 |
+
|
| 334 |
+
if args.dry_run:
|
| 335 |
+
base.run_dry_run_preview(examples, env_reward, args.output_dir)
|
| 336 |
+
return {
|
| 337 |
+
"args": args,
|
| 338 |
+
"runtime": runtime,
|
| 339 |
+
"scenario_names": scenario_names,
|
| 340 |
+
"examples": examples,
|
| 341 |
+
"reward_fn": env_reward,
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
from datasets import Dataset
|
| 345 |
+
|
| 346 |
+
FastLanguageModel = patch_unsloth_grpo()
|
| 347 |
+
train_dataset = Dataset.from_list(examples)
|
| 348 |
+
|
| 349 |
+
# 1. Load model with Unsloth quantized loading.
|
| 350 |
+
tokenizer, model = load_model_artifacts(
|
| 351 |
+
args.model_id,
|
| 352 |
+
trust_remote_code=args.trust_remote_code,
|
| 353 |
+
max_seq_length=args.max_seq_length,
|
| 354 |
+
load_in_4bit=not args.disable_4bit,
|
| 355 |
+
fast_inference=not args.disable_fast_inference,
|
| 356 |
+
)
|
| 357 |
+
# 2. Apply LoRA adapters.
|
| 358 |
+
model = apply_lora_adapters(FastLanguageModel, model, args)
|
| 359 |
+
|
| 360 |
+
print(
|
| 361 |
+
f"Unsloth training runtime: device={runtime['device']} "
|
| 362 |
+
f"name={runtime['device_name']} "
|
| 363 |
+
f"dtype={runtime['dtype']} "
|
| 364 |
+
f"load_in_4bit={not args.disable_4bit}"
|
| 365 |
+
)
|
| 366 |
+
print(
|
| 367 |
+
"OpenEnv reward: "
|
| 368 |
+
f"backend={args.reward_backend} scenarios={len(scenario_names)} "
|
| 369 |
+
f"examples={len(examples)}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# 3. Train with GRPO against the OpenEnv reward function.
|
| 373 |
+
trainer = build_unsloth_grpo_trainer(
|
| 374 |
+
model=model,
|
| 375 |
+
tokenizer=tokenizer,
|
| 376 |
+
reward_func=env_reward,
|
| 377 |
+
train_dataset=train_dataset,
|
| 378 |
+
args=args,
|
| 379 |
+
runtime=runtime,
|
| 380 |
+
)
|
| 381 |
+
trainer.train()
|
| 382 |
+
trainer.save_model(args.output_dir)
|
| 383 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 384 |
+
|
| 385 |
+
if args.save_merged_16bit:
|
| 386 |
+
merged_dir = Path(args.output_dir) / "merged_16bit"
|
| 387 |
+
try:
|
| 388 |
+
model.save_pretrained_merged(
|
| 389 |
+
str(merged_dir),
|
| 390 |
+
tokenizer,
|
| 391 |
+
save_method="merged_16bit",
|
| 392 |
+
)
|
| 393 |
+
print(f"Saved merged 16-bit model to {merged_dir}")
|
| 394 |
+
except AttributeError:
|
| 395 |
+
print("Merged 16-bit export is not available in this Unsloth build; skipping.")
|
| 396 |
+
|
| 397 |
+
if args.push_to_hub:
|
| 398 |
+
from huggingface_hub import HfApi
|
| 399 |
+
|
| 400 |
+
api = HfApi()
|
| 401 |
+
api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
|
| 402 |
+
print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
|
| 403 |
+
api.upload_folder(
|
| 404 |
+
folder_path=args.output_dir,
|
| 405 |
+
repo_id=args.push_to_hub,
|
| 406 |
+
repo_type="model",
|
| 407 |
+
create_pr=False,
|
| 408 |
+
)
|
| 409 |
+
print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
|
| 410 |
+
|
| 411 |
+
plot_paths = base.save_training_plots(
|
| 412 |
+
trainer.state.log_history,
|
| 413 |
+
args.output_dir,
|
| 414 |
+
metric_key=args.plot_metric_key,
|
| 415 |
+
)
|
| 416 |
+
print("Saved training plots:")
|
| 417 |
+
for plot_name, plot_path in plot_paths.items():
|
| 418 |
+
print(f" - {plot_name}: {plot_path}")
|
| 419 |
+
|
| 420 |
+
return {
|
| 421 |
+
"args": args,
|
| 422 |
+
"runtime": runtime,
|
| 423 |
+
"scenario_names": scenario_names,
|
| 424 |
+
"examples": examples,
|
| 425 |
+
"reward_fn": env_reward,
|
| 426 |
+
"train_dataset": train_dataset,
|
| 427 |
+
"tokenizer": tokenizer,
|
| 428 |
+
"model": model,
|
| 429 |
+
"trainer": trainer,
|
| 430 |
+
"plot_paths": plot_paths,
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def main() -> None:
|
| 435 |
+
run_training(parse_args())
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == "__main__":
|
| 439 |
+
main()
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|