"""Bio-Experiment Planning Environment. Implements the OpenEnv ``Environment`` interface as a POMDP where the agent proposes one structured experiment / analysis step at a time and receives simulated intermediate outputs from a latent biological world. """ from __future__ import annotations from typing import Any, Dict, List, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from models import ( ActionType, ConclusionClaim, ExperimentAction, ExperimentObservation, IntermediateOutput, PipelineStepRecord, ResourceUsage, TaskSpec, ) from server.rules.engine import RuleEngine from server.rewards.reward import RewardBreakdown, RewardComputer from server.simulator.latent_state import FullLatentState from server.simulator.noise import NoiseModel from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost from server.tasks.generator import TaskGenerator MAX_STEPS = 30 class BioExperimentEnvironment(Environment): """POMDP environment for iterative biological experiment planning. The agent observes ``ExperimentObservation`` (partial view) while the environment maintains a ``FullLatentState`` (hidden ground truth). """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__( self, scenario_name: Optional[str] = None, *, domain_randomise: bool = True, ) -> None: self._state = State(episode_id=str(uuid4()), step_count=0) self._latent: Optional[FullLatentState] = None self._task: Optional[TaskSpec] = None self._scenario_name = scenario_name self._noise = NoiseModel() self._engine = TransitionEngine(self._noise) self._rules = RuleEngine() self._rewards = RewardComputer() self._task_gen = TaskGenerator(domain_randomise=domain_randomise) self._history: List[PipelineStepRecord] = [] self._outputs: List[IntermediateOutput] = [] self._conclusions: List[ConclusionClaim] = [] self._subagent_outputs: List[Dict[str, Any]] = [] self._discovered_markers: List[str] = [] self._candidate_mechanisms: List[str] = [] self._cumulative_reward: float = 0.0 # ── Environment interface ─────────────────────────────────────────── def reset(self, seed: Optional[int] = None) -> ExperimentObservation: seed = seed if seed is not None else hash(uuid4()) % (2**31) self._noise.reseed(seed) self._state = State(episode_id=str(uuid4()), step_count=0) self._task, self._latent = self._task_gen.generate( seed=seed, scenario_name=self._scenario_name, ) self._latent.rng_seed = seed self._history.clear() self._outputs.clear() self._conclusions.clear() self._subagent_outputs.clear() self._discovered_markers.clear() self._candidate_mechanisms.clear() self._cumulative_reward = 0.0 return self._build_observation(reward=0.0, done=False) def step( # type: ignore[override] self, action: ExperimentAction ) -> ExperimentObservation: assert self._latent is not None, "Call reset() before step()" assert self._task is not None self._state.step_count += 1 prev_state = self._latent.model_copy(deep=True) violations = self._rules.check(action, self._latent) hard_v = self._rules.hard_violations(violations) soft_v = self._rules.soft_violations(violations) result = self._engine.step( self._latent, action, hard_violations=hard_v, soft_violations=soft_v, ) self._latent = result.next_state step_rb = self._rewards.step_reward( action, prev_state, self._latent, result.output, hard_v, soft_v, ) cost_budget, cost_time = compute_action_cost(action) self._history.append(PipelineStepRecord( step_index=self._state.step_count, action_type=action.action_type, method=action.method, parameters=action.parameters, output_summary=result.output.summary, output_type=result.output.output_type, success=result.output.success, quality_score=result.output.quality_score, resource_cost=cost_budget, time_cost_days=cost_time, )) self._outputs.append(result.output) self._update_discoveries(action, result.output) if action.action_type == ActionType.SYNTHESIZE_CONCLUSION: raw_claims = action.parameters.get("claims", []) for c in raw_claims: if isinstance(c, dict): self._conclusions.append(ConclusionClaim(**c)) done = result.done or self._state.step_count >= MAX_STEPS terminal_rb = RewardBreakdown() if done: terminal_rb = self._rewards.terminal_reward( self._latent, self._conclusions, self._task.success_criteria, discovered_markers=self._discovered_markers, candidate_mechanisms=self._candidate_mechanisms, ) total_reward = step_rb.total + terminal_rb.total self._cumulative_reward += total_reward breakdown = step_rb.to_dict() breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()}) return self._build_observation( reward=total_reward, done=done, latest_output=result.output, rule_violations=hard_v + soft_v, reward_breakdown=breakdown, metadata_extra={"reward_breakdown": breakdown}, ) @property def state(self) -> State: return self._state def set_scenario(self, scenario_name: Optional[str]) -> None: """Set the scenario used on the next reset.""" self._scenario_name = scenario_name # ── internal helpers ──────────────────────────────────────────────── def _build_observation( self, *, reward: float, done: bool, latest_output: Optional[IntermediateOutput] = None, rule_violations: Optional[List[str]] = None, reward_breakdown: Optional[Dict[str, float]] = None, metadata_extra: Optional[Dict[str, Any]] = None, ) -> ExperimentObservation: assert self._task is not None assert self._latent is not None res = self._latent.resources meta: Dict[str, Any] = { "episode_id": self._state.episode_id, "step": self._state.step_count, "cumulative_reward": self._cumulative_reward, } if metadata_extra: meta.update(metadata_extra) return ExperimentObservation( task=self._task, step_index=self._state.step_count, pipeline_history=list(self._history), available_assays=list(self._task.available_assays), available_tools=list(self._task.available_tools), resource_usage=ResourceUsage( budget_used=res.budget_used, budget_remaining=res.budget_remaining, time_used_days=res.time_used_days, time_remaining_days=res.time_remaining_days, samples_consumed=res.samples_consumed, compute_hours_used=res.compute_hours_used, ), latest_output=latest_output, all_outputs=list(self._outputs), discovered_markers=list(self._discovered_markers), candidate_mechanisms=list(self._candidate_mechanisms), uncertainty_summary=self._compute_uncertainty_summary(), subagent_outputs=list(self._subagent_outputs), conclusions=list(self._conclusions), rule_violations=rule_violations or [], step_reward_breakdown={}, done=done, reward=reward, metadata=meta, ) def _compute_uncertainty_summary(self) -> Dict[str, float]: if not self._outputs: return {} recent = self._outputs[-5:] avg_unc = sum(o.uncertainty for o in recent) / len(recent) avg_qual = sum(o.quality_score for o in recent) / len(recent) return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual} def _update_discoveries( self, action: ExperimentAction, output: IntermediateOutput ) -> None: if action.action_type == ActionType.MARKER_SELECTION: markers = output.data.get("markers", []) existing = set(self._discovered_markers) for m in markers: if m not in existing: self._discovered_markers.append(m) existing.add(m) if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE: regs = output.data.get("top_regulators", []) existing = set(self._candidate_mechanisms) for r in regs: if r not in existing: self._candidate_mechanisms.append(r) existing.add(r) if action.action_type == ActionType.PATHWAY_ENRICHMENT: pathways = output.data.get("top_pathways", []) existing = set(self._candidate_mechanisms) for p in pathways: if isinstance(p, dict) and p["pathway"] not in existing: self._candidate_mechanisms.append(p["pathway"]) existing.add(p["pathway"])