Ev3Dev's picture
Upload folder using huggingface_hub
5c3cfae verified
"""Transition dynamics engine β€” the heart of the biological simulator.
Orchestrates latent-state updates, output generation, resource accounting,
and constraint propagation for every agent action.
"""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from models import (
ActionType,
ExperimentAction,
IntermediateOutput,
OutputType,
TOOL_REGISTRY,
)
from .latent_state import FullLatentState
from .noise import NoiseModel
from .output_generator import OutputGenerator
# Fallback costs per ActionType when the agent doesn't specify a known tool.
_BASE_ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
ActionType.COLLECT_SAMPLE: (5_000, 7.0),
ActionType.SELECT_COHORT: ( 500, 1.0),
ActionType.PREPARE_LIBRARY: (8_000, 3.0),
ActionType.CULTURE_CELLS: (3_000, 14.0),
ActionType.PERTURB_GENE: (2_000, 3.0),
ActionType.PERTURB_COMPOUND: (1_000, 2.0),
ActionType.SEQUENCE_CELLS: (15_000, 5.0),
ActionType.RUN_QC: ( 100, 0.5),
ActionType.FILTER_DATA: ( 50, 0.25),
ActionType.NORMALIZE_DATA: ( 50, 0.25),
ActionType.INTEGRATE_BATCHES: ( 100, 0.5),
ActionType.CLUSTER_CELLS: ( 100, 0.5),
ActionType.DIFFERENTIAL_EXPRESSION: ( 100, 0.5),
ActionType.TRAJECTORY_ANALYSIS: ( 200, 1.0),
ActionType.PATHWAY_ENRICHMENT: ( 100, 0.5),
ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
ActionType.MARKER_SELECTION: ( 100, 0.5),
ActionType.VALIDATE_MARKER: (5_000, 14.0),
ActionType.DESIGN_FOLLOWUP: ( 100, 0.5),
ActionType.REQUEST_SUBAGENT_REVIEW: ( 50, 0.25),
ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
}
# Kept as public alias so existing imports (e.g. hackathon_environment) still work.
ACTION_COSTS = _BASE_ACTION_COSTS
def compute_action_cost(action: ExperimentAction) -> Tuple[float, float]:
"""Return (budget_cost, time_cost_days) for an action.
If the action specifies a ``method`` that exists in ``TOOL_REGISTRY``,
the tool's ``typical_cost_usd`` and ``typical_runtime_hours`` are used
(converted to days). Otherwise we fall back to the per-ActionType base
cost table.
"""
tool_spec = TOOL_REGISTRY.get(action.method or "")
if tool_spec is not None:
budget = tool_spec.typical_cost_usd
time_days = tool_spec.typical_runtime_hours / 24.0
return (budget, time_days)
return _BASE_ACTION_COSTS.get(action.action_type, (0.0, 0.0))
@dataclass
class TransitionResult:
"""Bundle returned by the transition engine after one step."""
next_state: FullLatentState
output: IntermediateOutput
reward_components: Dict[str, float] = field(default_factory=dict)
hard_violations: List[str] = field(default_factory=list)
soft_violations: List[str] = field(default_factory=list)
done: bool = False
class TransitionEngine:
"""Applies one action to the latent state, producing the next state
and a simulated intermediate output.
The engine delegates output generation to ``OutputGenerator`` and
constraint checking to external rule engines (injected at call time).
"""
def __init__(self, noise: NoiseModel):
self.noise = noise
self.output_gen = OutputGenerator(noise)
def step(
self,
state: FullLatentState,
action: ExperimentAction,
*,
hard_violations: Optional[List[str]] = None,
soft_violations: Optional[List[str]] = None,
) -> TransitionResult:
s = deepcopy(state)
s.step_count += 1
step_idx = s.step_count
hard_v = hard_violations or []
soft_v = soft_violations or []
if hard_v:
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=step_idx,
success=False,
summary=f"Action blocked: {'; '.join(hard_v)}",
)
return TransitionResult(
next_state=s,
output=output,
hard_violations=hard_v,
soft_violations=soft_v,
)
self._apply_resource_cost(s, action)
if s.resources.budget_exhausted or s.resources.time_exhausted:
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=step_idx,
success=False,
summary="Resources exhausted",
)
return TransitionResult(
next_state=s, output=output, done=True,
hard_violations=["resources_exhausted"],
)
self._update_progress(s, action)
output = self.output_gen.generate(action, s, step_idx)
if soft_v:
output.quality_score *= 0.5
output.warnings.extend(soft_v)
self._propagate_artifacts(s, action, output)
done = action.action_type == ActionType.SYNTHESIZE_CONCLUSION
return TransitionResult(
next_state=s,
output=output,
soft_violations=soft_v,
done=done,
)
# ── internals ───────────────────────────────────────────────────────
def _apply_resource_cost(
self, s: FullLatentState, action: ExperimentAction
) -> None:
budget_cost, time_cost = compute_action_cost(action)
s.resources.budget_used += budget_cost
s.resources.time_used_days += time_cost
if action.action_type in {
ActionType.RUN_QC, ActionType.FILTER_DATA,
ActionType.NORMALIZE_DATA, ActionType.INTEGRATE_BATCHES,
ActionType.CLUSTER_CELLS, ActionType.DIFFERENTIAL_EXPRESSION,
ActionType.TRAJECTORY_ANALYSIS, ActionType.PATHWAY_ENRICHMENT,
ActionType.REGULATORY_NETWORK_INFERENCE, ActionType.MARKER_SELECTION,
}:
s.resources.compute_hours_used += time_cost * 8
def _update_progress(
self, s: FullLatentState, action: ExperimentAction
) -> None:
at = action.action_type
p = s.progress
_MAP = {
ActionType.COLLECT_SAMPLE: "samples_collected",
ActionType.SELECT_COHORT: "cohort_selected",
ActionType.PREPARE_LIBRARY: "library_prepared",
ActionType.CULTURE_CELLS: "cells_cultured",
ActionType.PERTURB_GENE: "perturbation_applied",
ActionType.PERTURB_COMPOUND: "perturbation_applied",
ActionType.SEQUENCE_CELLS: "cells_sequenced",
ActionType.RUN_QC: "qc_performed",
ActionType.FILTER_DATA: "data_filtered",
ActionType.NORMALIZE_DATA: "data_normalized",
ActionType.INTEGRATE_BATCHES: "batches_integrated",
ActionType.CLUSTER_CELLS: "cells_clustered",
ActionType.DIFFERENTIAL_EXPRESSION: "de_performed",
ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred",
ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed",
ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
ActionType.MARKER_SELECTION: "markers_discovered",
ActionType.VALIDATE_MARKER: "markers_validated",
ActionType.DESIGN_FOLLOWUP: "followup_designed",
ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested",
ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
}
flag = _MAP.get(at)
if flag:
setattr(p, flag, True)
if at == ActionType.COLLECT_SAMPLE:
n = action.parameters.get("n_samples", 6)
s.resources.samples_available += n
if at == ActionType.SEQUENCE_CELLS:
s.resources.sequencing_lanes_used += 1
p.n_cells_sequenced = self.noise.sample_count(
s.biology.n_true_cells * s.technical.capture_efficiency
)
if at in {ActionType.PERTURB_GENE, ActionType.PERTURB_COMPOUND}:
self._apply_perturbation_effects(s, action)
if at == ActionType.FILTER_DATA:
retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
base = p.n_cells_sequenced or s.biology.n_true_cells
p.n_cells_after_filter = max(100, int(base * retain))
s.last_retain_frac = retain
if at == ActionType.CLUSTER_CELLS:
n_true = len(s.biology.cell_populations) or 5
p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
s.last_n_clusters = p.n_clusters_found
def _apply_perturbation_effects(
self, s: FullLatentState, action: ExperimentAction
) -> None:
"""Fold perturbation-specific gene effects into true_de_genes so
downstream DE analysis reflects the perturbed biology."""
target = action.parameters.get("target", "")
effects = s.biology.perturbation_effects.get(target, {})
if not effects:
return
# Efficiency drawn from the same distribution as the output handler
# so latent state and observable output are coherent.
if action.action_type == ActionType.PERTURB_GENE:
efficiency = self.noise.sample_qc_metric(0.80, 0.12, 0.0, 1.0)
else:
efficiency = self.noise.sample_qc_metric(0.70, 0.15, 0.0, 1.0)
s.last_perturbation_efficiency = efficiency
for gene_map in s.biology.true_de_genes.values():
for gene, delta in effects.items():
gene_map[gene] = gene_map.get(gene, 0.0) + delta * efficiency
def _propagate_artifacts(
self,
s: FullLatentState,
action: ExperimentAction,
output: IntermediateOutput,
) -> None:
if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
top = output.data.get("top_genes", [])
s.discovered_de_genes = [g["gene"] for g in top[:20]]
s.progress.n_de_genes_found = output.data.get("n_significant", 0)
if action.action_type == ActionType.CLUSTER_CELLS:
s.discovered_clusters = output.data.get("cluster_names", [])
if action.action_type == ActionType.MARKER_SELECTION:
s.progress.n_markers_found = output.data.get("n_candidates", 0)