"""Gymnasium-compatible wrapper around ``BioExperimentEnvironment``. Provides ``BioExperimentGymEnv`` which wraps the OpenEnv environment for local in-process RL training (no HTTP/WebSocket overhead). Observation and action spaces are represented as ``gymnasium.spaces.Dict`` so that standard RL libraries (SB3, CleanRL, etc.) can ingest them. """ from __future__ import annotations from typing import Any, Dict, Optional, Tuple import gymnasium as gym import numpy as np from gymnasium import spaces from models import ActionType, ExperimentAction, ExperimentObservation from server.hackathon_environment import BioExperimentEnvironment, MAX_STEPS ACTION_TYPE_LIST = list(ActionType) _N_ACTION_TYPES = len(ACTION_TYPE_LIST) _MAX_OUTPUTS = MAX_STEPS _MAX_HISTORY = MAX_STEPS _VEC_DIM = 64 class BioExperimentGymEnv(gym.Env): """Gymnasium ``Env`` backed by the in-process simulator. Observations are flattened into a dictionary of NumPy arrays suitable for RL policy networks. Actions are integer-indexed action types with a continuous confidence scalar. For LLM-based agents or planners that prefer structured ``ExperimentAction`` objects, use the underlying ``BioExperimentEnvironment`` directly instead. """ metadata = {"render_modes": ["human"]} def __init__(self, render_mode: Optional[str] = None): super().__init__() self._env = BioExperimentEnvironment() self.render_mode = render_mode self.action_space = spaces.Dict({ "action_type": spaces.Discrete(_N_ACTION_TYPES), "confidence": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), }) self.observation_space = spaces.Dict({ "step_index": spaces.Discrete(MAX_STEPS + 1), "budget_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), "time_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), "progress_flags": spaces.MultiBinary(18), "latest_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), "latest_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), "avg_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), "avg_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32), "n_violations": spaces.Discrete(20), "n_outputs": spaces.Discrete(_MAX_OUTPUTS + 1), "cumulative_reward": spaces.Box(-100.0, 100.0, shape=(), dtype=np.float32), }) self._last_obs: Optional[ExperimentObservation] = None # ── Gymnasium interface ───────────────────────────────────────────── def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: super().reset(seed=seed) obs = self._env.reset() self._last_obs = obs return self._vectorise(obs), self._info(obs) def step( self, action: Dict[str, Any] ) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]: action_idx = int(action["action_type"]) confidence = float(action.get("confidence", 0.5)) experiment_action = ExperimentAction( action_type=ACTION_TYPE_LIST[action_idx], confidence=confidence, ) obs = self._env.step(experiment_action) self._last_obs = obs terminated = obs.done truncated = obs.step_index >= MAX_STEPS and not terminated reward = obs.reward return ( self._vectorise(obs), reward, terminated, truncated, self._info(obs), ) def render(self) -> Optional[str]: if self.render_mode != "human" or self._last_obs is None: return None obs = self._last_obs lines = [ f"Step {obs.step_index}", f" Task: {obs.task.problem_statement[:80]}", f" Budget: ${obs.resource_usage.budget_remaining:,.0f} remaining", f" Time: {obs.resource_usage.time_remaining_days:.0f} days remaining", ] if obs.latest_output: lines.append(f" Latest: {obs.latest_output.summary}") if obs.rule_violations: lines.append(f" Violations: {obs.rule_violations}") text = "\n".join(lines) print(text) return text # ── helpers ───────────────────────────────────────────────────────── def _vectorise(self, obs: ExperimentObservation) -> Dict[str, Any]: progress = self._env._latent.progress if self._env._latent else None flags = np.zeros(18, dtype=np.int8) if progress: flag_names = [ "samples_collected", "cohort_selected", "cells_cultured", "library_prepared", "perturbation_applied", "cells_sequenced", "qc_performed", "data_filtered", "data_normalized", "batches_integrated", "cells_clustered", "de_performed", "trajectories_inferred", "pathways_analyzed", "networks_inferred", "markers_discovered", "markers_validated", "conclusion_reached", ] for i, f in enumerate(flag_names): flags[i] = int(getattr(progress, f, False)) unc = obs.uncertainty_summary lo = obs.latest_output return { "step_index": obs.step_index, "budget_remaining_frac": np.float32( obs.resource_usage.budget_remaining / max(obs.task.budget_limit, 1) ), "time_remaining_frac": np.float32( obs.resource_usage.time_remaining_days / max(obs.task.time_limit_days, 1) ), "progress_flags": flags, "latest_quality": np.float32(lo.quality_score if lo else 0.0), "latest_uncertainty": np.float32(lo.uncertainty if lo else 0.0), "avg_quality": np.float32(unc.get("avg_quality", 0.0)), "avg_uncertainty": np.float32(unc.get("avg_uncertainty", 0.0)), "n_violations": min(len(obs.rule_violations), 19), "n_outputs": min(len(obs.all_outputs), _MAX_OUTPUTS), "cumulative_reward": np.float32( obs.metadata.get("cumulative_reward", 0.0) if obs.metadata else 0.0 ), } def _info(self, obs: ExperimentObservation) -> Dict[str, Any]: return { "structured_obs": obs, "episode_id": obs.metadata.get("episode_id") if obs.metadata else None, }