File size: 7,040 Bytes
4db0438 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """Gymnasium-compatible wrapper around ``BioExperimentEnvironment``.
Provides ``BioExperimentGymEnv`` which wraps the OpenEnv environment for
local in-process RL training (no HTTP/WebSocket overhead).
Observation and action spaces are represented as ``gymnasium.spaces.Dict``
so that standard RL libraries (SB3, CleanRL, etc.) can ingest them.
"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from models import ActionType, ExperimentAction, ExperimentObservation
from server.hackathon_environment import BioExperimentEnvironment, MAX_STEPS
ACTION_TYPE_LIST = list(ActionType)
_N_ACTION_TYPES = len(ACTION_TYPE_LIST)
_MAX_OUTPUTS = MAX_STEPS
_MAX_HISTORY = MAX_STEPS
_VEC_DIM = 64
class BioExperimentGymEnv(gym.Env):
"""Gymnasium ``Env`` backed by the in-process simulator.
Observations are flattened into a dictionary of NumPy arrays suitable
for RL policy networks. Actions are integer-indexed action types with
a continuous confidence scalar.
For LLM-based agents or planners that prefer structured
``ExperimentAction`` objects, use the underlying
``BioExperimentEnvironment`` directly instead.
"""
metadata = {"render_modes": ["human"]}
def __init__(self, render_mode: Optional[str] = None):
super().__init__()
self._env = BioExperimentEnvironment()
self.render_mode = render_mode
self.action_space = spaces.Dict({
"action_type": spaces.Discrete(_N_ACTION_TYPES),
"confidence": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
})
self.observation_space = spaces.Dict({
"step_index": spaces.Discrete(MAX_STEPS + 1),
"budget_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
"time_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
"progress_flags": spaces.MultiBinary(18),
"latest_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
"latest_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
"avg_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
"avg_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
"n_violations": spaces.Discrete(20),
"n_outputs": spaces.Discrete(_MAX_OUTPUTS + 1),
"cumulative_reward": spaces.Box(-100.0, 100.0, shape=(), dtype=np.float32),
})
self._last_obs: Optional[ExperimentObservation] = None
# ββ Gymnasium interface βββββββββββββββββββββββββββββββββββββββββββββ
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
super().reset(seed=seed)
obs = self._env.reset()
self._last_obs = obs
return self._vectorise(obs), self._info(obs)
def step(
self, action: Dict[str, Any]
) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
action_idx = int(action["action_type"])
confidence = float(action.get("confidence", 0.5))
experiment_action = ExperimentAction(
action_type=ACTION_TYPE_LIST[action_idx],
confidence=confidence,
)
obs = self._env.step(experiment_action)
self._last_obs = obs
terminated = obs.done
truncated = obs.step_index >= MAX_STEPS and not terminated
reward = obs.reward
return (
self._vectorise(obs),
reward,
terminated,
truncated,
self._info(obs),
)
def render(self) -> Optional[str]:
if self.render_mode != "human" or self._last_obs is None:
return None
obs = self._last_obs
lines = [
f"Step {obs.step_index}",
f" Task: {obs.task.problem_statement[:80]}",
f" Budget: ${obs.resource_usage.budget_remaining:,.0f} remaining",
f" Time: {obs.resource_usage.time_remaining_days:.0f} days remaining",
]
if obs.latest_output:
lines.append(f" Latest: {obs.latest_output.summary}")
if obs.rule_violations:
lines.append(f" Violations: {obs.rule_violations}")
text = "\n".join(lines)
print(text)
return text
# ββ helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _vectorise(self, obs: ExperimentObservation) -> Dict[str, Any]:
progress = self._env._latent.progress if self._env._latent else None
flags = np.zeros(18, dtype=np.int8)
if progress:
flag_names = [
"samples_collected", "cohort_selected", "cells_cultured",
"library_prepared", "perturbation_applied", "cells_sequenced",
"qc_performed", "data_filtered", "data_normalized",
"batches_integrated", "cells_clustered", "de_performed",
"trajectories_inferred", "pathways_analyzed",
"networks_inferred", "markers_discovered",
"markers_validated", "conclusion_reached",
]
for i, f in enumerate(flag_names):
flags[i] = int(getattr(progress, f, False))
unc = obs.uncertainty_summary
lo = obs.latest_output
return {
"step_index": obs.step_index,
"budget_remaining_frac": np.float32(
obs.resource_usage.budget_remaining
/ max(obs.task.budget_limit, 1)
),
"time_remaining_frac": np.float32(
obs.resource_usage.time_remaining_days
/ max(obs.task.time_limit_days, 1)
),
"progress_flags": flags,
"latest_quality": np.float32(lo.quality_score if lo else 0.0),
"latest_uncertainty": np.float32(lo.uncertainty if lo else 0.0),
"avg_quality": np.float32(unc.get("avg_quality", 0.0)),
"avg_uncertainty": np.float32(unc.get("avg_uncertainty", 0.0)),
"n_violations": min(len(obs.rule_violations), 19),
"n_outputs": min(len(obs.all_outputs), _MAX_OUTPUTS),
"cumulative_reward": np.float32(
obs.metadata.get("cumulative_reward", 0.0)
if obs.metadata else 0.0
),
}
def _info(self, obs: ExperimentObservation) -> Dict[str, Any]:
return {
"structured_obs": obs,
"episode_id": obs.metadata.get("episode_id") if obs.metadata else None,
}
|