Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Literature-grounded experiment benchmark utilities. | |
| This module lets the environment run a paper-backed experiment plan, then | |
| compare the resulting simulated findings against curated expected findings | |
| from the literature. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| from dataclasses import asdict, dataclass, field | |
| from importlib.metadata import PackageNotFoundError, version | |
| from typing import Any, Dict, List, Optional, Sequence | |
| from models import ( | |
| ActionType, | |
| ConclusionClaim, | |
| ExperimentAction, | |
| ExperimentObservation, | |
| OutputType, | |
| TaskSpec, | |
| ) | |
| from server.hackathon_environment import BioExperimentEnvironment | |
| from server.tasks.scenarios import SCENARIO_LIBRARY, Scenario | |
| TOKEN_RE = re.compile(r"[A-Za-z0-9_+\-]+") | |
| STOPWORDS = { | |
| "a", | |
| "an", | |
| "and", | |
| "as", | |
| "by", | |
| "for", | |
| "from", | |
| "in", | |
| "into", | |
| "of", | |
| "on", | |
| "or", | |
| "the", | |
| "to", | |
| "using", | |
| "with", | |
| } | |
| BIO_LIBRARY_DISTRIBUTIONS = { | |
| "scanpy": "scanpy", | |
| "gseapy": "gseapy", | |
| "biopython": "biopython", | |
| } | |
| class PaperBenchmarkResult: | |
| scenario_name: str | |
| problem_statement: str | |
| matched_papers: List[str] | |
| bio_library_versions: Dict[str, Optional[str]] | |
| matched_findings: List[str] = field(default_factory=list) | |
| missed_findings: List[str] = field(default_factory=list) | |
| discovered_markers: List[str] = field(default_factory=list) | |
| candidate_mechanisms: List[str] = field(default_factory=list) | |
| conclusions: List[str] = field(default_factory=list) | |
| final_reward: float = 0.0 | |
| total_steps: int = 0 | |
| def match_ratio(self) -> float: | |
| total = len(self.matched_findings) + len(self.missed_findings) | |
| return len(self.matched_findings) / max(total, 1) | |
| def to_dict(self) -> Dict[str, Any]: | |
| data = asdict(self) | |
| data["match_ratio"] = self.match_ratio | |
| return data | |
| def detect_bio_library_versions() -> Dict[str, Optional[str]]: | |
| versions: Dict[str, Optional[str]] = {} | |
| for name, dist_name in BIO_LIBRARY_DISTRIBUTIONS.items(): | |
| try: | |
| versions[name] = version(dist_name) | |
| except PackageNotFoundError: | |
| versions[name] = None | |
| return versions | |
| def select_literature_scenario(problem_statement: str) -> Scenario: | |
| """Pick the closest literature-backed scenario for a prompt.""" | |
| prompt_tokens = set(_tokenize(problem_statement)) | |
| best_score = -1 | |
| best_scenario: Optional[Scenario] = None | |
| for scenario in SCENARIO_LIBRARY: | |
| if not scenario.task.paper_references: | |
| continue | |
| corpus = [ | |
| scenario.task.problem_statement, | |
| *(ref.title for ref in scenario.task.paper_references), | |
| *(finding.finding for finding in scenario.task.expected_findings), | |
| scenario.task.tissue, | |
| scenario.task.modality, | |
| *scenario.task.conditions, | |
| ] | |
| score = len(prompt_tokens & set(_tokenize(" ".join(corpus)))) | |
| if scenario.task.problem_statement.lower() in problem_statement.lower(): | |
| score += 4 | |
| if score > best_score: | |
| best_score = score | |
| best_scenario = scenario | |
| if best_scenario is None: | |
| raise ValueError("No literature-backed scenarios are available.") | |
| return best_scenario | |
| def run_paper_benchmark( | |
| *, | |
| problem_statement: str, | |
| scenario_name: Optional[str] = None, | |
| domain_randomise: bool = False, | |
| ) -> PaperBenchmarkResult: | |
| """Run a literature-backed episode and compare outputs to paper results.""" | |
| scenario = _resolve_scenario(problem_statement, scenario_name) | |
| env = BioExperimentEnvironment( | |
| scenario_name=scenario.name, | |
| domain_randomise=domain_randomise, | |
| ) | |
| obs = env.reset() | |
| for action in build_paper_aligned_actions(obs.task): | |
| obs = env.step(action) | |
| claims = infer_conclusion_claims(obs) | |
| obs = env.step( | |
| ExperimentAction( | |
| action_type=ActionType.SYNTHESIZE_CONCLUSION, | |
| parameters={"claims": [claim.model_dump() for claim in claims]}, | |
| justification=( | |
| "Summarize the simulated experimental evidence and compare it " | |
| "with the paper-backed expected findings." | |
| ), | |
| confidence=0.8, | |
| tool_call_spec=_tool_context( | |
| obs.task, | |
| libraries=["biopython"], | |
| ), | |
| ) | |
| ) | |
| matched, missed = compare_expected_findings(obs.task, obs) | |
| return PaperBenchmarkResult( | |
| scenario_name=scenario.name, | |
| problem_statement=obs.task.problem_statement, | |
| matched_papers=[ref.title for ref in obs.task.paper_references], | |
| bio_library_versions=detect_bio_library_versions(), | |
| matched_findings=matched, | |
| missed_findings=missed, | |
| discovered_markers=list(obs.discovered_markers), | |
| candidate_mechanisms=list(obs.candidate_mechanisms), | |
| conclusions=[c.claim for c in obs.conclusions], | |
| final_reward=float(obs.metadata.get("cumulative_reward", 0.0)), | |
| total_steps=obs.step_index, | |
| ) | |
| def build_paper_aligned_actions(task: TaskSpec) -> List[ExperimentAction]: | |
| """Construct a pragmatic analysis plan aligned to the task modality.""" | |
| actions: List[ExperimentAction] = [ | |
| ExperimentAction( | |
| action_type=ActionType.COLLECT_SAMPLE, | |
| parameters={"n_samples": 8}, | |
| justification="Collect enough samples to support downstream analysis.", | |
| confidence=0.75, | |
| tool_call_spec=_tool_context(task, libraries=["biopython"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.PREPARE_LIBRARY, | |
| method="10x_chromium", | |
| justification="Use a standard single-cell library prep workflow.", | |
| confidence=0.8, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.SEQUENCE_CELLS, | |
| method="NovaSeq", | |
| justification="Generate sufficient single-cell read depth.", | |
| confidence=0.8, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.RUN_QC, | |
| method="scanpy.pp.calculate_qc_metrics", | |
| justification="Check technical quality before downstream inference.", | |
| confidence=0.85, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.FILTER_DATA, | |
| method="scanpy.pp.filter_cells", | |
| justification="Remove low-quality cells and reduce technical noise.", | |
| confidence=0.85, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.NORMALIZE_DATA, | |
| method="scanpy.pp.normalize_total", | |
| justification="Normalize expression to prepare comparable profiles.", | |
| confidence=0.85, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.CLUSTER_CELLS, | |
| method="scanpy.tl.leiden", | |
| justification="Resolve cell states before focused interpretation.", | |
| confidence=0.8, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ] | |
| categories = {finding.category for finding in task.expected_findings} | |
| if "trajectory" in categories: | |
| actions.extend([ | |
| ExperimentAction( | |
| action_type=ActionType.TRAJECTORY_ANALYSIS, | |
| method="scanpy.tl.dpt", | |
| justification="Recover pseudotime structure and lineage branches.", | |
| confidence=0.8, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.REGULATORY_NETWORK_INFERENCE, | |
| method="pySCENIC", | |
| justification="Infer branch-associated regulators from the trajectory.", | |
| confidence=0.75, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.MARKER_SELECTION, | |
| method="scanpy.tl.rank_genes_groups", | |
| justification="Summarize lineage markers and branch-state genes.", | |
| confidence=0.75, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ]) | |
| return actions | |
| actions.extend([ | |
| ExperimentAction( | |
| action_type=ActionType.DIFFERENTIAL_EXPRESSION, | |
| method="scanpy.tl.rank_genes_groups", | |
| parameters={"comparison": _default_comparison_name(task)}, | |
| justification="Identify genes associated with the focal phenotype.", | |
| confidence=0.85, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.PATHWAY_ENRICHMENT, | |
| method="gseapy.prerank", | |
| justification="Translate DE hits into pathway-level interpretation.", | |
| confidence=0.8, | |
| tool_call_spec=_tool_context(task, libraries=["gseapy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.MARKER_SELECTION, | |
| method="scanpy.tl.rank_genes_groups", | |
| justification="Nominate candidate markers for follow-up validation.", | |
| confidence=0.8, | |
| tool_call_spec=_tool_context(task, libraries=["scanpy"]), | |
| ), | |
| ExperimentAction( | |
| action_type=ActionType.VALIDATE_MARKER, | |
| method="immunofluorescence", | |
| parameters={"marker": _preferred_marker(task)}, | |
| justification="Check whether the leading marker reproduces in validation.", | |
| confidence=0.75, | |
| tool_call_spec=_tool_context(task, libraries=["biopython"]), | |
| ), | |
| ]) | |
| return actions | |
| def infer_conclusion_claims(obs: ExperimentObservation) -> List[ConclusionClaim]: | |
| """Turn accumulated evidence into concise, paper-comparable claims.""" | |
| markers = set(obs.discovered_markers) | |
| mechanisms = set(obs.candidate_mechanisms) | |
| network_regulators = set(_extract_network_regulators(obs)) | |
| trajectory_output = _latest_output_data(obs, OutputType.TRAJECTORY_RESULT) | |
| claims: List[ConclusionClaim] = [] | |
| if "SPP1" in markers: | |
| claims.append(ConclusionClaim( | |
| claim="SPP1-positive macrophages are enriched in IPF fibrotic tissue.", | |
| confidence=0.84, | |
| claim_type="marker", | |
| evidence_steps=_evidence_steps(obs, { | |
| OutputType.DE_RESULT, | |
| OutputType.MARKER_RESULT, | |
| OutputType.VALIDATION_RESULT, | |
| }), | |
| )) | |
| if {"SPP1", "MERTK"} <= markers: | |
| claims.append(ConclusionClaim( | |
| claim="MERTK co-occurs with the SPP1-positive profibrotic macrophage state.", | |
| confidence=0.8, | |
| claim_type="marker", | |
| evidence_steps=_evidence_steps(obs, { | |
| OutputType.DE_RESULT, | |
| OutputType.MARKER_RESULT, | |
| }), | |
| )) | |
| if "extracellular_matrix_organisation" in mechanisms: | |
| claims.append(ConclusionClaim( | |
| claim=( | |
| "Extracellular matrix organization is a dominant fibrotic " | |
| "program in the IPF samples." | |
| ), | |
| confidence=0.78, | |
| claim_type="pathway", | |
| evidence_steps=_evidence_steps(obs, {OutputType.PATHWAY_RESULT}), | |
| )) | |
| if trajectory_output.get("branching_detected"): | |
| claims.append(ConclusionClaim( | |
| claim=( | |
| "Trajectory analysis recovered branching blood lineages rooted " | |
| "in HSCs." | |
| ), | |
| confidence=0.82, | |
| claim_type="trajectory", | |
| evidence_steps=_evidence_steps(obs, {OutputType.TRAJECTORY_RESULT}), | |
| )) | |
| if "GATA1" in network_regulators: | |
| claims.append(ConclusionClaim( | |
| claim="GATA1 emerges as a driver of erythroid fate commitment.", | |
| confidence=0.8, | |
| claim_type="regulatory_network", | |
| evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}), | |
| )) | |
| if {"CEBPA", "SPI1"} & network_regulators: | |
| claims.append(ConclusionClaim( | |
| claim="CEBPA and SPI1 support myeloid branch decisions.", | |
| confidence=0.78, | |
| claim_type="regulatory_network", | |
| evidence_steps=_evidence_steps(obs, {OutputType.NETWORK_RESULT}), | |
| )) | |
| return claims | |
| def compare_expected_findings( | |
| task: TaskSpec, | |
| obs: ExperimentObservation, | |
| ) -> tuple[List[str], List[str]]: | |
| """Compare the episode evidence against literature-backed findings.""" | |
| evidence_text = _evidence_text(obs) | |
| matched: List[str] = [] | |
| missed: List[str] = [] | |
| for finding in task.expected_findings: | |
| keywords = [kw.lower() for kw in finding.keywords] | |
| if not keywords: | |
| keywords = _tokenize(finding.finding) | |
| hits = sum(1 for kw in keywords if kw in evidence_text) | |
| threshold = max(1, (len(keywords) + 1) // 2) | |
| if hits >= threshold: | |
| matched.append(finding.finding) | |
| else: | |
| missed.append(finding.finding) | |
| return matched, missed | |
| def _resolve_scenario( | |
| problem_statement: str, | |
| scenario_name: Optional[str], | |
| ) -> Scenario: | |
| if scenario_name: | |
| for scenario in SCENARIO_LIBRARY: | |
| if scenario.name == scenario_name: | |
| return scenario | |
| raise ValueError(f"Unknown scenario_name '{scenario_name}'.") | |
| return select_literature_scenario(problem_statement) | |
| def _tool_context( | |
| task: TaskSpec, | |
| *, | |
| libraries: Sequence[str], | |
| include_expected_findings: bool = False, | |
| ) -> Dict[str, Any]: | |
| context: Dict[str, Any] = { | |
| "literature_query": task.problem_statement, | |
| "paper_references": [ | |
| { | |
| "title": ref.title, | |
| "doi": ref.doi, | |
| "pmid": ref.pmid, | |
| "url": ref.url, | |
| } | |
| for ref in task.paper_references | |
| ], | |
| "bioinformatics_libraries": list(libraries), | |
| } | |
| if include_expected_findings: | |
| context["expected_findings"] = [ | |
| finding.finding for finding in task.expected_findings | |
| ] | |
| return context | |
| def _default_comparison_name(task: TaskSpec) -> str: | |
| conditions = {condition.lower() for condition in task.conditions} | |
| if {"healthy", "ipf"} <= conditions: | |
| return "IPF_vs_healthy" | |
| if any("treated" in condition for condition in conditions) and any( | |
| "untreated" in condition for condition in conditions | |
| ): | |
| return "treated_vs_untreated" | |
| if any("healthy" in condition for condition in conditions): | |
| return "disease_vs_healthy" | |
| return "disease_vs_healthy" | |
| def _preferred_marker(task: TaskSpec) -> str: | |
| """Derive a candidate marker from the problem statement, not expected findings.""" | |
| tokens = [t for t in TOKEN_RE.findall(task.problem_statement) if t.isupper() and len(t) >= 3] | |
| if tokens: | |
| return tokens[0] | |
| return "unknown" | |
| def _latest_output_data( | |
| obs: ExperimentObservation, | |
| output_type: OutputType, | |
| ) -> Dict[str, Any]: | |
| for output in reversed(obs.all_outputs): | |
| if output.output_type == output_type: | |
| return output.data | |
| return {} | |
| def _extract_network_regulators(obs: ExperimentObservation) -> List[str]: | |
| for output in reversed(obs.all_outputs): | |
| if output.output_type == OutputType.NETWORK_RESULT: | |
| return output.data.get("top_regulators", []) | |
| return [] | |
| def _evidence_steps( | |
| obs: ExperimentObservation, | |
| output_types: set[OutputType], | |
| ) -> List[int]: | |
| return [ | |
| output.step_index | |
| for output in obs.all_outputs | |
| if output.output_type in output_types | |
| ] | |
| def _evidence_text(obs: ExperimentObservation) -> str: | |
| parts: List[str] = [] | |
| parts.extend(obs.discovered_markers) | |
| parts.extend(obs.candidate_mechanisms) | |
| parts.extend(conclusion.claim for conclusion in obs.conclusions) | |
| for output in obs.all_outputs: | |
| parts.append(output.summary) | |
| if output.output_type == OutputType.DE_RESULT: | |
| parts.extend( | |
| gene["gene"] | |
| for gene in output.data.get("top_genes", []) | |
| if isinstance(gene, dict) and "gene" in gene | |
| ) | |
| elif output.output_type == OutputType.PATHWAY_RESULT: | |
| parts.extend( | |
| pathway["pathway"] | |
| for pathway in output.data.get("top_pathways", []) | |
| if isinstance(pathway, dict) and "pathway" in pathway | |
| ) | |
| elif output.output_type == OutputType.NETWORK_RESULT: | |
| parts.extend(output.data.get("top_regulators", [])) | |
| elif output.output_type == OutputType.TRAJECTORY_RESULT: | |
| if output.data.get("branching_detected"): | |
| parts.append("branching lineage HSC trajectory") | |
| return " ".join(parts).lower() | |
| def _tokenize(text: str) -> List[str]: | |
| return [ | |
| token.lower() | |
| for token in TOKEN_RE.findall(text) | |
| if token and token.lower() not in STOPWORDS | |
| ] | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--problem-statement", | |
| default=( | |
| "Design a follow-up validation experiment for candidate biomarker " | |
| "SPP1 in idiopathic pulmonary fibrosis." | |
| ), | |
| ) | |
| parser.add_argument("--scenario-name", default=None) | |
| parser.add_argument("--domain-randomise", action="store_true") | |
| parser.add_argument("--json", action="store_true") | |
| args = parser.parse_args() | |
| result = run_paper_benchmark( | |
| problem_statement=args.problem_statement, | |
| scenario_name=args.scenario_name, | |
| domain_randomise=args.domain_randomise, | |
| ) | |
| if args.json: | |
| print(json.dumps(result.to_dict(), indent=2)) | |
| return | |
| print(f"Scenario: {result.scenario_name}") | |
| print(f"Problem: {result.problem_statement}") | |
| print(f"Paper: {', '.join(result.matched_papers)}") | |
| print(f"Match ratio: {result.match_ratio:.2%}") | |
| print(f"Matched findings: {len(result.matched_findings)}") | |
| print(f"Missed findings: {len(result.missed_findings)}") | |
| print(f"Discovered markers: {', '.join(result.discovered_markers[:8])}") | |
| print(f"Candidate mechanisms: {', '.join(result.candidate_mechanisms[:5])}") | |
| print(f"Conclusions: {len(result.conclusions)}") | |
| print(f"Final reward: {result.final_reward:+.3f}") | |
| print(f"Bio libraries: {json.dumps(result.bio_library_versions, sort_keys=True)}") | |
| if __name__ == "__main__": | |
| main() | |