biosim / lab_env /env.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
"""
LabEnv — A Gymnasium-style simulated wet-lab environment for RL training.
Simulates a single experiment workflow (e.g. PCR, ELISA) where the agent must
discover a hidden optimal protocol under time and budget constraints. The
experiment type is defined by an ExperimentSpec so any protocol-discovery
experiment can be modelled.
Designed for compatibility with OpenEnv's sandboxed execution model:
the reset/step/close interface can be served over HTTP via the adapter in
``openenv_adapter.py`` and uploaded to the OpenEnv hub on Hugging Face as a
standardized agentic environment for lab-automation research.
"""
from __future__ import annotations
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from lab_env.spec import ExperimentSpec, pcr_experiment_spec
# ---------------------------------------------------------------------------
# Backward compatibility: expose constants for default (PCR) spec
# ---------------------------------------------------------------------------
_DEFAULT_SPEC = pcr_experiment_spec()
NUM_PRESETS: int = _DEFAULT_SPEC.num_presets
ACTION_SETUP_START: int = _DEFAULT_SPEC.action_setup_start()
ACTION_SETUP_END: int = _DEFAULT_SPEC.action_setup_end()
ACTION_RUN_ASSAY: int = _DEFAULT_SPEC.action_run_assay()
ACTION_ORDER_TIPS: int = _DEFAULT_SPEC.action_order_start() + 0
ACTION_ORDER_BUFFER: int = _DEFAULT_SPEC.action_order_start() + 1
ACTION_ORDER_POLYMERASE: int = _DEFAULT_SPEC.action_order_start() + 2
ACTION_WAIT: int = _DEFAULT_SPEC.action_wait()
ACTION_FINISH: int = _DEFAULT_SPEC.action_finish()
NUM_ACTIONS: int = _DEFAULT_SPEC.num_actions
OBS_DIM: int = _DEFAULT_SPEC.obs_dim
# Legacy constants used by scripts
INITIAL_BUDGET: float = _DEFAULT_SPEC.initial_budget
RESULT_LABELS = _DEFAULT_SPEC.result_labels
RESULT_TO_IDX = {label: i for i, label in enumerate(RESULT_LABELS)}
class LabEnv(gym.Env):
"""Simulated wet-lab environment for any experiment type.
The experiment (protocol presets, inventory, rewards, outcome model) is
defined by an ExperimentSpec. Use LabEnv() for default PCR; use
LabEnv(spec=my_spec) for custom experiments.
Observation (Box, shape from spec):
[0] step_index (normalised)
[1] elapsed_minutes (normalised)
[2] remaining_budget (normalised)
[3..] inventory (one slot per inventory_items, normalised)
[...] last_result one-hot (len(result_labels))
[...] has_setup, current_preset_idx (norm), best_result_score
Actions (Discrete, from spec):
0 .. num_presets-1 setup_reaction(preset_index)
num_presets run_assay
num_presets+1 .. order_reagents(item) for each orderable item
... wait, finish
"""
metadata = {"render_modes": []}
def __init__(
self,
spec: ExperimentSpec | None = None,
render_mode: str | None = None,
) -> None:
super().__init__()
self.spec = spec if spec is not None else pcr_experiment_spec()
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(self.spec.obs_dim,), dtype=np.float32
)
self.action_space = spaces.Discrete(self.spec.num_actions)
self._rng: np.random.Generator | None = None
self._current_protocol_override: dict[str, Any] | None = None
self._reset_state()
# ------------------------------------------------------------------
# Gymnasium API
# ------------------------------------------------------------------
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[np.ndarray, dict[str, Any]]:
super().reset(seed=seed)
self._rng = np.random.default_rng(seed)
self._reset_state()
self._sample_hidden_optimum()
return self._obs(), self._info()
def step(
self, action: int
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
if self._terminated or self._truncated:
raise RuntimeError("Episode already done — call reset().")
reward = 0.0
self._step_index += 1
if self.spec.action_setup_start() <= action < self.spec.action_setup_end():
reward += self._do_setup(action)
elif action == self.spec.action_run_assay():
reward += self._do_run_assay()
elif self.spec.action_order_start() <= action < self.spec.action_order_end():
reward += self._do_order(action)
elif action == self.spec.action_wait():
reward += self._do_wait()
elif action == self.spec.action_finish():
reward += self._do_finish()
else:
raise ValueError(f"Invalid action {action}")
self._check_forced_termination()
if self._terminated or self._truncated:
reward += self._terminal_reward()
return self._obs(), reward, self._terminated, self._truncated, self._info()
def run_assay_with_protocol(
self, protocol: dict[str, Any]
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
"""Run one assay with an arbitrary protocol dict (no preset).
The spec must have evaluate_custom_protocol set (e.g. PCR/ELISA). Consumes
inventory and time like a normal assay; outcome is from the spec's outcome
model. Use this for agent-generated protocols.
"""
if self._terminated or self._truncated:
raise RuntimeError("Episode already done — call reset().")
if self.spec.evaluate_custom_protocol is None:
raise ValueError(
"This spec does not support custom protocols; evaluate_custom_protocol is not set."
)
self._step_index += 1
self._current_protocol_override = dict(protocol)
self._has_setup = True
reward = self._do_run_assay()
self._check_forced_termination()
if self._terminated or self._truncated:
reward += self._terminal_reward()
return self._obs(), reward, self._terminated, self._truncated, self._info()
def close(self) -> None:
pass
# ------------------------------------------------------------------
# Action implementations
# ------------------------------------------------------------------
def _do_setup(self, action: int) -> float:
preset_idx = action - self.spec.action_setup_start()
self._current_preset_idx = preset_idx
self._has_setup = True
self._elapsed_minutes += 1.0
return 0.0
def _fail_result_label(self) -> str:
if "fail" in self.spec.result_labels:
return "fail"
return self.spec.result_labels[-1] if self.spec.result_labels else "fail"
def _do_run_assay(self) -> float:
if not self._has_setup:
self._last_result = self._fail_result_label()
self._elapsed_minutes += self.spec.assay_time_minutes
return self.spec.assay_penalty
inv = self._inventory
for item in self.spec.inventory_items:
if inv.get(item, 0) < 1:
self._last_result = self._fail_result_label()
return self.spec.assay_penalty
for item in self.spec.inventory_items:
inv[item] = inv.get(item, 0) - 1
inv[item] = max(0, inv[item])
self._elapsed_minutes += self.spec.assay_time_minutes
result = self._sample_assay_result()
self._last_result = result
self._update_best(result)
imm = self.spec.immediate_result_reward.get(result, 0.0)
return self.spec.assay_penalty + imm
def _do_order(self, action: int) -> float:
idx = action - self.spec.action_order_start()
if idx < 0 or idx >= len(self.spec.orderable_items):
return 0.0
item = self.spec.orderable_items[idx]
if item not in self.spec.order_costs:
return 0.0
qty, cost = self.spec.order_costs[item]
if self._remaining_budget < cost:
return 0.0
self._remaining_budget -= cost
self._inventory[item] = min(
self._inventory.get(item, 0) + qty, self.spec.max_inventory
)
self._elapsed_minutes += self.spec.order_time_minutes
return 0.0
def _do_wait(self) -> float:
self._elapsed_minutes += self.spec.wait_minutes
return 0.0
def _do_finish(self) -> float:
self._terminated = True
return 0.0
# ------------------------------------------------------------------
# Termination
# ------------------------------------------------------------------
def _check_forced_termination(self) -> None:
if self._terminated:
return
if self._elapsed_minutes >= self.spec.max_minutes:
self._truncated = True
return
if self._remaining_budget <= 0:
self._truncated = True
return
if self._step_index >= self.spec.max_steps:
self._truncated = True
return
inv = self._inventory
can_run = all(inv.get(item, 0) >= 1 for item in self.spec.inventory_items)
can_order = any(
self._remaining_budget >= self.spec.order_costs.get(k, (0, float("inf")))[1]
for k in self.spec.orderable_items
)
if not can_run and not can_order:
self._truncated = True
def _terminal_reward(self) -> float:
bonus = self.spec.terminal_bonus.get(self._best_result, 0.0)
time_penalty = self.spec.time_penalty_per_min * self._elapsed_minutes
no_success = (
self.spec.no_success_penalty
if self._best_result in ("none", "fail") or self._best_result not in self.spec.terminal_bonus
else 0.0
)
return bonus + time_penalty + no_success
# ------------------------------------------------------------------
# Outcome model (delegated to spec)
# ------------------------------------------------------------------
def _sample_hidden_optimum(self) -> None:
if self._rng is None:
return
if self.spec.sample_hidden_optimum is not None:
self._hidden_optimum = self.spec.sample_hidden_optimum(self._rng)
else:
self._hidden_optimum = {}
def _sample_assay_result(self) -> str:
if self._rng is None:
return self.spec.result_labels[1] if len(self.spec.result_labels) > 1 else "fail"
if self._current_protocol_override is not None and self.spec.evaluate_custom_protocol is not None:
result = self.spec.evaluate_custom_protocol(
self._hidden_optimum,
self._current_protocol_override,
self._rng,
)
self._current_protocol_override = None
return result
if self.spec.sample_assay_result is not None:
return self.spec.sample_assay_result(
self._hidden_optimum,
self._current_preset_idx,
self.spec.presets,
self._rng,
)
# Default: random non-none result
choices = [r for r in self.spec.result_labels if r != "none"]
if not choices:
return "fail"
return str(self._rng.choice(choices))
def _update_best(self, result: str) -> None:
rank = {"fail": 0, "none": 0, "partial": 1, "success": 2}
if rank.get(result, 0) > rank.get(self._best_result, 0):
self._best_result = result
# ------------------------------------------------------------------
# Observation helpers
# ------------------------------------------------------------------
def _result_to_onehot(self, result: str) -> list[float]:
out = [0.0] * len(self.spec.result_labels)
for i, label in enumerate(self.spec.result_labels):
if label == result:
out[i] = 1.0
break
return out
def _obs(self) -> np.ndarray:
inv = self._inventory
result_onehot = self._result_to_onehot(self._last_result)
best_score = {"none": 0.0, "fail": 0.0, "partial": 0.5, "success": 1.0}.get(
self._best_result, 0.0
)
inv_slice = [
inv.get(item, 0) / self.spec.max_inventory
for item in self.spec.inventory_items
]
obs = np.array(
[
self._step_index / self.spec.max_steps,
self._elapsed_minutes / self.spec.max_minutes,
self._remaining_budget / self.spec.initial_budget,
*inv_slice,
*result_onehot,
float(self._has_setup),
(self._current_preset_idx / self.spec.num_presets) if self._has_setup else 0.0,
best_score,
],
dtype=np.float32,
)
return obs
def _info(self) -> dict[str, Any]:
return {
"step_index": self._step_index,
"elapsed_minutes": self._elapsed_minutes,
"remaining_budget": self._remaining_budget,
"inventory": dict(self._inventory),
"last_result": self._last_result,
"best_result": self._best_result,
}
# ------------------------------------------------------------------
# Internal state management
# ------------------------------------------------------------------
def _reset_state(self) -> None:
self._step_index = 0
self._elapsed_minutes = 0.0
self._remaining_budget = self.spec.initial_budget
self._inventory = dict(self.spec.initial_inventory)
self._last_result = self.spec.result_labels[0] if self.spec.result_labels else "none"
self._best_result = self._last_result
self._has_setup = False
self._current_preset_idx = 0
self._terminated = False
self._truncated = False
self._hidden_optimum: dict[str, Any] = {}