Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Bio-Experiment Planning Environment. | |
| Implements the OpenEnv ``Environment`` interface as a POMDP where the | |
| agent proposes one structured experiment / analysis step at a time and | |
| receives simulated intermediate outputs from a latent biological world. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import ( | |
| ActionType, | |
| ConclusionClaim, | |
| ExperimentAction, | |
| ExperimentObservation, | |
| IntermediateOutput, | |
| PipelineStepRecord, | |
| ResourceUsage, | |
| TaskSpec, | |
| ) | |
| from server.rules.engine import RuleEngine | |
| from server.rewards.reward import RewardBreakdown, RewardComputer | |
| from server.simulator.latent_state import FullLatentState | |
| from server.simulator.noise import NoiseModel | |
| from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost | |
| from server.tasks.generator import TaskGenerator | |
| MAX_STEPS = 30 | |
| class BioExperimentEnvironment(Environment): | |
| """POMDP environment for iterative biological experiment planning. | |
| The agent observes ``ExperimentObservation`` (partial view) while the | |
| environment maintains a ``FullLatentState`` (hidden ground truth). | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__( | |
| self, | |
| scenario_name: Optional[str] = None, | |
| *, | |
| domain_randomise: bool = True, | |
| ) -> None: | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._latent: Optional[FullLatentState] = None | |
| self._task: Optional[TaskSpec] = None | |
| self._scenario_name = scenario_name | |
| self._noise = NoiseModel() | |
| self._engine = TransitionEngine(self._noise) | |
| self._rules = RuleEngine() | |
| self._rewards = RewardComputer() | |
| self._task_gen = TaskGenerator(domain_randomise=domain_randomise) | |
| self._history: List[PipelineStepRecord] = [] | |
| self._outputs: List[IntermediateOutput] = [] | |
| self._conclusions: List[ConclusionClaim] = [] | |
| self._subagent_outputs: List[Dict[str, Any]] = [] | |
| self._discovered_markers: List[str] = [] | |
| self._candidate_mechanisms: List[str] = [] | |
| self._cumulative_reward: float = 0.0 | |
| # ββ Environment interface βββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, seed: Optional[int] = None) -> ExperimentObservation: | |
| seed = seed if seed is not None else hash(uuid4()) % (2**31) | |
| self._noise.reseed(seed) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._task, self._latent = self._task_gen.generate( | |
| seed=seed, | |
| scenario_name=self._scenario_name, | |
| ) | |
| self._latent.rng_seed = seed | |
| self._history.clear() | |
| self._outputs.clear() | |
| self._conclusions.clear() | |
| self._subagent_outputs.clear() | |
| self._discovered_markers.clear() | |
| self._candidate_mechanisms.clear() | |
| self._cumulative_reward = 0.0 | |
| return self._build_observation(reward=0.0, done=False) | |
| def step( # type: ignore[override] | |
| self, action: ExperimentAction | |
| ) -> ExperimentObservation: | |
| assert self._latent is not None, "Call reset() before step()" | |
| assert self._task is not None | |
| self._state.step_count += 1 | |
| prev_state = self._latent.model_copy(deep=True) | |
| violations = self._rules.check(action, self._latent) | |
| hard_v = self._rules.hard_violations(violations) | |
| soft_v = self._rules.soft_violations(violations) | |
| result = self._engine.step( | |
| self._latent, | |
| action, | |
| hard_violations=hard_v, | |
| soft_violations=soft_v, | |
| ) | |
| self._latent = result.next_state | |
| step_rb = self._rewards.step_reward( | |
| action, prev_state, self._latent, result.output, hard_v, soft_v, | |
| ) | |
| cost_budget, cost_time = compute_action_cost(action) | |
| self._history.append(PipelineStepRecord( | |
| step_index=self._state.step_count, | |
| action_type=action.action_type, | |
| method=action.method, | |
| parameters=action.parameters, | |
| output_summary=result.output.summary, | |
| output_type=result.output.output_type, | |
| success=result.output.success, | |
| quality_score=result.output.quality_score, | |
| resource_cost=cost_budget, | |
| time_cost_days=cost_time, | |
| )) | |
| self._outputs.append(result.output) | |
| self._update_discoveries(action, result.output) | |
| if action.action_type == ActionType.SYNTHESIZE_CONCLUSION and result.output.success: | |
| raw_claims = action.parameters.get("claims", []) | |
| for c in raw_claims: | |
| if isinstance(c, dict): | |
| self._conclusions.append(ConclusionClaim(**c)) | |
| done = result.done or self._state.step_count >= MAX_STEPS | |
| terminal_rb = RewardBreakdown() | |
| if done: | |
| terminal_rb = self._rewards.terminal_reward( | |
| self._latent, | |
| self._conclusions, | |
| self._task.success_criteria, | |
| discovered_markers=self._discovered_markers, | |
| candidate_mechanisms=self._candidate_mechanisms, | |
| ) | |
| total_reward = step_rb.total + terminal_rb.total | |
| self._cumulative_reward += total_reward | |
| breakdown = step_rb.to_dict() | |
| breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()}) | |
| return self._build_observation( | |
| reward=total_reward, | |
| done=done, | |
| latest_output=result.output, | |
| rule_violations=hard_v + soft_v, | |
| reward_breakdown=breakdown, | |
| metadata_extra={"reward_breakdown": breakdown}, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| def set_scenario(self, scenario_name: Optional[str]) -> None: | |
| """Set the scenario used on the next reset.""" | |
| self._scenario_name = scenario_name | |
| # ββ internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation( | |
| self, | |
| *, | |
| reward: float, | |
| done: bool, | |
| latest_output: Optional[IntermediateOutput] = None, | |
| rule_violations: Optional[List[str]] = None, | |
| reward_breakdown: Optional[Dict[str, float]] = None, | |
| metadata_extra: Optional[Dict[str, Any]] = None, | |
| ) -> ExperimentObservation: | |
| assert self._task is not None | |
| assert self._latent is not None | |
| res = self._latent.resources | |
| meta: Dict[str, Any] = { | |
| "episode_id": self._state.episode_id, | |
| "step": self._state.step_count, | |
| "cumulative_reward": self._cumulative_reward, | |
| } | |
| if metadata_extra: | |
| meta.update(metadata_extra) | |
| return ExperimentObservation( | |
| task=self._task, | |
| step_index=self._state.step_count, | |
| pipeline_history=list(self._history), | |
| available_assays=list(self._task.available_assays), | |
| available_tools=list(self._task.available_tools), | |
| resource_usage=ResourceUsage( | |
| budget_used=res.budget_used, | |
| budget_remaining=res.budget_remaining, | |
| time_used_days=res.time_used_days, | |
| time_remaining_days=res.time_remaining_days, | |
| samples_consumed=res.samples_consumed, | |
| compute_hours_used=res.compute_hours_used, | |
| ), | |
| latest_output=latest_output, | |
| all_outputs=list(self._outputs), | |
| discovered_markers=list(self._discovered_markers), | |
| candidate_mechanisms=list(self._candidate_mechanisms), | |
| uncertainty_summary=self._compute_uncertainty_summary(), | |
| subagent_outputs=list(self._subagent_outputs), | |
| conclusions=list(self._conclusions), | |
| rule_violations=rule_violations or [], | |
| step_reward_breakdown=reward_breakdown or {}, | |
| done=done, | |
| reward=reward, | |
| metadata=meta, | |
| ) | |
| def _compute_uncertainty_summary(self) -> Dict[str, float]: | |
| if not self._outputs: | |
| return {} | |
| recent = self._outputs[-5:] | |
| avg_unc = sum(o.uncertainty for o in recent) / len(recent) | |
| avg_qual = sum(o.quality_score for o in recent) / len(recent) | |
| return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual} | |
| def _update_discoveries( | |
| self, action: ExperimentAction, output: IntermediateOutput | |
| ) -> None: | |
| if action.action_type == ActionType.MARKER_SELECTION: | |
| markers = output.data.get("markers", []) | |
| existing = set(self._discovered_markers) | |
| for m in markers: | |
| if m not in existing: | |
| self._discovered_markers.append(m) | |
| existing.add(m) | |
| if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE: | |
| regs = output.data.get("top_regulators", []) | |
| existing = set(self._candidate_mechanisms) | |
| for r in regs: | |
| if r not in existing: | |
| self._candidate_mechanisms.append(r) | |
| existing.add(r) | |
| if action.action_type == ActionType.PATHWAY_ENRICHMENT: | |
| pathways = output.data.get("top_pathways", []) | |
| existing = set(self._candidate_mechanisms) | |
| for p in pathways: | |
| if isinstance(p, dict) and p["pathway"] not in existing: | |
| self._candidate_mechanisms.append(p["pathway"]) | |
| existing.add(p["pathway"]) | |