""" Experiment specification for the generic lab environment. Defines protocol presets, inventory, rewards, and outcome model so LabEnv can simulate any experiment type (PCR, ELISA, etc.) from a single spec. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable import numpy as np @dataclass class ExperimentSpec: """Specification for a single experiment type (PCR, ELISA, etc.). The environment uses this to build action/observation spaces and dynamics. Outcome logic is pluggable via sample_hidden_optimum and sample_assay_result. """ name: str """Short name for this experiment (e.g. 'pcr', 'elisa').""" presets: list[dict[str, Any]] """List of protocol presets the agent can choose (e.g. temp/cycles/ratio for PCR).""" inventory_items: list[str] """Ordered list of inventory item names (tips, buffer, polymerase, samples, ...).""" orderable_items: list[str] """Subset of inventory_items that can be reordered (each gets an order action).""" initial_inventory: dict[str, int] """Starting count per inventory item.""" order_costs: dict[str, tuple[int, float]] """For each orderable item: (quantity_per_order, cost_per_order).""" result_labels: list[str] """Possible assay outcomes, e.g. ['none', 'success', 'partial', 'fail'].""" # Limits max_steps: int = 30 max_minutes: float = 240.0 initial_budget: float = 500.0 max_inventory: int = 20 # Time costs assay_time_minutes: float = 20.0 order_time_minutes: float = 5.0 wait_minutes: float = 15.0 # Rewards assay_penalty: float = -3.0 time_penalty_per_min: float = -0.25 no_success_penalty: float = -20.0 immediate_result_reward: dict[str, float] = field(default_factory=dict) terminal_bonus: dict[str, float] = field(default_factory=dict) # Outcome model: callables that take (rng) or (hidden_state, preset_idx, presets, rng) sample_hidden_optimum: Callable[[np.random.Generator], dict[str, Any]] | None = None sample_assay_result: ( Callable[ [dict[str, Any], int, list[dict[str, Any]], np.random.Generator], str, ] | None ) = None # Custom protocol support: evaluate arbitrary protocol dict (for agent-generated protocols) evaluate_custom_protocol: ( Callable[ [dict[str, Any], dict[str, Any], np.random.Generator], str, ] | None ) = None """If set, (hidden_optimum, protocol_dict, rng) -> result label. Enables run_assay_with_protocol().""" protocol_param_schema: dict[str, Any] = field(default_factory=dict) """Schema describing protocol params for codegen/LLM: e.g. {"temp": {"type": "number", "description": "°C"}, ...}.""" @property def num_presets(self) -> int: return len(self.presets) @property def num_actions(self) -> int: return ( self.num_presets + 1 # run_assay + len(self.orderable_items) + 2 # wait, finish ) @property def obs_dim(self) -> int: return ( 3 # step_index, elapsed_minutes, remaining_budget + len(self.inventory_items) + len(self.result_labels) + 3 # has_setup, current_preset_idx (norm), best_result_score ) def action_setup_start(self) -> int: return 0 def action_setup_end(self) -> int: return self.num_presets def action_run_assay(self) -> int: return self.num_presets def action_order_start(self) -> int: return self.num_presets + 1 def action_order_end(self) -> int: return self.num_presets + 1 + len(self.orderable_items) def action_wait(self) -> int: return self.num_presets + 1 + len(self.orderable_items) def action_finish(self) -> int: return self.num_presets + 2 + len(self.orderable_items) # --------------------------------------------------------------------------- # PCR experiment spec (default / backward compatibility) # --------------------------------------------------------------------------- def _pcr_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]: temps = [55.0, 65.0, 72.0] cycles = [25, 35] ratios = ["conservative", "aggressive"] opt_temp = float(rng.choice(temps, p=[0.2, 0.5, 0.3])) + rng.uniform(-3.0, 3.0) opt_cycles = float(rng.choice(cycles, p=[0.6, 0.4])) + rng.uniform(-2.0, 2.0) opt_ratio = str(rng.choice(ratios, p=[0.6, 0.4])) return {"temp": opt_temp, "cycles": opt_cycles, "ratio": opt_ratio} def _pcr_sample_assay_result( hidden: dict[str, Any], preset_idx: int, presets: list[dict[str, Any]], rng: np.random.Generator, ) -> str: preset = presets[preset_idx] chosen_temp = float(preset["temp"]) chosen_cycles = float(preset["cycles"]) chosen_ratio = str(preset["ratio"]) opt_temp = hidden["temp"] opt_cycles = hidden["cycles"] opt_ratio = hidden["ratio"] temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0) cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0) ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3 closeness = temp_close * cycle_close * ratio_match p_success = closeness ** 2 p_partial = closeness * (1.0 - closeness) * 0.8 p_fail = 1.0 - p_success - p_partial return str( rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) ) def _pcr_evaluate_custom_protocol( hidden: dict[str, Any], protocol: dict[str, Any], rng: np.random.Generator, ) -> str: """Evaluate any protocol dict (temp, cycles, ratio) against hidden optimum.""" chosen_temp = float(protocol.get("temp", 60.0)) chosen_cycles = float(protocol.get("cycles", 30)) r = str(protocol.get("ratio", "conservative")).strip().lower() chosen_ratio = "conservative" if "conservative" in r else "aggressive" opt_temp = hidden["temp"] opt_cycles = hidden["cycles"] opt_ratio = hidden["ratio"] temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0) cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0) ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3 closeness = temp_close * cycle_close * ratio_match p_success = closeness ** 2 p_partial = closeness * (1.0 - closeness) * 0.8 p_fail = 1.0 - p_success - p_partial return str( rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) ) PCR_PROTOCOL_SCHEMA = { "temp": {"type": "number", "description": "Annealing temperature in °C (e.g. 55–72)"}, "cycles": {"type": "integer", "description": "Number of PCR cycles (e.g. 25–40)"}, "ratio": {"type": "string", "enum": ["conservative", "aggressive"], "description": "Reagent ratio"}, } def pcr_experiment_spec() -> ExperimentSpec: """Build the default PCR experiment spec (same behaviour as original LabEnv).""" from itertools import product temps = [55.0, 65.0, 72.0] cycles = [25, 35] ratios = ["conservative", "aggressive"] presets = [ {"temp": t, "cycles": c, "ratio": r} for t, c, r in product(temps, cycles, ratios) ] return ExperimentSpec( name="pcr", presets=presets, inventory_items=["tips", "buffer", "polymerase", "samples"], orderable_items=["tips", "buffer", "polymerase"], initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8}, order_costs={ "tips": (5, 10.0), "buffer": (5, 15.0), "polymerase": (3, 30.0), }, result_labels=["none", "success", "partial", "fail"], max_steps=30, max_minutes=240.0, initial_budget=500.0, max_inventory=20, assay_time_minutes=20.0, order_time_minutes=5.0, wait_minutes=15.0, assay_penalty=-3.0, time_penalty_per_min=-0.25, no_success_penalty=-20.0, immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0}, terminal_bonus={"success": 60.0, "partial": 25.0}, sample_hidden_optimum=_pcr_sample_hidden_optimum, sample_assay_result=_pcr_sample_assay_result, evaluate_custom_protocol=_pcr_evaluate_custom_protocol, protocol_param_schema=PCR_PROTOCOL_SCHEMA, ) # --------------------------------------------------------------------------- # ELISA experiment spec (same obs/action shape as PCR for agent compatibility) # --------------------------------------------------------------------------- def _elisa_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]: coating_hrs = [1.0, 2.0, 3.0] temps = [4.0, 37.0] blocks = ["bsa", "casein"] opt_coating = float(rng.choice(coating_hrs, p=[0.3, 0.5, 0.2])) + rng.uniform(-0.2, 0.2) opt_temp = float(rng.choice(temps, p=[0.5, 0.5])) + rng.uniform(-2.0, 2.0) opt_block = str(rng.choice(blocks, p=[0.6, 0.4])) return {"coating_hr": opt_coating, "temp": opt_temp, "block": opt_block} def _elisa_sample_assay_result( hidden: dict[str, Any], preset_idx: int, presets: list[dict[str, Any]], rng: np.random.Generator, ) -> str: preset = presets[preset_idx] c = float(preset["coating_hr"]) t = float(preset["temp"]) b = str(preset["block"]) oc = hidden["coating_hr"] ot = hidden["temp"] ob = hidden["block"] coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0) temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0) block_match = 1.0 if b == ob else 0.3 closeness = coat_close * temp_close * block_match p_success = closeness ** 2 p_partial = closeness * (1.0 - closeness) * 0.8 p_fail = 1.0 - p_success - p_partial return str( rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) ) def _elisa_evaluate_custom_protocol( hidden: dict[str, Any], protocol: dict[str, Any], rng: np.random.Generator, ) -> str: """Evaluate any protocol dict (coating_hr, temp, block) against hidden optimum.""" c = float(protocol.get("coating_hr", 2.0)) t = float(protocol.get("temp", 25.0)) b = str(protocol.get("block", "bsa")).strip().lower() block_clean = "bsa" if "bsa" in b else "casein" oc, ot, ob = hidden["coating_hr"], hidden["temp"], hidden["block"] coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0) temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0) block_match = 1.0 if block_clean == ob else 0.3 closeness = coat_close * temp_close * block_match p_success = closeness ** 2 p_partial = closeness * (1.0 - closeness) * 0.8 p_fail = 1.0 - p_success - p_partial return str( rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail]) ) ELISA_PROTOCOL_SCHEMA = { "coating_hr": {"type": "number", "description": "Coating time in hours (e.g. 1–3)"}, "temp": {"type": "number", "description": "Incubation temperature °C (e.g. 4 or 37)"}, "block": {"type": "string", "enum": ["bsa", "casein"], "description": "Blocking agent"}, } def elisa_experiment_spec() -> ExperimentSpec: """ELISA readout: coating time (hr), temperature (°C), blocking type. Same obs/action dims as PCR.""" from itertools import product coating_hrs = [1.0, 2.0, 3.0] temps = [4.0, 37.0] blocks = ["bsa", "casein"] presets = [ {"coating_hr": ch, "temp": t, "block": bl} for ch, t, bl in product(coating_hrs, temps, blocks) ] return ExperimentSpec( name="elisa", presets=presets, inventory_items=["tips", "buffer", "polymerase", "samples"], orderable_items=["tips", "buffer", "polymerase"], initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8}, order_costs={ "tips": (5, 10.0), "buffer": (5, 15.0), "polymerase": (3, 30.0), }, result_labels=["none", "success", "partial", "fail"], max_steps=30, max_minutes=240.0, initial_budget=500.0, max_inventory=20, assay_time_minutes=20.0, order_time_minutes=5.0, wait_minutes=15.0, assay_penalty=-3.0, time_penalty_per_min=-0.25, no_success_penalty=-20.0, immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0}, terminal_bonus={"success": 60.0, "partial": 25.0}, sample_hidden_optimum=_elisa_sample_hidden_optimum, sample_assay_result=_elisa_sample_assay_result, evaluate_custom_protocol=_elisa_evaluate_custom_protocol, protocol_param_schema=ELISA_PROTOCOL_SCHEMA, ) # --------------------------------------------------------------------------- # Workflow ID -> spec registry (for UI / API) # --------------------------------------------------------------------------- def get_spec_for_workflow(workflow_id: str) -> ExperimentSpec: """Return the experiment spec for a given workflow ID. Unknown IDs default to PCR.""" _registry: dict[str, Callable[[], ExperimentSpec]] = { "pcr-amplification": pcr_experiment_spec, "elisa-readout": elisa_experiment_spec, } factory = _registry.get(workflow_id) or pcr_experiment_spec return factory()