Ev3Dev commited on
Commit
db03c40
·
verified ·
1 Parent(s): 5c3cfae

Upload folder using huggingface_hub

Browse files
H100_JUPYTER_SETUP.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # H100 Jupyter Notebook Setup
2
+
3
+ This guide walks you through setting up the OpenEnv Bio Experiment environment on an **NVIDIA H100** Jupyter notebook instance (e.g., Jupiter Labs, Lambda Labs, RunPod, or similar).
4
+
5
+ ## Prerequisites
6
+
7
+ - **Python** ≥ 3.10 (3.10, 3.11, or 3.12 recommended)
8
+ - **uv** – fast Python package manager ([install instructions](#installing-uv))
9
+ - **NVIDIA driver** ≥ 535.104.05 (usually pre-installed on H100 instances)
10
+ - **CUDA** – H100 uses CUDA 12.x; PyTorch wheels bundle the runtime, so a separate CUDA Toolkit is not required
11
+
12
+ ## Installing uv
13
+
14
+ If `uv` is not already installed:
15
+
16
+ ```bash
17
+ # Unix/Linux (including Jupiter notebook terminals)
18
+ curl -LsSf https://astral.sh/uv/install.sh | sh
19
+
20
+ # Or with pip
21
+ pip install uv
22
+ ```
23
+
24
+ Verify:
25
+
26
+ ```bash
27
+ uv --version
28
+ ```
29
+
30
+ ## Quick Setup (Recommended)
31
+
32
+ ### 1. Clone and enter the project
33
+
34
+ ```bash
35
+ git clone <repository-url> OpenENV-Hackathon
36
+ cd OpenENV-Hackathon
37
+ ```
38
+
39
+ ### 2. Use uv's auto PyTorch backend
40
+
41
+ uv can detect your GPU and pick the right PyTorch build. For H100 (CUDA 12.x):
42
+
43
+ ```bash
44
+ # Install everything: core + training (TRL, transformers, torch, unsloth) + Jupyter
45
+ UV_TORCH_BACKEND=cu128 uv sync --extra train
46
+
47
+ # Add Jupyter kernel support
48
+ uv add ipykernel jupyter --extra train
49
+ ```
50
+
51
+ If `UV_TORCH_BACKEND=cu128` fails (e.g., cu128 wheels not available yet), try:
52
+
53
+ ```bash
54
+ UV_TORCH_BACKEND=cu126 uv sync --extra train
55
+ ```
56
+
57
+ ### 3. Register the environment as a Jupyter kernel
58
+
59
+ ```bash
60
+ uv run python -m ipykernel install --user --name openenv-bio --display-name "OpenEnv Bio (H100)"
61
+ ```
62
+
63
+ ### 4. Verify CUDA
64
+
65
+ In a new Jupyter notebook, select the **"OpenEnv Bio (H100)"** kernel and run:
66
+
67
+ ```python
68
+ import torch
69
+ print(f"PyTorch: {torch.__version__}")
70
+ print(f"CUDA available: {torch.cuda.is_available()}")
71
+ print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
72
+ ```
73
+
74
+ Expected output (or similar):
75
+
76
+ ```
77
+ PyTorch: 2.x.x+cu128
78
+ CUDA available: True
79
+ GPU: NVIDIA H100 ...
80
+ ```
81
+
82
+ ### 5. Sanity check the environment
83
+
84
+ ```bash
85
+ uv run pytest tests/test_environment.py tests/test_literature_benchmark.py -q
86
+ ```
87
+
88
+ ## Manual PyTorch CUDA Configuration
89
+
90
+ If you need explicit control over the PyTorch index (e.g., for reproducibility), add the following to `pyproject.toml`:
91
+
92
+ ### Add to `pyproject.toml`
93
+
94
+ ```toml
95
+ # After [tool.uv], add:
96
+
97
+ [[tool.uv.index]]
98
+ name = "pytorch-cu128"
99
+ url = "https://download.pytorch.org/whl/cu128"
100
+ explicit = true
101
+
102
+ [tool.uv.sources]
103
+ torch = [{ index = "pytorch-cu128" }]
104
+ torchvision = [{ index = "pytorch-cu128" }]
105
+ ```
106
+
107
+ Then run:
108
+
109
+ ```bash
110
+ uv sync --extra train
111
+ uv add ipykernel jupyter --extra train
112
+ ```
113
+
114
+ For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
115
+
116
+ ## Dependency Groups
117
+
118
+ | uv sync flag | Contents |
119
+ |-------------------|--------------------------------------------------------------------------|
120
+ | *(default)* | Core: `openenv-core`, `numpy`, `scipy`, `pydantic` |
121
+ | `--extra dev` | Testing: `pytest`, `pytest-cov` |
122
+ | `--extra train` | Training: `torch`, `transformers`, `trl`, `accelerate`, `peft`, `unsloth`, etc. |
123
+ | `--extra bio` | Bioinformatics: `scanpy`, `biopython`, `gseapy` |
124
+ | `--extra train --extra dev` | Combined for development + training |
125
+
126
+ ## Preferred H100 Workflow
127
+
128
+ On H100, use the quantized Unsloth entrypoints:
129
+
130
+ ```bash
131
+ uv run python training_unsloth.py --dry-run
132
+ uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B --output-dir training/grpo-unsloth-output
133
+ uv run python run_agent_unsloth.py
134
+ ```
135
+
136
+ The checked-in `inference.ipynb` notebook now uses `training_unsloth.py` helpers with 4-bit loading and fast inference enabled by default.
137
+
138
+ ## Running Training in a Jupyter Notebook
139
+
140
+ Example cell:
141
+
142
+ ```python
143
+ # In a notebook with the OpenEnv Bio (H100) kernel
144
+ !uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B --dry-run
145
+ ```
146
+
147
+ Or run interactively from Python:
148
+
149
+ ```python
150
+ import subprocess
151
+ subprocess.run([
152
+ "uv", "run", "python", "training_unsloth.py",
153
+ "--model-id", "Qwen/Qwen3.5-4B",
154
+ "--output-dir", "training/grpo-unsloth-output",
155
+ ], check=True)
156
+ ```
157
+
158
+ ## Requirements Summary
159
+
160
+ | Component | Version / Notes |
161
+ |----------------|------------------------------------------------------|
162
+ | Python | ≥ 3.10 |
163
+ | uv | ≥ 0.5.3 (for PyTorch index support) |
164
+ | torch | ≥ 2.10.0 (cu128 or cu126 for H100) |
165
+ | transformers | ≥ 5.3.0 |
166
+ | trl | ≥ 0.29.0 |
167
+ | accelerate | ≥ 1.13.0 |
168
+ | Jupyter | Optional, for notebook workflows |
169
+
170
+ ## Troubleshooting
171
+
172
+ ### `torch.cuda.is_available()` is False
173
+
174
+ - Confirm the Jupyter kernel is the one where you ran `uv sync` (the one with `ipykernel`).
175
+ - Ensure no CPU-only PyTorch is overriding the CUDA build (e.g., from a different conda/pip env).
176
+ - Run `uv run python -c "import torch; print(torch.__file__)"` to verify PyTorch comes from your project venv.
177
+
178
+ ### Flash Attention / causal-conv fallback warnings
179
+
180
+ These are common and usually harmless; execution continues with a slower path. For best H100 performance, ensure `transformers` and `torch` are recent versions that support Flash Attention 2.
181
+
182
+ ### HuggingFace symlink warnings
183
+
184
+ Set:
185
+
186
+ ```bash
187
+ export HF_HUB_DISABLE_SYMLINKS_WARNING=1
188
+ ```
189
+
190
+ ### Out-of-memory during training
191
+
192
+ - Reduce `--num-generations` or `--rollout-steps`.
193
+ - Use a smaller model (e.g., `Qwen/Qwen3.5-0.8B`) for experiments.
194
+ - Keep `--disable-4bit` off unless you explicitly need wider weights.
195
+
196
+ ## See Also
197
+
198
+ - Main [README.md](README.md) for project overview, APIs, and usage
199
+ - [uv PyTorch guide](https://docs.astral.sh/uv/guides/integration/pytorch/) for advanced PyTorch configuration
README.md CHANGED
@@ -249,6 +249,8 @@ An episode ends when one of the following happens:
249
 
250
  Dependencies are managed with `uv`. The package requires Python ≥ 3.10.
251
 
 
 
252
  ```bash
253
  # Core environment only
254
  uv sync
@@ -342,6 +344,13 @@ The environment class supports concurrent sessions, but the bundled server is cu
342
  uv run python run_agent.py
343
  ```
344
 
 
 
 
 
 
 
 
345
  Configuration is via environment variables:
346
 
347
  | Variable | Default | Description |
@@ -367,6 +376,14 @@ uv run python training_script.py --dry-run
367
  uv run python training_script.py --model-id Qwen/Qwen3.5-0.8B
368
  ```
369
 
 
 
 
 
 
 
 
 
370
  Key arguments:
371
 
372
  | Argument | Default | Description |
@@ -381,13 +398,15 @@ Key arguments:
381
  | `--scenario-name` | all | Repeatable; restricts which scenarios are used |
382
  | `--domain-randomise` | off | Enable domain randomisation |
383
  | `--num-generations` | `4` | GRPO generations per prompt |
384
- | `--max-completion-length` | `220` | Max tokens for model completions |
385
  | `--max-prompt-length` | `768` | Max tokens for prompts |
386
  | `--learning-rate` | `5e-6` | AdamW learning rate |
387
  | `--dry-run` | off | Build data and test reward without training |
388
 
389
  By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
390
 
 
 
391
  After training, the script saves plots to the output directory:
392
 
393
  - `training_loss.png`
@@ -413,7 +432,9 @@ This runs N episodes with a `random` or `heuristic` policy, saves JSON trajector
413
  - `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
414
  - `training/rollout_collection.py` collects direct environment rollouts into trajectory files
415
  - `training_script.py` trains a GRPO policy with OpenEnv reward calls
 
416
  - `run_agent.py` runs a local language model planner against the environment
 
417
  - `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
418
  - `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
419
 
@@ -488,6 +509,7 @@ That makes it suitable for:
488
  ├── openenv.yaml # OpenEnv platform deployment config
489
  ├── pyproject.toml # Package metadata and dependency groups
490
  ├── run_agent.py # Single-episode interactive agent runner
 
491
  ├── server/
492
  │ ├── app.py # FastAPI/OpenEnv server entry point
493
  │ ├── Dockerfile # Multi-stage Docker build
@@ -512,6 +534,7 @@ That makes it suitable for:
512
  │ ├── rollout_collection.py # Direct rollout collection helper
513
  │ └── trajectory.py # Trajectory serialization and dataset utilities
514
  ├── training_script.py # TRL GRPO training entry point
 
515
  └── tests/
516
  ├── test_environment.py
517
  ├── test_literature_benchmark.py
 
249
 
250
  Dependencies are managed with `uv`. The package requires Python ≥ 3.10.
251
 
252
+ > **H100 Jupyter notebook setup:** See [H100_JUPYTER_SETUP.md](H100_JUPYTER_SETUP.md) for environment setup on NVIDIA H100 instances with Jupyter.
253
+
254
  ```bash
255
  # Core environment only
256
  uv sync
 
344
  uv run python run_agent.py
345
  ```
346
 
347
+ For H100 and other large-GPU workflows, prefer the quantized Unsloth path:
348
+
349
+ ```bash
350
+ uv sync --extra train
351
+ uv run python run_agent_unsloth.py
352
+ ```
353
+
354
  Configuration is via environment variables:
355
 
356
  | Variable | Default | Description |
 
376
  uv run python training_script.py --model-id Qwen/Qwen3.5-0.8B
377
  ```
378
 
379
+ For H100, the preferred entrypoint is `training_unsloth.py`, which uses Unsloth 4-bit loading plus LoRA for faster quantized GRPO training:
380
+
381
+ ```bash
382
+ uv sync --extra train
383
+ uv run python training_unsloth.py --dry-run
384
+ uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B
385
+ ```
386
+
387
  Key arguments:
388
 
389
  | Argument | Default | Description |
 
398
  | `--scenario-name` | all | Repeatable; restricts which scenarios are used |
399
  | `--domain-randomise` | off | Enable domain randomisation |
400
  | `--num-generations` | `4` | GRPO generations per prompt |
401
+ | `--max-completion-length` | `160` | Max tokens for model completions |
402
  | `--max-prompt-length` | `768` | Max tokens for prompts |
403
  | `--learning-rate` | `5e-6` | AdamW learning rate |
404
  | `--dry-run` | off | Build data and test reward without training |
405
 
406
  By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
407
 
408
+ `training_unsloth.py` adds H100-oriented options such as `--max-seq-length`, `--disable-4bit`, `--disable-fast-inference`, and LoRA settings (`--lora-r`, `--lora-alpha`, `--lora-dropout`).
409
+
410
  After training, the script saves plots to the output directory:
411
 
412
  - `training_loss.png`
 
432
  - `training/literature_benchmark.py` runs paper-aligned action sequences and compares outcomes against curated expected findings
433
  - `training/rollout_collection.py` collects direct environment rollouts into trajectory files
434
  - `training_script.py` trains a GRPO policy with OpenEnv reward calls
435
+ - `training_unsloth.py` trains a quantized GRPO policy with Unsloth on H100-class GPUs
436
  - `run_agent.py` runs a local language model planner against the environment
437
+ - `run_agent_unsloth.py` runs the planner with Unsloth 4-bit loading for faster inference
438
  - `training/trajectory.py` stores trajectories for offline RL, imitation learning, replay, and evaluation
439
  - `training/evaluation.py` computes online, benchmark, expert-review, and fidelity-oriented metrics
440
 
 
509
  ├── openenv.yaml # OpenEnv platform deployment config
510
  ├── pyproject.toml # Package metadata and dependency groups
511
  ├── run_agent.py # Single-episode interactive agent runner
512
+ ├── run_agent_unsloth.py # Quantized Unsloth interactive agent runner
513
  ├── server/
514
  │ ├── app.py # FastAPI/OpenEnv server entry point
515
  │ ├── Dockerfile # Multi-stage Docker build
 
534
  │ ├── rollout_collection.py # Direct rollout collection helper
535
  │ └── trajectory.py # Trajectory serialization and dataset utilities
536
  ├── training_script.py # TRL GRPO training entry point
537
+ ├── training_unsloth.py # Unsloth quantized GRPO training entry point
538
  └── tests/
539
  ├── test_environment.py
540
  ├── test_literature_benchmark.py
_dashboard_state.json CHANGED
@@ -1,40 +1,40 @@
1
  {
2
- "timestamp": 1772961529.009918,
3
- "step": 4,
4
  "episode_done": false,
5
- "cumulative_reward": 1.8781281693998408,
6
- "gen_time_s": 15.02,
7
- "model_response_raw": "{\n\"action_type\": \"prepare_library\",\n\"method\": \"\",\n\"parameters\": {\n\"cutoff_value\": \"\"\n},\n\"Justification\": \"The error 'Cannot sequence without library preparation' indicates that while sample collection was completed successfully via dry-polymerization technology ('collectible'), mature cDNA libraries have yet to be constructed from those frozen tissues. Library prep requires reverse transcription followed by amplification/cleanup using polymerases/ligation enzymes. Since all other required conditions met but materials remain unprocessed, preparing high-throughput Illumina libraries first allows",
8
  "model_thinking": "",
9
  "thinking_enabled": false,
10
  "task": {
11
- "problem_statement": "Determine the effect of anti-TNF biologic therapy on cell states in synovium tissue affected by rheumatoid arthritis.",
12
  "organism": "human",
13
- "tissue": "synovium",
14
  "modality": "scRNA-seq",
15
  "conditions": [
16
- "untreated_rheumatoid_arthritis",
17
- "anti_TNF_treated"
18
  ],
19
- "budget_limit": 80340.57829173621,
20
- "time_limit_days": 140.67464277202689
21
  },
22
  "resources": {
23
- "budget_used": 13500.0,
24
- "budget_remaining": 66840.58,
25
- "time_used_days": 11.0,
26
- "time_remaining_days": 129.7,
27
  "samples_consumed": 0,
28
- "compute_hours_used": 0.0
29
  },
30
  "pipeline_history": [
31
  {
32
  "step_index": 1,
33
  "action_type": "collect_sample",
34
  "method": null,
35
- "output_summary": "Collected 6 samples (quality=0.54)",
36
  "success": true,
37
- "quality_score": 0.541,
38
  "resource_cost": 5000.0,
39
  "time_cost_days": 7.0
40
  },
@@ -62,116 +62,259 @@
62
  "step_index": 4,
63
  "action_type": "prepare_library",
64
  "method": null,
65
- "output_summary": "Library prepared (complexity=0.72)",
66
  "success": true,
67
- "quality_score": 0.718,
68
  "resource_cost": 8000.0,
69
  "time_cost_days": 3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
71
  ],
72
  "current_action": {
73
- "action_type": "prepare_library",
74
  "method": null,
75
  "parameters": {
76
- "cutoff_value": ""
 
 
 
 
 
 
 
 
 
77
  },
78
- "justification": "The error 'Cannot sequence without library preparation' indicates that while sample collection was completed successfully via dry-polymerization technology ('collectible'), mature cDNA libraries have yet to be constructed from those frozen tissues. Library prep requires reverse transcription followed by amplification/cleanup using polymerases/ligation enzymes. Since all other required conditions met but materials remain unprocessed, preparing high-throughput Illumina libraries first allows conversion of input DNA/RNA directly into sequencable formats.",
79
- "confidence": 0.95
80
  },
81
  "latest_output": {
82
- "summary": "Library prepared (complexity=0.72)",
83
- "success": true,
84
- "quality_score": 0.718,
85
  "uncertainty": 0.0,
86
  "warnings": [],
87
- "data_preview": "{\"library_complexity\": 0.7177962559040082, \"method\": \"10x_chromium\"}"
88
  },
89
  "discovered_markers": [],
90
  "candidate_mechanisms": [],
91
- "rule_violations": [],
 
 
 
92
  "uncertainty_summary": {
93
- "avg_uncertainty": 0.0,
94
- "avg_quality": 0.815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  },
96
- "reward_breakdown": {},
97
  "latent": {
98
  "cell_populations": [
99
  {
100
- "name": "fibroblast",
101
- "proportion": 0.445,
102
- "marker_genes": [
103
- "COL1A1",
104
- "FAP",
105
- "THY1"
106
- ],
107
- "state": "activated"
108
- },
109
- {
110
- "name": "CD4_T_cell",
111
- "proportion": 0.179,
112
  "marker_genes": [
113
- "CD3D",
114
- "CD4",
115
- "IL7R"
116
  ],
117
- "state": "quiescent"
118
  },
119
  {
120
- "name": "CD8_T_cell",
121
- "proportion": 0.139,
122
  "marker_genes": [
123
- "CD3D",
124
- "CD8A",
125
- "GZMB"
126
  ],
127
- "state": "activated"
128
  },
129
  {
130
- "name": "B_cell",
131
- "proportion": 0.142,
132
  "marker_genes": [
133
- "CD19",
134
- "MS4A1",
135
- "CD79A"
136
  ],
137
- "state": "quiescent"
138
  },
139
  {
140
  "name": "endothelial",
141
- "proportion": 0.096,
142
  "marker_genes": [
143
- "PECAM1",
144
- "VWF"
 
145
  ],
146
  "state": "quiescent"
147
  }
148
  ],
149
  "true_markers": [
150
- "TNF",
151
- "IL6",
152
- "MMP3",
153
- "CXCL13"
154
  ],
155
  "causal_mechanisms": [
156
- "TNF/NF-kB-driven synovial inflammation",
157
- "Th17-mediated cartilage destruction via MMPs"
 
158
  ],
159
  "true_pathways": {
160
- "JAK_STAT_signalling": 0.785,
161
- "TNF_signalling": 0.723,
162
- "Th17_differentiation": 0.633,
163
- "NF_kB_signalling": 0.826,
164
- "matrix_metalloproteinase_activity": 0.847
165
- },
166
- "true_de_genes_count": 9,
167
- "true_regulatory_network_size": 16,
168
  "confounders": {},
169
- "n_true_cells": 15873,
170
  "technical": {
171
- "ambient_rna_fraction": 0.037873267501661645,
172
- "doublet_rate": 0.03797665930677535,
173
- "dropout_rate": 0.14738025069803395,
174
- "sample_quality": 0.9068064354870293,
175
  "library_complexity": 0.8,
176
  "capture_efficiency": 0.6
177
  },
@@ -181,24 +324,24 @@
181
  "cells_cultured": false,
182
  "library_prepared": true,
183
  "perturbation_applied": false,
184
- "cells_sequenced": false,
185
- "qc_performed": false,
186
- "data_filtered": false,
187
- "data_normalized": false,
188
  "batches_integrated": false,
189
- "cells_clustered": false,
190
  "de_performed": false,
191
  "trajectories_inferred": false,
192
  "pathways_analyzed": false,
193
  "networks_inferred": false,
194
  "markers_discovered": false,
195
  "markers_validated": false,
196
- "followup_designed": false,
197
  "subagent_review_requested": false,
198
  "conclusion_reached": false,
199
- "n_cells_sequenced": null,
200
- "n_cells_after_filter": null,
201
- "n_clusters_found": null,
202
  "n_de_genes_found": null,
203
  "n_markers_found": null
204
  },
 
1
  {
2
+ "timestamp": 1772965877.7057643,
3
+ "step": 20,
4
  "episode_done": false,
5
+ "cumulative_reward": -1.0267290925637838,
6
+ "gen_time_s": 28.83,
7
+ "model_response_raw": "{\n\"action_type\": \"marker_selection\",\n\"method\": \"\",\n\"parameters\": {\"\"diff_genes\"\": [],\"gene_lists_file\":\"DE_results.h5ad\"},\n\"Justification\": \"The 'violation' error indicates we cannot proceed directly from previous failed attempts due to incomplete filtering/normalization logic implied by the blockage condition ('without prior analysis'). However, looking strictly at the provided completed vs failing log chain where `design` was followed by `[OK] design`, but then immediate failures occurred (`synthesis_blocked`) which were resolved upon finding/deepening insights? Wait, re-evaluating strict fl",
8
  "model_thinking": "",
9
  "thinking_enabled": false,
10
  "task": {
11
+ "problem_statement": "Validate candidate biomarker TREM2 for Alzheimer's disease in brain tissue using single-cell RNA sequencing.",
12
  "organism": "human",
13
+ "tissue": "brain",
14
  "modality": "scRNA-seq",
15
  "conditions": [
16
+ "healthy",
17
+ "Alzheimer"
18
  ],
19
+ "budget_limit": 115264.15272286453,
20
+ "time_limit_days": 111.8768783273548
21
  },
22
  "resources": {
23
+ "budget_used": 28900.0,
24
+ "budget_remaining": 86364.15,
25
+ "time_used_days": 18.0,
26
+ "time_remaining_days": 93.9,
27
  "samples_consumed": 0,
28
+ "compute_hours_used": 12.0
29
  },
30
  "pipeline_history": [
31
  {
32
  "step_index": 1,
33
  "action_type": "collect_sample",
34
  "method": null,
35
+ "output_summary": "Collected 6 samples (quality=0.57)",
36
  "success": true,
37
+ "quality_score": 0.575,
38
  "resource_cost": 5000.0,
39
  "time_cost_days": 7.0
40
  },
 
62
  "step_index": 4,
63
  "action_type": "prepare_library",
64
  "method": null,
65
+ "output_summary": "Library prepared (complexity=0.77)",
66
  "success": true,
67
+ "quality_score": 0.77,
68
  "resource_cost": 8000.0,
69
  "time_cost_days": 3.0
70
+ },
71
+ {
72
+ "step_index": 5,
73
+ "action_type": "sequence_cells",
74
+ "method": null,
75
+ "output_summary": "Sequenced 12178 cells, 9490 genes detected, median UMI=2966",
76
+ "success": true,
77
+ "quality_score": 0.457,
78
+ "resource_cost": 15000.0,
79
+ "time_cost_days": 5.0
80
+ },
81
+ {
82
+ "step_index": 6,
83
+ "action_type": "filter_data",
84
+ "method": null,
85
+ "output_summary": "Action blocked: Cannot filter data before QC",
86
+ "success": false,
87
+ "quality_score": 1.0,
88
+ "resource_cost": 50.0,
89
+ "time_cost_days": 0.2
90
+ },
91
+ {
92
+ "step_index": 7,
93
+ "action_type": "run_qc",
94
+ "method": null,
95
+ "output_summary": "QC metrics computed",
96
+ "success": true,
97
+ "quality_score": 0.842,
98
+ "resource_cost": 100.0,
99
+ "time_cost_days": 0.5
100
+ },
101
+ {
102
+ "step_index": 8,
103
+ "action_type": "filter_data",
104
+ "method": null,
105
+ "output_summary": "Filtered 12178 \u2192 9967 cells (82% retained)",
106
+ "success": true,
107
+ "quality_score": 0.818,
108
+ "resource_cost": 50.0,
109
+ "time_cost_days": 0.2
110
+ },
111
+ {
112
+ "step_index": 9,
113
+ "action_type": "normalize_data",
114
+ "method": null,
115
+ "output_summary": "Normalized with log_normalize",
116
+ "success": true,
117
+ "quality_score": 1.0,
118
+ "resource_cost": 50.0,
119
+ "time_cost_days": 0.2
120
+ },
121
+ {
122
+ "step_index": 10,
123
+ "action_type": "marker_selection",
124
+ "method": null,
125
+ "output_summary": "Action blocked: Cannot select markers without DE results",
126
+ "success": false,
127
+ "quality_score": 1.0,
128
+ "resource_cost": 100.0,
129
+ "time_cost_days": 0.5
130
+ },
131
+ {
132
+ "step_index": 11,
133
+ "action_type": "design_followup_experiment",
134
+ "method": null,
135
+ "output_summary": "Action blocked: Follow-up design without prior analysis is blocked; complete wet-lab and computational steps first",
136
+ "success": false,
137
+ "quality_score": 1.0,
138
+ "resource_cost": 100.0,
139
+ "time_cost_days": 0.5
140
+ },
141
+ {
142
+ "step_index": 12,
143
+ "action_type": "cluster_cells",
144
+ "method": null,
145
+ "output_summary": "Found 3 clusters",
146
+ "success": true,
147
+ "quality_score": 0.765,
148
+ "resource_cost": 100.0,
149
+ "time_cost_days": 0.5
150
+ },
151
+ {
152
+ "step_index": 13,
153
+ "action_type": "design_followup_experiment",
154
+ "method": null,
155
+ "output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
156
+ "success": true,
157
+ "quality_score": 0.28,
158
+ "resource_cost": 100.0,
159
+ "time_cost_days": 0.5
160
+ },
161
+ {
162
+ "step_index": 14,
163
+ "action_type": "synthesize_conclusion",
164
+ "method": null,
165
+ "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
166
+ "success": false,
167
+ "quality_score": 1.0,
168
+ "resource_cost": 0.0,
169
+ "time_cost_days": 0.5
170
+ },
171
+ {
172
+ "step_index": 15,
173
+ "action_type": "marker_selection",
174
+ "method": null,
175
+ "output_summary": "Action blocked: Cannot select markers without DE results",
176
+ "success": false,
177
+ "quality_score": 1.0,
178
+ "resource_cost": 100.0,
179
+ "time_cost_days": 0.5
180
+ },
181
+ {
182
+ "step_index": 16,
183
+ "action_type": "synthesize_conclusion",
184
+ "method": null,
185
+ "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
186
+ "success": false,
187
+ "quality_score": 1.0,
188
+ "resource_cost": 0.0,
189
+ "time_cost_days": 0.5
190
  }
191
  ],
192
  "current_action": {
193
+ "action_type": "synthesize_conclusion",
194
  "method": null,
195
  "parameters": {
196
+ "claims": [
197
+ {
198
+ "top_markers": [],
199
+ "causal_mechanisms": [],
200
+ "predicted_pathways": {},
201
+ "confidence": 0.5,
202
+ "claim_type": "correlational",
203
+ "claim": "forced terminal conclusion"
204
+ }
205
+ ]
206
  },
207
+ "justification": "forced terminal conclusion",
208
+ "confidence": 0.5
209
  },
210
  "latest_output": {
211
+ "summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion before inferring pathways or mechanisms",
212
+ "success": false,
213
+ "quality_score": 1.0,
214
  "uncertainty": 0.0,
215
  "warnings": [],
216
+ "data_preview": null
217
  },
218
  "discovered_markers": [],
219
  "candidate_mechanisms": [],
220
+ "rule_violations": [
221
+ "Cannot synthesise conclusion before discovering or validating markers",
222
+ "Cannot synthesise conclusion before inferring pathways or mechanisms"
223
+ ],
224
  "uncertainty_summary": {
225
+ "avg_uncertainty": 0.194,
226
+ "avg_quality": 0.809
227
+ },
228
+ "reward_breakdown": {
229
+ "validity": -1.0,
230
+ "ordering": 0.0,
231
+ "info_gain": 0.0,
232
+ "efficiency": 0.0,
233
+ "novelty": 0.0,
234
+ "penalty": -1.0,
235
+ "shaping": 0.0,
236
+ "terminal": 0.0,
237
+ "total": -2.0,
238
+ "hard_violations": 2.0,
239
+ "term_validity": 0.0,
240
+ "term_ordering": 0.0,
241
+ "term_info_gain": 0.0,
242
+ "term_efficiency": 0.0,
243
+ "term_novelty": 0.0,
244
+ "term_penalty": 0.0,
245
+ "term_shaping": 0.0,
246
+ "term_terminal": 0.0,
247
+ "term_total": 0.0
248
  },
 
249
  "latent": {
250
  "cell_populations": [
251
  {
252
+ "name": "excitatory_neuron",
253
+ "proportion": 0.425,
 
 
 
 
 
 
 
 
 
 
254
  "marker_genes": [
255
+ "SLC17A7",
256
+ "CAMK2A",
257
+ "NRGN"
258
  ],
259
+ "state": "stressed"
260
  },
261
  {
262
+ "name": "inhibitory_neuron",
263
+ "proportion": 0.346,
264
  "marker_genes": [
265
+ "GAD1",
266
+ "GAD2",
267
+ "SLC32A1"
268
  ],
269
+ "state": "normal"
270
  },
271
  {
272
+ "name": "OPC",
273
+ "proportion": 0.093,
274
  "marker_genes": [
275
+ "PDGFRA",
276
+ "CSPG4",
277
+ "OLIG2"
278
  ],
279
+ "state": "progenitor"
280
  },
281
  {
282
  "name": "endothelial",
283
+ "proportion": 0.137,
284
  "marker_genes": [
285
+ "CLDN5",
286
+ "FLT1",
287
+ "PECAM1"
288
  ],
289
  "state": "quiescent"
290
  }
291
  ],
292
  "true_markers": [
293
+ "TREM2",
294
+ "APOE",
295
+ "GFAP"
 
296
  ],
297
  "causal_mechanisms": [
298
+ "TREM2-mediated microglial activation in amyloid clearance",
299
+ "complement-driven synaptic pruning",
300
+ "reactive astrogliosis amplifying neuroinflammation"
301
  ],
302
  "true_pathways": {
303
+ "complement_cascade": 0.827,
304
+ "neuroinflammation": 0.666,
305
+ "amyloid_processing": 0.733,
306
+ "synaptic_signalling": 0.438,
307
+ "lipid_metabolism": 0.616
308
+ },
309
+ "true_de_genes_count": 10,
310
+ "true_regulatory_network_size": 0,
311
  "confounders": {},
312
+ "n_true_cells": 20321,
313
  "technical": {
314
+ "ambient_rna_fraction": 0.050723618495539344,
315
+ "doublet_rate": 0.0546771548836933,
316
+ "dropout_rate": 0.05122168063297322,
317
+ "sample_quality": 0.937985596833521,
318
  "library_complexity": 0.8,
319
  "capture_efficiency": 0.6
320
  },
 
324
  "cells_cultured": false,
325
  "library_prepared": true,
326
  "perturbation_applied": false,
327
+ "cells_sequenced": true,
328
+ "qc_performed": true,
329
+ "data_filtered": true,
330
+ "data_normalized": true,
331
  "batches_integrated": false,
332
+ "cells_clustered": true,
333
  "de_performed": false,
334
  "trajectories_inferred": false,
335
  "pathways_analyzed": false,
336
  "networks_inferred": false,
337
  "markers_discovered": false,
338
  "markers_validated": false,
339
+ "followup_designed": true,
340
  "subagent_review_requested": false,
341
  "conclusion_reached": false,
342
+ "n_cells_sequenced": 12178,
343
+ "n_cells_after_filter": 9967,
344
+ "n_clusters_found": "3",
345
  "n_de_genes_found": null,
346
  "n_markers_found": null
347
  },
dashboard.html CHANGED
@@ -304,6 +304,20 @@ function esc(s) { if (s == null) return '—'; const d = document.createElement(
304
  function pct(used, total) { if (!total) return 0; return Math.min(100, Math.max(0, (used / total) * 100)); }
305
  function gaugeColor(p) { return p < 50 ? 'var(--green)' : p < 80 ? 'var(--amber)' : 'var(--red)'; }
306
  function fmt(n) { if (n == null) return '0'; return Number(n).toLocaleString('en-US', { maximumFractionDigits: 0 }); }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  function gauge(label, value, pctVal, inv) {
308
  let bar = '';
309
  if (pctVal != null) { const c = inv ? gaugeColor(100-pctVal) : gaugeColor(pctVal); bar = `<div class="gauge-bar"><div class="gauge-bar-fill" style="width:${pctVal.toFixed(1)}%;background:${c}"></div></div>`; }
@@ -360,7 +374,10 @@ function showReport() {
360
  const conc = s.conclusions || [];
361
  const trueM = lat.true_markers || [];
362
  const trueMech = lat.causal_mechanisms || [];
363
- const agentM = s.discovered_markers || [];
 
 
 
364
  const markerHits = agentM.filter(m => trueM.some(t => t.toUpperCase() === m.toUpperCase()));
365
  const r = s.resources || {};
366
 
@@ -393,7 +410,7 @@ function showReport() {
393
  html += `<div class="report-section"><h3>Ground Truth Comparison</h3>
394
  <div class="comparison-row"><div class="comparison-col"><h4>Agent's Markers</h4><div class="tag-list">${comparedTags(agentM, trueM, 'green')}</div></div>
395
  <div class="comparison-col"><h4>True Markers</h4><div class="tag-list">${tagsHTML(trueM,'green')}</div></div></div>
396
- <div class="comparison-row"><div class="comparison-col"><h4>Agent's Mechanisms</h4><div class="tag-list">${comparedTags(s.candidate_mechanisms, trueMech, 'pink')}</div></div>
397
  <div class="comparison-col"><h4>True Mechanisms</h4><div class="tag-list">${tagsHTML(trueMech,'pink')}</div></div></div>
398
  </div>`;
399
 
@@ -473,12 +490,16 @@ function renderState(s) {
473
  // Ground truth comparison (visible when done or has conclusions)
474
  const lat = s.latent;
475
  if ((s.episode_done || conc.length) && lat) {
 
 
 
 
476
  $('card-gt-comparison').style.display = '';
477
- setHTML('gt-agent-markers', comparedTags(s.discovered_markers, lat.true_markers, 'green'));
478
  setHTML('gt-true-markers', tagsHTML(lat.true_markers, 'green'));
479
- setHTML('gt-agent-mechs', comparedTags(s.candidate_mechanisms, lat.causal_mechanisms, 'pink'));
480
  setHTML('gt-true-mechs', tagsHTML(lat.causal_mechanisms, 'pink'));
481
- const hits = (s.discovered_markers||[]).filter(m => (lat.true_markers||[]).some(t => t.toUpperCase()===m.toUpperCase()));
482
  $('gt-score').innerHTML = `Marker accuracy: <span style="color:var(--accent)">${hits.length}</span> / ${(lat.true_markers||[]).length} true markers recovered`;
483
  } else { $('card-gt-comparison').style.display = 'none'; }
484
 
 
304
  function pct(used, total) { if (!total) return 0; return Math.min(100, Math.max(0, (used / total) * 100)); }
305
  function gaugeColor(p) { return p < 50 ? 'var(--green)' : p < 80 ? 'var(--amber)' : 'var(--red)'; }
306
  function fmt(n) { if (n == null) return '0'; return Number(n).toLocaleString('en-US', { maximumFractionDigits: 0 }); }
307
+ function uniqueItems(arr) {
308
+ const out = [];
309
+ const seen = new Set();
310
+ (arr || []).forEach(item => {
311
+ if (item == null) return;
312
+ const text = String(item).trim();
313
+ if (!text) return;
314
+ const key = text.toUpperCase();
315
+ if (seen.has(key)) return;
316
+ seen.add(key);
317
+ out.push(text);
318
+ });
319
+ return out;
320
+ }
321
  function gauge(label, value, pctVal, inv) {
322
  let bar = '';
323
  if (pctVal != null) { const c = inv ? gaugeColor(100-pctVal) : gaugeColor(pctVal); bar = `<div class="gauge-bar"><div class="gauge-bar-fill" style="width:${pctVal.toFixed(1)}%;background:${c}"></div></div>`; }
 
374
  const conc = s.conclusions || [];
375
  const trueM = lat.true_markers || [];
376
  const trueMech = lat.causal_mechanisms || [];
377
+ const conclusionMarkers = uniqueItems(conc.flatMap(c => c.top_markers || []));
378
+ const conclusionMechanisms = uniqueItems(conc.flatMap(c => c.causal_mechanisms || []));
379
+ const agentM = uniqueItems((s.discovered_markers && s.discovered_markers.length) ? s.discovered_markers : conclusionMarkers);
380
+ const agentMechanisms = uniqueItems((s.candidate_mechanisms && s.candidate_mechanisms.length) ? s.candidate_mechanisms : conclusionMechanisms);
381
  const markerHits = agentM.filter(m => trueM.some(t => t.toUpperCase() === m.toUpperCase()));
382
  const r = s.resources || {};
383
 
 
410
  html += `<div class="report-section"><h3>Ground Truth Comparison</h3>
411
  <div class="comparison-row"><div class="comparison-col"><h4>Agent's Markers</h4><div class="tag-list">${comparedTags(agentM, trueM, 'green')}</div></div>
412
  <div class="comparison-col"><h4>True Markers</h4><div class="tag-list">${tagsHTML(trueM,'green')}</div></div></div>
413
+ <div class="comparison-row"><div class="comparison-col"><h4>Agent's Mechanisms</h4><div class="tag-list">${comparedTags(agentMechanisms, trueMech, 'pink')}</div></div>
414
  <div class="comparison-col"><h4>True Mechanisms</h4><div class="tag-list">${tagsHTML(trueMech,'pink')}</div></div></div>
415
  </div>`;
416
 
 
490
  // Ground truth comparison (visible when done or has conclusions)
491
  const lat = s.latent;
492
  if ((s.episode_done || conc.length) && lat) {
493
+ const conclusionMarkers = uniqueItems(conc.flatMap(c => c.top_markers || []));
494
+ const conclusionMechanisms = uniqueItems(conc.flatMap(c => c.causal_mechanisms || []));
495
+ const comparisonMarkers = uniqueItems((s.discovered_markers && s.discovered_markers.length) ? s.discovered_markers : conclusionMarkers);
496
+ const comparisonMechanisms = uniqueItems((s.candidate_mechanisms && s.candidate_mechanisms.length) ? s.candidate_mechanisms : conclusionMechanisms);
497
  $('card-gt-comparison').style.display = '';
498
+ setHTML('gt-agent-markers', comparedTags(comparisonMarkers, lat.true_markers, 'green'));
499
  setHTML('gt-true-markers', tagsHTML(lat.true_markers, 'green'));
500
+ setHTML('gt-agent-mechs', comparedTags(comparisonMechanisms, lat.causal_mechanisms, 'pink'));
501
  setHTML('gt-true-mechs', tagsHTML(lat.causal_mechanisms, 'pink'));
502
+ const hits = comparisonMarkers.filter(m => (lat.true_markers||[]).some(t => t.toUpperCase()===m.toUpperCase()));
503
  $('gt-score').innerHTML = `Marker accuracy: <span style="color:var(--accent)">${hits.length}</span> / ${(lat.true_markers||[]).length} true markers recovered`;
504
  } else { $('card-gt-comparison').style.display = 'none'; }
505
 
debug-904eee.log ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "n studies. Sequencing ensures uniform coverage which aids robustness during subsequent filtering.\",\n\"Confidence\": 0.75\n}", "repaired_tail": "n studies. Sequencing ensures uniform coverage which aids robustness during subsequent filtering.\",\n\"Confidence\": 0.75\n}", "hypothesisId": "H1"}, "timestamp": 1772961980164}
2
+ {"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "ssues common in multi-condition studies. Sequencing ensures uniform coverage which aids robustness during subsequent filtering.\",\n\"Confidence\": 0.75\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772961980165}
3
+ {"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "ia unsupervised learning using dimension reduction techniques like PCA/UMAP from normalized data.\",\n\"Confidence\": 0.95\n}", "repaired_tail": "a unsupervised learning using dimension reduction techniques like PCA/UMAP from normalized data.\",\n\"Confidence\": 0.95\n}\"", "hypothesisId": "H1"}, "timestamp": 1772962029639}
4
+ {"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "inct cell types/states first via unsupervised learning using dimension reduction techniques like PCA/UMAP from normalized data.\",\n\"Confidence\": 0.95\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962029640}
5
+ {"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "usters via downsampling/filtering based on read depth and genome coverage thresholds is mandatory.\",\n\"Confidence\": 0.9\n}", "repaired_tail": "usters via downsampling/filtering based on read depth and genome coverage thresholds is mandatory.\",\n\"Confidence\": 0.9\n}", "hypothesisId": "H1"}, "timestamp": 1772962075204}
6
+ {"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "ighly degraded/contaminated clusters via downsampling/filtering based on read depth and genome coverage thresholds is mandatory.\",\n\"Confidence\": 0.9\n}", "repair_returned": false, "last_json_err": "Expecting value: line 2 column 36 (char 37)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962075205}
7
+ {"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "factors rather than dynamic processes like maturation pathways required for trajectory inference.\",\n\"Confidence\": 0.95\n}", "repaired_tail": "factors rather than dynamic processes like maturation pathways required for trajectory inference.\",\n\"Confidence\": 0.95\n}", "hypothesisId": "H1"}, "timestamp": 1772962101951}
8
+ {"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "identify static transcription factors rather than dynamic processes like maturation pathways required for trajectory inference.\",\n\"Confidence\": 0.95\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962101951}
9
+ {"sessionId": "904eee", "message": "repair_failed", "data": {"input_tail": "s unique to intermediate progenitors versus terminally differentiated cells within those clusters.\",\n\"Confidence\": 0.9\n}", "repaired_tail": "s unique to intermediate progenitors versus terminally differentiated cells within those clusters.\",\n\"Confidence\": 0.9\n}", "hypothesisId": "H1"}, "timestamp": 1772962135915}
10
+ {"sessionId": "904eee", "message": "extract_failed", "data": {"normalized_tail": "ndance transcriptional features unique to intermediate progenitors versus terminally differentiated cells within those clusters.\",\n\"Confidence\": 0.9\n}", "repair_returned": false, "last_json_err": "Expecting ':' delimiter: line 1 column 4 (char 3)", "has_python_none": false, "hypothesisId": "H4"}, "timestamp": 1772962135916}
inference.ipynb ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Self-Driving Lab Inference on H100 With Unsloth\n",
8
+ "\n",
9
+ "This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
10
+ ],
11
+ "id": "a9d34036"
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "metadata": {},
16
+ "source": [
17
+ "%pip install -q -U torch transformers unsloth"
18
+ ],
19
+ "execution_count": null,
20
+ "outputs": [],
21
+ "id": "20b36e01"
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "metadata": {},
26
+ "source": [
27
+ "import json\n",
28
+ "\n",
29
+ "import torch\n",
30
+ "\n",
31
+ "from training_script import format_observation\n",
32
+ "from training_unsloth import generate_action_with_model, load_model_artifacts\n",
33
+ "from server.hackathon_environment import BioExperimentEnvironment\n",
34
+ "\n",
35
+ "print(\"CUDA available:\", torch.cuda.is_available())\n",
36
+ "if torch.cuda.is_available():\n",
37
+ " print(\"GPU:\", torch.cuda.get_device_name(0))\n",
38
+ " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
39
+ ],
40
+ "execution_count": null,
41
+ "outputs": [],
42
+ "id": "bcf24a2e"
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "metadata": {},
47
+ "source": [
48
+ "MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n",
49
+ "SCENARIO_NAME = \"cardiac_disease_de\"\n",
50
+ "SEED = 42\n",
51
+ "\n",
52
+ "tokenizer, model = load_model_artifacts(\n",
53
+ " MODEL_PATH,\n",
54
+ " trust_remote_code=True,\n",
55
+ " max_seq_length=2048,\n",
56
+ " load_in_4bit=True,\n",
57
+ " fast_inference=True,\n",
58
+ " prepare_for_inference=True,\n",
59
+ ")\n",
60
+ "\n",
61
+ "env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
62
+ "obs = env.reset(seed=SEED)\n",
63
+ "print(format_observation(obs)[:3000])"
64
+ ],
65
+ "execution_count": null,
66
+ "outputs": [],
67
+ "id": "c54f2cfd"
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "metadata": {},
72
+ "source": [
73
+ "result = generate_action_with_model(\n",
74
+ " model,\n",
75
+ " tokenizer,\n",
76
+ " obs,\n",
77
+ " max_new_tokens=160,\n",
78
+ " temperature=0.2,\n",
79
+ " top_p=0.9,\n",
80
+ " do_sample=True,\n",
81
+ ")\n",
82
+ "\n",
83
+ "print(\"Model response:\\n\")\n",
84
+ "print(result[\"response_text\"])\n",
85
+ "print(\"\\nParsed action:\\n\")\n",
86
+ "result[\"action\"].model_dump() if result[\"action\"] is not None else None"
87
+ ],
88
+ "execution_count": null,
89
+ "outputs": [],
90
+ "id": "f9b25208"
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "metadata": {},
95
+ "source": [
96
+ "if result[\"action\"] is not None:\n",
97
+ " next_obs = env.step(result[\"action\"])\n",
98
+ " print(\"Reward:\", next_obs.reward)\n",
99
+ " print(\"Done:\", next_obs.done)\n",
100
+ " print(\"Violations:\", next_obs.rule_violations)\n",
101
+ " print(\"Markers:\", next_obs.discovered_markers[:5])\n",
102
+ " print(\"Mechanisms:\", next_obs.candidate_mechanisms[:5])\n",
103
+ " if next_obs.latest_output is not None:\n",
104
+ " print(\"Summary:\", next_obs.latest_output.summary)\n",
105
+ " print(\"Latest data preview:\")\n",
106
+ " print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
107
+ "else:\n",
108
+ " print(\"Model output did not parse into an ExperimentAction.\")"
109
+ ],
110
+ "execution_count": null,
111
+ "outputs": [],
112
+ "id": "c2408f52"
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "metadata": {},
117
+ "source": [
118
+ "# Optional short closed-loop rollout.\n",
119
+ "obs = env.reset(seed=7)\n",
120
+ "trajectory = []\n",
121
+ "\n",
122
+ "for step_idx in range(8):\n",
123
+ " result = generate_action_with_model(model, tokenizer, obs, max_new_tokens=160)\n",
124
+ " action = result[\"action\"]\n",
125
+ " record = {\n",
126
+ " \"step\": step_idx + 1,\n",
127
+ " \"response_text\": result[\"response_text\"],\n",
128
+ " \"action\": action.model_dump() if action is not None else None,\n",
129
+ " }\n",
130
+ " trajectory.append(record)\n",
131
+ " if action is None:\n",
132
+ " break\n",
133
+ "\n",
134
+ " next_obs = env.step(action)\n",
135
+ " record.update({\n",
136
+ " \"reward\": next_obs.reward,\n",
137
+ " \"done\": next_obs.done,\n",
138
+ " \"violations\": list(next_obs.rule_violations),\n",
139
+ " \"latest_summary\": next_obs.latest_output.summary if next_obs.latest_output is not None else None,\n",
140
+ " \"discovered_markers\": list(next_obs.discovered_markers[:5]),\n",
141
+ " \"candidate_mechanisms\": list(next_obs.candidate_mechanisms[:5]),\n",
142
+ " })\n",
143
+ " obs = next_obs\n",
144
+ " if obs.done:\n",
145
+ " break\n",
146
+ "\n",
147
+ "trajectory"
148
+ ],
149
+ "execution_count": null,
150
+ "outputs": [],
151
+ "id": "8af34f32"
152
+ }
153
+ ],
154
+ "metadata": {
155
+ "kernelspec": {
156
+ "display_name": "Python 3",
157
+ "language": "python",
158
+ "name": "python3"
159
+ },
160
+ "language_info": {
161
+ "name": "python"
162
+ }
163
+ },
164
+ "nbformat": 4,
165
+ "nbformat_minor": 5
166
+ }
pyproject.toml CHANGED
@@ -39,6 +39,7 @@ train = [
39
  "torch>=2.10.0",
40
  "transformers>=5.3.0",
41
  "trl>=0.29.0",
 
42
  ]
43
 
44
  [project.scripts]
 
39
  "torch>=2.10.0",
40
  "transformers>=5.3.0",
41
  "trl>=0.29.0",
42
+ "unsloth",
43
  ]
44
 
45
  [project.scripts]
run_agent.py CHANGED
@@ -16,6 +16,7 @@ from models import (
16
  ActionType,
17
  ExperimentAction,
18
  ExperimentObservation,
 
19
  build_agent_observation_context,
20
  build_agent_system_prompt,
21
  )
@@ -205,6 +206,13 @@ def _strip_js_comments(text: str) -> str:
205
 
206
  def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
207
  stripped = _normalize_jsonish_text(text).strip()
 
 
 
 
 
 
 
208
  fence_prefix = "```"
209
  if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
210
  lines = stripped.splitlines()
@@ -226,6 +234,7 @@ def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
226
  break
227
  start = stripped.find("{", start + 1)
228
 
 
229
  first_brace = stripped.find("{")
230
  if first_brace != -1:
231
  repaired = _repair_truncated_json(stripped[first_brace:])
@@ -365,34 +374,34 @@ def parse_action(text: str) -> Optional[ExperimentAction]:
365
  if d is not None:
366
  action_type = normalize_action_type(get_payload_value(d, "action_type"))
367
  if action_type is None:
368
- return None
369
-
370
- parameters = get_payload_value(d, "parameters", "params") or {}
371
- if not isinstance(parameters, dict):
372
- parameters = {}
373
 
374
- confidence = get_payload_value(d, "confidence")
375
- if confidence is None:
376
- confidence = 0.5
377
- try:
378
- confidence = float(confidence)
379
- except (TypeError, ValueError):
380
- confidence = 0.5
381
 
382
- justification = get_payload_value(
383
- d, "justification", "reasoning", "rationale", "reason"
384
- )
385
- if justification is not None and not isinstance(justification, str):
386
- justification = compact_preview(justification, 200)
387
- method = normalize_optional_string(get_payload_value(d, "method"))
388
-
389
- return ExperimentAction(
390
- action_type=ActionType(action_type),
391
- method=method,
392
- parameters=parameters,
393
- justification=justification,
394
- confidence=min(1.0, max(0.0, confidence)),
395
- )
396
 
397
  action_match = re.search(
398
  r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
@@ -472,6 +481,107 @@ def should_force_terminal_conclusion(
472
  )
473
 
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
  def write_dashboard_state(
477
  env: BioExperimentEnvironment,
@@ -888,6 +998,8 @@ def main():
888
  confidence=action.confidence,
889
  )
890
 
 
 
891
  log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
892
  if thinking:
893
  log(f" Thinking: {thinking[:200]}")
 
16
  ActionType,
17
  ExperimentAction,
18
  ExperimentObservation,
19
+ OutputType,
20
  build_agent_observation_context,
21
  build_agent_system_prompt,
22
  )
 
206
 
207
  def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
208
  stripped = _normalize_jsonish_text(text).strip()
209
+ if stripped.startswith('"') and stripped.endswith('"'):
210
+ try:
211
+ unwrapped = json.loads(stripped)
212
+ except json.JSONDecodeError:
213
+ unwrapped = None
214
+ if isinstance(unwrapped, str):
215
+ stripped = _normalize_jsonish_text(unwrapped).strip()
216
  fence_prefix = "```"
217
  if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
218
  lines = stripped.splitlines()
 
234
  break
235
  start = stripped.find("{", start + 1)
236
 
237
+ repaired = None
238
  first_brace = stripped.find("{")
239
  if first_brace != -1:
240
  repaired = _repair_truncated_json(stripped[first_brace:])
 
374
  if d is not None:
375
  action_type = normalize_action_type(get_payload_value(d, "action_type"))
376
  if action_type is None:
377
+ pass
378
+ else:
379
+ parameters = get_payload_value(d, "parameters", "params") or {}
380
+ if not isinstance(parameters, dict):
381
+ parameters = {}
382
 
383
+ confidence = get_payload_value(d, "confidence")
384
+ if confidence is None:
385
+ confidence = 0.5
386
+ try:
387
+ confidence = float(confidence)
388
+ except (TypeError, ValueError):
389
+ confidence = 0.5
390
 
391
+ justification = get_payload_value(
392
+ d, "justification", "justifyement", "reasoning", "rationale", "reason"
393
+ )
394
+ if justification is not None and not isinstance(justification, str):
395
+ justification = compact_preview(justification, 200)
396
+ method = normalize_optional_string(get_payload_value(d, "method"))
397
+
398
+ return ExperimentAction(
399
+ action_type=ActionType(action_type),
400
+ method=method,
401
+ parameters=parameters,
402
+ justification=justification,
403
+ confidence=min(1.0, max(0.0, confidence)),
404
+ )
405
 
406
  action_match = re.search(
407
  r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
 
481
  )
482
 
483
 
484
+ def _unique_nonempty(items: List[str], limit: int = 5) -> List[str]:
485
+ seen: set[str] = set()
486
+ result: List[str] = []
487
+ for raw in items:
488
+ value = normalize_optional_string(raw)
489
+ if not value:
490
+ continue
491
+ key = value.upper()
492
+ if key in seen:
493
+ continue
494
+ seen.add(key)
495
+ result.append(value)
496
+ if len(result) >= limit:
497
+ break
498
+ return result
499
+
500
+
501
+ def _infer_conclusion_evidence(
502
+ obs: ExperimentObservation,
503
+ ) -> tuple[List[str], List[str], Dict[str, float]]:
504
+ top_markers = _unique_nonempty(list(obs.discovered_markers), limit=5)
505
+ causal_mechanisms = _unique_nonempty(list(obs.candidate_mechanisms), limit=5)
506
+ predicted_pathways: Dict[str, float] = {}
507
+
508
+ for output in reversed(obs.all_outputs):
509
+ if not output.success:
510
+ continue
511
+
512
+ data = output.data or {}
513
+ if not top_markers:
514
+ if output.output_type == OutputType.MARKER_RESULT:
515
+ top_markers = _unique_nonempty(list(data.get("markers", [])), limit=5)
516
+ elif output.output_type == OutputType.DE_RESULT:
517
+ top_markers = _unique_nonempty(
518
+ [item.get("gene") for item in data.get("top_genes", []) if isinstance(item, dict)],
519
+ limit=5,
520
+ )
521
+
522
+ if output.output_type == OutputType.PATHWAY_RESULT and not predicted_pathways:
523
+ for item in data.get("top_pathways", []):
524
+ if not isinstance(item, dict):
525
+ continue
526
+ pathway = normalize_optional_string(item.get("pathway"))
527
+ score = item.get("score")
528
+ if pathway and isinstance(score, (int, float)):
529
+ predicted_pathways[pathway] = float(score)
530
+ if len(predicted_pathways) >= 5:
531
+ break
532
+
533
+ if not causal_mechanisms:
534
+ if output.output_type == OutputType.PATHWAY_RESULT:
535
+ causal_mechanisms = _unique_nonempty(
536
+ [item.get("pathway") for item in data.get("top_pathways", []) if isinstance(item, dict)],
537
+ limit=5,
538
+ )
539
+ elif output.output_type == OutputType.NETWORK_RESULT:
540
+ causal_mechanisms = _unique_nonempty(
541
+ list(data.get("top_regulators", [])),
542
+ limit=5,
543
+ )
544
+
545
+ if top_markers and causal_mechanisms and predicted_pathways:
546
+ break
547
+
548
+ return top_markers, causal_mechanisms, predicted_pathways
549
+
550
+
551
+ def ensure_conclusion_claims(
552
+ obs: ExperimentObservation,
553
+ action: ExperimentAction,
554
+ ) -> ExperimentAction:
555
+ if action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
556
+ return action
557
+
558
+ parameters = dict(action.parameters or {})
559
+ raw_claims = parameters.get("claims")
560
+ if isinstance(raw_claims, list) and raw_claims:
561
+ normalized_claims = [claim for claim in raw_claims if isinstance(claim, dict)]
562
+ if normalized_claims:
563
+ parameters["claims"] = normalized_claims
564
+ if parameters != action.parameters:
565
+ return action.model_copy(update={"parameters": parameters})
566
+ return action
567
+
568
+ top_markers, causal_mechanisms, predicted_pathways = _infer_conclusion_evidence(obs)
569
+ claim_type = "causal" if causal_mechanisms else "correlational"
570
+ conditions = " vs ".join(obs.task.conditions[:2]) if obs.task.conditions else "the task conditions"
571
+ claim = action.justification or f"Final synthesis for {conditions}."
572
+
573
+ parameters["claims"] = [{
574
+ "top_markers": top_markers,
575
+ "causal_mechanisms": causal_mechanisms,
576
+ "predicted_pathways": predicted_pathways,
577
+ "confidence": action.confidence,
578
+ "claim_type": claim_type,
579
+ "claim": claim,
580
+ }]
581
+ if not action.justification:
582
+ action = action.model_copy(update={"justification": claim})
583
+ return action.model_copy(update={"parameters": parameters})
584
+
585
 
586
  def write_dashboard_state(
587
  env: BioExperimentEnvironment,
 
998
  confidence=action.confidence,
999
  )
1000
 
1001
+ action = ensure_conclusion_claims(obs, action)
1002
+
1003
  log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
1004
  if thinking:
1005
  log(f" Thinking: {thinking[:200]}")
run_agent_unsloth.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the bio-experiment environment with a quantized Unsloth model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import time
8
+ from typing import Any, Dict, Optional
9
+
10
+ from models import ActionType, ExperimentAction
11
+ from server.hackathon_environment import BioExperimentEnvironment
12
+ from training_unsloth import (
13
+ DEFAULT_MAX_SEQ_LENGTH,
14
+ generate_action_with_model,
15
+ load_model_artifacts,
16
+ )
17
+ from training_script import DEFAULT_COMPLETION_TOKEN_BUDGET
18
+
19
+ import run_agent as base
20
+
21
+ MODEL_ID = os.getenv("RUN_AGENT_UNSLOTH_MODEL_ID", "unsloth/Qwen3.5-2B-GGUF")
22
+ MAX_EPISODE_STEPS = int(
23
+ os.getenv("RUN_AGENT_UNSLOTH_MAX_EPISODE_STEPS", str(base.MAX_EPISODE_STEPS))
24
+ )
25
+ MAX_NEW_TOKENS = int(
26
+ os.getenv(
27
+ "RUN_AGENT_UNSLOTH_MAX_NEW_TOKENS",
28
+ str(DEFAULT_COMPLETION_TOKEN_BUDGET),
29
+ )
30
+ )
31
+ MAX_SEQ_LENGTH = int(
32
+ os.getenv("RUN_AGENT_UNSLOTH_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH))
33
+ )
34
+ TRUST_REMOTE_CODE = (
35
+ os.getenv("RUN_AGENT_UNSLOTH_TRUST_REMOTE_CODE", "1").strip().lower()
36
+ not in {"0", "false", "off"}
37
+ )
38
+ LOAD_IN_4BIT = (
39
+ os.getenv("RUN_AGENT_UNSLOTH_LOAD_IN_4BIT", "1").strip().lower()
40
+ not in {"0", "false", "off"}
41
+ )
42
+ FAST_INFERENCE = (
43
+ os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "1").strip().lower()
44
+ not in {"0", "false", "off"}
45
+ )
46
+
47
+
48
+ def check_dashboard_command() -> Optional[Dict[str, Any]]:
49
+ try:
50
+ raw = base.DASHBOARD_CMD_PATH.read_text(encoding="utf-8")
51
+ base.DASHBOARD_CMD_PATH.unlink(missing_ok=True)
52
+ return json.loads(raw)
53
+ except (FileNotFoundError, json.JSONDecodeError):
54
+ return None
55
+
56
+
57
+ def run_episode(
58
+ model: Any,
59
+ tokenizer: Any,
60
+ *,
61
+ scenario_name: Optional[str] = None,
62
+ custom_ground_truth: Optional[Dict[str, Any]] = None,
63
+ ) -> None:
64
+ env = BioExperimentEnvironment(scenario_name=scenario_name)
65
+ obs = env.reset()
66
+
67
+ if custom_ground_truth and env._latent:
68
+ gt = custom_ground_truth
69
+ bio = env._latent.biology
70
+ if gt.get("true_markers"):
71
+ bio.true_markers = gt["true_markers"]
72
+ if gt.get("causal_mechanisms"):
73
+ bio.causal_mechanisms = gt["causal_mechanisms"]
74
+ if gt.get("true_pathways"):
75
+ bio.true_pathways = {
76
+ key: float(value) for key, value in gt["true_pathways"].items()
77
+ }
78
+
79
+ base.log("\n" + "=" * 70)
80
+ base.log(f"TASK: {obs.task.problem_statement}")
81
+ base.log(f"Conditions: {obs.task.conditions}")
82
+ base.log(
83
+ f"Budget: ${obs.task.budget_limit:,.0f} | "
84
+ f"Time: {obs.task.time_limit_days:.0f} days"
85
+ )
86
+ base.log("Runtime: Unsloth quantized generation")
87
+ base.log("=" * 70)
88
+
89
+ cumulative_reward = 0.0
90
+ base.write_dashboard_state(env, obs, step=0, cumulative_reward=0.0)
91
+
92
+ for step in range(MAX_EPISODE_STEPS):
93
+ cmd = check_dashboard_command()
94
+ if cmd and cmd.get("action") == "restart":
95
+ base.log("\n[DASHBOARD] Restart requested - ending episode early.")
96
+ break
97
+
98
+ t0 = time.time()
99
+ result = generate_action_with_model(
100
+ model,
101
+ tokenizer,
102
+ obs,
103
+ max_new_tokens=MAX_NEW_TOKENS,
104
+ temperature=0.2,
105
+ top_p=0.9,
106
+ do_sample=True,
107
+ )
108
+ response = result["response_text"]
109
+ action = result["action"]
110
+ gen_time = time.time() - t0
111
+
112
+ is_last_step = step == MAX_EPISODE_STEPS - 1
113
+ if action is None:
114
+ if is_last_step:
115
+ base.log("\n [!] Parse failed on final step - forcing synthesize_conclusion.")
116
+ action = ExperimentAction(
117
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
118
+ justification="forced terminal conclusion",
119
+ confidence=0.5,
120
+ )
121
+ else:
122
+ base.log(
123
+ f"\n [!] Parse failed, skipping step. Raw: {response[:150]}"
124
+ )
125
+ continue
126
+
127
+ completed_types = {
128
+ record.action_type for record in obs.pipeline_history if record.success
129
+ }
130
+ failed_types = {
131
+ record.action_type for record in obs.pipeline_history if not record.success
132
+ }
133
+
134
+ if base.should_force_terminal_conclusion(action, completed_types):
135
+ base.log(
136
+ f"\n [!] repeated completed meta step {action.action_type.value} "
137
+ f"- forcing synthesize_conclusion."
138
+ )
139
+ action = ExperimentAction(
140
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
141
+ justification="repeated completed meta step forced terminal conclusion",
142
+ confidence=action.confidence,
143
+ )
144
+ completed_types = {
145
+ record.action_type for record in obs.pipeline_history if record.success
146
+ }
147
+
148
+ skip_reason = None
149
+ if action.action_type in completed_types:
150
+ skip_reason = f"blocked repeat of completed step {action.action_type.value}"
151
+ elif action.action_type in failed_types:
152
+ if base.should_block_failed_reattempt(obs.pipeline_history, action.action_type):
153
+ skip_reason = (
154
+ f"blocked re-attempt of failed step {action.action_type.value}"
155
+ )
156
+
157
+ if skip_reason:
158
+ if is_last_step:
159
+ base.log(
160
+ f"\n [!] {skip_reason} on final step - forcing synthesize_conclusion."
161
+ )
162
+ action = ExperimentAction(
163
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
164
+ justification="forced terminal conclusion",
165
+ confidence=0.5,
166
+ )
167
+ else:
168
+ base.log(f"\n [!] {skip_reason}, skipping step.")
169
+ continue
170
+
171
+ if is_last_step and action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
172
+ base.log(
173
+ f"\n [!] Final step - overriding {action.action_type.value} "
174
+ "with synthesize_conclusion."
175
+ )
176
+ action = ExperimentAction(
177
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
178
+ justification="forced terminal conclusion",
179
+ confidence=action.confidence,
180
+ )
181
+
182
+ action = base.ensure_conclusion_claims(obs, action)
183
+
184
+ base.log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
185
+ if action.justification:
186
+ base.log(f" Rationale: {action.justification}")
187
+ else:
188
+ base.log(" Rationale: [model did not provide one]")
189
+ if action.parameters:
190
+ base.log(f" Parameters: {base.compact_preview(action.parameters, 200)}")
191
+ elif response:
192
+ base.log(
193
+ " Model response: "
194
+ f"{base.compact_preview(response, base.MODEL_RESPONSE_PREVIEW_CHARS)}"
195
+ )
196
+
197
+ obs = env.step(action)
198
+
199
+ if obs.latest_output:
200
+ latest_output = obs.latest_output
201
+ status = "OK" if latest_output.success else "FAIL"
202
+ base.log(f" [{status}] {latest_output.summary}")
203
+ if latest_output.warnings:
204
+ base.log(f" Warnings: {latest_output.warnings}")
205
+
206
+ step_reward = obs.reward
207
+ cumulative_reward += step_reward
208
+ base.log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
209
+ base.log(
210
+ f" Budget: ${obs.resource_usage.budget_remaining:,.0f} | "
211
+ f"Time: {obs.resource_usage.time_remaining_days:.0f}d"
212
+ )
213
+
214
+ base.write_dashboard_state(
215
+ env,
216
+ obs,
217
+ step=step + 1,
218
+ cumulative_reward=cumulative_reward,
219
+ model_response=response,
220
+ action=action,
221
+ gen_time=gen_time,
222
+ episode_done=obs.done,
223
+ )
224
+
225
+ if obs.rule_violations:
226
+ base.log(f" Violations: {obs.rule_violations}")
227
+ if obs.done:
228
+ break
229
+
230
+ base.log(f"\n{'=' * 70}")
231
+ base.log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
232
+ base.log(f" Steps: {obs.step_index}")
233
+ base.log(f" Total reward: {cumulative_reward:+.3f}")
234
+ base.log(f" Budget used: ${obs.resource_usage.budget_used:,.0f}")
235
+ base.log(f" Time used: {obs.resource_usage.time_used_days:.0f} days")
236
+ if obs.conclusions:
237
+ base.log(" Conclusions:")
238
+ for conclusion in obs.conclusions:
239
+ base.log(
240
+ f" [{conclusion.claim_type}, conf={conclusion.confidence:.2f}] "
241
+ f"{conclusion.claim}"
242
+ )
243
+ if conclusion.top_markers:
244
+ base.log(f" Markers: {conclusion.top_markers}")
245
+ if conclusion.causal_mechanisms:
246
+ base.log(f" Mechanisms: {conclusion.causal_mechanisms}")
247
+ if conclusion.predicted_pathways:
248
+ base.log(f" Pathways: {conclusion.predicted_pathways}")
249
+ base.log("=" * 70)
250
+
251
+
252
+ def main() -> None:
253
+ runtime = base.resolve_torch_runtime()
254
+ base.log(
255
+ f"Using Unsloth runtime: device={runtime['device']} "
256
+ f"name={runtime['device_name']} dtype={runtime['dtype']}"
257
+ )
258
+ tokenizer, model = load_model_artifacts(
259
+ MODEL_ID,
260
+ trust_remote_code=TRUST_REMOTE_CODE,
261
+ max_seq_length=MAX_SEQ_LENGTH,
262
+ load_in_4bit=LOAD_IN_4BIT,
263
+ fast_inference=FAST_INFERENCE,
264
+ prepare_for_inference=True,
265
+ )
266
+ base.DASHBOARD_CMD_PATH.unlink(missing_ok=True)
267
+ run_episode(model, tokenizer)
268
+
269
+ while True:
270
+ base.log("\nWaiting for dashboard command (restart / new task) ...")
271
+ while True:
272
+ cmd = check_dashboard_command()
273
+ if cmd:
274
+ break
275
+ time.sleep(1.0)
276
+
277
+ action_type = cmd.get("action", "restart")
278
+ if action_type == "quit":
279
+ base.log("Quit requested.")
280
+ break
281
+
282
+ scenario = cmd.get("scenario_name")
283
+ ground_truth = cmd.get("ground_truth")
284
+ base.log(f"\n[DASHBOARD] {action_type} - scenario={scenario}")
285
+ run_episode(
286
+ model,
287
+ tokenizer,
288
+ scenario_name=scenario,
289
+ custom_ground_truth=ground_truth,
290
+ )
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()
server/hackathon_environment.py CHANGED
@@ -132,7 +132,7 @@ class BioExperimentEnvironment(Environment):
132
  self._outputs.append(result.output)
133
  self._update_discoveries(action, result.output)
134
 
135
- if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
136
  raw_claims = action.parameters.get("claims", [])
137
  for c in raw_claims:
138
  if isinstance(c, dict):
@@ -218,7 +218,7 @@ class BioExperimentEnvironment(Environment):
218
  subagent_outputs=list(self._subagent_outputs),
219
  conclusions=list(self._conclusions),
220
  rule_violations=rule_violations or [],
221
- step_reward_breakdown={},
222
  done=done,
223
  reward=reward,
224
  metadata=meta,
 
132
  self._outputs.append(result.output)
133
  self._update_discoveries(action, result.output)
134
 
135
+ if action.action_type == ActionType.SYNTHESIZE_CONCLUSION and result.output.success:
136
  raw_claims = action.parameters.get("claims", [])
137
  for c in raw_claims:
138
  if isinstance(c, dict):
 
218
  subagent_outputs=list(self._subagent_outputs),
219
  conclusions=list(self._conclusions),
220
  rule_violations=rule_violations or [],
221
+ step_reward_breakdown=reward_breakdown or {},
222
  done=done,
223
  reward=reward,
224
  metadata=meta,
server/rewards/reward.py CHANGED
@@ -24,7 +24,7 @@ The terminal reward adds:
24
  from __future__ import annotations
25
 
26
  from dataclasses import dataclass, field
27
- from typing import Any, Dict, List, Optional
28
 
29
  from models import (
30
  ActionType,
@@ -214,10 +214,19 @@ class RewardComputer:
214
  discovered_markers,
215
  candidate_mechanisms,
216
  )
217
- discovery_error_penalty = -2.5 * (1.0 - discovery_alignment)
 
 
218
  rb.components["discovery_alignment"] = discovery_alignment
219
  rb.components["discovery_error_penalty"] = discovery_error_penalty
220
 
 
 
 
 
 
 
 
221
  eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
222
  rb.terminal = (
223
  3.0 * completeness
@@ -225,6 +234,7 @@ class RewardComputer:
225
  + 1.0 * eff_bonus
226
  + overconf
227
  + discovery_error_penalty
 
228
  )
229
  return rb
230
 
@@ -470,3 +480,42 @@ class RewardComputer:
470
  if not components:
471
  return 1.0
472
  return sum(components) / len(components)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  from __future__ import annotations
25
 
26
  from dataclasses import dataclass, field
27
+ from typing import Dict, List, Optional
28
 
29
  from models import (
30
  ActionType,
 
214
  discovered_markers,
215
  candidate_mechanisms,
216
  )
217
+ discovery_error_penalty = -6.0 * (1.0 - discovery_alignment)
218
+ if discovery_alignment < 0.25:
219
+ discovery_error_penalty -= 2.0
220
  rb.components["discovery_alignment"] = discovery_alignment
221
  rb.components["discovery_error_penalty"] = discovery_error_penalty
222
 
223
+ conclusion_alignment = self._conclusion_alignment(state, conclusions)
224
+ conclusion_error_penalty = -4.0 * (1.0 - conclusion_alignment)
225
+ if conclusions and conclusion_alignment < 0.25:
226
+ conclusion_error_penalty -= 1.5
227
+ rb.components["conclusion_alignment"] = conclusion_alignment
228
+ rb.components["conclusion_error_penalty"] = conclusion_error_penalty
229
+
230
  eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
231
  rb.terminal = (
232
  3.0 * completeness
 
234
  + 1.0 * eff_bonus
235
  + overconf
236
  + discovery_error_penalty
237
+ + conclusion_error_penalty
238
  )
239
  return rb
240
 
 
480
  if not components:
481
  return 1.0
482
  return sum(components) / len(components)
483
+
484
+ def _conclusion_alignment(
485
+ self,
486
+ s: FullLatentState,
487
+ conclusions: List[ConclusionClaim],
488
+ ) -> float:
489
+ if not conclusions:
490
+ return 0.0
491
+
492
+ pred_markers = [marker for conclusion in conclusions for marker in conclusion.top_markers]
493
+ pred_mechanisms = [
494
+ mechanism
495
+ for conclusion in conclusions
496
+ for mechanism in conclusion.causal_mechanisms
497
+ ]
498
+
499
+ if not pred_markers and not pred_mechanisms:
500
+ return self._legacy_calibration(s, conclusions)
501
+
502
+ components: List[float] = []
503
+ if s.biology.true_markers or pred_markers:
504
+ marker_recall = marker_set_score(pred_markers, s.biology.true_markers)
505
+ marker_precision = marker_set_score(s.biology.true_markers, pred_markers)
506
+ components.append((marker_recall + marker_precision) / 2.0)
507
+
508
+ if s.biology.causal_mechanisms or pred_mechanisms:
509
+ mechanism_recall = mechanism_set_score(
510
+ pred_mechanisms,
511
+ s.biology.causal_mechanisms,
512
+ )
513
+ mechanism_precision = mechanism_set_score(
514
+ s.biology.causal_mechanisms,
515
+ pred_mechanisms,
516
+ )
517
+ components.append((mechanism_recall + mechanism_precision) / 2.0)
518
+
519
+ if not components:
520
+ return 1.0
521
+ return sum(components) / len(components)
server/rules/engine.py CHANGED
@@ -45,6 +45,16 @@ class RuleEngine:
45
  p.markers_validated,
46
  ])
47
 
 
 
 
 
 
 
 
 
 
 
48
  def check(
49
  self, action: ExperimentAction, state: FullLatentState
50
  ) -> List[RuleViolation]:
@@ -238,6 +248,20 @@ class RuleEngine:
238
  message="Cannot synthesise conclusion without substantive analysis",
239
  ))
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  claims = action.parameters.get("claims", [])
242
  for claim in claims:
243
  if isinstance(claim, dict) and claim.get("claim_type") == "causal":
 
45
  p.markers_validated,
46
  ])
47
 
48
+ @staticmethod
49
+ def _has_marker_evidence(s: FullLatentState) -> bool:
50
+ p = s.progress
51
+ return p.markers_discovered or p.markers_validated
52
+
53
+ @staticmethod
54
+ def _has_mechanism_evidence(s: FullLatentState) -> bool:
55
+ p = s.progress
56
+ return p.pathways_analyzed or p.networks_inferred
57
+
58
  def check(
59
  self, action: ExperimentAction, state: FullLatentState
60
  ) -> List[RuleViolation]:
 
248
  message="Cannot synthesise conclusion without substantive analysis",
249
  ))
250
 
251
+ if not self._has_marker_evidence(s):
252
+ vs.append(RuleViolation(
253
+ rule_id="conclusion_without_marker_evidence",
254
+ severity=Severity.HARD,
255
+ message="Cannot synthesise conclusion before discovering or validating markers",
256
+ ))
257
+
258
+ if not self._has_mechanism_evidence(s):
259
+ vs.append(RuleViolation(
260
+ rule_id="conclusion_without_mechanism_evidence",
261
+ severity=Severity.HARD,
262
+ message="Cannot synthesise conclusion before inferring pathways or mechanisms",
263
+ ))
264
+
265
  claims = action.parameters.get("claims", [])
266
  for claim in claims:
267
  if isinstance(claim, dict) and claim.get("claim_type") == "causal":
tests/test_environment.py CHANGED
@@ -64,7 +64,7 @@ class TestEnvironmentLifecycle:
64
  parameters={"assay": "qPCR"},
65
  ))
66
  assert obs.latest_output is not None
67
- assert obs.latest_output.success is True
68
  assert any("follow-up design" in msg.lower() for msg in obs.rule_violations)
69
 
70
  def test_conclusion_ends_episode(self):
@@ -81,6 +81,8 @@ class TestEnvironmentLifecycle:
81
  ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
82
  ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
83
  parameters={"comparison": "disease_vs_healthy"}),
 
 
84
  ExperimentAction(
85
  action_type=ActionType.SYNTHESIZE_CONCLUSION,
86
  parameters={"claims": [
@@ -94,3 +96,33 @@ class TestEnvironmentLifecycle:
94
 
95
  assert obs.done is True
96
  assert obs.reward != 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  parameters={"assay": "qPCR"},
65
  ))
66
  assert obs.latest_output is not None
67
+ assert obs.latest_output.success is False
68
  assert any("follow-up design" in msg.lower() for msg in obs.rule_violations)
69
 
70
  def test_conclusion_ends_episode(self):
 
81
  ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
82
  ExperimentAction(action_type=ActionType.DIFFERENTIAL_EXPRESSION,
83
  parameters={"comparison": "disease_vs_healthy"}),
84
+ ExperimentAction(action_type=ActionType.PATHWAY_ENRICHMENT),
85
+ ExperimentAction(action_type=ActionType.MARKER_SELECTION),
86
  ExperimentAction(
87
  action_type=ActionType.SYNTHESIZE_CONCLUSION,
88
  parameters={"claims": [
 
96
 
97
  assert obs.done is True
98
  assert obs.reward != 0.0
99
+
100
+ def test_blocked_conclusion_does_not_persist_claims(self):
101
+ env = BioExperimentEnvironment()
102
+ env.reset()
103
+
104
+ pipeline = [
105
+ ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
106
+ ExperimentAction(action_type=ActionType.PREPARE_LIBRARY),
107
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
108
+ ExperimentAction(action_type=ActionType.RUN_QC),
109
+ ExperimentAction(action_type=ActionType.FILTER_DATA),
110
+ ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
111
+ ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
112
+ ]
113
+ for action in pipeline:
114
+ obs = env.step(action)
115
+ assert obs.latest_output is not None
116
+ assert obs.latest_output.success is True
117
+
118
+ obs = env.step(ExperimentAction(
119
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
120
+ parameters={"claims": [
121
+ {"claim": "Premature conclusion", "confidence": 0.9},
122
+ ]},
123
+ ))
124
+
125
+ assert obs.latest_output is not None
126
+ assert obs.latest_output.success is False
127
+ assert obs.conclusions == []
128
+ assert any("markers" in msg.lower() for msg in obs.rule_violations)
tests/test_rewards.py CHANGED
@@ -108,7 +108,13 @@ class TestTerminalReward:
108
  claim_type="causal",
109
  ),
110
  ]
111
- rb = rc.terminal_reward(state, claims, [])
 
 
 
 
 
 
112
  assert rb.terminal > 0
113
 
114
  def test_overconfident_wrong_claim_penalised(self):
@@ -165,3 +171,46 @@ class TestTerminalReward:
165
  assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
166
  assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
167
  assert aligned.terminal > misaligned.terminal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  claim_type="causal",
109
  ),
110
  ]
111
+ rb = rc.terminal_reward(
112
+ state,
113
+ claims,
114
+ [],
115
+ discovered_markers=["NPPA"],
116
+ candidate_mechanisms=["TGF-beta-driven fibrosis"],
117
+ )
118
  assert rb.terminal > 0
119
 
120
  def test_overconfident_wrong_claim_penalised(self):
 
171
  assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
172
  assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
173
  assert aligned.terminal > misaligned.terminal
174
+
175
+ def test_conclusion_error_penalizes_wrong_structured_claims(self):
176
+ rc = RewardComputer()
177
+ state = FullLatentState(
178
+ biology=LatentBiologicalState(
179
+ true_markers=["NPPA", "NPPB"],
180
+ causal_mechanisms=["TGF-beta-driven fibrosis"],
181
+ ),
182
+ progress=ExperimentProgress(
183
+ data_normalized=True,
184
+ de_performed=True,
185
+ markers_discovered=True,
186
+ pathways_analyzed=True,
187
+ conclusion_reached=True,
188
+ ),
189
+ resources=ResourceState(budget_total=100_000, budget_used=40_000),
190
+ )
191
+ aligned = rc.terminal_reward(
192
+ state,
193
+ [
194
+ ConclusionClaim(
195
+ top_markers=["NPPA", "NPPB"],
196
+ causal_mechanisms=["TGF-beta-driven fibrosis"],
197
+ confidence=0.8,
198
+ ),
199
+ ],
200
+ [],
201
+ )
202
+ misaligned = rc.terminal_reward(
203
+ state,
204
+ [
205
+ ConclusionClaim(
206
+ top_markers=["WRONG1"],
207
+ causal_mechanisms=["unrelated process"],
208
+ confidence=0.8,
209
+ ),
210
+ ],
211
+ [],
212
+ )
213
+
214
+ assert aligned.components["conclusion_alignment"] > misaligned.components["conclusion_alignment"]
215
+ assert aligned.components["conclusion_error_penalty"] > misaligned.components["conclusion_error_penalty"]
216
+ assert aligned.terminal > misaligned.terminal
tests/test_rules.py CHANGED
@@ -55,47 +55,68 @@ class TestPrerequisites:
55
 
56
 
57
  class TestRedundancy:
58
- def test_double_qc_is_soft(self):
59
  engine = RuleEngine()
60
  violations = engine.check(
61
  ExperimentAction(action_type=ActionType.RUN_QC),
62
  _state(cells_sequenced=True, qc_performed=True),
63
  )
64
  hard = engine.hard_violations(violations)
65
- soft = engine.soft_violations(violations)
66
- assert not hard
67
- assert any("redundant" in m.lower() for m in soft)
68
 
69
- def test_repeated_followup_design_is_soft(self):
70
  engine = RuleEngine()
71
  violations = engine.check(
72
  ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
73
  _state(followup_designed=True, de_performed=True),
74
  )
75
  hard = engine.hard_violations(violations)
76
- soft = engine.soft_violations(violations)
77
- assert not hard
78
- assert any("redundant" in m.lower() for m in soft)
79
 
80
 
81
  class TestMetaActionTiming:
82
- def test_followup_design_without_analysis_is_soft(self):
83
  engine = RuleEngine()
84
  violations = engine.check(
85
  ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
86
  _state(),
87
  )
88
- soft = engine.soft_violations(violations)
89
- assert any("follow-up design" in m.lower() for m in soft)
90
 
91
- def test_subagent_review_without_analysis_is_soft(self):
92
  engine = RuleEngine()
93
  violations = engine.check(
94
  ExperimentAction(action_type=ActionType.REQUEST_SUBAGENT_REVIEW),
95
  _state(),
96
  )
97
- soft = engine.soft_violations(violations)
98
- assert any("subagent review" in m.lower() for m in soft)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  class TestResourceConstraints:
101
  def test_exhausted_budget_blocked(self):
 
55
 
56
 
57
  class TestRedundancy:
58
+ def test_double_qc_is_hard_blocked(self):
59
  engine = RuleEngine()
60
  violations = engine.check(
61
  ExperimentAction(action_type=ActionType.RUN_QC),
62
  _state(cells_sequenced=True, qc_performed=True),
63
  )
64
  hard = engine.hard_violations(violations)
65
+ assert any("redundant" in m.lower() for m in hard)
 
 
66
 
67
+ def test_repeated_followup_design_is_hard_blocked(self):
68
  engine = RuleEngine()
69
  violations = engine.check(
70
  ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
71
  _state(followup_designed=True, de_performed=True),
72
  )
73
  hard = engine.hard_violations(violations)
74
+ assert any("redundant" in m.lower() for m in hard)
 
 
75
 
76
 
77
  class TestMetaActionTiming:
78
+ def test_followup_design_without_analysis_is_hard_blocked(self):
79
  engine = RuleEngine()
80
  violations = engine.check(
81
  ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
82
  _state(),
83
  )
84
+ hard = engine.hard_violations(violations)
85
+ assert any("follow-up design" in m.lower() for m in hard)
86
 
87
+ def test_subagent_review_without_analysis_is_hard_blocked(self):
88
  engine = RuleEngine()
89
  violations = engine.check(
90
  ExperimentAction(action_type=ActionType.REQUEST_SUBAGENT_REVIEW),
91
  _state(),
92
  )
93
+ hard = engine.hard_violations(violations)
94
+ assert any("subagent review" in m.lower() for m in hard)
95
+
96
+ def test_conclusion_without_marker_or_mechanism_evidence_is_hard_blocked(self):
97
+ engine = RuleEngine()
98
+ violations = engine.check(
99
+ ExperimentAction(action_type=ActionType.SYNTHESIZE_CONCLUSION),
100
+ _state(data_normalized=True, cells_clustered=True),
101
+ )
102
+ hard = engine.hard_violations(violations)
103
+ assert any("markers" in m.lower() for m in hard)
104
+ assert any("pathways or mechanisms" in m.lower() for m in hard)
105
+
106
+ def test_conclusion_with_marker_and_mechanism_evidence_is_allowed(self):
107
+ engine = RuleEngine()
108
+ violations = engine.check(
109
+ ExperimentAction(action_type=ActionType.SYNTHESIZE_CONCLUSION),
110
+ _state(
111
+ data_normalized=True,
112
+ cells_clustered=True,
113
+ markers_discovered=True,
114
+ pathways_analyzed=True,
115
+ ),
116
+ )
117
+ hard = engine.hard_violations(violations)
118
+ assert not hard
119
+
120
 
121
  class TestResourceConstraints:
122
  def test_exhausted_budget_blocked(self):
tests/test_run_agent.py CHANGED
@@ -1,7 +1,7 @@
1
  """Tests for run_agent parser and fallback helpers."""
2
 
3
  from models import ActionType, ExperimentAction
4
- from run_agent import fallback_action, parse_action
5
  from server.hackathon_environment import BioExperimentEnvironment
6
 
7
 
@@ -23,7 +23,23 @@ def test_parse_action_accepts_justifyement_typo():
23
  assert action.justification == "typo key"
24
 
25
 
26
- def test_fallback_uses_observation_progress_not_step_index():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
28
  obs = env.reset(seed=0)
29
  for action_type in (
@@ -32,5 +48,44 @@ def test_fallback_uses_observation_progress_not_step_index():
32
  ActionType.SEQUENCE_CELLS,
33
  ):
34
  obs = env.step(ExperimentAction(action_type=action_type))
35
- action = fallback_action(obs)
36
- assert action.action_type == ActionType.RUN_QC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tests for run_agent parser and fallback helpers."""
2
 
3
  from models import ActionType, ExperimentAction
4
+ from run_agent import ensure_conclusion_claims, extract_json_object, parse_action, should_block_failed_reattempt
5
  from server.hackathon_environment import BioExperimentEnvironment
6
 
7
 
 
23
  assert action.justification == "typo key"
24
 
25
 
26
+ def test_extract_json_object_unwraps_quoted_json_string():
27
+ parsed = extract_json_object(
28
+ '"{\\"action_type\\": \\"run_qc\\", \\"method\\": \\"\\", \\"parameters\\": {}, \\"Justification\\": \\"check quality\\", \\"confidence\\": 0.8}"'
29
+ )
30
+ assert parsed is not None
31
+ assert parsed["action_type"] == "run_qc"
32
+
33
+
34
+ def test_parse_action_falls_back_when_inner_object_lacks_action_type():
35
+ action = parse_action(
36
+ '"{\\"action_type\\": \\"design_followup_experiment\\", \\"method\\": \\"\\", \\"parameters\\": {\\"criterion_description\\": \\"\\"}, \\"Justification\\": \\"follow-up\\", \\"confidence\\": 0.6, \\"threshold_value\\": {\\"conditions\\": [], \\"gene_filter_criteria\\": \\"x\\", \\"sample_group_size\\": 3}}"' # noqa: E501
37
+ )
38
+ assert action is not None
39
+ assert action.action_type == ActionType.DESIGN_FOLLOWUP
40
+
41
+
42
+ def test_should_block_failed_reattempt_until_pipeline_progress():
43
  env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
44
  obs = env.reset(seed=0)
45
  for action_type in (
 
48
  ActionType.SEQUENCE_CELLS,
49
  ):
50
  obs = env.step(ExperimentAction(action_type=action_type))
51
+ assert should_block_failed_reattempt(obs.pipeline_history, ActionType.SEQUENCE_CELLS) is False
52
+ assert should_block_failed_reattempt(obs.pipeline_history, ActionType.RUN_QC) is False
53
+
54
+
55
+ def test_ensure_conclusion_claims_infers_from_outputs_when_discoveries_empty():
56
+ env = BioExperimentEnvironment(scenario_name="cardiac_disease_de", domain_randomise=False)
57
+ obs = env.reset(seed=0)
58
+ pipeline = [
59
+ ExperimentAction(action_type=ActionType.COLLECT_SAMPLE),
60
+ ExperimentAction(action_type=ActionType.PREPARE_LIBRARY),
61
+ ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
62
+ ExperimentAction(action_type=ActionType.RUN_QC),
63
+ ExperimentAction(action_type=ActionType.FILTER_DATA),
64
+ ExperimentAction(action_type=ActionType.NORMALIZE_DATA),
65
+ ExperimentAction(action_type=ActionType.CLUSTER_CELLS),
66
+ ExperimentAction(
67
+ action_type=ActionType.DIFFERENTIAL_EXPRESSION,
68
+ parameters={"comparison": "disease_vs_healthy"},
69
+ ),
70
+ ExperimentAction(action_type=ActionType.PATHWAY_ENRICHMENT),
71
+ ]
72
+ for action in pipeline:
73
+ obs = env.step(action)
74
+
75
+ sparse_obs = obs.model_copy(update={
76
+ "discovered_markers": [],
77
+ "candidate_mechanisms": [],
78
+ })
79
+ action = ensure_conclusion_claims(
80
+ sparse_obs,
81
+ ExperimentAction(
82
+ action_type=ActionType.SYNTHESIZE_CONCLUSION,
83
+ confidence=0.9,
84
+ parameters={},
85
+ ),
86
+ )
87
+
88
+ claims = action.parameters["claims"]
89
+ assert claims[0]["top_markers"]
90
+ assert claims[0]["causal_mechanisms"]
91
+ assert claims[0]["predicted_pathways"]
tests/test_training_script.py CHANGED
@@ -43,6 +43,15 @@ def test_parse_action_completion_accepts_reasoning_alias():
43
  assert action.justification == "Measure quality before filtering."
44
 
45
 
 
 
 
 
 
 
 
 
 
46
  def test_build_prompt_examples_contains_reference_action():
47
  examples = build_prompt_examples(
48
  dataset_episodes=1,
@@ -55,6 +64,7 @@ def test_build_prompt_examples_contains_reference_action():
55
  assert len(examples) == 2
56
  assert examples[0]["scenario_name"] == "cardiac_disease_de"
57
  assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
 
58
 
59
 
60
  def test_openenv_reward_penalizes_invalid_completion():
 
43
  assert action.justification == "Measure quality before filtering."
44
 
45
 
46
+ def test_parse_action_completion_normalizes_run_agent_aliases():
47
+ action = parse_action_completion(
48
+ '{"action_type":"network_inference","method":"pySCENIC"}'
49
+ )
50
+ assert action is not None
51
+ assert action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE
52
+ assert action.method == "pySCENIC"
53
+
54
+
55
  def test_build_prompt_examples_contains_reference_action():
56
  examples = build_prompt_examples(
57
  dataset_episodes=1,
 
64
  assert len(examples) == 2
65
  assert examples[0]["scenario_name"] == "cardiac_disease_de"
66
  assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
67
+ assert '"action_type": "select_cohort"' in examples[1]["reference_action"]
68
 
69
 
70
  def test_openenv_reward_penalizes_invalid_completion():
train.ipynb ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cbde861c",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Train A Self-Driving Lab Policy on H100\n",
9
+ "\n",
10
+ "This notebook is designed for Jupyter GPU nodes such as H100 clusters.\n",
11
+ "It uses the notebook-friendly helpers in `training_script.py` to build prompts from the same self-driving lab environment state used by `run_agent.py`, preview reference actions, and launch GRPO training without shelling out to the CLI."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "da2e770c",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "%pip install -q -U torch transformers datasets trl accelerate matplotlib huggingface_hub\n",
22
+ "\n",
23
+ "# Optional extras used by some reward-scoring paths.\n",
24
+ "%pip install -q -U sentence-transformers gseapy"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "f4444591",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "from pathlib import Path\n",
35
+ "\n",
36
+ "import torch\n",
37
+ "\n",
38
+ "from training_script import build_prompt_examples, make_training_args, run_training\n",
39
+ "\n",
40
+ "print(\"CUDA available:\", torch.cuda.is_available())\n",
41
+ "if torch.cuda.is_available():\n",
42
+ " print(\"GPU:\", torch.cuda.get_device_name(0))\n",
43
+ " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
44
+ "\n",
45
+ "Path(\"artifacts\").mkdir(exist_ok=True)"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "c9c472b3",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "args = make_training_args(\n",
56
+ " model_id=\"Qwen/Qwen3.5-0.8B\",\n",
57
+ " output_dir=\"artifacts/grpo-h100\",\n",
58
+ " dataset_episodes=32,\n",
59
+ " rollout_steps=10,\n",
60
+ " collection_policy=\"heuristic\",\n",
61
+ " reward_backend=\"local\",\n",
62
+ " domain_randomise=True,\n",
63
+ " num_generations=4,\n",
64
+ " max_completion_length=160,\n",
65
+ " max_prompt_length=1280,\n",
66
+ " per_device_train_batch_size=4,\n",
67
+ " gradient_accumulation_steps=4,\n",
68
+ " learning_rate=5e-6,\n",
69
+ " num_train_epochs=1.0,\n",
70
+ " logging_steps=1,\n",
71
+ " save_steps=25,\n",
72
+ " trust_remote_code=True,\n",
73
+ " dry_run=False,\n",
74
+ " seed=42,\n",
75
+ ")\n",
76
+ "\n",
77
+ "args"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "d4c3d9c4",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "preview_examples = build_prompt_examples(\n",
88
+ " dataset_episodes=1,\n",
89
+ " rollout_steps=args.rollout_steps,\n",
90
+ " collection_policy=args.collection_policy,\n",
91
+ " scenario_names=[\"cardiac_disease_de\"],\n",
92
+ " seed=args.seed,\n",
93
+ " domain_randomise=args.domain_randomise,\n",
94
+ ")\n",
95
+ "\n",
96
+ "print(preview_examples[0][\"prompt\"][:3500])\n",
97
+ "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "647663dd",
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "# Optional smoke test before a full run.\n",
108
+ "dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
109
+ "dry_run_result = run_training(dry_run_args)\n",
110
+ "len(dry_run_result[\"examples\"])"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "id": "5f29f456",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "from IPython.display import Image, display\n",
121
+ "\n",
122
+ "train_result = run_training(args)\n",
123
+ "for name, plot_path in train_result[\"plot_paths\"].items():\n",
124
+ " print(name, plot_path)\n",
125
+ " display(Image(filename=plot_path))"
126
+ ]
127
+ }
128
+ ],
129
+ "metadata": {
130
+ "kernelspec": {
131
+ "display_name": "Python 3",
132
+ "language": "python",
133
+ "name": "python3"
134
+ },
135
+ "language_info": {
136
+ "name": "python"
137
+ }
138
+ },
139
+ "nbformat": 4,
140
+ "nbformat_minor": 5
141
+ }
training_script.py CHANGED
@@ -1,4 +1,4 @@
1
- """Train a planner with TRL GRPO and OpenEnv rewards."""
2
 
3
  from __future__ import annotations
4
 
@@ -21,29 +21,53 @@ from models import (
21
  from server.hackathon_environment import BioExperimentEnvironment
22
  from server.tasks.scenarios import SCENARIO_LIBRARY
23
 
24
- DEFAULT_MODEL_ID = "Qwen/Qwen3.5-0.8B"
25
  DEFAULT_OUTPUT_DIR = "training/grpo-output"
26
  DEFAULT_BASE_URL = "http://localhost:8000"
 
27
  INVALID_ACTION_PENALTY = -2.0
28
  ENVIRONMENT_ERROR_PENALTY = -4.0
29
 
30
  SYSTEM_PROMPT = build_agent_system_prompt()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  HEURISTIC_SEQUENCE = [
33
  ActionType.COLLECT_SAMPLE,
 
34
  ActionType.PREPARE_LIBRARY,
35
  ActionType.SEQUENCE_CELLS,
36
  ActionType.RUN_QC,
37
  ActionType.FILTER_DATA,
38
  ActionType.NORMALIZE_DATA,
 
39
  ActionType.CLUSTER_CELLS,
40
  ActionType.DIFFERENTIAL_EXPRESSION,
41
  ActionType.PATHWAY_ENRICHMENT,
42
  ActionType.MARKER_SELECTION,
 
 
43
  ActionType.SYNTHESIZE_CONCLUSION,
44
  ]
45
 
46
- VALID_ACTION_TYPES = {action.value for action in ActionType}
47
 
48
 
49
  def compact_preview(value: Any, max_chars: int = 160) -> str:
@@ -129,7 +153,11 @@ def build_argument_parser() -> argparse.ArgumentParser:
129
  help="Enable domain randomisation while building prompts and local rewards.",
130
  )
131
  parser.add_argument("--num-generations", type=int, default=2)
132
- parser.add_argument("--max-completion-length", type=int, default=220)
 
 
 
 
133
  parser.add_argument("--max-prompt-length", type=int, default=768)
134
  parser.add_argument("--per-device-train-batch-size", type=int, default=2)
135
  parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
@@ -197,23 +225,42 @@ def format_observation(obs: ExperimentObservation) -> str:
197
  if context:
198
  parts.append(context)
199
  if obs.pipeline_history:
200
- parts.append("History:")
201
- for step in obs.pipeline_history[-5:]:
 
202
  tag = "OK" if step.success else "FAIL"
203
- line = f" [{tag}] {step.action_type.value}: {step.output_summary[:100]}"
204
- if step.parameters:
205
- line += f" | params={compact_preview(step.parameters, 120)}"
 
206
  parts.append(line)
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if obs.latest_output and obs.latest_output.data:
208
  parts.append(
209
- f"Latest output data: {compact_preview(obs.latest_output.data, 200)}"
210
  )
211
  if obs.rule_violations:
212
- parts.append(f"Violations: {obs.rule_violations}")
213
  if obs.discovered_markers:
214
- parts.append(f"Markers: {obs.discovered_markers[:5]}")
215
  if obs.candidate_mechanisms:
216
- parts.append(f"Mechanisms: {obs.candidate_mechanisms[:5]}")
 
 
 
 
217
  return "\n".join(parts)
218
 
219
 
@@ -251,6 +298,7 @@ def default_comparison_name(conditions: Sequence[str]) -> str:
251
  def build_experiment_action(
252
  action_type: ActionType,
253
  discovered_markers: Sequence[str],
 
254
  conditions: Sequence[str],
255
  ) -> ExperimentAction:
256
  method = None
@@ -260,9 +308,27 @@ def build_experiment_action(
260
  if action_type == ActionType.COLLECT_SAMPLE:
261
  parameters = {"n_samples": 6}
262
  justification = "Collect enough samples to start the experiment."
 
 
 
 
 
 
263
  elif action_type == ActionType.PREPARE_LIBRARY:
264
  method = "10x_chromium"
265
  justification = "Prepare a single-cell library for sequencing."
 
 
 
 
 
 
 
 
 
 
 
 
266
  elif action_type == ActionType.SEQUENCE_CELLS:
267
  method = "NovaSeq"
268
  justification = "Generate reads for downstream single-cell analysis."
@@ -275,6 +341,9 @@ def build_experiment_action(
275
  elif action_type == ActionType.NORMALIZE_DATA:
276
  method = "scanpy.pp.normalize_total"
277
  justification = "Normalize counts for comparable expression profiles."
 
 
 
278
  elif action_type == ActionType.CLUSTER_CELLS:
279
  method = "scanpy.tl.leiden"
280
  justification = "Resolve cell states before interpretation."
@@ -291,20 +360,34 @@ def build_experiment_action(
291
  elif action_type == ActionType.MARKER_SELECTION:
292
  method = "scanpy.tl.rank_genes_groups"
293
  justification = "Nominate marker genes for validation."
 
 
 
294
  elif action_type == ActionType.VALIDATE_MARKER:
295
  method = "qPCR"
296
  parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
297
  justification = "Validate the strongest discovered marker."
 
 
 
 
 
 
 
 
298
  elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
299
  top = list(discovered_markers[:5]) if discovered_markers else []
300
  parameters = {
301
  "claims": [{
302
  "top_markers": top,
303
- "causal_mechanisms": [],
304
- "predicted_pathways": {},
 
 
 
305
  "confidence": 0.6,
306
- "claim_type": "correlational",
307
- "claim": "",
308
  }],
309
  }
310
  justification = "Summarize the current evidence into a conclusion."
@@ -375,6 +458,7 @@ def build_prompt_examples(
375
  [action.action_type for action in history_actions],
376
  ),
377
  discovered_markers=obs.discovered_markers,
 
378
  conditions=obs.task.conditions,
379
  )
380
  examples.append({
@@ -543,11 +627,145 @@ def normalize_optional_string(value: Any) -> Optional[str]:
543
  return compact_preview(value, 80)
544
 
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  def parse_action_completion(text: str) -> Optional[ExperimentAction]:
547
  payload = extract_json_object(text)
548
  if payload is not None:
549
- action_type = get_payload_value(payload, "action_type")
550
- if action_type not in VALID_ACTION_TYPES:
551
  return None
552
 
553
  parameters = get_payload_value(payload, "parameters", "params") or {}
@@ -584,8 +802,8 @@ def parse_action_completion(text: str) -> Optional[ExperimentAction]:
584
  if not action_match:
585
  return None
586
 
587
- action_type = action_match.group(1).strip()
588
- if action_type not in VALID_ACTION_TYPES:
589
  return None
590
 
591
  method_match = re.search(
@@ -733,6 +951,7 @@ class OpenEnvReward:
733
  obs = env.step(previous_action)
734
  if obs.done:
735
  return float(obs.reward)
 
736
  obs = env.step(action)
737
  return float(obs.reward)
738
 
@@ -1081,7 +1300,7 @@ def generate_action_with_model(
1081
  tokenizer: Any,
1082
  prompt_or_observation: str | ExperimentObservation,
1083
  *,
1084
- max_new_tokens: int = 220,
1085
  temperature: float = 0.2,
1086
  top_p: float = 0.9,
1087
  do_sample: bool = True,
@@ -1114,6 +1333,8 @@ def generate_action_with_model(
1114
  new_tokens = output_ids[0][prompt_tokens:]
1115
  response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
1116
  action = parse_action_completion(response_text)
 
 
1117
  return {
1118
  "prompt": prompt,
1119
  "response_text": response_text,
 
1
+ """Train a self-driving lab planner with TRL GRPO and OpenEnv rewards."""
2
 
3
  from __future__ import annotations
4
 
 
21
  from server.hackathon_environment import BioExperimentEnvironment
22
  from server.tasks.scenarios import SCENARIO_LIBRARY
23
 
24
+ DEFAULT_MODEL_ID = "Qwen/Qwen3.5-4B"
25
  DEFAULT_OUTPUT_DIR = "training/grpo-output"
26
  DEFAULT_BASE_URL = "http://localhost:8000"
27
+ DEFAULT_COMPLETION_TOKEN_BUDGET = 160
28
  INVALID_ACTION_PENALTY = -2.0
29
  ENVIRONMENT_ERROR_PENALTY = -4.0
30
 
31
  SYSTEM_PROMPT = build_agent_system_prompt()
32
 
33
+ ACTION_TYPES = [action.value for action in ActionType]
34
+ ACTION_TYPE_ALIASES = {
35
+ "collect_samples": ActionType.COLLECT_SAMPLE.value,
36
+ "collect_sample_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
37
+ "collect_samples_from_bone_marrow": ActionType.COLLECT_SAMPLE.value,
38
+ "prepare_sc_library": ActionType.PREPARE_LIBRARY.value,
39
+ "sequence_single_cells": ActionType.SEQUENCE_CELLS.value,
40
+ "qc": ActionType.RUN_QC.value,
41
+ "run_quality_control": ActionType.RUN_QC.value,
42
+ "cluster": ActionType.CLUSTER_CELLS.value,
43
+ "de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
44
+ "differential_expression_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
45
+ "trajectory_inference": ActionType.TRAJECTORY_ANALYSIS.value,
46
+ "infer_trajectory": ActionType.TRAJECTORY_ANALYSIS.value,
47
+ "network_inference": ActionType.REGULATORY_NETWORK_INFERENCE.value,
48
+ "select_markers": ActionType.MARKER_SELECTION.value,
49
+ "final_conclusion": ActionType.SYNTHESIZE_CONCLUSION.value,
50
+ }
51
+
52
  HEURISTIC_SEQUENCE = [
53
  ActionType.COLLECT_SAMPLE,
54
+ ActionType.SELECT_COHORT,
55
  ActionType.PREPARE_LIBRARY,
56
  ActionType.SEQUENCE_CELLS,
57
  ActionType.RUN_QC,
58
  ActionType.FILTER_DATA,
59
  ActionType.NORMALIZE_DATA,
60
+ ActionType.INTEGRATE_BATCHES,
61
  ActionType.CLUSTER_CELLS,
62
  ActionType.DIFFERENTIAL_EXPRESSION,
63
  ActionType.PATHWAY_ENRICHMENT,
64
  ActionType.MARKER_SELECTION,
65
+ ActionType.TRAJECTORY_ANALYSIS,
66
+ ActionType.REGULATORY_NETWORK_INFERENCE,
67
  ActionType.SYNTHESIZE_CONCLUSION,
68
  ]
69
 
70
+ VALID_ACTION_TYPES = set(ACTION_TYPES)
71
 
72
 
73
  def compact_preview(value: Any, max_chars: int = 160) -> str:
 
153
  help="Enable domain randomisation while building prompts and local rewards.",
154
  )
155
  parser.add_argument("--num-generations", type=int, default=2)
156
+ parser.add_argument(
157
+ "--max-completion-length",
158
+ type=int,
159
+ default=DEFAULT_COMPLETION_TOKEN_BUDGET,
160
+ )
161
  parser.add_argument("--max-prompt-length", type=int, default=768)
162
  parser.add_argument("--per-device-train-batch-size", type=int, default=2)
163
  parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
 
225
  if context:
226
  parts.append(context)
227
  if obs.pipeline_history:
228
+ last5 = obs.pipeline_history[-5:]
229
+ parts.append("Recent history:")
230
+ for step in last5:
231
  tag = "OK" if step.success else "FAIL"
232
+ line = f" [{tag}] {step.action_type.value}"
233
+ if step.method:
234
+ line += f" ({step.method})"
235
+ line += f": {step.output_summary[:80]}"
236
  parts.append(line)
237
+ completed = {
238
+ step.action_type for step in obs.pipeline_history if step.success
239
+ }
240
+ if completed:
241
+ parts.append(
242
+ "Completed steps (do NOT repeat): "
243
+ + ", ".join(sorted(action.value for action in completed))
244
+ )
245
+ remaining = [
246
+ action.value for action in HEURISTIC_SEQUENCE if action not in completed
247
+ ]
248
+ if remaining:
249
+ parts.append(f"Remaining steps (choose one): {', '.join(remaining)}")
250
  if obs.latest_output and obs.latest_output.data:
251
  parts.append(
252
+ f"Latest data: {compact_preview(obs.latest_output.data, 200)}"
253
  )
254
  if obs.rule_violations:
255
+ parts.append(f"VIOLATIONS: {obs.rule_violations}")
256
  if obs.discovered_markers:
257
+ parts.append(f"Markers found so far: {obs.discovered_markers[:5]}")
258
  if obs.candidate_mechanisms:
259
+ parts.append(f"Candidate mechanisms: {obs.candidate_mechanisms[:5]}")
260
+ parts.append(
261
+ 'Output ONLY a single JSON object with these exact keys, no comments, no extra text:\n'
262
+ '{"action_type": "<one of the remaining steps>", "method": null, "parameters": {}, "justification": "<why>", "confidence": 0.8}'
263
+ )
264
  return "\n".join(parts)
265
 
266
 
 
298
  def build_experiment_action(
299
  action_type: ActionType,
300
  discovered_markers: Sequence[str],
301
+ candidate_mechanisms: Sequence[str],
302
  conditions: Sequence[str],
303
  ) -> ExperimentAction:
304
  method = None
 
308
  if action_type == ActionType.COLLECT_SAMPLE:
309
  parameters = {"n_samples": 6}
310
  justification = "Collect enough samples to start the experiment."
311
+ elif action_type == ActionType.SELECT_COHORT:
312
+ parameters = {
313
+ "comparison": default_comparison_name(conditions),
314
+ "conditions": list(conditions[:2]) or ["disease", "healthy"],
315
+ }
316
+ justification = "Define the cohort split before committing to downstream analysis."
317
  elif action_type == ActionType.PREPARE_LIBRARY:
318
  method = "10x_chromium"
319
  justification = "Prepare a single-cell library for sequencing."
320
+ elif action_type == ActionType.CULTURE_CELLS:
321
+ method = "organoid_culture"
322
+ parameters = {"duration_days": 7}
323
+ justification = "Expand viable cells before a perturbation or profiling step."
324
+ elif action_type == ActionType.PERTURB_GENE:
325
+ method = "CRISPRi"
326
+ parameters = {"target_gene": candidate_mechanisms[0] if candidate_mechanisms else "STAT3"}
327
+ justification = "Test whether a candidate regulator causally shifts cell state."
328
+ elif action_type == ActionType.PERTURB_COMPOUND:
329
+ method = "small_molecule_screen"
330
+ parameters = {"compound": candidate_mechanisms[0] if candidate_mechanisms else "TGFb_inhibitor"}
331
+ justification = "Probe the pathway hypothesis with a targeted compound perturbation."
332
  elif action_type == ActionType.SEQUENCE_CELLS:
333
  method = "NovaSeq"
334
  justification = "Generate reads for downstream single-cell analysis."
 
341
  elif action_type == ActionType.NORMALIZE_DATA:
342
  method = "scanpy.pp.normalize_total"
343
  justification = "Normalize counts for comparable expression profiles."
344
+ elif action_type == ActionType.INTEGRATE_BATCHES:
345
+ method = "scanorama.integrate"
346
+ justification = "Correct batch effects before comparing cellular programs."
347
  elif action_type == ActionType.CLUSTER_CELLS:
348
  method = "scanpy.tl.leiden"
349
  justification = "Resolve cell states before interpretation."
 
360
  elif action_type == ActionType.MARKER_SELECTION:
361
  method = "scanpy.tl.rank_genes_groups"
362
  justification = "Nominate marker genes for validation."
363
+ elif action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
364
+ method = "pySCENIC"
365
+ justification = "Infer upstream regulators behind the observed state changes."
366
  elif action_type == ActionType.VALIDATE_MARKER:
367
  method = "qPCR"
368
  parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
369
  justification = "Validate the strongest discovered marker."
370
+ elif action_type == ActionType.DESIGN_FOLLOWUP:
371
+ method = "followup_plan"
372
+ parameters = {"priority_hypothesis": candidate_mechanisms[0] if candidate_mechanisms else "fibrotic_activation"}
373
+ justification = "Propose the next experiment to disambiguate remaining uncertainty."
374
+ elif action_type == ActionType.REQUEST_SUBAGENT_REVIEW:
375
+ method = "peer_review"
376
+ parameters = {"focus": "experimental_design"}
377
+ justification = "Request a review of the current self-driving lab plan."
378
  elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
379
  top = list(discovered_markers[:5]) if discovered_markers else []
380
  parameters = {
381
  "claims": [{
382
  "top_markers": top,
383
+ "causal_mechanisms": list(candidate_mechanisms[:5]),
384
+ "predicted_pathways": {
385
+ mechanism: 0.6
386
+ for mechanism in list(candidate_mechanisms[:3])
387
+ },
388
  "confidence": 0.6,
389
+ "claim_type": "causal" if candidate_mechanisms else "correlational",
390
+ "claim": f"Synthesis for {default_comparison_name(conditions)}.",
391
  }],
392
  }
393
  justification = "Summarize the current evidence into a conclusion."
 
458
  [action.action_type for action in history_actions],
459
  ),
460
  discovered_markers=obs.discovered_markers,
461
+ candidate_mechanisms=obs.candidate_mechanisms,
462
  conditions=obs.task.conditions,
463
  )
464
  examples.append({
 
627
  return compact_preview(value, 80)
628
 
629
 
630
+ def normalize_action_type(raw_action_type: Any) -> Optional[str]:
631
+ if not isinstance(raw_action_type, str):
632
+ return None
633
+
634
+ candidate = raw_action_type.strip().lower()
635
+ if candidate in ACTION_TYPES:
636
+ return candidate
637
+ if candidate in ACTION_TYPE_ALIASES:
638
+ return ACTION_TYPE_ALIASES[candidate]
639
+
640
+ candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
641
+ if candidate in ACTION_TYPES:
642
+ return candidate
643
+ if candidate in ACTION_TYPE_ALIASES:
644
+ return ACTION_TYPE_ALIASES[candidate]
645
+
646
+ heuristics = [
647
+ (("collect", "sample"), ActionType.COLLECT_SAMPLE.value),
648
+ (("cohort",), ActionType.SELECT_COHORT.value),
649
+ (("library",), ActionType.PREPARE_LIBRARY.value),
650
+ (("culture",), ActionType.CULTURE_CELLS.value),
651
+ (("perturb", "gene"), ActionType.PERTURB_GENE.value),
652
+ (("perturb", "compound"), ActionType.PERTURB_COMPOUND.value),
653
+ (("sequence",), ActionType.SEQUENCE_CELLS.value),
654
+ (("qc",), ActionType.RUN_QC.value),
655
+ (("quality", "control"), ActionType.RUN_QC.value),
656
+ (("filter",), ActionType.FILTER_DATA.value),
657
+ (("normal",), ActionType.NORMALIZE_DATA.value),
658
+ (("integrat", "batch"), ActionType.INTEGRATE_BATCHES.value),
659
+ (("cluster",), ActionType.CLUSTER_CELLS.value),
660
+ (("differential", "expression"), ActionType.DIFFERENTIAL_EXPRESSION.value),
661
+ (("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
662
+ (("trajectory",), ActionType.TRAJECTORY_ANALYSIS.value),
663
+ (("network",), ActionType.REGULATORY_NETWORK_INFERENCE.value),
664
+ (("marker",), ActionType.MARKER_SELECTION.value),
665
+ (("validat", "marker"), ActionType.VALIDATE_MARKER.value),
666
+ (("followup",), ActionType.DESIGN_FOLLOWUP.value),
667
+ (("review",), ActionType.REQUEST_SUBAGENT_REVIEW.value),
668
+ (("conclusion",), ActionType.SYNTHESIZE_CONCLUSION.value),
669
+ ]
670
+ for fragments, normalized in heuristics:
671
+ if all(fragment in candidate for fragment in fragments):
672
+ return normalized
673
+ return None
674
+
675
+
676
+ def _unique_nonempty(items: Sequence[Any], limit: int = 5) -> List[str]:
677
+ seen: set[str] = set()
678
+ result: List[str] = []
679
+ for raw in items:
680
+ value = normalize_optional_string(raw)
681
+ if not value:
682
+ continue
683
+ key = value.upper()
684
+ if key in seen:
685
+ continue
686
+ seen.add(key)
687
+ result.append(value)
688
+ if len(result) >= limit:
689
+ break
690
+ return result
691
+
692
+
693
+ def _infer_conclusion_evidence(
694
+ obs: ExperimentObservation,
695
+ ) -> Tuple[List[str], List[str], Dict[str, float]]:
696
+ top_markers = _unique_nonempty(list(obs.discovered_markers), limit=5)
697
+ causal_mechanisms = _unique_nonempty(list(obs.candidate_mechanisms), limit=5)
698
+ predicted_pathways: Dict[str, float] = {}
699
+
700
+ for output in reversed(obs.all_outputs):
701
+ if not output.success:
702
+ continue
703
+ data = output.data or {}
704
+
705
+ if not top_markers:
706
+ markers = data.get("markers", [])
707
+ if isinstance(markers, list):
708
+ top_markers = _unique_nonempty(markers, limit=5)
709
+ if not causal_mechanisms:
710
+ regulators = data.get("top_regulators", [])
711
+ if isinstance(regulators, list):
712
+ causal_mechanisms = _unique_nonempty(regulators, limit=5)
713
+ if not predicted_pathways:
714
+ for item in data.get("top_pathways", []):
715
+ if not isinstance(item, dict):
716
+ continue
717
+ pathway = normalize_optional_string(item.get("pathway"))
718
+ score = item.get("score")
719
+ if pathway and isinstance(score, (int, float)):
720
+ predicted_pathways[pathway] = float(score)
721
+ if len(predicted_pathways) >= 5:
722
+ break
723
+ if top_markers and causal_mechanisms and predicted_pathways:
724
+ break
725
+
726
+ return top_markers, causal_mechanisms, predicted_pathways
727
+
728
+
729
+ def ensure_conclusion_claims(
730
+ obs: ExperimentObservation,
731
+ action: ExperimentAction,
732
+ ) -> ExperimentAction:
733
+ if action.action_type != ActionType.SYNTHESIZE_CONCLUSION:
734
+ return action
735
+
736
+ parameters = dict(action.parameters or {})
737
+ raw_claims = parameters.get("claims")
738
+ if isinstance(raw_claims, list):
739
+ normalized_claims = [claim for claim in raw_claims if isinstance(claim, dict)]
740
+ if normalized_claims:
741
+ parameters["claims"] = normalized_claims
742
+ if parameters != action.parameters:
743
+ return action.model_copy(update={"parameters": parameters})
744
+ return action
745
+
746
+ top_markers, causal_mechanisms, predicted_pathways = _infer_conclusion_evidence(obs)
747
+ claim_type = "causal" if causal_mechanisms else "correlational"
748
+ conditions = " vs ".join(obs.task.conditions[:2]) if obs.task.conditions else "the task conditions"
749
+ claim = action.justification or f"Final synthesis for {conditions}."
750
+
751
+ parameters["claims"] = [{
752
+ "top_markers": top_markers,
753
+ "causal_mechanisms": causal_mechanisms,
754
+ "predicted_pathways": predicted_pathways,
755
+ "confidence": action.confidence,
756
+ "claim_type": claim_type,
757
+ "claim": claim,
758
+ }]
759
+ if not action.justification:
760
+ action = action.model_copy(update={"justification": claim})
761
+ return action.model_copy(update={"parameters": parameters})
762
+
763
+
764
  def parse_action_completion(text: str) -> Optional[ExperimentAction]:
765
  payload = extract_json_object(text)
766
  if payload is not None:
767
+ action_type = normalize_action_type(get_payload_value(payload, "action_type"))
768
+ if action_type is None:
769
  return None
770
 
771
  parameters = get_payload_value(payload, "parameters", "params") or {}
 
802
  if not action_match:
803
  return None
804
 
805
+ action_type = normalize_action_type(action_match.group(1))
806
+ if action_type is None:
807
  return None
808
 
809
  method_match = re.search(
 
951
  obs = env.step(previous_action)
952
  if obs.done:
953
  return float(obs.reward)
954
+ action = ensure_conclusion_claims(obs, action)
955
  obs = env.step(action)
956
  return float(obs.reward)
957
 
 
1300
  tokenizer: Any,
1301
  prompt_or_observation: str | ExperimentObservation,
1302
  *,
1303
+ max_new_tokens: int = DEFAULT_COMPLETION_TOKEN_BUDGET,
1304
  temperature: float = 0.2,
1305
  top_p: float = 0.9,
1306
  do_sample: bool = True,
 
1333
  new_tokens = output_ids[0][prompt_tokens:]
1334
  response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
1335
  action = parse_action_completion(response_text)
1336
+ if action is not None and isinstance(prompt_or_observation, ExperimentObservation):
1337
+ action = ensure_conclusion_claims(prompt_or_observation, action)
1338
  return {
1339
  "prompt": prompt,
1340
  "response_text": response_text,
training_unsloth.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train and run quantized self-driving lab models with Unsloth.
2
+
3
+ This keeps the same OpenEnv prompt + reward wiring as `training_script.py`,
4
+ but arranges the Unsloth path in the more typical pattern:
5
+ 1. patch GRPO support
6
+ 2. load a quantized model
7
+ 3. apply LoRA adapters
8
+ 4. train with an explicit OpenEnv reward function
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import random
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Optional, Sequence
17
+
18
+ import training_script as base
19
+
20
+ DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
21
+ DEFAULT_MAX_SEQ_LENGTH = 2048
22
+ DEFAULT_LORA_R = 16
23
+ DEFAULT_LORA_ALPHA = 16
24
+ DEFAULT_LORA_DROPOUT = 0.0
25
+ LORA_TARGET_MODULES = [
26
+ "q_proj",
27
+ "k_proj",
28
+ "v_proj",
29
+ "o_proj",
30
+ "gate_proj",
31
+ "up_proj",
32
+ "down_proj",
33
+ ]
34
+
35
+
36
+ def require_unsloth():
37
+ try:
38
+ from unsloth import FastLanguageModel, PatchFastRL
39
+ except ImportError as exc: # pragma: no cover - depends on optional extra
40
+ raise RuntimeError(
41
+ "Unsloth is not installed. Run `uv sync --extra train` "
42
+ "to install the H100/quantized training dependencies."
43
+ ) from exc
44
+ return FastLanguageModel, PatchFastRL
45
+
46
+
47
+ def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any):
48
+ for optional_key in ("fast_inference", "trust_remote_code"):
49
+ try:
50
+ return FastLanguageModel.from_pretrained(**kwargs)
51
+ except TypeError as exc:
52
+ if optional_key in kwargs and optional_key in str(exc):
53
+ kwargs = dict(kwargs)
54
+ kwargs.pop(optional_key, None)
55
+ continue
56
+ raise
57
+ return FastLanguageModel.from_pretrained(**kwargs)
58
+
59
+
60
+ def build_argument_parser() -> argparse.ArgumentParser:
61
+ parser = base.build_argument_parser()
62
+ parser.description = (
63
+ "Train a GRPO policy with Unsloth quantized loading for faster H100 runs."
64
+ )
65
+ parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR)
66
+ parser.add_argument(
67
+ "--max-seq-length",
68
+ type=int,
69
+ default=DEFAULT_MAX_SEQ_LENGTH,
70
+ help="Context length passed to Unsloth model loading.",
71
+ )
72
+ parser.add_argument(
73
+ "--disable-4bit",
74
+ action="store_true",
75
+ help="Disable 4-bit quantized loading and use the wider base weights.",
76
+ )
77
+ parser.add_argument(
78
+ "--disable-fast-inference",
79
+ action="store_true",
80
+ help="Disable Unsloth fast inference kernels where supported.",
81
+ )
82
+ parser.add_argument(
83
+ "--lora-r",
84
+ type=int,
85
+ default=DEFAULT_LORA_R,
86
+ help="LoRA rank used for the quantized GRPO policy.",
87
+ )
88
+ parser.add_argument(
89
+ "--lora-alpha",
90
+ type=int,
91
+ default=DEFAULT_LORA_ALPHA,
92
+ help="LoRA alpha used for the quantized GRPO policy.",
93
+ )
94
+ parser.add_argument(
95
+ "--lora-dropout",
96
+ type=float,
97
+ default=DEFAULT_LORA_DROPOUT,
98
+ help="LoRA dropout used for the quantized GRPO policy.",
99
+ )
100
+ parser.add_argument(
101
+ "--save-merged-16bit",
102
+ action="store_true",
103
+ help="Also export a merged 16-bit model after training if supported.",
104
+ )
105
+ return parser
106
+
107
+
108
+ def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
109
+ return build_argument_parser().parse_args(argv)
110
+
111
+
112
+ def make_training_args(**overrides: Any) -> argparse.Namespace:
113
+ parser = build_argument_parser()
114
+ defaults = vars(parser.parse_args([]))
115
+ unknown = sorted(set(overrides) - set(defaults))
116
+ if unknown:
117
+ raise ValueError(f"Unknown training args: {', '.join(unknown)}")
118
+ defaults.update(overrides)
119
+ return argparse.Namespace(**defaults)
120
+
121
+
122
+ def load_model_artifacts(
123
+ model_id: str,
124
+ *,
125
+ trust_remote_code: bool,
126
+ max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
127
+ load_in_4bit: bool = True,
128
+ fast_inference: bool = True,
129
+ prepare_for_inference: bool = False,
130
+ ):
131
+ FastLanguageModel, _ = require_unsloth()
132
+ runtime = base.resolve_torch_runtime()
133
+
134
+ print(f"Loading Unsloth tokenizer+model for {model_id} ...")
135
+ model, tokenizer = _call_unsloth_from_pretrained(
136
+ FastLanguageModel,
137
+ model_name=model_id,
138
+ max_seq_length=max_seq_length,
139
+ dtype="auto",
140
+ load_in_4bit=load_in_4bit,
141
+ fast_inference=fast_inference,
142
+ trust_remote_code=trust_remote_code,
143
+ )
144
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
145
+ tokenizer.pad_token = tokenizer.eos_token
146
+
147
+ if prepare_for_inference:
148
+ try:
149
+ FastLanguageModel.for_inference(model)
150
+ except AttributeError:
151
+ pass
152
+
153
+ device = getattr(model, "device", None)
154
+ if device is None:
155
+ try:
156
+ device = next(model.parameters()).device
157
+ except StopIteration:
158
+ device = runtime["device"]
159
+ print(f"Loaded model on device: {device}")
160
+ return tokenizer, model
161
+
162
+
163
+ def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward:
164
+ """Return the OpenEnv-compatible reward callable used by GRPO."""
165
+ return base.OpenEnvReward(
166
+ reward_backend=args.reward_backend,
167
+ base_url=args.base_url,
168
+ domain_randomise=args.domain_randomise,
169
+ )
170
+
171
+
172
+ def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
173
+ """Build the OpenEnv rollout states that seed GRPO prompts."""
174
+ scenario_names = base.selected_scenarios(args.scenario_name)
175
+ examples = base.build_prompt_examples(
176
+ dataset_episodes=args.dataset_episodes,
177
+ rollout_steps=args.rollout_steps,
178
+ collection_policy=args.collection_policy,
179
+ scenario_names=scenario_names,
180
+ seed=args.seed,
181
+ domain_randomise=args.domain_randomise,
182
+ )
183
+ return {
184
+ "scenario_names": scenario_names,
185
+ "examples": examples,
186
+ }
187
+
188
+
189
+ def patch_unsloth_grpo():
190
+ """Patch TRL GRPO to use Unsloth's optimized kernels."""
191
+ FastLanguageModel, PatchFastRL = require_unsloth()
192
+ PatchFastRL("GRPO", FastLanguageModel)
193
+ return FastLanguageModel
194
+
195
+
196
+ def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any:
197
+ """Apply LoRA adapters in the usual Unsloth configuration style."""
198
+ return FastLanguageModel.get_peft_model(
199
+ model,
200
+ r=args.lora_r,
201
+ target_modules=LORA_TARGET_MODULES,
202
+ lora_alpha=args.lora_alpha,
203
+ lora_dropout=args.lora_dropout,
204
+ bias="none",
205
+ use_gradient_checkpointing=True,
206
+ random_state=args.seed,
207
+ )
208
+
209
+
210
+ def build_grpo_config(
211
+ args: argparse.Namespace,
212
+ runtime: Dict[str, Any],
213
+ ):
214
+ from trl import GRPOConfig
215
+
216
+ return GRPOConfig(
217
+ output_dir=args.output_dir,
218
+ learning_rate=args.learning_rate,
219
+ per_device_train_batch_size=args.per_device_train_batch_size,
220
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
221
+ num_generations=args.num_generations,
222
+ max_completion_length=args.max_completion_length,
223
+ num_train_epochs=args.num_train_epochs,
224
+ logging_steps=args.logging_steps,
225
+ save_steps=args.save_steps,
226
+ bf16=runtime["bf16"],
227
+ fp16=runtime["fp16"],
228
+ report_to="none",
229
+ remove_unused_columns=False,
230
+ )
231
+
232
+
233
+ def build_unsloth_grpo_trainer(
234
+ *,
235
+ model: Any,
236
+ tokenizer: Any,
237
+ reward_func: Any,
238
+ train_dataset: Any,
239
+ args: argparse.Namespace,
240
+ runtime: Dict[str, Any],
241
+ ):
242
+ from trl import GRPOTrainer
243
+
244
+ config = build_grpo_config(args, runtime)
245
+ return GRPOTrainer(
246
+ model=model,
247
+ reward_funcs=reward_func,
248
+ args=config,
249
+ train_dataset=train_dataset,
250
+ processing_class=tokenizer,
251
+ )
252
+
253
+
254
+ def generate_action_with_model(
255
+ model: Any,
256
+ tokenizer: Any,
257
+ prompt_or_observation: str | base.ExperimentObservation,
258
+ *,
259
+ max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET,
260
+ temperature: float = 0.2,
261
+ top_p: float = 0.9,
262
+ do_sample: bool = True,
263
+ ) -> Dict[str, Any]:
264
+ import torch
265
+
266
+ if isinstance(prompt_or_observation, base.ExperimentObservation):
267
+ prompt = base.build_training_prompt(prompt_or_observation)
268
+ else:
269
+ prompt = str(prompt_or_observation)
270
+
271
+ model_device = getattr(model, "device", None)
272
+ if model_device is None:
273
+ try:
274
+ model_device = next(model.parameters()).device
275
+ except StopIteration:
276
+ model_device = base.resolve_torch_runtime()["device"]
277
+
278
+ inputs = tokenizer(prompt, return_tensors="pt")
279
+ inputs = {key: value.to(model_device) for key, value in inputs.items()}
280
+ prompt_tokens = inputs["input_ids"].shape[1]
281
+
282
+ generation_kwargs = {
283
+ "max_new_tokens": max_new_tokens,
284
+ "do_sample": do_sample,
285
+ "temperature": temperature,
286
+ "top_p": top_p,
287
+ "pad_token_id": tokenizer.pad_token_id,
288
+ }
289
+ with torch.no_grad():
290
+ output_ids = model.generate(**inputs, **generation_kwargs)
291
+
292
+ new_tokens = output_ids[0][prompt_tokens:]
293
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
294
+ action = base.parse_action_completion(response_text)
295
+ if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation):
296
+ action = base.ensure_conclusion_claims(prompt_or_observation, action)
297
+ return {
298
+ "prompt": prompt,
299
+ "response_text": response_text,
300
+ "action": action,
301
+ }
302
+
303
+
304
+ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
305
+ random.seed(args.seed)
306
+ runtime = base.resolve_torch_runtime()
307
+
308
+ if args.load_model_only:
309
+ tokenizer, model = load_model_artifacts(
310
+ args.model_id,
311
+ trust_remote_code=args.trust_remote_code,
312
+ max_seq_length=args.max_seq_length,
313
+ load_in_4bit=not args.disable_4bit,
314
+ fast_inference=not args.disable_fast_inference,
315
+ prepare_for_inference=True,
316
+ )
317
+ device = getattr(model, "device", "unknown")
318
+ print(f"Unsloth model ready: {args.model_id}")
319
+ print(f"Tokenizer vocab size: {len(tokenizer)}")
320
+ print(f"Model device: {device}")
321
+ print(f"Runtime device name: {runtime['device_name']}")
322
+ return {
323
+ "args": args,
324
+ "runtime": runtime,
325
+ "tokenizer": tokenizer,
326
+ "model": model,
327
+ }
328
+
329
+ prompt_data = prepare_prompt_examples(args)
330
+ scenario_names = prompt_data["scenario_names"]
331
+ examples = prompt_data["examples"]
332
+ env_reward = build_openenv_reward(args)
333
+
334
+ if args.dry_run:
335
+ base.run_dry_run_preview(examples, env_reward, args.output_dir)
336
+ return {
337
+ "args": args,
338
+ "runtime": runtime,
339
+ "scenario_names": scenario_names,
340
+ "examples": examples,
341
+ "reward_fn": env_reward,
342
+ }
343
+
344
+ from datasets import Dataset
345
+
346
+ FastLanguageModel = patch_unsloth_grpo()
347
+ train_dataset = Dataset.from_list(examples)
348
+
349
+ # 1. Load model with Unsloth quantized loading.
350
+ tokenizer, model = load_model_artifacts(
351
+ args.model_id,
352
+ trust_remote_code=args.trust_remote_code,
353
+ max_seq_length=args.max_seq_length,
354
+ load_in_4bit=not args.disable_4bit,
355
+ fast_inference=not args.disable_fast_inference,
356
+ )
357
+ # 2. Apply LoRA adapters.
358
+ model = apply_lora_adapters(FastLanguageModel, model, args)
359
+
360
+ print(
361
+ f"Unsloth training runtime: device={runtime['device']} "
362
+ f"name={runtime['device_name']} "
363
+ f"dtype={runtime['dtype']} "
364
+ f"load_in_4bit={not args.disable_4bit}"
365
+ )
366
+ print(
367
+ "OpenEnv reward: "
368
+ f"backend={args.reward_backend} scenarios={len(scenario_names)} "
369
+ f"examples={len(examples)}"
370
+ )
371
+
372
+ # 3. Train with GRPO against the OpenEnv reward function.
373
+ trainer = build_unsloth_grpo_trainer(
374
+ model=model,
375
+ tokenizer=tokenizer,
376
+ reward_func=env_reward,
377
+ train_dataset=train_dataset,
378
+ args=args,
379
+ runtime=runtime,
380
+ )
381
+ trainer.train()
382
+ trainer.save_model(args.output_dir)
383
+ tokenizer.save_pretrained(args.output_dir)
384
+
385
+ if args.save_merged_16bit:
386
+ merged_dir = Path(args.output_dir) / "merged_16bit"
387
+ try:
388
+ model.save_pretrained_merged(
389
+ str(merged_dir),
390
+ tokenizer,
391
+ save_method="merged_16bit",
392
+ )
393
+ print(f"Saved merged 16-bit model to {merged_dir}")
394
+ except AttributeError:
395
+ print("Merged 16-bit export is not available in this Unsloth build; skipping.")
396
+
397
+ if args.push_to_hub:
398
+ from huggingface_hub import HfApi
399
+
400
+ api = HfApi()
401
+ api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
402
+ print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
403
+ api.upload_folder(
404
+ folder_path=args.output_dir,
405
+ repo_id=args.push_to_hub,
406
+ repo_type="model",
407
+ create_pr=False,
408
+ )
409
+ print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
410
+
411
+ plot_paths = base.save_training_plots(
412
+ trainer.state.log_history,
413
+ args.output_dir,
414
+ metric_key=args.plot_metric_key,
415
+ )
416
+ print("Saved training plots:")
417
+ for plot_name, plot_path in plot_paths.items():
418
+ print(f" - {plot_name}: {plot_path}")
419
+
420
+ return {
421
+ "args": args,
422
+ "runtime": runtime,
423
+ "scenario_names": scenario_names,
424
+ "examples": examples,
425
+ "reward_fn": env_reward,
426
+ "train_dataset": train_dataset,
427
+ "tokenizer": tokenizer,
428
+ "model": model,
429
+ "trainer": trainer,
430
+ "plot_paths": plot_paths,
431
+ }
432
+
433
+
434
+ def main() -> None:
435
+ run_training(parse_args())
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()
uv.lock CHANGED
The diff for this file is too large to render. See raw diff