""" LabEnv — A Gymnasium-style simulated wet-lab environment for RL training. Simulates a single experiment workflow (e.g. PCR, ELISA) where the agent must discover a hidden optimal protocol under time and budget constraints. The experiment type is defined by an ExperimentSpec so any protocol-discovery experiment can be modelled. Designed for compatibility with OpenEnv's sandboxed execution model: the reset/step/close interface can be served over HTTP via the adapter in ``openenv_adapter.py`` and uploaded to the OpenEnv hub on Hugging Face as a standardized agentic environment for lab-automation research. """ from __future__ import annotations from typing import Any import gymnasium as gym import numpy as np from gymnasium import spaces from lab_env.spec import ExperimentSpec, pcr_experiment_spec # --------------------------------------------------------------------------- # Backward compatibility: expose constants for default (PCR) spec # --------------------------------------------------------------------------- _DEFAULT_SPEC = pcr_experiment_spec() NUM_PRESETS: int = _DEFAULT_SPEC.num_presets ACTION_SETUP_START: int = _DEFAULT_SPEC.action_setup_start() ACTION_SETUP_END: int = _DEFAULT_SPEC.action_setup_end() ACTION_RUN_ASSAY: int = _DEFAULT_SPEC.action_run_assay() ACTION_ORDER_TIPS: int = _DEFAULT_SPEC.action_order_start() + 0 ACTION_ORDER_BUFFER: int = _DEFAULT_SPEC.action_order_start() + 1 ACTION_ORDER_POLYMERASE: int = _DEFAULT_SPEC.action_order_start() + 2 ACTION_WAIT: int = _DEFAULT_SPEC.action_wait() ACTION_FINISH: int = _DEFAULT_SPEC.action_finish() NUM_ACTIONS: int = _DEFAULT_SPEC.num_actions OBS_DIM: int = _DEFAULT_SPEC.obs_dim # Legacy constants used by scripts INITIAL_BUDGET: float = _DEFAULT_SPEC.initial_budget RESULT_LABELS = _DEFAULT_SPEC.result_labels RESULT_TO_IDX = {label: i for i, label in enumerate(RESULT_LABELS)} class LabEnv(gym.Env): """Simulated wet-lab environment for any experiment type. The experiment (protocol presets, inventory, rewards, outcome model) is defined by an ExperimentSpec. Use LabEnv() for default PCR; use LabEnv(spec=my_spec) for custom experiments. Observation (Box, shape from spec): [0] step_index (normalised) [1] elapsed_minutes (normalised) [2] remaining_budget (normalised) [3..] inventory (one slot per inventory_items, normalised) [...] last_result one-hot (len(result_labels)) [...] has_setup, current_preset_idx (norm), best_result_score Actions (Discrete, from spec): 0 .. num_presets-1 setup_reaction(preset_index) num_presets run_assay num_presets+1 .. order_reagents(item) for each orderable item ... wait, finish """ metadata = {"render_modes": []} def __init__( self, spec: ExperimentSpec | None = None, render_mode: str | None = None, ) -> None: super().__init__() self.spec = spec if spec is not None else pcr_experiment_spec() self.observation_space = spaces.Box( low=0.0, high=1.0, shape=(self.spec.obs_dim,), dtype=np.float32 ) self.action_space = spaces.Discrete(self.spec.num_actions) self._rng: np.random.Generator | None = None self._current_protocol_override: dict[str, Any] | None = None self._reset_state() # ------------------------------------------------------------------ # Gymnasium API # ------------------------------------------------------------------ def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[np.ndarray, dict[str, Any]]: super().reset(seed=seed) self._rng = np.random.default_rng(seed) self._reset_state() self._sample_hidden_optimum() return self._obs(), self._info() def step( self, action: int ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: if self._terminated or self._truncated: raise RuntimeError("Episode already done — call reset().") reward = 0.0 self._step_index += 1 if self.spec.action_setup_start() <= action < self.spec.action_setup_end(): reward += self._do_setup(action) elif action == self.spec.action_run_assay(): reward += self._do_run_assay() elif self.spec.action_order_start() <= action < self.spec.action_order_end(): reward += self._do_order(action) elif action == self.spec.action_wait(): reward += self._do_wait() elif action == self.spec.action_finish(): reward += self._do_finish() else: raise ValueError(f"Invalid action {action}") self._check_forced_termination() if self._terminated or self._truncated: reward += self._terminal_reward() return self._obs(), reward, self._terminated, self._truncated, self._info() def run_assay_with_protocol( self, protocol: dict[str, Any] ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: """Run one assay with an arbitrary protocol dict (no preset). The spec must have evaluate_custom_protocol set (e.g. PCR/ELISA). Consumes inventory and time like a normal assay; outcome is from the spec's outcome model. Use this for agent-generated protocols. """ if self._terminated or self._truncated: raise RuntimeError("Episode already done — call reset().") if self.spec.evaluate_custom_protocol is None: raise ValueError( "This spec does not support custom protocols; evaluate_custom_protocol is not set." ) self._step_index += 1 self._current_protocol_override = dict(protocol) self._has_setup = True reward = self._do_run_assay() self._check_forced_termination() if self._terminated or self._truncated: reward += self._terminal_reward() return self._obs(), reward, self._terminated, self._truncated, self._info() def close(self) -> None: pass # ------------------------------------------------------------------ # Action implementations # ------------------------------------------------------------------ def _do_setup(self, action: int) -> float: preset_idx = action - self.spec.action_setup_start() self._current_preset_idx = preset_idx self._has_setup = True self._elapsed_minutes += 1.0 return 0.0 def _fail_result_label(self) -> str: if "fail" in self.spec.result_labels: return "fail" return self.spec.result_labels[-1] if self.spec.result_labels else "fail" def _do_run_assay(self) -> float: if not self._has_setup: self._last_result = self._fail_result_label() self._elapsed_minutes += self.spec.assay_time_minutes return self.spec.assay_penalty inv = self._inventory for item in self.spec.inventory_items: if inv.get(item, 0) < 1: self._last_result = self._fail_result_label() return self.spec.assay_penalty for item in self.spec.inventory_items: inv[item] = inv.get(item, 0) - 1 inv[item] = max(0, inv[item]) self._elapsed_minutes += self.spec.assay_time_minutes result = self._sample_assay_result() self._last_result = result self._update_best(result) imm = self.spec.immediate_result_reward.get(result, 0.0) return self.spec.assay_penalty + imm def _do_order(self, action: int) -> float: idx = action - self.spec.action_order_start() if idx < 0 or idx >= len(self.spec.orderable_items): return 0.0 item = self.spec.orderable_items[idx] if item not in self.spec.order_costs: return 0.0 qty, cost = self.spec.order_costs[item] if self._remaining_budget < cost: return 0.0 self._remaining_budget -= cost self._inventory[item] = min( self._inventory.get(item, 0) + qty, self.spec.max_inventory ) self._elapsed_minutes += self.spec.order_time_minutes return 0.0 def _do_wait(self) -> float: self._elapsed_minutes += self.spec.wait_minutes return 0.0 def _do_finish(self) -> float: self._terminated = True return 0.0 # ------------------------------------------------------------------ # Termination # ------------------------------------------------------------------ def _check_forced_termination(self) -> None: if self._terminated: return if self._elapsed_minutes >= self.spec.max_minutes: self._truncated = True return if self._remaining_budget <= 0: self._truncated = True return if self._step_index >= self.spec.max_steps: self._truncated = True return inv = self._inventory can_run = all(inv.get(item, 0) >= 1 for item in self.spec.inventory_items) can_order = any( self._remaining_budget >= self.spec.order_costs.get(k, (0, float("inf")))[1] for k in self.spec.orderable_items ) if not can_run and not can_order: self._truncated = True def _terminal_reward(self) -> float: bonus = self.spec.terminal_bonus.get(self._best_result, 0.0) time_penalty = self.spec.time_penalty_per_min * self._elapsed_minutes no_success = ( self.spec.no_success_penalty if self._best_result in ("none", "fail") or self._best_result not in self.spec.terminal_bonus else 0.0 ) return bonus + time_penalty + no_success # ------------------------------------------------------------------ # Outcome model (delegated to spec) # ------------------------------------------------------------------ def _sample_hidden_optimum(self) -> None: if self._rng is None: return if self.spec.sample_hidden_optimum is not None: self._hidden_optimum = self.spec.sample_hidden_optimum(self._rng) else: self._hidden_optimum = {} def _sample_assay_result(self) -> str: if self._rng is None: return self.spec.result_labels[1] if len(self.spec.result_labels) > 1 else "fail" if self._current_protocol_override is not None and self.spec.evaluate_custom_protocol is not None: result = self.spec.evaluate_custom_protocol( self._hidden_optimum, self._current_protocol_override, self._rng, ) self._current_protocol_override = None return result if self.spec.sample_assay_result is not None: return self.spec.sample_assay_result( self._hidden_optimum, self._current_preset_idx, self.spec.presets, self._rng, ) # Default: random non-none result choices = [r for r in self.spec.result_labels if r != "none"] if not choices: return "fail" return str(self._rng.choice(choices)) def _update_best(self, result: str) -> None: rank = {"fail": 0, "none": 0, "partial": 1, "success": 2} if rank.get(result, 0) > rank.get(self._best_result, 0): self._best_result = result # ------------------------------------------------------------------ # Observation helpers # ------------------------------------------------------------------ def _result_to_onehot(self, result: str) -> list[float]: out = [0.0] * len(self.spec.result_labels) for i, label in enumerate(self.spec.result_labels): if label == result: out[i] = 1.0 break return out def _obs(self) -> np.ndarray: inv = self._inventory result_onehot = self._result_to_onehot(self._last_result) best_score = {"none": 0.0, "fail": 0.0, "partial": 0.5, "success": 1.0}.get( self._best_result, 0.0 ) inv_slice = [ inv.get(item, 0) / self.spec.max_inventory for item in self.spec.inventory_items ] obs = np.array( [ self._step_index / self.spec.max_steps, self._elapsed_minutes / self.spec.max_minutes, self._remaining_budget / self.spec.initial_budget, *inv_slice, *result_onehot, float(self._has_setup), (self._current_preset_idx / self.spec.num_presets) if self._has_setup else 0.0, best_score, ], dtype=np.float32, ) return obs def _info(self) -> dict[str, Any]: return { "step_index": self._step_index, "elapsed_minutes": self._elapsed_minutes, "remaining_budget": self._remaining_budget, "inventory": dict(self._inventory), "last_result": self._last_result, "best_result": self._best_result, } # ------------------------------------------------------------------ # Internal state management # ------------------------------------------------------------------ def _reset_state(self) -> None: self._step_index = 0 self._elapsed_minutes = 0.0 self._remaining_budget = self.spec.initial_budget self._inventory = dict(self.spec.initial_inventory) self._last_result = self.spec.result_labels[0] if self.spec.result_labels else "none" self._best_result = self._last_result self._has_setup = False self._current_preset_idx = 0 self._terminated = False self._truncated = False self._hidden_optimum: dict[str, Any] = {}