| """ |
| LabEnv — A Gymnasium-style simulated wet-lab environment for RL training. |
| |
| Simulates a single experiment workflow (e.g. PCR, ELISA) where the agent must |
| discover a hidden optimal protocol under time and budget constraints. The |
| experiment type is defined by an ExperimentSpec so any protocol-discovery |
| experiment can be modelled. |
| |
| Designed for compatibility with OpenEnv's sandboxed execution model: |
| the reset/step/close interface can be served over HTTP via the adapter in |
| ``openenv_adapter.py`` and uploaded to the OpenEnv hub on Hugging Face as a |
| standardized agentic environment for lab-automation research. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import gymnasium as gym |
| import numpy as np |
| from gymnasium import spaces |
|
|
| from lab_env.spec import ExperimentSpec, pcr_experiment_spec |
|
|
|
|
| |
| |
| |
|
|
| _DEFAULT_SPEC = pcr_experiment_spec() |
| NUM_PRESETS: int = _DEFAULT_SPEC.num_presets |
| ACTION_SETUP_START: int = _DEFAULT_SPEC.action_setup_start() |
| ACTION_SETUP_END: int = _DEFAULT_SPEC.action_setup_end() |
| ACTION_RUN_ASSAY: int = _DEFAULT_SPEC.action_run_assay() |
| ACTION_ORDER_TIPS: int = _DEFAULT_SPEC.action_order_start() + 0 |
| ACTION_ORDER_BUFFER: int = _DEFAULT_SPEC.action_order_start() + 1 |
| ACTION_ORDER_POLYMERASE: int = _DEFAULT_SPEC.action_order_start() + 2 |
| ACTION_WAIT: int = _DEFAULT_SPEC.action_wait() |
| ACTION_FINISH: int = _DEFAULT_SPEC.action_finish() |
| NUM_ACTIONS: int = _DEFAULT_SPEC.num_actions |
| OBS_DIM: int = _DEFAULT_SPEC.obs_dim |
|
|
| |
| INITIAL_BUDGET: float = _DEFAULT_SPEC.initial_budget |
| RESULT_LABELS = _DEFAULT_SPEC.result_labels |
| RESULT_TO_IDX = {label: i for i, label in enumerate(RESULT_LABELS)} |
|
|
|
|
| class LabEnv(gym.Env): |
| """Simulated wet-lab environment for any experiment type. |
| |
| The experiment (protocol presets, inventory, rewards, outcome model) is |
| defined by an ExperimentSpec. Use LabEnv() for default PCR; use |
| LabEnv(spec=my_spec) for custom experiments. |
| |
| Observation (Box, shape from spec): |
| [0] step_index (normalised) |
| [1] elapsed_minutes (normalised) |
| [2] remaining_budget (normalised) |
| [3..] inventory (one slot per inventory_items, normalised) |
| [...] last_result one-hot (len(result_labels)) |
| [...] has_setup, current_preset_idx (norm), best_result_score |
| |
| Actions (Discrete, from spec): |
| 0 .. num_presets-1 setup_reaction(preset_index) |
| num_presets run_assay |
| num_presets+1 .. order_reagents(item) for each orderable item |
| ... wait, finish |
| """ |
|
|
| metadata = {"render_modes": []} |
|
|
| def __init__( |
| self, |
| spec: ExperimentSpec | None = None, |
| render_mode: str | None = None, |
| ) -> None: |
| super().__init__() |
| self.spec = spec if spec is not None else pcr_experiment_spec() |
| self.observation_space = spaces.Box( |
| low=0.0, high=1.0, shape=(self.spec.obs_dim,), dtype=np.float32 |
| ) |
| self.action_space = spaces.Discrete(self.spec.num_actions) |
|
|
| self._rng: np.random.Generator | None = None |
| self._current_protocol_override: dict[str, Any] | None = None |
| self._reset_state() |
|
|
| |
| |
| |
|
|
| def reset( |
| self, |
| *, |
| seed: int | None = None, |
| options: dict[str, Any] | None = None, |
| ) -> tuple[np.ndarray, dict[str, Any]]: |
| super().reset(seed=seed) |
| self._rng = np.random.default_rng(seed) |
| self._reset_state() |
| self._sample_hidden_optimum() |
| return self._obs(), self._info() |
|
|
| def step( |
| self, action: int |
| ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: |
| if self._terminated or self._truncated: |
| raise RuntimeError("Episode already done — call reset().") |
|
|
| reward = 0.0 |
| self._step_index += 1 |
|
|
| if self.spec.action_setup_start() <= action < self.spec.action_setup_end(): |
| reward += self._do_setup(action) |
| elif action == self.spec.action_run_assay(): |
| reward += self._do_run_assay() |
| elif self.spec.action_order_start() <= action < self.spec.action_order_end(): |
| reward += self._do_order(action) |
| elif action == self.spec.action_wait(): |
| reward += self._do_wait() |
| elif action == self.spec.action_finish(): |
| reward += self._do_finish() |
| else: |
| raise ValueError(f"Invalid action {action}") |
|
|
| self._check_forced_termination() |
|
|
| if self._terminated or self._truncated: |
| reward += self._terminal_reward() |
|
|
| return self._obs(), reward, self._terminated, self._truncated, self._info() |
|
|
| def run_assay_with_protocol( |
| self, protocol: dict[str, Any] |
| ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: |
| """Run one assay with an arbitrary protocol dict (no preset). |
| |
| The spec must have evaluate_custom_protocol set (e.g. PCR/ELISA). Consumes |
| inventory and time like a normal assay; outcome is from the spec's outcome |
| model. Use this for agent-generated protocols. |
| """ |
| if self._terminated or self._truncated: |
| raise RuntimeError("Episode already done — call reset().") |
| if self.spec.evaluate_custom_protocol is None: |
| raise ValueError( |
| "This spec does not support custom protocols; evaluate_custom_protocol is not set." |
| ) |
| self._step_index += 1 |
| self._current_protocol_override = dict(protocol) |
| self._has_setup = True |
| reward = self._do_run_assay() |
| self._check_forced_termination() |
| if self._terminated or self._truncated: |
| reward += self._terminal_reward() |
| return self._obs(), reward, self._terminated, self._truncated, self._info() |
|
|
| def close(self) -> None: |
| pass |
|
|
| |
| |
| |
|
|
| def _do_setup(self, action: int) -> float: |
| preset_idx = action - self.spec.action_setup_start() |
| self._current_preset_idx = preset_idx |
| self._has_setup = True |
| self._elapsed_minutes += 1.0 |
| return 0.0 |
|
|
| def _fail_result_label(self) -> str: |
| if "fail" in self.spec.result_labels: |
| return "fail" |
| return self.spec.result_labels[-1] if self.spec.result_labels else "fail" |
|
|
| def _do_run_assay(self) -> float: |
| if not self._has_setup: |
| self._last_result = self._fail_result_label() |
| self._elapsed_minutes += self.spec.assay_time_minutes |
| return self.spec.assay_penalty |
|
|
| inv = self._inventory |
| for item in self.spec.inventory_items: |
| if inv.get(item, 0) < 1: |
| self._last_result = self._fail_result_label() |
| return self.spec.assay_penalty |
|
|
| for item in self.spec.inventory_items: |
| inv[item] = inv.get(item, 0) - 1 |
| inv[item] = max(0, inv[item]) |
|
|
| self._elapsed_minutes += self.spec.assay_time_minutes |
|
|
| result = self._sample_assay_result() |
| self._last_result = result |
| self._update_best(result) |
|
|
| imm = self.spec.immediate_result_reward.get(result, 0.0) |
| return self.spec.assay_penalty + imm |
|
|
| def _do_order(self, action: int) -> float: |
| idx = action - self.spec.action_order_start() |
| if idx < 0 or idx >= len(self.spec.orderable_items): |
| return 0.0 |
| item = self.spec.orderable_items[idx] |
| if item not in self.spec.order_costs: |
| return 0.0 |
| qty, cost = self.spec.order_costs[item] |
|
|
| if self._remaining_budget < cost: |
| return 0.0 |
|
|
| self._remaining_budget -= cost |
| self._inventory[item] = min( |
| self._inventory.get(item, 0) + qty, self.spec.max_inventory |
| ) |
| self._elapsed_minutes += self.spec.order_time_minutes |
| return 0.0 |
|
|
| def _do_wait(self) -> float: |
| self._elapsed_minutes += self.spec.wait_minutes |
| return 0.0 |
|
|
| def _do_finish(self) -> float: |
| self._terminated = True |
| return 0.0 |
|
|
| |
| |
| |
|
|
| def _check_forced_termination(self) -> None: |
| if self._terminated: |
| return |
| if self._elapsed_minutes >= self.spec.max_minutes: |
| self._truncated = True |
| return |
| if self._remaining_budget <= 0: |
| self._truncated = True |
| return |
| if self._step_index >= self.spec.max_steps: |
| self._truncated = True |
| return |
|
|
| inv = self._inventory |
| can_run = all(inv.get(item, 0) >= 1 for item in self.spec.inventory_items) |
| can_order = any( |
| self._remaining_budget >= self.spec.order_costs.get(k, (0, float("inf")))[1] |
| for k in self.spec.orderable_items |
| ) |
| if not can_run and not can_order: |
| self._truncated = True |
|
|
| def _terminal_reward(self) -> float: |
| bonus = self.spec.terminal_bonus.get(self._best_result, 0.0) |
| time_penalty = self.spec.time_penalty_per_min * self._elapsed_minutes |
| no_success = ( |
| self.spec.no_success_penalty |
| if self._best_result in ("none", "fail") or self._best_result not in self.spec.terminal_bonus |
| else 0.0 |
| ) |
| return bonus + time_penalty + no_success |
|
|
| |
| |
| |
|
|
| def _sample_hidden_optimum(self) -> None: |
| if self._rng is None: |
| return |
| if self.spec.sample_hidden_optimum is not None: |
| self._hidden_optimum = self.spec.sample_hidden_optimum(self._rng) |
| else: |
| self._hidden_optimum = {} |
|
|
| def _sample_assay_result(self) -> str: |
| if self._rng is None: |
| return self.spec.result_labels[1] if len(self.spec.result_labels) > 1 else "fail" |
| if self._current_protocol_override is not None and self.spec.evaluate_custom_protocol is not None: |
| result = self.spec.evaluate_custom_protocol( |
| self._hidden_optimum, |
| self._current_protocol_override, |
| self._rng, |
| ) |
| self._current_protocol_override = None |
| return result |
| if self.spec.sample_assay_result is not None: |
| return self.spec.sample_assay_result( |
| self._hidden_optimum, |
| self._current_preset_idx, |
| self.spec.presets, |
| self._rng, |
| ) |
| |
| choices = [r for r in self.spec.result_labels if r != "none"] |
| if not choices: |
| return "fail" |
| return str(self._rng.choice(choices)) |
|
|
| def _update_best(self, result: str) -> None: |
| rank = {"fail": 0, "none": 0, "partial": 1, "success": 2} |
| if rank.get(result, 0) > rank.get(self._best_result, 0): |
| self._best_result = result |
|
|
| |
| |
| |
|
|
| def _result_to_onehot(self, result: str) -> list[float]: |
| out = [0.0] * len(self.spec.result_labels) |
| for i, label in enumerate(self.spec.result_labels): |
| if label == result: |
| out[i] = 1.0 |
| break |
| return out |
|
|
| def _obs(self) -> np.ndarray: |
| inv = self._inventory |
| result_onehot = self._result_to_onehot(self._last_result) |
| best_score = {"none": 0.0, "fail": 0.0, "partial": 0.5, "success": 1.0}.get( |
| self._best_result, 0.0 |
| ) |
|
|
| inv_slice = [ |
| inv.get(item, 0) / self.spec.max_inventory |
| for item in self.spec.inventory_items |
| ] |
| obs = np.array( |
| [ |
| self._step_index / self.spec.max_steps, |
| self._elapsed_minutes / self.spec.max_minutes, |
| self._remaining_budget / self.spec.initial_budget, |
| *inv_slice, |
| *result_onehot, |
| float(self._has_setup), |
| (self._current_preset_idx / self.spec.num_presets) if self._has_setup else 0.0, |
| best_score, |
| ], |
| dtype=np.float32, |
| ) |
| return obs |
|
|
| def _info(self) -> dict[str, Any]: |
| return { |
| "step_index": self._step_index, |
| "elapsed_minutes": self._elapsed_minutes, |
| "remaining_budget": self._remaining_budget, |
| "inventory": dict(self._inventory), |
| "last_result": self._last_result, |
| "best_result": self._best_result, |
| } |
|
|
| |
| |
| |
|
|
| def _reset_state(self) -> None: |
| self._step_index = 0 |
| self._elapsed_minutes = 0.0 |
| self._remaining_budget = self.spec.initial_budget |
| self._inventory = dict(self.spec.initial_inventory) |
| self._last_result = self.spec.result_labels[0] if self.spec.result_labels else "none" |
| self._best_result = self._last_result |
| self._has_setup = False |
| self._current_preset_idx = 0 |
| self._terminated = False |
| self._truncated = False |
| self._hidden_optimum: dict[str, Any] = {} |
|
|