bio-experiment / server /hackathon_environment.py
Ev3Dev's picture
Upload folder using huggingface_hub
db03c40 verified
"""Bio-Experiment Planning Environment.
Implements the OpenEnv ``Environment`` interface as a POMDP where the
agent proposes one structured experiment / analysis step at a time and
receives simulated intermediate outputs from a latent biological world.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import (
ActionType,
ConclusionClaim,
ExperimentAction,
ExperimentObservation,
IntermediateOutput,
PipelineStepRecord,
ResourceUsage,
TaskSpec,
)
from server.rules.engine import RuleEngine
from server.rewards.reward import RewardBreakdown, RewardComputer
from server.simulator.latent_state import FullLatentState
from server.simulator.noise import NoiseModel
from server.simulator.transition import ACTION_COSTS, TransitionEngine, compute_action_cost
from server.tasks.generator import TaskGenerator
MAX_STEPS = 30
class BioExperimentEnvironment(Environment):
"""POMDP environment for iterative biological experiment planning.
The agent observes ``ExperimentObservation`` (partial view) while the
environment maintains a ``FullLatentState`` (hidden ground truth).
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
scenario_name: Optional[str] = None,
*,
domain_randomise: bool = True,
) -> None:
self._state = State(episode_id=str(uuid4()), step_count=0)
self._latent: Optional[FullLatentState] = None
self._task: Optional[TaskSpec] = None
self._scenario_name = scenario_name
self._noise = NoiseModel()
self._engine = TransitionEngine(self._noise)
self._rules = RuleEngine()
self._rewards = RewardComputer()
self._task_gen = TaskGenerator(domain_randomise=domain_randomise)
self._history: List[PipelineStepRecord] = []
self._outputs: List[IntermediateOutput] = []
self._conclusions: List[ConclusionClaim] = []
self._subagent_outputs: List[Dict[str, Any]] = []
self._discovered_markers: List[str] = []
self._candidate_mechanisms: List[str] = []
self._cumulative_reward: float = 0.0
# ── Environment interface ───────────────────────────────────────────
def reset(self, seed: Optional[int] = None) -> ExperimentObservation:
seed = seed if seed is not None else hash(uuid4()) % (2**31)
self._noise.reseed(seed)
self._state = State(episode_id=str(uuid4()), step_count=0)
self._task, self._latent = self._task_gen.generate(
seed=seed,
scenario_name=self._scenario_name,
)
self._latent.rng_seed = seed
self._history.clear()
self._outputs.clear()
self._conclusions.clear()
self._subagent_outputs.clear()
self._discovered_markers.clear()
self._candidate_mechanisms.clear()
self._cumulative_reward = 0.0
return self._build_observation(reward=0.0, done=False)
def step( # type: ignore[override]
self, action: ExperimentAction
) -> ExperimentObservation:
assert self._latent is not None, "Call reset() before step()"
assert self._task is not None
self._state.step_count += 1
prev_state = self._latent.model_copy(deep=True)
violations = self._rules.check(action, self._latent)
hard_v = self._rules.hard_violations(violations)
soft_v = self._rules.soft_violations(violations)
result = self._engine.step(
self._latent,
action,
hard_violations=hard_v,
soft_violations=soft_v,
)
self._latent = result.next_state
step_rb = self._rewards.step_reward(
action, prev_state, self._latent, result.output, hard_v, soft_v,
)
cost_budget, cost_time = compute_action_cost(action)
self._history.append(PipelineStepRecord(
step_index=self._state.step_count,
action_type=action.action_type,
method=action.method,
parameters=action.parameters,
output_summary=result.output.summary,
output_type=result.output.output_type,
success=result.output.success,
quality_score=result.output.quality_score,
resource_cost=cost_budget,
time_cost_days=cost_time,
))
self._outputs.append(result.output)
self._update_discoveries(action, result.output)
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION and result.output.success:
raw_claims = action.parameters.get("claims", [])
for c in raw_claims:
if isinstance(c, dict):
self._conclusions.append(ConclusionClaim(**c))
done = result.done or self._state.step_count >= MAX_STEPS
terminal_rb = RewardBreakdown()
if done:
terminal_rb = self._rewards.terminal_reward(
self._latent,
self._conclusions,
self._task.success_criteria,
discovered_markers=self._discovered_markers,
candidate_mechanisms=self._candidate_mechanisms,
)
total_reward = step_rb.total + terminal_rb.total
self._cumulative_reward += total_reward
breakdown = step_rb.to_dict()
breakdown.update({f"term_{k}": v for k, v in terminal_rb.to_dict().items()})
return self._build_observation(
reward=total_reward,
done=done,
latest_output=result.output,
rule_violations=hard_v + soft_v,
reward_breakdown=breakdown,
metadata_extra={"reward_breakdown": breakdown},
)
@property
def state(self) -> State:
return self._state
def set_scenario(self, scenario_name: Optional[str]) -> None:
"""Set the scenario used on the next reset."""
self._scenario_name = scenario_name
# ── internal helpers ────────────────────────────────────────────────
def _build_observation(
self,
*,
reward: float,
done: bool,
latest_output: Optional[IntermediateOutput] = None,
rule_violations: Optional[List[str]] = None,
reward_breakdown: Optional[Dict[str, float]] = None,
metadata_extra: Optional[Dict[str, Any]] = None,
) -> ExperimentObservation:
assert self._task is not None
assert self._latent is not None
res = self._latent.resources
meta: Dict[str, Any] = {
"episode_id": self._state.episode_id,
"step": self._state.step_count,
"cumulative_reward": self._cumulative_reward,
}
if metadata_extra:
meta.update(metadata_extra)
return ExperimentObservation(
task=self._task,
step_index=self._state.step_count,
pipeline_history=list(self._history),
available_assays=list(self._task.available_assays),
available_tools=list(self._task.available_tools),
resource_usage=ResourceUsage(
budget_used=res.budget_used,
budget_remaining=res.budget_remaining,
time_used_days=res.time_used_days,
time_remaining_days=res.time_remaining_days,
samples_consumed=res.samples_consumed,
compute_hours_used=res.compute_hours_used,
),
latest_output=latest_output,
all_outputs=list(self._outputs),
discovered_markers=list(self._discovered_markers),
candidate_mechanisms=list(self._candidate_mechanisms),
uncertainty_summary=self._compute_uncertainty_summary(),
subagent_outputs=list(self._subagent_outputs),
conclusions=list(self._conclusions),
rule_violations=rule_violations or [],
step_reward_breakdown=reward_breakdown or {},
done=done,
reward=reward,
metadata=meta,
)
def _compute_uncertainty_summary(self) -> Dict[str, float]:
if not self._outputs:
return {}
recent = self._outputs[-5:]
avg_unc = sum(o.uncertainty for o in recent) / len(recent)
avg_qual = sum(o.quality_score for o in recent) / len(recent)
return {"avg_uncertainty": avg_unc, "avg_quality": avg_qual}
def _update_discoveries(
self, action: ExperimentAction, output: IntermediateOutput
) -> None:
if action.action_type == ActionType.MARKER_SELECTION:
markers = output.data.get("markers", [])
existing = set(self._discovered_markers)
for m in markers:
if m not in existing:
self._discovered_markers.append(m)
existing.add(m)
if action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE:
regs = output.data.get("top_regulators", [])
existing = set(self._candidate_mechanisms)
for r in regs:
if r not in existing:
self._candidate_mechanisms.append(r)
existing.add(r)
if action.action_type == ActionType.PATHWAY_ENRICHMENT:
pathways = output.data.get("top_pathways", [])
existing = set(self._candidate_mechanisms)
for p in pathways:
if isinstance(p, dict) and p["pathway"] not in existing:
self._candidate_mechanisms.append(p["pathway"])
existing.add(p["pathway"])