| | """Transition dynamics engine β the heart of the biological simulator.
|
| |
|
| | Orchestrates latent-state updates, output generation, resource accounting,
|
| | and constraint propagation for every agent action.
|
| | """
|
| |
|
| | from __future__ import annotations
|
| |
|
| | from copy import deepcopy
|
| | from dataclasses import dataclass, field
|
| | from typing import Any, Dict, List, Optional, Tuple
|
| |
|
| | from models import (
|
| | ActionType,
|
| | ExperimentAction,
|
| | IntermediateOutput,
|
| | OutputType,
|
| | TOOL_REGISTRY,
|
| | )
|
| |
|
| | from .latent_state import FullLatentState
|
| | from .noise import NoiseModel
|
| | from .output_generator import OutputGenerator
|
| |
|
| |
|
| |
|
| | _BASE_ACTION_COSTS: Dict[ActionType, Tuple[float, float]] = {
|
| | ActionType.COLLECT_SAMPLE: (5_000, 7.0),
|
| | ActionType.SELECT_COHORT: ( 500, 1.0),
|
| | ActionType.PREPARE_LIBRARY: (8_000, 3.0),
|
| | ActionType.CULTURE_CELLS: (3_000, 14.0),
|
| | ActionType.PERTURB_GENE: (2_000, 3.0),
|
| | ActionType.PERTURB_COMPOUND: (1_000, 2.0),
|
| | ActionType.SEQUENCE_CELLS: (15_000, 5.0),
|
| | ActionType.RUN_QC: ( 100, 0.5),
|
| | ActionType.FILTER_DATA: ( 50, 0.25),
|
| | ActionType.NORMALIZE_DATA: ( 50, 0.25),
|
| | ActionType.INTEGRATE_BATCHES: ( 100, 0.5),
|
| | ActionType.CLUSTER_CELLS: ( 100, 0.5),
|
| | ActionType.DIFFERENTIAL_EXPRESSION: ( 100, 0.5),
|
| | ActionType.TRAJECTORY_ANALYSIS: ( 200, 1.0),
|
| | ActionType.PATHWAY_ENRICHMENT: ( 100, 0.5),
|
| | ActionType.REGULATORY_NETWORK_INFERENCE: ( 300, 1.0),
|
| | ActionType.MARKER_SELECTION: ( 100, 0.5),
|
| | ActionType.VALIDATE_MARKER: (5_000, 14.0),
|
| | ActionType.DESIGN_FOLLOWUP: ( 100, 0.5),
|
| | ActionType.REQUEST_SUBAGENT_REVIEW: ( 50, 0.25),
|
| | ActionType.SYNTHESIZE_CONCLUSION: ( 0, 0.5),
|
| | }
|
| |
|
| |
|
| | ACTION_COSTS = _BASE_ACTION_COSTS
|
| |
|
| |
|
| | def compute_action_cost(action: ExperimentAction) -> Tuple[float, float]:
|
| | """Return (budget_cost, time_cost_days) for an action.
|
| |
|
| | If the action specifies a ``method`` that exists in ``TOOL_REGISTRY``,
|
| | the tool's ``typical_cost_usd`` and ``typical_runtime_hours`` are used
|
| | (converted to days). Otherwise we fall back to the per-ActionType base
|
| | cost table.
|
| | """
|
| | tool_spec = TOOL_REGISTRY.get(action.method or "")
|
| | if tool_spec is not None:
|
| | budget = tool_spec.typical_cost_usd
|
| | time_days = tool_spec.typical_runtime_hours / 24.0
|
| | return (budget, time_days)
|
| | return _BASE_ACTION_COSTS.get(action.action_type, (0.0, 0.0))
|
| |
|
| |
|
| | @dataclass
|
| | class TransitionResult:
|
| | """Bundle returned by the transition engine after one step."""
|
| |
|
| | next_state: FullLatentState
|
| | output: IntermediateOutput
|
| | reward_components: Dict[str, float] = field(default_factory=dict)
|
| | hard_violations: List[str] = field(default_factory=list)
|
| | soft_violations: List[str] = field(default_factory=list)
|
| | done: bool = False
|
| |
|
| |
|
| | class TransitionEngine:
|
| | """Applies one action to the latent state, producing the next state
|
| | and a simulated intermediate output.
|
| |
|
| | The engine delegates output generation to ``OutputGenerator`` and
|
| | constraint checking to external rule engines (injected at call time).
|
| | """
|
| |
|
| | def __init__(self, noise: NoiseModel):
|
| | self.noise = noise
|
| | self.output_gen = OutputGenerator(noise)
|
| |
|
| | def step(
|
| | self,
|
| | state: FullLatentState,
|
| | action: ExperimentAction,
|
| | *,
|
| | hard_violations: Optional[List[str]] = None,
|
| | soft_violations: Optional[List[str]] = None,
|
| | ) -> TransitionResult:
|
| | s = deepcopy(state)
|
| | s.step_count += 1
|
| | step_idx = s.step_count
|
| |
|
| | hard_v = hard_violations or []
|
| | soft_v = soft_violations or []
|
| |
|
| | if hard_v:
|
| | output = IntermediateOutput(
|
| | output_type=OutputType.FAILURE_REPORT,
|
| | step_index=step_idx,
|
| | success=False,
|
| | summary=f"Action blocked: {'; '.join(hard_v)}",
|
| | )
|
| | return TransitionResult(
|
| | next_state=s,
|
| | output=output,
|
| | hard_violations=hard_v,
|
| | soft_violations=soft_v,
|
| | )
|
| |
|
| | self._apply_resource_cost(s, action)
|
| |
|
| | if s.resources.budget_exhausted or s.resources.time_exhausted:
|
| | output = IntermediateOutput(
|
| | output_type=OutputType.FAILURE_REPORT,
|
| | step_index=step_idx,
|
| | success=False,
|
| | summary="Resources exhausted",
|
| | )
|
| | return TransitionResult(
|
| | next_state=s, output=output, done=True,
|
| | hard_violations=["resources_exhausted"],
|
| | )
|
| |
|
| | self._update_progress(s, action)
|
| |
|
| | output = self.output_gen.generate(action, s, step_idx)
|
| |
|
| | if soft_v:
|
| | output.quality_score *= 0.5
|
| | output.warnings.extend(soft_v)
|
| |
|
| | self._propagate_artifacts(s, action, output)
|
| |
|
| | done = action.action_type == ActionType.SYNTHESIZE_CONCLUSION
|
| |
|
| | return TransitionResult(
|
| | next_state=s,
|
| | output=output,
|
| | soft_violations=soft_v,
|
| | done=done,
|
| | )
|
| |
|
| |
|
| |
|
| | def _apply_resource_cost(
|
| | self, s: FullLatentState, action: ExperimentAction
|
| | ) -> None:
|
| | budget_cost, time_cost = compute_action_cost(action)
|
| | s.resources.budget_used += budget_cost
|
| | s.resources.time_used_days += time_cost
|
| | if action.action_type in {
|
| | ActionType.RUN_QC, ActionType.FILTER_DATA,
|
| | ActionType.NORMALIZE_DATA, ActionType.INTEGRATE_BATCHES,
|
| | ActionType.CLUSTER_CELLS, ActionType.DIFFERENTIAL_EXPRESSION,
|
| | ActionType.TRAJECTORY_ANALYSIS, ActionType.PATHWAY_ENRICHMENT,
|
| | ActionType.REGULATORY_NETWORK_INFERENCE, ActionType.MARKER_SELECTION,
|
| | }:
|
| | s.resources.compute_hours_used += time_cost * 8
|
| |
|
| | def _update_progress(
|
| | self, s: FullLatentState, action: ExperimentAction
|
| | ) -> None:
|
| | at = action.action_type
|
| | p = s.progress
|
| | _MAP = {
|
| | ActionType.COLLECT_SAMPLE: "samples_collected",
|
| | ActionType.SELECT_COHORT: "cohort_selected",
|
| | ActionType.PREPARE_LIBRARY: "library_prepared",
|
| | ActionType.CULTURE_CELLS: "cells_cultured",
|
| | ActionType.PERTURB_GENE: "perturbation_applied",
|
| | ActionType.PERTURB_COMPOUND: "perturbation_applied",
|
| | ActionType.SEQUENCE_CELLS: "cells_sequenced",
|
| | ActionType.RUN_QC: "qc_performed",
|
| | ActionType.FILTER_DATA: "data_filtered",
|
| | ActionType.NORMALIZE_DATA: "data_normalized",
|
| | ActionType.INTEGRATE_BATCHES: "batches_integrated",
|
| | ActionType.CLUSTER_CELLS: "cells_clustered",
|
| | ActionType.DIFFERENTIAL_EXPRESSION: "de_performed",
|
| | ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred",
|
| | ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed",
|
| | ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred",
|
| | ActionType.MARKER_SELECTION: "markers_discovered",
|
| | ActionType.VALIDATE_MARKER: "markers_validated",
|
| | ActionType.DESIGN_FOLLOWUP: "followup_designed",
|
| | ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested",
|
| | ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached",
|
| | }
|
| | flag = _MAP.get(at)
|
| | if flag:
|
| | setattr(p, flag, True)
|
| |
|
| | if at == ActionType.COLLECT_SAMPLE:
|
| | n = action.parameters.get("n_samples", 6)
|
| | s.resources.samples_available += n
|
| |
|
| | if at == ActionType.SEQUENCE_CELLS:
|
| | s.resources.sequencing_lanes_used += 1
|
| | p.n_cells_sequenced = self.noise.sample_count(
|
| | s.biology.n_true_cells * s.technical.capture_efficiency
|
| | )
|
| |
|
| | if at in {ActionType.PERTURB_GENE, ActionType.PERTURB_COMPOUND}:
|
| | self._apply_perturbation_effects(s, action)
|
| |
|
| | if at == ActionType.FILTER_DATA:
|
| | retain = self.noise.sample_qc_metric(0.85, 0.05, 0.5, 1.0)
|
| | base = p.n_cells_sequenced or s.biology.n_true_cells
|
| | p.n_cells_after_filter = max(100, int(base * retain))
|
| | s.last_retain_frac = retain
|
| |
|
| | if at == ActionType.CLUSTER_CELLS:
|
| | n_true = len(s.biology.cell_populations) or 5
|
| | p.n_clusters_found = self.noise.sample_cluster_count(n_true, 0.8)
|
| | s.last_n_clusters = p.n_clusters_found
|
| |
|
| | def _apply_perturbation_effects(
|
| | self, s: FullLatentState, action: ExperimentAction
|
| | ) -> None:
|
| | """Fold perturbation-specific gene effects into true_de_genes so
|
| | downstream DE analysis reflects the perturbed biology."""
|
| | target = action.parameters.get("target", "")
|
| | effects = s.biology.perturbation_effects.get(target, {})
|
| | if not effects:
|
| | return
|
| |
|
| |
|
| | if action.action_type == ActionType.PERTURB_GENE:
|
| | efficiency = self.noise.sample_qc_metric(0.80, 0.12, 0.0, 1.0)
|
| | else:
|
| | efficiency = self.noise.sample_qc_metric(0.70, 0.15, 0.0, 1.0)
|
| | s.last_perturbation_efficiency = efficiency
|
| | for gene_map in s.biology.true_de_genes.values():
|
| | for gene, delta in effects.items():
|
| | gene_map[gene] = gene_map.get(gene, 0.0) + delta * efficiency
|
| |
|
| | def _propagate_artifacts(
|
| | self,
|
| | s: FullLatentState,
|
| | action: ExperimentAction,
|
| | output: IntermediateOutput,
|
| | ) -> None:
|
| | if action.action_type == ActionType.DIFFERENTIAL_EXPRESSION:
|
| | top = output.data.get("top_genes", [])
|
| | s.discovered_de_genes = [g["gene"] for g in top[:20]]
|
| | s.progress.n_de_genes_found = output.data.get("n_significant", 0)
|
| |
|
| | if action.action_type == ActionType.CLUSTER_CELLS:
|
| | s.discovered_clusters = output.data.get("cluster_names", [])
|
| |
|
| | if action.action_type == ActionType.MARKER_SELECTION:
|
| | s.progress.n_markers_found = output.data.get("n_candidates", 0)
|
| |
|