Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Procedural scenario generator. | |
| Composes biologically coherent ``Scenario`` objects from the curated | |
| palette in ``bio_palette``, producing fully populated | |
| ``LatentBiologicalState`` instances that drive every simulator tool | |
| (clustering, DE, pathway enrichment, trajectory, regulatory networks, | |
| marker validation) with realistic intermediate outputs. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| from models import TaskSpec | |
| from server.simulator.latent_state import ( | |
| CellPopulation, | |
| LatentBiologicalState, | |
| TechnicalState, | |
| ) | |
| from .bio_palette import ( | |
| DISEASE_PROFILES, | |
| HIDDEN_FAILURE_TEMPLATES, | |
| PATHWAY_LIBRARY, | |
| PERTURBATION_TEMPLATES, | |
| REGULATORY_TEMPLATES, | |
| TISSUE_CELL_TYPES, | |
| TRAJECTORY_TEMPLATES, | |
| CellTypeTemplate, | |
| DiseaseProfile, | |
| ) | |
| from .scenarios import Scenario | |
| logger = logging.getLogger(__name__) | |
| SCENARIO_TYPES = ("de", "trajectory", "perturbation", "biomarker") | |
| _DIFFICULTY_PARAMS = { | |
| "easy": { | |
| "n_pops": (4, 5), | |
| "de_scale": (1.2, 1.6), | |
| "noise_dropout": (0.05, 0.10), | |
| "noise_doublet": (0.03, 0.06), | |
| "noise_ambient": (0.02, 0.05), | |
| "noise_batch_strength": (0.05, 0.12), | |
| "n_batches": (1, 2), | |
| "budget_range": (70_000, 100_000), | |
| "time_range": (100, 150), | |
| "sample_quality": (0.85, 0.95), | |
| "include_trajectory": False, | |
| "include_perturbation": False, | |
| "include_network": False, | |
| "include_failure_conditions": False, | |
| }, | |
| "medium": { | |
| "n_pops": (5, 7), | |
| "de_scale": (0.9, 1.3), | |
| "noise_dropout": (0.08, 0.14), | |
| "noise_doublet": (0.04, 0.08), | |
| "noise_ambient": (0.03, 0.07), | |
| "noise_batch_strength": (0.08, 0.18), | |
| "n_batches": (1, 3), | |
| "budget_range": (80_000, 120_000), | |
| "time_range": (120, 180), | |
| "sample_quality": (0.78, 0.92), | |
| "include_trajectory": True, | |
| "include_perturbation": False, | |
| "include_network": True, | |
| "include_failure_conditions": False, | |
| }, | |
| "hard": { | |
| "n_pops": (6, 8), | |
| "de_scale": (0.6, 1.0), | |
| "noise_dropout": (0.10, 0.20), | |
| "noise_doublet": (0.06, 0.12), | |
| "noise_ambient": (0.05, 0.10), | |
| "noise_batch_strength": (0.12, 0.25), | |
| "n_batches": (2, 4), | |
| "budget_range": (90_000, 140_000), | |
| "time_range": (140, 200), | |
| "sample_quality": (0.65, 0.85), | |
| "include_trajectory": True, | |
| "include_perturbation": True, | |
| "include_network": True, | |
| "include_failure_conditions": True, | |
| }, | |
| } | |
| def generate_scenario( | |
| seed: int, | |
| difficulty: str = "medium", | |
| scenario_type: Optional[str] = None, | |
| ) -> Scenario: | |
| """Generate a single procedural scenario with complete latent state. | |
| Parameters | |
| ---------- | |
| seed | |
| RNG seed for reproducibility. | |
| difficulty | |
| One of ``"easy"``, ``"medium"``, ``"hard"``. | |
| scenario_type | |
| One of ``"de"``, ``"trajectory"``, ``"perturbation"``, | |
| ``"biomarker"``, or ``None`` for random selection. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| params = _DIFFICULTY_PARAMS[difficulty] | |
| if scenario_type is None: | |
| scenario_type = rng.choice(SCENARIO_TYPES) | |
| disease_key = rng.choice(list(DISEASE_PROFILES.keys())) | |
| disease = DISEASE_PROFILES[disease_key] | |
| tissue = disease.tissue | |
| cell_templates = TISSUE_CELL_TYPES.get(tissue, []) | |
| if not cell_templates: | |
| tissue = rng.choice(list(TISSUE_CELL_TYPES.keys())) | |
| cell_templates = TISSUE_CELL_TYPES[tissue] | |
| populations = _sample_populations(rng, cell_templates, disease, params) | |
| de_genes = _build_de_genes(rng, disease, params) | |
| pathways = _build_pathways(rng, disease) | |
| markers = _derive_markers(rng, de_genes, disease) | |
| mechanisms = list(disease.mechanism_templates) | |
| n_cells = int(rng.integers(8_000, 22_000)) | |
| trajectory = None | |
| if scenario_type == "trajectory" or ( | |
| params["include_trajectory"] and rng.random() < 0.4 | |
| ): | |
| trajectory = _build_trajectory(rng, tissue, populations) | |
| reg_network: Dict[str, List[str]] = {} | |
| if scenario_type == "trajectory" or ( | |
| params["include_network"] and rng.random() < 0.5 | |
| ): | |
| reg_network = _build_regulatory_network(rng, tissue, populations) | |
| perturbation_effects: Dict[str, Dict[str, float]] = {} | |
| if scenario_type == "perturbation" or ( | |
| params["include_perturbation"] and rng.random() < 0.5 | |
| ): | |
| perturbation_effects = _build_perturbation(rng, disease) | |
| technical = _build_technical(rng, params) | |
| hidden_failures: List[str] = [] | |
| if params["include_failure_conditions"] and rng.random() < 0.6: | |
| n_failures = int(rng.integers(1, 3)) | |
| indices = rng.choice( | |
| len(HIDDEN_FAILURE_TEMPLATES), size=min(n_failures, len(HIDDEN_FAILURE_TEMPLATES)), replace=False, | |
| ) | |
| hidden_failures = [HIDDEN_FAILURE_TEMPLATES[i] for i in indices] | |
| task = _build_task(rng, disease, tissue, scenario_type, params, perturbation_effects) | |
| biology = LatentBiologicalState( | |
| cell_populations=populations, | |
| true_de_genes=de_genes, | |
| true_pathways=pathways, | |
| true_trajectory=trajectory, | |
| true_regulatory_network=reg_network, | |
| perturbation_effects=perturbation_effects, | |
| true_markers=markers, | |
| causal_mechanisms=mechanisms, | |
| n_true_cells=n_cells, | |
| ) | |
| name = f"proc_{disease.name}_{scenario_type}_{seed}" | |
| tags = [scenario_type, "scRNA-seq", tissue, disease.name, difficulty] | |
| return Scenario( | |
| name=name, | |
| task=task, | |
| biology=biology, | |
| technical=technical, | |
| hidden_failure_conditions=hidden_failures, | |
| difficulty=difficulty, | |
| tags=tags, | |
| ) | |
| def generate_procedural_scenarios( | |
| n: int = 20, | |
| seed: int = 42, | |
| ) -> List[Scenario]: | |
| """Pre-generate a pool of procedural scenarios across difficulties.""" | |
| rng = np.random.default_rng(seed) | |
| scenarios: List[Scenario] = [] | |
| difficulties = ["easy", "medium", "hard"] | |
| for i in range(n): | |
| diff = difficulties[i % len(difficulties)] | |
| child_seed = int(rng.integers(0, 2**31)) | |
| scenario = generate_scenario( | |
| seed=child_seed, | |
| difficulty=diff, | |
| scenario_type=None, | |
| ) | |
| scenarios.append(scenario) | |
| logger.info("Generated %d procedural scenarios.", len(scenarios)) | |
| return scenarios | |
| # ── Internal builders ─────────────────────────────────────────────────────── | |
| def _sample_populations( | |
| rng: np.random.Generator, | |
| templates: List[CellTypeTemplate], | |
| disease: DiseaseProfile, | |
| params: dict, | |
| ) -> List[CellPopulation]: | |
| lo, hi = params["n_pops"] | |
| n_pops = int(rng.integers(lo, hi + 1)) | |
| n_pops = min(n_pops, len(templates)) | |
| indices = rng.choice(len(templates), size=n_pops, replace=False) | |
| selected = [templates[i] for i in sorted(indices)] | |
| responding_names = set(disease.responding_cell_types) | |
| populations: List[CellPopulation] = [] | |
| for tmpl in selected: | |
| prop = float(rng.uniform(*tmpl.proportion_range)) | |
| state = rng.choice(tmpl.states) | |
| condition_response: Dict[str, float] = {} | |
| if tmpl.disease_responsive and tmpl.name in responding_names: | |
| condition_response[disease.condition_name] = float( | |
| rng.uniform(*tmpl.response_range) | |
| ) | |
| populations.append(CellPopulation( | |
| name=tmpl.name, | |
| proportion=prop, | |
| marker_genes=list(tmpl.marker_genes), | |
| state=state, | |
| condition_response=condition_response, | |
| )) | |
| total = sum(p.proportion for p in populations) | |
| if total > 0: | |
| for p in populations: | |
| p.proportion = round(p.proportion / total, 4) | |
| return populations | |
| def _build_de_genes( | |
| rng: np.random.Generator, | |
| disease: DiseaseProfile, | |
| params: dict, | |
| ) -> Dict[str, Dict[str, float]]: | |
| comparison = f"{disease.condition_name}_vs_healthy" | |
| scale_lo, scale_hi = params["de_scale"] | |
| effects: Dict[str, float] = {} | |
| for gene, (lo, hi) in disease.de_genes.items(): | |
| base = float(rng.uniform(lo, hi)) | |
| scale = float(rng.uniform(scale_lo, scale_hi)) | |
| if base > 0: | |
| effects[gene] = round(base * scale, 3) | |
| else: | |
| effects[gene] = round(base * scale, 3) | |
| return {comparison: effects} | |
| def _build_pathways( | |
| rng: np.random.Generator, | |
| disease: DiseaseProfile, | |
| ) -> Dict[str, float]: | |
| pathways: Dict[str, float] = {} | |
| for pw, (lo, hi) in disease.pathways.items(): | |
| pathways[pw] = round(float(rng.uniform(lo, hi)), 3) | |
| return pathways | |
| def _derive_markers( | |
| rng: np.random.Generator, | |
| de_genes: Dict[str, Dict[str, float]], | |
| disease: DiseaseProfile, | |
| ) -> List[str]: | |
| markers = list(disease.markers) | |
| all_effects: Dict[str, float] = {} | |
| for effects in de_genes.values(): | |
| all_effects.update(effects) | |
| for gene in markers: | |
| if gene not in all_effects: | |
| all_effects[gene] = float(rng.uniform(1.0, 2.5)) | |
| for comp_effects in de_genes.values(): | |
| comp_effects[gene] = all_effects[gene] | |
| n_markers = min(len(markers), int(rng.integers(3, 7))) | |
| return markers[:n_markers] | |
| def _build_trajectory( | |
| rng: np.random.Generator, | |
| tissue: str, | |
| populations: List[CellPopulation], | |
| ) -> Optional[Dict[str, Any]]: | |
| pop_names = {p.name for p in populations} | |
| for tmpl in TRAJECTORY_TEMPLATES: | |
| if tmpl.tissue == tissue: | |
| valid_branches = [ | |
| branch for branch in tmpl.branches | |
| if all(node in pop_names for node in branch) | |
| ] | |
| if valid_branches: | |
| return { | |
| "root": tmpl.root_population, | |
| "n_lineages": len(valid_branches), | |
| "branching": len(valid_branches) > 1, | |
| "branches": valid_branches, | |
| } | |
| if len(populations) >= 3: | |
| root = populations[0].name | |
| branches = [[root, p.name] for p in populations[1:]] | |
| selected = branches[:int(rng.integers(2, min(4, len(branches)) + 1))] | |
| return { | |
| "root": root, | |
| "n_lineages": len(selected), | |
| "branching": len(selected) > 1, | |
| "branches": selected, | |
| } | |
| return None | |
| def _build_regulatory_network( | |
| rng: np.random.Generator, | |
| tissue: str, | |
| populations: List[CellPopulation], | |
| ) -> Dict[str, List[str]]: | |
| all_genes = set() | |
| for p in populations: | |
| all_genes.update(p.marker_genes) | |
| network: Dict[str, List[str]] = {} | |
| tissue_to_programs = { | |
| "bone_marrow": ["erythroid", "myeloid", "stem_cell"], | |
| "thymus": ["lymphoid"], | |
| "blood": ["lymphoid", "myeloid"], | |
| "spleen": ["lymphoid"], | |
| "brain": ["neuronal", "inflammatory"], | |
| "heart": ["fibrotic", "inflammatory"], | |
| "lung": ["fibrotic", "inflammatory"], | |
| "liver": ["fibrotic", "inflammatory"], | |
| "kidney": ["fibrotic", "inflammatory"], | |
| "colon": ["inflammatory", "stem_cell"], | |
| "pancreas": ["inflammatory"], | |
| "skin": ["inflammatory"], | |
| "breast": ["inflammatory"], | |
| "synovium": ["inflammatory", "lymphoid"], | |
| "aorta": ["inflammatory"], | |
| } | |
| programs = tissue_to_programs.get(tissue, ["inflammatory"]) | |
| for prog_name in programs: | |
| prog = REGULATORY_TEMPLATES.get(prog_name, {}) | |
| for tf, targets in prog.items(): | |
| network[tf] = list(targets) | |
| if not network: | |
| for p in populations[:2]: | |
| if len(p.marker_genes) >= 2: | |
| tf = p.marker_genes[0] | |
| network[tf] = p.marker_genes[1:] | |
| return network | |
| def _build_perturbation( | |
| rng: np.random.Generator, | |
| disease: DiseaseProfile, | |
| ) -> Dict[str, Dict[str, float]]: | |
| disease_pathways = set(disease.pathways.keys()) | |
| matching = [ | |
| (name, tmpl) for name, tmpl in PERTURBATION_TEMPLATES.items() | |
| if tmpl.target_pathway in disease_pathways | |
| ] | |
| if matching: | |
| name, tmpl = matching[int(rng.integers(0, len(matching)))] | |
| else: | |
| name = rng.choice(list(PERTURBATION_TEMPLATES.keys())) | |
| tmpl = PERTURBATION_TEMPLATES[name] | |
| scaled: Dict[str, float] = {} | |
| for gene, effect in tmpl.gene_effects.items(): | |
| scale = float(rng.uniform(0.7, 1.3)) | |
| scaled[gene] = round(effect * scale, 3) | |
| return {name: scaled} | |
| def _build_technical( | |
| rng: np.random.Generator, | |
| params: dict, | |
| ) -> TechnicalState: | |
| n_batches = int(rng.integers(*params["n_batches"])) | |
| batch_effects: Dict[str, float] = {} | |
| for i in range(max(1, n_batches)): | |
| strength = float(rng.uniform(*params["noise_batch_strength"])) | |
| batch_effects[f"batch_{i}"] = round(strength, 3) | |
| return TechnicalState( | |
| batch_effects=batch_effects, | |
| dropout_rate=round(float(rng.uniform(*params["noise_dropout"])), 3), | |
| doublet_rate=round(float(rng.uniform(*params["noise_doublet"])), 3), | |
| ambient_rna_fraction=round(float(rng.uniform(*params["noise_ambient"])), 3), | |
| sample_quality=round(float(rng.uniform(*params["sample_quality"])), 3), | |
| ) | |
| def _build_task( | |
| rng: np.random.Generator, | |
| disease: DiseaseProfile, | |
| tissue: str, | |
| scenario_type: str, | |
| params: dict, | |
| perturbation_effects: Dict[str, Dict[str, float]], | |
| ) -> TaskSpec: | |
| budget = float(rng.integers(*params["budget_range"])) | |
| time_days = float(rng.integers(*params["time_range"])) | |
| if scenario_type == "de": | |
| problem = ( | |
| f"Identify differentially expressed genes between " | |
| f"{disease.display_name} and healthy {tissue} tissue " | |
| f"using single-cell RNA sequencing." | |
| ) | |
| criteria = [ | |
| f"Identify DE genes between {disease.condition_name} and healthy", | |
| "Validate at least one candidate marker", | |
| ] | |
| elif scenario_type == "trajectory": | |
| problem = ( | |
| f"Infer the developmental trajectory of cell populations " | |
| f"in {tissue} tissue in the context of {disease.display_name}." | |
| ) | |
| criteria = [ | |
| "Reconstruct branching lineage structure", | |
| "Identify key transcription factors driving fate decisions", | |
| ] | |
| elif scenario_type == "perturbation": | |
| pert_name = next(iter(perturbation_effects), "treatment") | |
| pert_tmpl = PERTURBATION_TEMPLATES.get(pert_name) | |
| pert_desc = pert_tmpl.description if pert_tmpl else pert_name | |
| problem = ( | |
| f"Determine the effect of {pert_desc} on cell states " | |
| f"in {tissue} tissue affected by {disease.display_name}." | |
| ) | |
| criteria = [ | |
| "Quantify shift in cell activation states", | |
| f"Identify pathways modulated by {pert_name}", | |
| "Propose validation strategy", | |
| ] | |
| else: | |
| top_marker = disease.markers[0] if disease.markers else "candidate" | |
| problem = ( | |
| f"Validate candidate biomarker {top_marker} for " | |
| f"{disease.display_name} in {tissue} tissue using " | |
| f"single-cell RNA sequencing." | |
| ) | |
| criteria = [ | |
| f"Validate {top_marker} as a disease marker", | |
| "Confirm expression specificity across cell types", | |
| ] | |
| conditions = ["healthy", disease.condition_name] | |
| if scenario_type == "perturbation" and perturbation_effects: | |
| pert_name = next(iter(perturbation_effects)) | |
| conditions = [f"untreated_{disease.condition_name}", f"{pert_name}_treated"] | |
| return TaskSpec( | |
| problem_statement=problem, | |
| modality="scRNA-seq", | |
| organism="human", | |
| tissue=tissue, | |
| conditions=conditions, | |
| budget_limit=budget, | |
| time_limit_days=time_days, | |
| success_criteria=criteria, | |
| ) | |