biosim / lab_env /spec.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
"""
Experiment specification for the generic lab environment.
Defines protocol presets, inventory, rewards, and outcome model so LabEnv can
simulate any experiment type (PCR, ELISA, etc.) from a single spec.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
import numpy as np
@dataclass
class ExperimentSpec:
"""Specification for a single experiment type (PCR, ELISA, etc.).
The environment uses this to build action/observation spaces and dynamics.
Outcome logic is pluggable via sample_hidden_optimum and sample_assay_result.
"""
name: str
"""Short name for this experiment (e.g. 'pcr', 'elisa')."""
presets: list[dict[str, Any]]
"""List of protocol presets the agent can choose (e.g. temp/cycles/ratio for PCR)."""
inventory_items: list[str]
"""Ordered list of inventory item names (tips, buffer, polymerase, samples, ...)."""
orderable_items: list[str]
"""Subset of inventory_items that can be reordered (each gets an order action)."""
initial_inventory: dict[str, int]
"""Starting count per inventory item."""
order_costs: dict[str, tuple[int, float]]
"""For each orderable item: (quantity_per_order, cost_per_order)."""
result_labels: list[str]
"""Possible assay outcomes, e.g. ['none', 'success', 'partial', 'fail']."""
# Limits
max_steps: int = 30
max_minutes: float = 240.0
initial_budget: float = 500.0
max_inventory: int = 20
# Time costs
assay_time_minutes: float = 20.0
order_time_minutes: float = 5.0
wait_minutes: float = 15.0
# Rewards
assay_penalty: float = -3.0
time_penalty_per_min: float = -0.25
no_success_penalty: float = -20.0
immediate_result_reward: dict[str, float] = field(default_factory=dict)
terminal_bonus: dict[str, float] = field(default_factory=dict)
# Outcome model: callables that take (rng) or (hidden_state, preset_idx, presets, rng)
sample_hidden_optimum: Callable[[np.random.Generator], dict[str, Any]] | None = None
sample_assay_result: (
Callable[
[dict[str, Any], int, list[dict[str, Any]], np.random.Generator],
str,
]
| None
) = None
# Custom protocol support: evaluate arbitrary protocol dict (for agent-generated protocols)
evaluate_custom_protocol: (
Callable[
[dict[str, Any], dict[str, Any], np.random.Generator],
str,
]
| None
) = None
"""If set, (hidden_optimum, protocol_dict, rng) -> result label. Enables run_assay_with_protocol()."""
protocol_param_schema: dict[str, Any] = field(default_factory=dict)
"""Schema describing protocol params for codegen/LLM: e.g. {"temp": {"type": "number", "description": "°C"}, ...}."""
@property
def num_presets(self) -> int:
return len(self.presets)
@property
def num_actions(self) -> int:
return (
self.num_presets
+ 1 # run_assay
+ len(self.orderable_items)
+ 2 # wait, finish
)
@property
def obs_dim(self) -> int:
return (
3 # step_index, elapsed_minutes, remaining_budget
+ len(self.inventory_items)
+ len(self.result_labels)
+ 3 # has_setup, current_preset_idx (norm), best_result_score
)
def action_setup_start(self) -> int:
return 0
def action_setup_end(self) -> int:
return self.num_presets
def action_run_assay(self) -> int:
return self.num_presets
def action_order_start(self) -> int:
return self.num_presets + 1
def action_order_end(self) -> int:
return self.num_presets + 1 + len(self.orderable_items)
def action_wait(self) -> int:
return self.num_presets + 1 + len(self.orderable_items)
def action_finish(self) -> int:
return self.num_presets + 2 + len(self.orderable_items)
# ---------------------------------------------------------------------------
# PCR experiment spec (default / backward compatibility)
# ---------------------------------------------------------------------------
def _pcr_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]:
temps = [55.0, 65.0, 72.0]
cycles = [25, 35]
ratios = ["conservative", "aggressive"]
opt_temp = float(rng.choice(temps, p=[0.2, 0.5, 0.3])) + rng.uniform(-3.0, 3.0)
opt_cycles = float(rng.choice(cycles, p=[0.6, 0.4])) + rng.uniform(-2.0, 2.0)
opt_ratio = str(rng.choice(ratios, p=[0.6, 0.4]))
return {"temp": opt_temp, "cycles": opt_cycles, "ratio": opt_ratio}
def _pcr_sample_assay_result(
hidden: dict[str, Any],
preset_idx: int,
presets: list[dict[str, Any]],
rng: np.random.Generator,
) -> str:
preset = presets[preset_idx]
chosen_temp = float(preset["temp"])
chosen_cycles = float(preset["cycles"])
chosen_ratio = str(preset["ratio"])
opt_temp = hidden["temp"]
opt_cycles = hidden["cycles"]
opt_ratio = hidden["ratio"]
temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0)
cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0)
ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3
closeness = temp_close * cycle_close * ratio_match
p_success = closeness ** 2
p_partial = closeness * (1.0 - closeness) * 0.8
p_fail = 1.0 - p_success - p_partial
return str(
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
)
def _pcr_evaluate_custom_protocol(
hidden: dict[str, Any],
protocol: dict[str, Any],
rng: np.random.Generator,
) -> str:
"""Evaluate any protocol dict (temp, cycles, ratio) against hidden optimum."""
chosen_temp = float(protocol.get("temp", 60.0))
chosen_cycles = float(protocol.get("cycles", 30))
r = str(protocol.get("ratio", "conservative")).strip().lower()
chosen_ratio = "conservative" if "conservative" in r else "aggressive"
opt_temp = hidden["temp"]
opt_cycles = hidden["cycles"]
opt_ratio = hidden["ratio"]
temp_close = 1.0 - min(abs(chosen_temp - opt_temp) / 20.0, 1.0)
cycle_close = 1.0 - min(abs(chosen_cycles - opt_cycles) / 15.0, 1.0)
ratio_match = 1.0 if chosen_ratio == opt_ratio else 0.3
closeness = temp_close * cycle_close * ratio_match
p_success = closeness ** 2
p_partial = closeness * (1.0 - closeness) * 0.8
p_fail = 1.0 - p_success - p_partial
return str(
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
)
PCR_PROTOCOL_SCHEMA = {
"temp": {"type": "number", "description": "Annealing temperature in °C (e.g. 55–72)"},
"cycles": {"type": "integer", "description": "Number of PCR cycles (e.g. 25–40)"},
"ratio": {"type": "string", "enum": ["conservative", "aggressive"], "description": "Reagent ratio"},
}
def pcr_experiment_spec() -> ExperimentSpec:
"""Build the default PCR experiment spec (same behaviour as original LabEnv)."""
from itertools import product
temps = [55.0, 65.0, 72.0]
cycles = [25, 35]
ratios = ["conservative", "aggressive"]
presets = [
{"temp": t, "cycles": c, "ratio": r}
for t, c, r in product(temps, cycles, ratios)
]
return ExperimentSpec(
name="pcr",
presets=presets,
inventory_items=["tips", "buffer", "polymerase", "samples"],
orderable_items=["tips", "buffer", "polymerase"],
initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8},
order_costs={
"tips": (5, 10.0),
"buffer": (5, 15.0),
"polymerase": (3, 30.0),
},
result_labels=["none", "success", "partial", "fail"],
max_steps=30,
max_minutes=240.0,
initial_budget=500.0,
max_inventory=20,
assay_time_minutes=20.0,
order_time_minutes=5.0,
wait_minutes=15.0,
assay_penalty=-3.0,
time_penalty_per_min=-0.25,
no_success_penalty=-20.0,
immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0},
terminal_bonus={"success": 60.0, "partial": 25.0},
sample_hidden_optimum=_pcr_sample_hidden_optimum,
sample_assay_result=_pcr_sample_assay_result,
evaluate_custom_protocol=_pcr_evaluate_custom_protocol,
protocol_param_schema=PCR_PROTOCOL_SCHEMA,
)
# ---------------------------------------------------------------------------
# ELISA experiment spec (same obs/action shape as PCR for agent compatibility)
# ---------------------------------------------------------------------------
def _elisa_sample_hidden_optimum(rng: np.random.Generator) -> dict[str, Any]:
coating_hrs = [1.0, 2.0, 3.0]
temps = [4.0, 37.0]
blocks = ["bsa", "casein"]
opt_coating = float(rng.choice(coating_hrs, p=[0.3, 0.5, 0.2])) + rng.uniform(-0.2, 0.2)
opt_temp = float(rng.choice(temps, p=[0.5, 0.5])) + rng.uniform(-2.0, 2.0)
opt_block = str(rng.choice(blocks, p=[0.6, 0.4]))
return {"coating_hr": opt_coating, "temp": opt_temp, "block": opt_block}
def _elisa_sample_assay_result(
hidden: dict[str, Any],
preset_idx: int,
presets: list[dict[str, Any]],
rng: np.random.Generator,
) -> str:
preset = presets[preset_idx]
c = float(preset["coating_hr"])
t = float(preset["temp"])
b = str(preset["block"])
oc = hidden["coating_hr"]
ot = hidden["temp"]
ob = hidden["block"]
coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0)
temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0)
block_match = 1.0 if b == ob else 0.3
closeness = coat_close * temp_close * block_match
p_success = closeness ** 2
p_partial = closeness * (1.0 - closeness) * 0.8
p_fail = 1.0 - p_success - p_partial
return str(
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
)
def _elisa_evaluate_custom_protocol(
hidden: dict[str, Any],
protocol: dict[str, Any],
rng: np.random.Generator,
) -> str:
"""Evaluate any protocol dict (coating_hr, temp, block) against hidden optimum."""
c = float(protocol.get("coating_hr", 2.0))
t = float(protocol.get("temp", 25.0))
b = str(protocol.get("block", "bsa")).strip().lower()
block_clean = "bsa" if "bsa" in b else "casein"
oc, ot, ob = hidden["coating_hr"], hidden["temp"], hidden["block"]
coat_close = 1.0 - min(abs(c - oc) / 2.0, 1.0)
temp_close = 1.0 - min(abs(t - ot) / 35.0, 1.0)
block_match = 1.0 if block_clean == ob else 0.3
closeness = coat_close * temp_close * block_match
p_success = closeness ** 2
p_partial = closeness * (1.0 - closeness) * 0.8
p_fail = 1.0 - p_success - p_partial
return str(
rng.choice(["success", "partial", "fail"], p=[p_success, p_partial, p_fail])
)
ELISA_PROTOCOL_SCHEMA = {
"coating_hr": {"type": "number", "description": "Coating time in hours (e.g. 1–3)"},
"temp": {"type": "number", "description": "Incubation temperature °C (e.g. 4 or 37)"},
"block": {"type": "string", "enum": ["bsa", "casein"], "description": "Blocking agent"},
}
def elisa_experiment_spec() -> ExperimentSpec:
"""ELISA readout: coating time (hr), temperature (°C), blocking type. Same obs/action dims as PCR."""
from itertools import product
coating_hrs = [1.0, 2.0, 3.0]
temps = [4.0, 37.0]
blocks = ["bsa", "casein"]
presets = [
{"coating_hr": ch, "temp": t, "block": bl}
for ch, t, bl in product(coating_hrs, temps, blocks)
]
return ExperimentSpec(
name="elisa",
presets=presets,
inventory_items=["tips", "buffer", "polymerase", "samples"],
orderable_items=["tips", "buffer", "polymerase"],
initial_inventory={"tips": 10, "buffer": 10, "polymerase": 5, "samples": 8},
order_costs={
"tips": (5, 10.0),
"buffer": (5, 15.0),
"polymerase": (3, 30.0),
},
result_labels=["none", "success", "partial", "fail"],
max_steps=30,
max_minutes=240.0,
initial_budget=500.0,
max_inventory=20,
assay_time_minutes=20.0,
order_time_minutes=5.0,
wait_minutes=15.0,
assay_penalty=-3.0,
time_penalty_per_min=-0.25,
no_success_penalty=-20.0,
immediate_result_reward={"success": 15.0, "partial": 5.0, "fail": 0.0},
terminal_bonus={"success": 60.0, "partial": 25.0},
sample_hidden_optimum=_elisa_sample_hidden_optimum,
sample_assay_result=_elisa_sample_assay_result,
evaluate_custom_protocol=_elisa_evaluate_custom_protocol,
protocol_param_schema=ELISA_PROTOCOL_SCHEMA,
)
# ---------------------------------------------------------------------------
# Workflow ID -> spec registry (for UI / API)
# ---------------------------------------------------------------------------
def get_spec_for_workflow(workflow_id: str) -> ExperimentSpec:
"""Return the experiment spec for a given workflow ID. Unknown IDs default to PCR."""
_registry: dict[str, Callable[[], ExperimentSpec]] = {
"pcr-amplification": pcr_experiment_spec,
"elisa-readout": elisa_experiment_spec,
}
factory = _registry.get(workflow_id) or pcr_experiment_spec
return factory()