| | """Bio-Experiment Planning Environment.
|
| |
|
| | Implements the OpenEnv ``Environment`` interface as a POMDP where the
|
| | agent proposes one structured experiment / analysis step at a time and
|
| | receives simulated intermediate outputs from a latent biological world.
|
| | """
|
| |
|
| | from __future__ import annotations
|
| |
|
| | from typing import Any, Dict, List, Optional
|
| | from uuid import uuid4
|
| |
|
| | from openenv.core.env_server.interfaces import Environment
|
| | from openenv.core.env_server.types import State
|
| |
|
| | from models import (
|
| | ActionType,
|
| | ConclusionClaim,
|
| | ExperimentAction,
|
| | ExperimentObservation,
|
| | IntermediateOutput,
|
| | PipelineStepRecord,
|
| | ResourceUsage,
|
| | TaskSpec,
|
| | )
|
| |
|
| | from server.rules.engine import RuleEngine
|
| | from server.rewards.reward import RewardBreakdown, RewardComputer
|
| | from server.simulator.latent_state import FullLatentState
|
| | from server.simulator.noise import NoiseModel
|
| | from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost
|
| | from server.tasks.generator import TaskGenerator
|
| |
|
| |
|
| | MAX_STEPS = 30
|
| |
|
| |
|
| | class BioExperimentEnvironment(Environment):
|
| | """POMDP environment for iterative biological experiment planning.
|
| |
|
| | The agent observes ``ExperimentObservation`` (partial view) while the
|
| | environment maintains a ``FullLatentState`` (hidden ground truth).
|
| | """
|
| |
|
| | SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| |
|
| | def __init__(
|
| | self,
|
| | scenario_name: Optional[str] = None,
|
| | *,
|
| | domain_randomise: bool = True,
|
| | ) -> None:
|
| | self._state = State(episode_id=str(uuid4()), step_count=0)
|
| | self._latent: Optional[FullLatentState] = None
|
| | self._task: Optional[TaskSpec] = None
|
| | self._scenario_name = scenario_name
|
| | self._noise = NoiseModel()
|
| | self._engine = TransitionEngine(self._noise)
|
| | self._rules = RuleEngine()
|
| | self._rewards = RewardComputer()
|
| | self._task_gen = TaskGenerator(domain_randomise=domain_randomise)
|
| |
|
| | self._history: List[PipelineStepRecord] = []
|
| | self._outputs: List[IntermediateOutput] = []
|
| | self._conclusions: List[ConclusionClaim] = []
|
| | self._subagent_outputs: List[Dict[str, Any]] = []
|
| | self._discovered_markers: List[str] = []
|
| | self._candidate_mechanisms: List[str] = []
|
| | self._cumulative_reward: float = 0.0
|
| |
|
| |
|
| |
|
| | def reset(self, seed: Optional[int] = None) -> ExperimentObservation:
|
| | seed = seed if seed is not None else hash(uuid4()) % (2**31)
|
| | self._noise.reseed(seed)
|
| | self._state = State(episode_id=str(uuid4()), step_count=0)
|
| |
|
| | self._task, self._latent = self._task_gen.generate(
|
| | seed=seed,
|
| | scenario_name=self._scenario_name,
|
| | )
|
| | self._latent.rng_seed = seed
|
| |
|
| | self._history.clear()
|
| | self._outputs.clear()
|
| | self._conclusions.clear()
|
| | self._subagent_outputs.clear()
|
| | self._discovered_markers.clear()
|
| | self._candidate_mechanisms.clear()
|
| | self._cumulative_reward = 0.0
|
| |
|
| | return self._build_observation(reward=0.0, done=False)
|
| |
|
| | def step(
|
| | self, action: ExperimentAction
|
| | ) -> ExperimentObservation:
|
| | assert self._latent is not None, "Call reset() before step()"
|
| | assert self._task is not None
|
| |
|
| | self._state.step_count += 1
|
| | prev_state = self._latent.model_copy(deep=True)
|
| |
|
| | violations = self._rules.check(action, self._latent)
|
| | hard_v = self._rules.hard_violations(violations)
|
| | soft_v = self._rules.soft_violations(violations)
|
| |
|
| | result = self._engine.step(
|
| | self._latent,
|
| | action,
|
| | hard_violations=hard_v,
|
| | soft_violations=soft_v,
|
| | )
|
| | self._latent = result.next_state
|
| |
|
| | step_rb = self._rewards.step_reward(
|
| | action, prev_state, self._latent, result.output, hard_v, soft_v,
|
| | )
|
| |
|
| | cost_budget, cost_time = compute_action_cost(action)
|
| | self._history.append(PipelineStepRecord(
|
| | step_index=self._state.step_count,
|
| | action_type=action.action_type,
|
| | method=action.method,
|
| | parameters=action.parameters,
|
| | output_summary=result.output.summary,
|
| | output_type=result.output.output_type,
|
| | success=result.output.success,
|
| | quality_score=result.output.quality_score,
|
| | resource_cost=cost_budget,
|
| | time_cost_days=cost_time,
|
| | ))
|
| | self._outputs.append(result.output)
|
| | self._update_discoveries(action, result.output)
|
| |
|
| | if action.action_type == ActionType.SYNTHESIZE_CONCLUSION and result.output.success:
|
| | raw_claims = action.parameters.get("claims", [])
|
| | for c in raw_claims:
|
| | if isinstance(c, dict):
|
| | self._conclusions.append(ConclusionClaim(**c))
|
| |
|
| | done = result.done or self._state.step_count >= MAX_STEPS
|
| |
|
| | terminal_rb = RewardBreakdown()
|
| | if done:
|
| | terminal_rb = self._rewards.terminal_reward(
|
| | self._latent,
|
| | self._conclusions,
|
| | self._task.success_criteria,
|
| | discovered_markers=self._discovered_markers,
|
| | candidate_mechanisms=self._candidate_mechanisms,
|
| | )
|
| |
|
| | total_reward = step_rb.total + terminal_rb.total
|
| | self._cumulative_reward += total_reward
|
| |
|
| | breakdown = step_rb.to_dict()
|
| | breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()})
|
| |
|
| | return self._build_observation(
|
| | reward=total_reward,
|
| | done=done,
|
| | latest_output=result.output,
|
| | rule_violations=hard_v + soft_v,
|
| | reward_breakdown=breakdown,
|
| | metadata_extra={"reward_breakdown": breakdown},
|
| | )
|
| |
|
| | @property
|
| | def state(self) -> State:
|
| | return self._state
|
| |
|
| | def set_scenario(self, scenario_name: Optional[str]) -> None:
|
| | """Set the scenario used on the next reset."""
|
| |
|
| | self._scenario_name = scenario_name
|
| |
|
| |
|
| |
|
| | def _build_observation(
|
| | self,
|
| | *,
|
| | reward: float,
|
| | done: bool,
|
| | latest_output: Optional[IntermediateOutput] = None,
|
| | rule_violations: Optional[List[str]] = None,
|
| | reward_breakdown: Optional[Dict[str, float]] = None,
|
| | metadata_extra: Optional[Dict[str, Any]] = None,
|
| | ) -> ExperimentObservation:
|
| | assert self._task is not None
|
| | assert self._latent is not None
|
| | res = self._latent.resources
|
| | meta: Dict[str, Any] = {
|
| | "episode_id": self._state.episode_id,
|
| | "step": self._state.step_count,
|
| | "cumulative_reward": self._cumulative_reward,
|
| | }
|
| | if metadata_extra:
|
| | meta.update(metadata_extra)
|
| | return ExperimentObservation(
|
| | task=self._task,
|
| | step_index=self._state.step_count,
|
| | pipeline_history=list(self._history),
|
| | available_assays=list(self._task.available_assays),
|
| | available_tools=list(self._task.available_tools),
|
| | resource_usage=ResourceUsage(
|
| | budget_used=res.budget_used,
|
| | budget_remaining=res.budget_remaining,
|
| | time_used_days=res.time_used_days,
|
| | time_remaining_days=res.time_remaining_days,
|
| | samples_consumed=res.samples_consumed,
|
| | compute_hours_used=res.compute_hours_used,
|
| | ),
|
| | latest_output=latest_output,
|
| | all_outputs=list(self._outputs),
|
| | discovered_markers=list(self._discovered_markers),
|
| | candidate_mechanisms=list(self._candidate_mechanisms),
|
| | uncertainty_summary=self._compute_uncertainty_summary(),
|
| | subagent_outputs=list(self._subagent_outputs),
|
| | conclusions=list(self._conclusions),
|
| | rule_violations=rule_violations or [],
|
| | step_reward_breakdown=reward_breakdown or {},
|
| | done=done,
|
| | reward=reward,
|
| | metadata=meta,
|
| | )
|
| |
|
| | def _compute_uncertainty_summary(self) -> Dict[str, float]:
|
| | if not self._outputs:
|
| | return {}
|
| | recent = self._outputs[-5:]
|
| | avg_unc = sum(o.uncertainty for o in recent) / len(recent)
|
| | avg_qual = sum(o.quality_score for o in recent) / len(recent)
|
| | return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual}
|
| |
|
| | def _update_discoveries(
|
| | self, action: ExperimentAction, output: IntermediateOutput
|
| | ) -> None:
|
| | if action.action_type == ActionType.MARKER_SELECTION:
|
| | markers = output.data.get("markers", [])
|
| | existing = set(self._discovered_markers)
|
| | for m in markers:
|
| | if m not in existing:
|
| | self._discovered_markers.append(m)
|
| | existing.add(m)
|
| | if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
|
| | regs = output.data.get("top_regulators", [])
|
| | existing = set(self._candidate_mechanisms)
|
| | for r in regs:
|
| | if r not in existing:
|
| | self._candidate_mechanisms.append(r)
|
| | existing.add(r)
|
| | if action.action_type == ActionType.PATHWAY_ENRICHMENT:
|
| | pathways = output.data.get("top_pathways", [])
|
| | existing = set(self._candidate_mechanisms)
|
| | for p in pathways:
|
| | if isinstance(p, dict) and p["pathway"] not in existing:
|
| | self._candidate_mechanisms.append(p["pathway"])
|
| | existing.add(p["pathway"])
|
| |
|