"""Biological rule engine — hard and soft constraint checking. Hard constraints block action execution entirely. Soft constraints allow execution but degrade output quality and incur penalties. """ from __future__ import annotations from dataclasses import dataclass from enum import Enum from typing import List from models import ActionType, ExperimentAction, TOOL_REGISTRY from server.simulator.latent_state import FullLatentState class Severity(str, Enum): HARD = "hard" SOFT = "soft" @dataclass class RuleViolation: rule_id: str severity: Severity message: str class RuleEngine: """Evaluates biological and resource constraints against the current latent state before each action is applied. """ @staticmethod def _has_analysis_evidence(s: FullLatentState) -> bool: p = s.progress return any([ p.cells_clustered, p.de_performed, p.trajectories_inferred, p.pathways_analyzed, p.networks_inferred, p.markers_discovered, p.markers_validated, ]) @staticmethod def _has_marker_evidence(s: FullLatentState) -> bool: p = s.progress return p.markers_discovered or p.markers_validated @staticmethod def _has_mechanism_evidence(s: FullLatentState) -> bool: p = s.progress return p.pathways_analyzed or p.networks_inferred def check( self, action: ExperimentAction, state: FullLatentState ) -> List[RuleViolation]: violations: List[RuleViolation] = [] violations.extend(self._check_prerequisites(action, state)) violations.extend(self._check_resource_constraints(action, state)) violations.extend(self._check_redundancy(action, state)) violations.extend(self._check_causal_validity(action, state)) violations.extend(self._check_tool_compatibility(action, state)) return violations def hard_violations(self, violations: List[RuleViolation]) -> List[str]: return [v.message for v in violations if v.severity == Severity.HARD] def soft_violations(self, violations: List[RuleViolation]) -> List[str]: return [v.message for v in violations if v.severity == Severity.SOFT] # ── prerequisite rules ────────────────────────────────────────────── def _check_prerequisites( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] at = action.action_type p = s.progress REQUIRES = { ActionType.PREPARE_LIBRARY: [ ("samples_collected", "Cannot prepare library without collected samples"), ], ActionType.SEQUENCE_CELLS: [ ("library_prepared", "Cannot sequence without library preparation"), ], ActionType.RUN_QC: [ ("cells_sequenced", "Cannot run QC before sequencing"), ], ActionType.FILTER_DATA: [ ("qc_performed", "Cannot filter data before QC"), ], ActionType.NORMALIZE_DATA: [ ("data_filtered", "Cannot normalise before filtering"), ], ActionType.INTEGRATE_BATCHES: [ ("data_normalized", "Cannot integrate batches before normalisation"), ], ActionType.CLUSTER_CELLS: [ ("data_normalized", "Cannot cluster before normalisation"), ], ActionType.DIFFERENTIAL_EXPRESSION: [ ("data_normalized", "Cannot run DE before normalisation"), ], ActionType.TRAJECTORY_ANALYSIS: [ ("data_normalized", "Cannot infer trajectories before normalisation"), ], ActionType.PATHWAY_ENRICHMENT: [ ("de_performed", "Cannot run pathway enrichment without DE results"), ], ActionType.REGULATORY_NETWORK_INFERENCE: [ ("data_normalized", "Cannot infer networks before normalisation"), ], ActionType.MARKER_SELECTION: [ ("de_performed", "Cannot select markers without DE results"), ], ActionType.VALIDATE_MARKER: [ ("markers_discovered", "Cannot validate markers before discovery"), ], ActionType.PERTURB_GENE: [ ("samples_collected", "Cannot perturb without samples"), ], ActionType.PERTURB_COMPOUND: [ ("samples_collected", "Cannot perturb without samples"), ], ActionType.CULTURE_CELLS: [ ("samples_collected", "Cannot culture without samples"), ], ActionType.SYNTHESIZE_CONCLUSION: [ ("data_normalized", "Cannot synthesize conclusions before data normalization"), ], } for flag, msg in REQUIRES.get(at, []): if not getattr(p, flag, False): vs.append(RuleViolation( rule_id=f"prereq_{at.value}_{flag}", severity=Severity.HARD, message=msg, )) return vs # ── resource constraints ──────────────────────────────────────────── def _check_resource_constraints( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] if s.resources.budget_exhausted: vs.append(RuleViolation( rule_id="budget_exhausted", severity=Severity.HARD, message="Budget exhausted - no further actions possible", )) if s.resources.time_exhausted: vs.append(RuleViolation( rule_id="time_exhausted", severity=Severity.HARD, message="Time limit reached - no further actions possible", )) remaining = s.resources.budget_remaining from server.simulator.transition import compute_action_cost cost, _ = compute_action_cost(action) if cost > remaining and remaining > 0: vs.append(RuleViolation( rule_id="budget_insufficient", severity=Severity.HARD, message=f"Action costs ${cost:,.0f} but only ${remaining:,.0f} remains", )) return vs # ── redundancy checks ─────────────────────────────────────────────── def _check_redundancy( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] at = action.action_type p = s.progress REDUNDANT = { ActionType.COLLECT_SAMPLE: "samples_collected", ActionType.PREPARE_LIBRARY: "library_prepared", ActionType.SEQUENCE_CELLS: "cells_sequenced", ActionType.RUN_QC: "qc_performed", ActionType.FILTER_DATA: "data_filtered", ActionType.NORMALIZE_DATA: "data_normalized", ActionType.CLUSTER_CELLS: "cells_clustered", ActionType.DIFFERENTIAL_EXPRESSION: "de_performed", ActionType.TRAJECTORY_ANALYSIS: "trajectories_inferred", ActionType.PATHWAY_ENRICHMENT: "pathways_analyzed", ActionType.REGULATORY_NETWORK_INFERENCE: "networks_inferred", ActionType.MARKER_SELECTION: "markers_discovered", ActionType.VALIDATE_MARKER: "markers_validated", ActionType.DESIGN_FOLLOWUP: "followup_designed", ActionType.REQUEST_SUBAGENT_REVIEW: "subagent_review_requested", ActionType.SYNTHESIZE_CONCLUSION: "conclusion_reached", } flag = REDUNDANT.get(at) if flag and getattr(p, flag, False): vs.append(RuleViolation( rule_id=f"redundant_{at.value}", severity=Severity.HARD, message=f"Step '{at.value}' already completed — redundant action blocked", )) return vs # ── causal validity ───────────────────────────────────────────────── def _check_causal_validity( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] has_analysis_evidence = self._has_analysis_evidence(s) if action.action_type == ActionType.DESIGN_FOLLOWUP: if not has_analysis_evidence: vs.append(RuleViolation( rule_id="premature_followup_design", severity=Severity.HARD, message=( "Follow-up design without prior analysis is blocked; " "complete wet-lab and computational steps first" ), )) if action.action_type == ActionType.REQUEST_SUBAGENT_REVIEW: if not has_analysis_evidence: vs.append(RuleViolation( rule_id="premature_subagent_review", severity=Severity.HARD, message=( "Subagent review without prior analysis is blocked; " "generate evidence first" ), )) if action.action_type == ActionType.SYNTHESIZE_CONCLUSION: if not s.progress.de_performed and not s.progress.cells_clustered: vs.append(RuleViolation( rule_id="premature_conclusion", severity=Severity.HARD, message="Cannot synthesise conclusion without substantive analysis", )) if not self._has_marker_evidence(s): vs.append(RuleViolation( rule_id="conclusion_without_marker_evidence", severity=Severity.HARD, message="Cannot synthesise conclusion before discovering or validating markers", )) if not self._has_mechanism_evidence(s): vs.append(RuleViolation( rule_id="conclusion_without_mechanism_evidence", severity=Severity.HARD, message="Cannot synthesise conclusion before inferring pathways or mechanisms", )) claims = action.parameters.get("claims", []) for claim in claims: if isinstance(claim, dict) and claim.get("claim_type") == "causal": if not s.progress.markers_validated and not s.progress.networks_inferred: vs.append(RuleViolation( rule_id="unsupported_causal_claim", severity=Severity.SOFT, message="Causal claim without validation or network evidence", )) break if action.action_type == ActionType.PATHWAY_ENRICHMENT: if not s.progress.de_performed: vs.append(RuleViolation( rule_id="pathway_without_de", severity=Severity.SOFT, message="Pathway enrichment without DE may yield unreliable results", )) return vs # ── tool / modality compatibility ──────────────────────────────────── _KNOWN_METHODS = { "scanpy.pp.calculate_qc_metrics", "scanpy.pp.filter_cells", "scanpy.pp.filter_genes", "scanpy.pp.normalize_total", "scanpy.pp.log1p", "scanpy.pp.highly_variable_genes", "scanpy.pp.neighbors", "scanpy.tl.leiden", "scanpy.tl.louvain", "scanpy.tl.rank_genes_groups", "scanpy.tl.paga", "scanpy.tl.umap", "gseapy.prerank", "gseapy.gsea", "10x_chromium", "NovaSeq", } _METHOD_TO_TOOL = { "scanpy.pp.calculate_qc_metrics": "Scanpy", "scanpy.pp.filter_cells": "Scanpy", "scanpy.pp.filter_genes": "Scanpy", "scanpy.pp.normalize_total": "Scanpy", "scanpy.pp.log1p": "Scanpy", "scanpy.pp.highly_variable_genes": "Scanpy", "scanpy.pp.neighbors": "Scanpy", "scanpy.tl.leiden": "Leiden", "scanpy.tl.louvain": "Louvain", "scanpy.tl.rank_genes_groups": "Scanpy", "scanpy.tl.paga": "PAGA", "scanpy.tl.umap": "UMAP", "gseapy.prerank": "Scanpy", "gseapy.gsea": "Scanpy", "10x_chromium": "CellRanger", "NovaSeq": "CellRanger", } def _check_tool_compatibility( self, action: ExperimentAction, s: FullLatentState ) -> List[RuleViolation]: """Warn when the chosen tool is incompatible with the task modality.""" vs: List[RuleViolation] = [] method = action.method if not method: return vs resolved = self._METHOD_TO_TOOL.get(method, method) tool_spec = TOOL_REGISTRY.get(resolved) if tool_spec is None and method not in self._KNOWN_METHODS: vs.append(RuleViolation( rule_id="unknown_tool", severity=Severity.SOFT, message=f"Tool '{method}' is not in the registry — results may be unreliable", )) return vs if tool_spec is None: return vs # Check modality compatibility (modality lives on the task, which is # stored in the latent state's associated TaskSpec — but the latent # state doesn't carry the TaskSpec directly. We can still check via # the action's own context or fall back gracefully). task_modality = getattr(s, "task_modality", None) if task_modality and tool_spec.modalities: if task_modality not in tool_spec.modalities: vs.append(RuleViolation( rule_id="tool_modality_mismatch", severity=Severity.SOFT, message=( f"Tool '{method}' is designed for " f"{', '.join(tool_spec.modalities)} but task modality " f"is '{task_modality}'" ), )) return vs