| """ |
| Experiment specification for the generic lab environment. |
| |
| Defines protocol presets, inventory, rewards, and outcome model so LabEnv can |
| simulate any experiment type (PCR, ELISA, etc.) from a single spec. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Any, Callable |
|
|
| import numpy as np |
|
|
|
|
| @dataclass |
| class ExperimentSpec: |
| """Specification for a single experiment type (PCR, ELISA, etc.). |
| |
| The environment uses this to build action/observation spaces and dynamics. |
| Outcome logic is pluggable via sample_hidden_optimum and sample_assay_result. |
| """ |
|
|
| name: str |
| """Short name for this experiment (e.g. 'pcr', 'elisa').""" |
|
|
| presets: list[dict[str, Any]] |
| """List of protocol presets the agent can choose (e.g. temp/cycles/ratio for PCR).""" |
|
|
| inventory_items: list[str] |
| """Ordered list of inventory item names (tips, buffer, polymerase, samples, ...).""" |
|
|
| orderable_items: list[str] |
| """Subset of inventory_items that can be reordered (each gets an order action).""" |
|
|
| initial_inventory: dict[str, int] |
| """Starting count per inventory item.""" |
|
|
| order_costs: dict[str, tuple[int, float]] |
| """For each orderable item: (quantity_per_order, cost_per_order).""" |
|
|
| result_labels: list[str] |
| """Possible assay outcomes, e.g. ['none', 'success', 'partial', 'fail'].""" |
|
|
| |
| max_steps: int = 30 |
| max_minutes: float = 240.0 |
| initial_budget: float = 500.0 |
| max_inventory: int = 20 |
|
|
| |
| assay_time_minutes: float = 20.0 |
| order_time_minutes: float = 5.0 |
| wait_minutes: float = 15.0 |
|
|
| |
| assay_penalty: float = -3.0 |
| time_penalty_per_min: float = -0.25 |
| no_success_penalty: float = -20.0 |
| immediate_result_reward: dict[str, float] = field(default_factory=dict) |
| terminal_bonus: dict[str, float] = field(default_factory=dict) |
|
|
| |
| sample_hidden_optimum: Callable[[np.random.Generator], dict[str, Any]] | None = None |
| sample_assay_result: ( |
| Callable[ |
| [dict[str, Any], int, list[dict[str, Any]], np.random.Generator], |
| str, |
| ] |
| | None |
| ) = None |
|
|
| |
| evaluate_custom_protocol: ( |
| Callable[ |
| [dict[str, Any], dict[str, Any], np.random.Generator], |
| str, |
| ] |
| | None |
| ) = None |
| """If set, (hidden_optimum, protocol_dict, rng) -> result label. Enables run_assay_with_protocol().""" |
|
|
| protocol_param_schema: dict[str, Any] = field(default_factory=dict) |
| """Schema describing protocol params for codegen/LLM: e.g. {"temp": {"type": "number", "description": "°C"}, ...}.""" |
|
|
| @property |
| def num_presets(self) -> int: |
| return len(self.presets) |
|
|
| @property |
| def num_actions(self) -> int: |
| return ( |
| self.num_presets |
| + 1 |
| + len(self.orderable_items) |
| + 2 |
| ) |
|
|
| @property |
| def obs_dim(self) -> int: |
| return ( |
| 3 |
| + len(self.inventory_items) |
| + len(self.result_labels) |
| + 3 |
| ) |
|
|
| def action_setup_start(self) -> int: |
| return 0 |
|
|
| def action_setup_end(self) -> int: |
| return self.num_presets |
|
|
| def action_run_assay(self) -> int: |
| return self.num_presets |
|
|
| def action_order_start(self) -> int: |
| return self.num_presets + 1 |
|
|
| def action_order_end(self) -> int: |
| return self.num_presets + 1 + len(self.orderable_items) |
|
|
| def action_wait(self) -> int: |
| return self.num_presets + 1 + len(self.orderable_items) |
|
|
| def action_finish(self) -> int: |
| return self.num_presets + 2 + len(self.orderable_items) |
|
|
|
|
| |
| |
| |
|
|
| def _pcr_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]: |
| temps = [55.0, 65.0, 72.0] |
| cycles = [25, 35] |
| ratios = ["conservative", "aggressive"] |
| opt_temp = float(rng.choice(temps, p=[0.2, 0.5, 0.3])) + rng.uniform(-3.0, 3.0) |
| opt_cycles = float(rng.choice(cycles, p=[0.6, 0.4])) + rng.uniform(-2.0, 2.0) |
| opt_ratio = str(rng.choice(ratios, p=[0.6, 0.4])) |
| return {"temp": opt_temp, "cycles": opt_cycles, "ratio": opt_ratio} |
|
|
|
|
| def _pcr_sample_assay_result( |
| hidden: dict[str, Any], |
| preset_idx: int, |
| presets: list[dict[str, Any]], |
| rng: np.random.Generator, |
| ) -> str: |
| preset = presets[preset_idx] |
| chosen_temp = float(preset["temp"]) |
| chosen_cycles = float(preset["cycles"]) |
| chosen_ratio = str(preset["ratio"]) |
| opt_temp = hidden["temp"] |
| opt_cycles = hidden["cycles"] |
| opt_ratio = hidden["ratio"] |
|
|
| temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0) |
| cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0) |
| ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3 |
| closeness = temp_close * cycle_close * ratio_match |
|
|
| p_success = closeness ** 2 |
| p_partial = closeness * (1.0 - closeness) * 0.8 |
| p_fail = 1.0 - p_success - p_partial |
| return str( |
| rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) |
| ) |
|
|
|
|
| def _pcr_evaluate_custom_protocol( |
| hidden: dict[str, Any], |
| protocol: dict[str, Any], |
| rng: np.random.Generator, |
| ) -> str: |
| """Evaluate any protocol dict (temp, cycles, ratio) against hidden optimum.""" |
| chosen_temp = float(protocol.get("temp", 60.0)) |
| chosen_cycles = float(protocol.get("cycles", 30)) |
| r = str(protocol.get("ratio", "conservative")).strip().lower() |
| chosen_ratio = "conservative" if "conservative" in r else "aggressive" |
| opt_temp = hidden["temp"] |
| opt_cycles = hidden["cycles"] |
| opt_ratio = hidden["ratio"] |
| temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0) |
| cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0) |
| ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3 |
| closeness = temp_close * cycle_close * ratio_match |
| p_success = closeness ** 2 |
| p_partial = closeness * (1.0 - closeness) * 0.8 |
| p_fail = 1.0 - p_success - p_partial |
| return str( |
| rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) |
| ) |
|
|
|
|
| PCR_PROTOCOL_SCHEMA = { |
| "temp": {"type": "number", "description": "Annealing temperature in °C (e.g. 55–72)"}, |
| "cycles": {"type": "integer", "description": "Number of PCR cycles (e.g. 25–40)"}, |
| "ratio": {"type": "string", "enum": ["conservative", "aggressive"], "description": "Reagent ratio"}, |
| } |
|
|
|
|
| def pcr_experiment_spec() -> ExperimentSpec: |
| """Build the default PCR experiment spec (same behaviour as original LabEnv).""" |
| from itertools import product |
|
|
| temps = [55.0, 65.0, 72.0] |
| cycles = [25, 35] |
| ratios = ["conservative", "aggressive"] |
| presets = [ |
| {"temp": t, "cycles": c, "ratio": r} |
| for t, c, r in product(temps, cycles, ratios) |
| ] |
| return ExperimentSpec( |
| name="pcr", |
| presets=presets, |
| inventory_items=["tips", "buffer", "polymerase", "samples"], |
| orderable_items=["tips", "buffer", "polymerase"], |
| initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8}, |
| order_costs={ |
| "tips": (5, 10.0), |
| "buffer": (5, 15.0), |
| "polymerase": (3, 30.0), |
| }, |
| result_labels=["none", "success", "partial", "fail"], |
| max_steps=30, |
| max_minutes=240.0, |
| initial_budget=500.0, |
| max_inventory=20, |
| assay_time_minutes=20.0, |
| order_time_minutes=5.0, |
| wait_minutes=15.0, |
| assay_penalty=-3.0, |
| time_penalty_per_min=-0.25, |
| no_success_penalty=-20.0, |
| immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0}, |
| terminal_bonus={"success": 60.0, "partial": 25.0}, |
| sample_hidden_optimum=_pcr_sample_hidden_optimum, |
| sample_assay_result=_pcr_sample_assay_result, |
| evaluate_custom_protocol=_pcr_evaluate_custom_protocol, |
| protocol_param_schema=PCR_PROTOCOL_SCHEMA, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _elisa_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]: |
| coating_hrs = [1.0, 2.0, 3.0] |
| temps = [4.0, 37.0] |
| blocks = ["bsa", "casein"] |
| opt_coating = float(rng.choice(coating_hrs, p=[0.3, 0.5, 0.2])) + rng.uniform(-0.2, 0.2) |
| opt_temp = float(rng.choice(temps, p=[0.5, 0.5])) + rng.uniform(-2.0, 2.0) |
| opt_block = str(rng.choice(blocks, p=[0.6, 0.4])) |
| return {"coating_hr": opt_coating, "temp": opt_temp, "block": opt_block} |
|
|
|
|
| def _elisa_sample_assay_result( |
| hidden: dict[str, Any], |
| preset_idx: int, |
| presets: list[dict[str, Any]], |
| rng: np.random.Generator, |
| ) -> str: |
| preset = presets[preset_idx] |
| c = float(preset["coating_hr"]) |
| t = float(preset["temp"]) |
| b = str(preset["block"]) |
| oc = hidden["coating_hr"] |
| ot = hidden["temp"] |
| ob = hidden["block"] |
| coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0) |
| temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0) |
| block_match = 1.0 if b == ob else 0.3 |
| closeness = coat_close * temp_close * block_match |
| p_success = closeness ** 2 |
| p_partial = closeness * (1.0 - closeness) * 0.8 |
| p_fail = 1.0 - p_success - p_partial |
| return str( |
| rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) |
| ) |
|
|
|
|
| def _elisa_evaluate_custom_protocol( |
| hidden: dict[str, Any], |
| protocol: dict[str, Any], |
| rng: np.random.Generator, |
| ) -> str: |
| """Evaluate any protocol dict (coating_hr, temp, block) against hidden optimum.""" |
| c = float(protocol.get("coating_hr", 2.0)) |
| t = float(protocol.get("temp", 25.0)) |
| b = str(protocol.get("block", "bsa")).strip().lower() |
| block_clean = "bsa" if "bsa" in b else "casein" |
| oc, ot, ob = hidden["coating_hr"], hidden["temp"], hidden["block"] |
| coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0) |
| temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0) |
| block_match = 1.0 if block_clean == ob else 0.3 |
| closeness = coat_close * temp_close * block_match |
| p_success = closeness ** 2 |
| p_partial = closeness * (1.0 - closeness) * 0.8 |
| p_fail = 1.0 - p_success - p_partial |
| return str( |
| rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) |
| ) |
|
|
|
|
| ELISA_PROTOCOL_SCHEMA = { |
| "coating_hr": {"type": "number", "description": "Coating time in hours (e.g. 1–3)"}, |
| "temp": {"type": "number", "description": "Incubation temperature °C (e.g. 4 or 37)"}, |
| "block": {"type": "string", "enum": ["bsa", "casein"], "description": "Blocking agent"}, |
| } |
|
|
|
|
| def elisa_experiment_spec() -> ExperimentSpec: |
| """ELISA readout: coating time (hr), temperature (°C), blocking type. Same obs/action dims as PCR.""" |
| from itertools import product |
|
|
| coating_hrs = [1.0, 2.0, 3.0] |
| temps = [4.0, 37.0] |
| blocks = ["bsa", "casein"] |
| presets = [ |
| {"coating_hr": ch, "temp": t, "block": bl} |
| for ch, t, bl in product(coating_hrs, temps, blocks) |
| ] |
| return ExperimentSpec( |
| name="elisa", |
| presets=presets, |
| inventory_items=["tips", "buffer", "polymerase", "samples"], |
| orderable_items=["tips", "buffer", "polymerase"], |
| initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8}, |
| order_costs={ |
| "tips": (5, 10.0), |
| "buffer": (5, 15.0), |
| "polymerase": (3, 30.0), |
| }, |
| result_labels=["none", "success", "partial", "fail"], |
| max_steps=30, |
| max_minutes=240.0, |
| initial_budget=500.0, |
| max_inventory=20, |
| assay_time_minutes=20.0, |
| order_time_minutes=5.0, |
| wait_minutes=15.0, |
| assay_penalty=-3.0, |
| time_penalty_per_min=-0.25, |
| no_success_penalty=-20.0, |
| immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0}, |
| terminal_bonus={"success": 60.0, "partial": 25.0}, |
| sample_hidden_optimum=_elisa_sample_hidden_optimum, |
| sample_assay_result=_elisa_sample_assay_result, |
| evaluate_custom_protocol=_elisa_evaluate_custom_protocol, |
| protocol_param_schema=ELISA_PROTOCOL_SCHEMA, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def get_spec_for_workflow(workflow_id: str) -> ExperimentSpec: |
| """Return the experiment spec for a given workflow ID. Unknown IDs default to PCR.""" |
| _registry: dict[str, Callable[[], ExperimentSpec]] = { |
| "pcr-amplification": pcr_experiment_spec, |
| "elisa-readout": elisa_experiment_spec, |
| } |
| factory = _registry.get(workflow_id) or pcr_experiment_spec |
| return factory() |
|
|