| | """Decomposable reward function for the bio-experiment planning POMDP.
|
| |
|
| | Reward components
|
| | βββββββββββββββββ
|
| | r_validity β biological validity of the chosen action
|
| | r_ordering β correct ordering of experiment steps
|
| | r_info_gain β information gain from the step's output
|
| | r_efficiency β resource efficiency (budget & time normalised)
|
| | r_novelty β bonus for non-redundant, non-trivial actions
|
| | r_penalty β penalties for violations, redundancy, waste
|
| | r_terminal β terminal quality & calibration against hidden truth
|
| |
|
| | Potential-based shaping
|
| | Ο(s) β progress potential used for dense shaping signal
|
| |
|
| | The final step reward is:
|
| | R_t = r_validity + r_ordering + r_info_gain + r_efficiency
|
| | + r_novelty + r_penalty + [Ο(s_{t+1}) β Ο(s_t)]
|
| |
|
| | The terminal reward adds:
|
| | R_T += r_terminal
|
| | """
|
| |
|
| | from __future__ import annotations
|
| |
|
| | from dataclasses import dataclass, field
|
| | from typing import Dict, List, Optional
|
| |
|
| | from models import (
|
| | ActionType,
|
| | ConclusionClaim,
|
| | ExperimentAction,
|
| | IntermediateOutput,
|
| | META_ACTIONS,
|
| | TOOL_REGISTRY,
|
| | WET_LAB_ACTIONS,
|
| | )
|
| |
|
| | from server.biology.gene_index import (
|
| | marker_set_score,
|
| | mechanism_set_score,
|
| | score_pathways,
|
| | )
|
| | from server.simulator.latent_state import FullLatentState
|
| |
|
| |
|
| | @dataclass
|
| | class RewardBreakdown:
|
| | validity: float = 0.0
|
| | ordering: float = 0.0
|
| | info_gain: float = 0.0
|
| | efficiency: float = 0.0
|
| | novelty: float = 0.0
|
| | penalty: float = 0.0
|
| | shaping: float = 0.0
|
| | terminal: float = 0.0
|
| | components: Dict[str, float] = field(default_factory=dict)
|
| |
|
| | @property
|
| | def total(self) -> float:
|
| | return (
|
| | self.validity
|
| | + self.ordering
|
| | + self.info_gain
|
| | + self.efficiency
|
| | + self.novelty
|
| | + self.penalty
|
| | + self.shaping
|
| | + self.terminal
|
| | )
|
| |
|
| | def to_dict(self) -> Dict[str, float]:
|
| | d = {
|
| | "validity": self.validity,
|
| | "ordering": self.ordering,
|
| | "info_gain": self.info_gain,
|
| | "efficiency": self.efficiency,
|
| | "novelty": self.novelty,
|
| | "penalty": self.penalty,
|
| | "shaping": self.shaping,
|
| | "terminal": self.terminal,
|
| | "total": self.total,
|
| | }
|
| | d.update(self.components)
|
| | return d
|
| |
|
| |
|
| | class RewardComputer:
|
| | """Computes step-wise and terminal rewards.
|
| |
|
| | Parameters
|
| | ----------
|
| | efficiency_weight : float
|
| | Relative importance of resource efficiency.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | efficiency_weight: float = 0.3,
|
| | info_gain_weight: float = 0.4,
|
| | validity_weight: float = 0.3,
|
| | ):
|
| | self.w_eff = efficiency_weight
|
| | self.w_ig = info_gain_weight
|
| | self.w_val = validity_weight
|
| |
|
| |
|
| |
|
| | def step_reward(
|
| | self,
|
| | action: ExperimentAction,
|
| | prev_state: FullLatentState,
|
| | next_state: FullLatentState,
|
| | output: IntermediateOutput,
|
| | hard_violations: List[str],
|
| | soft_violations: List[str],
|
| | ) -> RewardBreakdown:
|
| | rb = RewardBreakdown()
|
| |
|
| |
|
| | if hard_violations:
|
| | rb.validity = -1.0
|
| | rb.penalty = -0.5 * len(hard_violations)
|
| | rb.components["hard_violations"] = len(hard_violations)
|
| | return rb
|
| |
|
| | rb.validity = self.w_val * (1.0 if output.success else 0.0)
|
| |
|
| | ordering_score = self._ordering_score(action, prev_state)
|
| | rb.ordering = 0.2 * ordering_score
|
| | if ordering_score < 0:
|
| | rb.penalty += ordering_score * 0.3
|
| |
|
| |
|
| | rb.info_gain = self.w_ig * output.quality_score * (1.0 - output.uncertainty)
|
| | if action.action_type in META_ACTIONS and not (
|
| | prev_state.progress.de_performed
|
| | or prev_state.progress.cells_clustered
|
| | ):
|
| |
|
| | rb.info_gain *= 0.2
|
| |
|
| |
|
| | budget_frac = (
|
| | (next_state.resources.budget_used - prev_state.resources.budget_used)
|
| | / max(next_state.resources.budget_total, 1)
|
| | )
|
| | rb.efficiency = self.w_eff * max(0.0, 1.0 - 5.0 * budget_frac)
|
| |
|
| |
|
| | if not soft_violations:
|
| | rb.novelty = 0.1
|
| |
|
| |
|
| | tool_fit = self._tool_fit_score(action, prev_state)
|
| | rb.components["tool_fit"] = tool_fit
|
| | rb.validity += 0.15 * tool_fit
|
| |
|
| |
|
| | rb.penalty = -0.15 * len(soft_violations)
|
| | if action.action_type in META_ACTIONS and not (
|
| | prev_state.progress.de_performed
|
| | or prev_state.progress.cells_clustered
|
| | ):
|
| | rb.penalty -= 0.25
|
| | rb.components["premature_meta_action_penalty"] = -0.25
|
| |
|
| |
|
| |
|
| | phi_prev = self._potential(prev_state)
|
| | phi_next = self._potential(next_state)
|
| | rb.shaping = phi_next - phi_prev
|
| |
|
| | return rb
|
| |
|
| |
|
| |
|
| | def terminal_reward(
|
| | self,
|
| | state: FullLatentState,
|
| | conclusions: List[ConclusionClaim],
|
| | task_success_criteria: List[str],
|
| | discovered_markers: Optional[List[str]] = None,
|
| | candidate_mechanisms: Optional[List[str]] = None,
|
| | ) -> RewardBreakdown:
|
| | rb = RewardBreakdown()
|
| | discovered_markers = discovered_markers or []
|
| | candidate_mechanisms = candidate_mechanisms or []
|
| |
|
| |
|
| | completeness = self._completeness(state)
|
| | rb.components["completeness"] = completeness
|
| |
|
| |
|
| | calibration = self._calibration(state, conclusions)
|
| | rb.components["calibration"] = calibration
|
| |
|
| |
|
| | budget_eff = state.resources.budget_remaining / max(
|
| | state.resources.budget_total, 1
|
| | )
|
| | time_eff = state.resources.time_remaining_days / max(
|
| | state.resources.time_limit_days, 1
|
| | )
|
| | rb.components["budget_efficiency"] = budget_eff
|
| | rb.components["time_efficiency"] = time_eff
|
| |
|
| |
|
| | overconf = self._overconfidence_penalty(state, conclusions)
|
| | rb.components["overconfidence_penalty"] = overconf
|
| |
|
| | discovery_alignment = self._discovery_alignment(
|
| | state,
|
| | discovered_markers,
|
| | candidate_mechanisms,
|
| | )
|
| | discovery_error_penalty = -6.0 * (1.0 - discovery_alignment)
|
| | if discovery_alignment < 0.25:
|
| | discovery_error_penalty -= 2.0
|
| | rb.components["discovery_alignment"] = discovery_alignment
|
| | rb.components["discovery_error_penalty"] = discovery_error_penalty
|
| |
|
| | conclusion_alignment = self._conclusion_alignment(state, conclusions)
|
| | conclusion_error_penalty = -4.0 * (1.0 - conclusion_alignment)
|
| | if conclusions and conclusion_alignment < 0.25:
|
| | conclusion_error_penalty -= 1.5
|
| | rb.components["conclusion_alignment"] = conclusion_alignment
|
| | rb.components["conclusion_error_penalty"] = conclusion_error_penalty
|
| |
|
| | eff_bonus = (budget_eff + time_eff) / 2.0 if completeness >= 0.3 else 0.0
|
| | rb.terminal = (
|
| | 3.0 * completeness
|
| | + 4.0 * calibration
|
| | + 1.0 * eff_bonus
|
| | + overconf
|
| | + discovery_error_penalty
|
| | + conclusion_error_penalty
|
| | )
|
| | return rb
|
| |
|
| |
|
| |
|
| | def _ordering_score(
|
| | self, action: ExperimentAction, s: FullLatentState
|
| | ) -> float:
|
| | """Heuristic: 1.0 if natural next, 0.3 if acceptable, -1.0 if premature."""
|
| | at = action.action_type
|
| | p = s.progress
|
| | NATURAL_NEXT = {
|
| | ActionType.COLLECT_SAMPLE: not p.samples_collected,
|
| | ActionType.PREPARE_LIBRARY: p.samples_collected and not p.library_prepared,
|
| | ActionType.SEQUENCE_CELLS: p.library_prepared and not p.cells_sequenced,
|
| | ActionType.RUN_QC: p.cells_sequenced and not p.qc_performed,
|
| | ActionType.FILTER_DATA: p.qc_performed and not p.data_filtered,
|
| | ActionType.NORMALIZE_DATA: p.data_filtered and not p.data_normalized,
|
| | ActionType.CLUSTER_CELLS: p.data_normalized and not p.cells_clustered,
|
| | ActionType.DIFFERENTIAL_EXPRESSION: p.data_normalized and not p.de_performed,
|
| | ActionType.PATHWAY_ENRICHMENT: p.de_performed and not p.pathways_analyzed,
|
| | ActionType.MARKER_SELECTION: p.de_performed and not p.markers_discovered,
|
| | ActionType.VALIDATE_MARKER: p.markers_discovered and not p.markers_validated,
|
| | ActionType.SYNTHESIZE_CONCLUSION: (
|
| | p.de_performed or p.cells_clustered
|
| | ) and not p.conclusion_reached,
|
| | }
|
| | if NATURAL_NEXT.get(at, False):
|
| | return 1.0
|
| |
|
| | has_evidence = any([
|
| | p.cells_clustered, p.de_performed, p.trajectories_inferred,
|
| | p.pathways_analyzed, p.networks_inferred, p.markers_discovered,
|
| | ])
|
| | if at in META_ACTIONS and not has_evidence:
|
| | return -1.0
|
| |
|
| | return 0.3
|
| |
|
| | def _potential(self, s: FullLatentState) -> float:
|
| | """Progress potential Ο(s) β counts completed milestones.
|
| |
|
| | Returns 0.0 at terminal states so that the shaping signal
|
| | telescopes correctly over the episode.
|
| | """
|
| | if s.progress.conclusion_reached:
|
| | return 0.0
|
| | p = s.progress
|
| | milestones = [
|
| | p.samples_collected,
|
| | p.library_prepared,
|
| | p.cells_sequenced,
|
| | p.qc_performed,
|
| | p.data_filtered,
|
| | p.data_normalized,
|
| | p.cells_clustered,
|
| | p.de_performed,
|
| | p.pathways_analyzed,
|
| | p.markers_discovered,
|
| | p.markers_validated,
|
| | p.conclusion_reached,
|
| | ]
|
| | return sum(milestones) / len(milestones)
|
| |
|
| | def _completeness(self, s: FullLatentState) -> float:
|
| | p = s.progress
|
| | core = [
|
| | p.samples_collected,
|
| | p.cells_sequenced,
|
| | p.qc_performed,
|
| | p.data_filtered,
|
| | p.data_normalized,
|
| | p.de_performed or p.cells_clustered,
|
| | p.conclusion_reached,
|
| | ]
|
| | return sum(core) / len(core)
|
| |
|
| | def _calibration(
|
| | self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| | ) -> float:
|
| | """Structured set-similarity calibration against hidden ground truth.
|
| |
|
| | Uses pathway-weighted Gaussian similarity for markers, semantic
|
| | similarity for mechanisms, and activity-weighted matching for pathways.
|
| | Falls back to legacy substring matching when structured fields are empty.
|
| | """
|
| | if not conclusions:
|
| | return 0.0
|
| |
|
| | pred_markers = [g for c in conclusions for g in c.top_markers]
|
| | pred_mechs = [m for c in conclusions for m in c.causal_mechanisms]
|
| | pred_pathways = {
|
| | p: v for c in conclusions for p, v in c.predicted_pathways.items()
|
| | }
|
| |
|
| | has_structured = bool(pred_markers or pred_mechs or pred_pathways)
|
| |
|
| | if has_structured:
|
| | m_score = marker_set_score(pred_markers, s.biology.true_markers)
|
| | mech_score = mechanism_set_score(
|
| | pred_mechs, s.biology.causal_mechanisms
|
| | )
|
| | pw_score = score_pathways(pred_pathways, s.biology.true_pathways)
|
| | return 0.50 * m_score + 0.35 * mech_score + 0.15 * pw_score
|
| |
|
| | return self._legacy_calibration(s, conclusions)
|
| |
|
| | @staticmethod
|
| | def _legacy_calibration(
|
| | s: FullLatentState, conclusions: List[ConclusionClaim]
|
| | ) -> float:
|
| | """Substring-based calibration kept for backward compatibility."""
|
| | true_mechanisms = set(s.biology.causal_mechanisms)
|
| | true_markers = set(s.biology.true_markers)
|
| | score = 0.0
|
| | n = len(conclusions)
|
| |
|
| | for c in conclusions:
|
| | claim_lower = c.claim.lower()
|
| | match = any(m.lower() in claim_lower for m in true_mechanisms)
|
| | marker_match = any(m.lower() in claim_lower for m in true_markers)
|
| | if match or marker_match:
|
| | score += 1.0
|
| | else:
|
| | score -= 0.3
|
| | return max(0.0, min(1.0, score / max(n, 1)))
|
| |
|
| | _METHOD_TO_TOOL: Dict[str, str] = {
|
| | "scanpy.pp.calculate_qc_metrics": "Scanpy",
|
| | "scanpy.pp.filter_cells": "Scanpy",
|
| | "scanpy.pp.filter_genes": "Scanpy",
|
| | "scanpy.pp.normalize_total": "Scanpy",
|
| | "scanpy.pp.log1p": "Scanpy",
|
| | "scanpy.pp.highly_variable_genes": "Scanpy",
|
| | "scanpy.pp.neighbors": "Scanpy",
|
| | "scanpy.tl.leiden": "Leiden",
|
| | "scanpy.tl.louvain": "Louvain",
|
| | "scanpy.tl.rank_genes_groups": "Scanpy",
|
| | "scanpy.tl.paga": "PAGA",
|
| | "scanpy.tl.umap": "UMAP",
|
| | "gseapy.prerank": "Scanpy",
|
| | "gseapy.gsea": "Scanpy",
|
| | "10x_chromium": "CellRanger",
|
| | "NovaSeq": "CellRanger",
|
| | }
|
| |
|
| | @staticmethod
|
| | def _tool_fit_score(
|
| | action: ExperimentAction, s: FullLatentState
|
| | ) -> float:
|
| | """Score how well the chosen tool matches the task modality.
|
| |
|
| | Returns +1.0 for a perfect match, 0.0 if no tool specified,
|
| | -1.0 for a known tool used on an incompatible modality.
|
| | """
|
| | method = action.method
|
| | if not method:
|
| | return 0.0
|
| | resolved = RewardComputer._METHOD_TO_TOOL.get(method, method)
|
| | tool_spec = TOOL_REGISTRY.get(resolved)
|
| | if tool_spec is None:
|
| | return -0.5
|
| | modality = getattr(s, "task_modality", None)
|
| | if not modality or not tool_spec.modalities:
|
| | return 0.0
|
| | if modality in tool_spec.modalities:
|
| | return 1.0
|
| | return -1.0
|
| |
|
| | def _overconfidence_penalty(
|
| | self, s: FullLatentState, conclusions: List[ConclusionClaim]
|
| | ) -> float:
|
| | """Penalise high-confidence claims that disagree with ground truth.
|
| |
|
| | Checks structured fields (top_markers, causal_mechanisms) first;
|
| | falls back to claim substring matching for backward compatibility.
|
| | """
|
| | penalty = 0.0
|
| | true_markers_lower = {m.lower() for m in s.biology.true_markers}
|
| | true_mechs_lower = {m.lower() for m in s.biology.causal_mechanisms}
|
| | true_set = true_markers_lower | true_mechs_lower
|
| |
|
| | for c in conclusions:
|
| | if c.confidence <= 0.8:
|
| | continue
|
| |
|
| | has_structured = bool(c.top_markers or c.causal_mechanisms)
|
| | if has_structured:
|
| | marker_hit = any(
|
| | g.upper().strip() in {m.upper() for m in s.biology.true_markers}
|
| | for g in c.top_markers
|
| | )
|
| | mech_hit = any(
|
| | any(kw in m.lower() for kw in t.lower().split())
|
| | for m in c.causal_mechanisms
|
| | for t in s.biology.causal_mechanisms
|
| | )
|
| | is_correct = marker_hit or mech_hit
|
| | else:
|
| | is_correct = any(t in c.claim.lower() for t in true_set)
|
| |
|
| | if not is_correct:
|
| | penalty -= 0.5 * c.confidence
|
| |
|
| | return penalty
|
| |
|
| | def _discovery_alignment(
|
| | self,
|
| | s: FullLatentState,
|
| | discovered_markers: List[str],
|
| | candidate_mechanisms: List[str],
|
| | ) -> float:
|
| | """Symmetric end-of-episode similarity for discovered biology.
|
| |
|
| | Forward scoring measures recall against hidden truth. Reverse scoring
|
| | measures how well the agent's discoveries map back onto real biology,
|
| | which penalizes extra hallucinated markers or mechanisms.
|
| | """
|
| | components: List[float] = []
|
| |
|
| | if s.biology.true_markers or discovered_markers:
|
| | marker_recall = marker_set_score(
|
| | discovered_markers,
|
| | s.biology.true_markers,
|
| | )
|
| | marker_precision = marker_set_score(
|
| | s.biology.true_markers,
|
| | discovered_markers,
|
| | )
|
| | components.append((marker_recall + marker_precision) / 2.0)
|
| |
|
| | if s.biology.causal_mechanisms or candidate_mechanisms:
|
| | mechanism_recall = mechanism_set_score(
|
| | candidate_mechanisms,
|
| | s.biology.causal_mechanisms,
|
| | )
|
| | mechanism_precision = mechanism_set_score(
|
| | s.biology.causal_mechanisms,
|
| | candidate_mechanisms,
|
| | )
|
| | components.append((mechanism_recall + mechanism_precision) / 2.0)
|
| |
|
| | if not components:
|
| | return 1.0
|
| | return sum(components) / len(components)
|
| |
|
| | def _conclusion_alignment(
|
| | self,
|
| | s: FullLatentState,
|
| | conclusions: List[ConclusionClaim],
|
| | ) -> float:
|
| | if not conclusions:
|
| | return 0.0
|
| |
|
| | pred_markers = [marker for conclusion in conclusions for marker in conclusion.top_markers]
|
| | pred_mechanisms = [
|
| | mechanism
|
| | for conclusion in conclusions
|
| | for mechanism in conclusion.causal_mechanisms
|
| | ]
|
| |
|
| | if not pred_markers and not pred_mechanisms:
|
| | return self._legacy_calibration(s, conclusions)
|
| |
|
| | components: List[float] = []
|
| | if s.biology.true_markers or pred_markers:
|
| | marker_recall = marker_set_score(pred_markers, s.biology.true_markers)
|
| | marker_precision = marker_set_score(s.biology.true_markers, pred_markers)
|
| | components.append((marker_recall + marker_precision) / 2.0)
|
| |
|
| | if s.biology.causal_mechanisms or pred_mechanisms:
|
| | mechanism_recall = mechanism_set_score(
|
| | pred_mechanisms,
|
| | s.biology.causal_mechanisms,
|
| | )
|
| | mechanism_precision = mechanism_set_score(
|
| | s.biology.causal_mechanisms,
|
| | pred_mechanisms,
|
| | )
|
| | components.append((mechanism_recall + mechanism_precision) / 2.0)
|
| |
|
| | if not components:
|
| | return 1.0
|
| | return sum(components) / len(components)
|
| |
|