File size: 7,040 Bytes
4db0438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Gymnasium-compatible wrapper around ``BioExperimentEnvironment``.



Provides ``BioExperimentGymEnv`` which wraps the OpenEnv environment for

local in-process RL training (no HTTP/WebSocket overhead).



Observation and action spaces are represented as ``gymnasium.spaces.Dict``

so that standard RL libraries (SB3, CleanRL, etc.) can ingest them.

"""

from __future__ import annotations

from typing import Any, Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from models import ActionType, ExperimentAction, ExperimentObservation
from server.hackathon_environment import BioExperimentEnvironment, MAX_STEPS


ACTION_TYPE_LIST = list(ActionType)
_N_ACTION_TYPES = len(ACTION_TYPE_LIST)

_MAX_OUTPUTS = MAX_STEPS
_MAX_HISTORY = MAX_STEPS
_VEC_DIM = 64


class BioExperimentGymEnv(gym.Env):
    """Gymnasium ``Env`` backed by the in-process simulator.



    Observations are flattened into a dictionary of NumPy arrays suitable

    for RL policy networks.  Actions are integer-indexed action types with

    a continuous confidence scalar.



    For LLM-based agents or planners that prefer structured

    ``ExperimentAction`` objects, use the underlying

    ``BioExperimentEnvironment`` directly instead.

    """

    metadata = {"render_modes": ["human"]}

    def __init__(self, render_mode: Optional[str] = None):
        super().__init__()
        self._env = BioExperimentEnvironment()
        self.render_mode = render_mode

        self.action_space = spaces.Dict({
            "action_type": spaces.Discrete(_N_ACTION_TYPES),
            "confidence": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
        })

        self.observation_space = spaces.Dict({
            "step_index": spaces.Discrete(MAX_STEPS + 1),
            "budget_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
            "time_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
            "progress_flags": spaces.MultiBinary(18),
            "latest_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
            "latest_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
            "avg_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
            "avg_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
            "n_violations": spaces.Discrete(20),
            "n_outputs": spaces.Discrete(_MAX_OUTPUTS + 1),
            "cumulative_reward": spaces.Box(-100.0, 100.0, shape=(), dtype=np.float32),
        })

        self._last_obs: Optional[ExperimentObservation] = None

    # ── Gymnasium interface ─────────────────────────────────────────────

    def reset(

        self,

        *,

        seed: Optional[int] = None,

        options: Optional[Dict[str, Any]] = None,

    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        super().reset(seed=seed)
        obs = self._env.reset()
        self._last_obs = obs
        return self._vectorise(obs), self._info(obs)

    def step(

        self, action: Dict[str, Any]

    ) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
        action_idx = int(action["action_type"])
        confidence = float(action.get("confidence", 0.5))

        experiment_action = ExperimentAction(
            action_type=ACTION_TYPE_LIST[action_idx],
            confidence=confidence,
        )
        obs = self._env.step(experiment_action)
        self._last_obs = obs

        terminated = obs.done
        truncated = obs.step_index >= MAX_STEPS and not terminated
        reward = obs.reward

        return (
            self._vectorise(obs),
            reward,
            terminated,
            truncated,
            self._info(obs),
        )

    def render(self) -> Optional[str]:
        if self.render_mode != "human" or self._last_obs is None:
            return None
        obs = self._last_obs
        lines = [
            f"Step {obs.step_index}",
            f"  Task: {obs.task.problem_statement[:80]}",
            f"  Budget: ${obs.resource_usage.budget_remaining:,.0f} remaining",
            f"  Time: {obs.resource_usage.time_remaining_days:.0f} days remaining",
        ]
        if obs.latest_output:
            lines.append(f"  Latest: {obs.latest_output.summary}")
        if obs.rule_violations:
            lines.append(f"  Violations: {obs.rule_violations}")
        text = "\n".join(lines)
        print(text)
        return text

    # ── helpers ─────────────────────────────────────────────────────────

    def _vectorise(self, obs: ExperimentObservation) -> Dict[str, Any]:
        progress = self._env._latent.progress if self._env._latent else None
        flags = np.zeros(18, dtype=np.int8)
        if progress:
            flag_names = [
                "samples_collected", "cohort_selected", "cells_cultured",
                "library_prepared", "perturbation_applied", "cells_sequenced",
                "qc_performed", "data_filtered", "data_normalized",
                "batches_integrated", "cells_clustered", "de_performed",
                "trajectories_inferred", "pathways_analyzed",
                "networks_inferred", "markers_discovered",
                "markers_validated", "conclusion_reached",
            ]
            for i, f in enumerate(flag_names):
                flags[i] = int(getattr(progress, f, False))

        unc = obs.uncertainty_summary
        lo = obs.latest_output

        return {
            "step_index": obs.step_index,
            "budget_remaining_frac": np.float32(
                obs.resource_usage.budget_remaining
                / max(obs.task.budget_limit, 1)
            ),
            "time_remaining_frac": np.float32(
                obs.resource_usage.time_remaining_days
                / max(obs.task.time_limit_days, 1)
            ),
            "progress_flags": flags,
            "latest_quality": np.float32(lo.quality_score if lo else 0.0),
            "latest_uncertainty": np.float32(lo.uncertainty if lo else 0.0),
            "avg_quality": np.float32(unc.get("avg_quality", 0.0)),
            "avg_uncertainty": np.float32(unc.get("avg_uncertainty", 0.0)),
            "n_violations": min(len(obs.rule_violations), 19),
            "n_outputs": min(len(obs.all_outputs), _MAX_OUTPUTS),
            "cumulative_reward": np.float32(
                obs.metadata.get("cumulative_reward", 0.0)
                if obs.metadata else 0.0
            ),
        }

    def _info(self, obs: ExperimentObservation) -> Dict[str, Any]:
        return {
            "structured_obs": obs,
            "episode_id": obs.metadata.get("episode_id") if obs.metadata else None,
        }