Spaces:
Runtime error
Runtime error
| """ | |
| decision_engine.py | |
| ================== | |
| Unified decision logic β probability, dominant mechanism, confidence level, | |
| and human-readable interpretation grounded in internal signals. | |
| Prediction is NEVER computed without the explainability analysis being | |
| present first. This module enforces that ordering. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import numpy as np | |
| from explainability_engine import ( | |
| SpliceSignals, | |
| V4Signals, | |
| ClassicSignals, | |
| _classify_risk_tier, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Result dataclass | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DecisionResult: | |
| variant: str | |
| unified_probability: float | |
| risk_tier: str | |
| tier_desc: str | |
| dominant_mechanism: str | |
| confidence: str | |
| # per-model | |
| splice_prob: float | |
| v4_prob: float | |
| classic_prob: float | |
| # XAI engine metrics | |
| mutation_peak_ratio: float | |
| counterfactual_magnitude:float | |
| cross_model_locality: float | |
| signal_concentration: float | |
| explainability_strength: float | |
| activation_pattern: str | |
| # interpretation | |
| splice_analysis: str | |
| protein_analysis: str | |
| agreement_analysis: str | |
| final_explanation: str | |
| # full structured JSON (as string) | |
| report_json: str = field(repr=False) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dominant mechanism logic | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _dominant_mechanism(splice: SpliceSignals, v4: V4Signals, | |
| classic: ClassicSignals, | |
| prob_std: float) -> str: | |
| """ | |
| Determine which mechanism drives the prediction. | |
| Splice-driven: splice model dominates AND aura / splice importance elevated | |
| Protein-driven: v4 + classic > splice probability by meaningful margin | |
| Consensus: all three models agree within 0.10 | |
| Ambiguous: high disagreement (prob_std > 0.12) | |
| """ | |
| probs = np.array([splice.probability, v4.probability, classic.probability]) | |
| p_max = float(probs.max()) | |
| p_min = float(probs.min()) | |
| if prob_std > 0.14: | |
| return "Ambiguous" | |
| # Splice dominance | |
| splice_leads = splice.probability > v4.probability + 0.05 | |
| high_aura = splice.splice_aura_score > 0.35 | |
| high_splice_i = float(splice.splice_imp.max()) > 0.50 | |
| if splice_leads and (high_aura or high_splice_i): | |
| return "Splice-driven" | |
| # Protein dominance | |
| protein_avg = (v4.probability + classic.probability) / 2 | |
| if protein_avg > splice.probability + 0.05 and float(splice.region_imp[0]) > 0.5: | |
| return "Protein-driven" | |
| # Consensus | |
| if p_max - p_min <= 0.10: | |
| return "Consensus" | |
| return "Ambiguous" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Confidence level | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _confidence(unified_prob: float, prob_std: float, | |
| ess: float, cf_mag: float) -> str: | |
| """ | |
| High: strong signal from all three axes | |
| Moderate: partial support | |
| Low: conflicting or weak signals | |
| """ | |
| score = 0 | |
| # Model agreement | |
| if prob_std < 0.05: score += 2 | |
| elif prob_std < 0.12: score += 1 | |
| # Explainability strength | |
| if ess >= 0.65: score += 2 | |
| elif ess >= 0.35: score += 1 | |
| # Counterfactual magnitude | |
| if cf_mag >= 0.25: score += 2 | |
| elif cf_mag >= 0.10: score += 1 | |
| # Probability extremity | |
| if unified_prob >= 0.85 or unified_prob <= 0.15: score += 1 | |
| if score >= 5: return "High" | |
| if score >= 3: return "Moderate" | |
| return "Low" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Human-readable interpretation builders | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _splice_analysis_text(s: SpliceSignals, variant: str) -> str: | |
| parts = [] | |
| tier, _ = _classify_risk_tier(s.probability) | |
| parts.append( | |
| f"The splice model classifies variant {variant} as " | |
| f"'{tier}' with probability {s.probability:.4f}." | |
| ) | |
| if s.mutation_peak_ratio >= 2.0: | |
| parts.append( | |
| f"The conv3 activation peak at the mutation site " | |
| f"(position {s.mutation_pos}) is {s.mutation_peak_ratio:.2f}Γ " | |
| f"the mean window activation, indicating the model is strongly " | |
| f"attending to the mutation location." | |
| ) | |
| elif s.mutation_peak_ratio >= 1.0: | |
| parts.append( | |
| f"The mutation position ({s.mutation_pos}) has above-average " | |
| f"conv3 activation (MPR={s.mutation_peak_ratio:.2f}Γ)." | |
| ) | |
| else: | |
| parts.append( | |
| f"The mutation position ({s.mutation_pos}) does not carry an " | |
| f"elevated conv3 activation peak (MPR={s.mutation_peak_ratio:.2f}Γ), " | |
| f"suggesting the pathogenic signal may arise from broader context." | |
| ) | |
| if s.splice_risk_donor in ("CRITICAL SPLICE SITE", "SPLICE REGION"): | |
| parts.append( | |
| f"The variant lies {s.dist_donor} bp from the nearest GT donor " | |
| f"dinucleotide (risk: {s.splice_risk_donor}). " | |
| f"Splice donor importance score = {float(s.splice_imp[0]):.3f}." | |
| ) | |
| if s.splice_risk_acceptor in ("CRITICAL SPLICE SITE", "SPLICE REGION"): | |
| parts.append( | |
| f"The variant lies {s.dist_acceptor} bp from the nearest AG acceptor " | |
| f"dinucleotide (risk: {s.splice_risk_acceptor}). " | |
| f"Splice acceptor importance score = {float(s.splice_imp[1]):.3f}." | |
| ) | |
| cf = s.counterfactual | |
| if cf.get("probability_range", 0) > 0.15: | |
| parts.append( | |
| f"Counterfactual analysis: swapping the alternate base changes " | |
| f"the pathogenicity probability by up to " | |
| f"{cf['probability_range']:.3f} " | |
| f"(range {cf['min_probability']:.3f}β{cf['max_probability']:.3f}), " | |
| f"confirming strong position-level causality." | |
| ) | |
| abl = s.ablation | |
| dom = abl.get("dominant_feature", "unknown") | |
| parts.append( | |
| f"Feature ablation: '{dom}' contributes the largest share " | |
| f"of the pathogenic signal " | |
| f"({abl.get(dom.lower().split()[0]+'_pct', '?')}%)." | |
| ) | |
| return " ".join(parts) | |
| def _protein_analysis_text(v4: V4Signals, classic: ClassicSignals) -> str: | |
| parts = [] | |
| v4_tier, _ = _classify_risk_tier(v4.probability) | |
| cl_tier, _ = _classify_risk_tier(classic.probability) | |
| parts.append( | |
| f"V4 model: '{v4_tier}' (prob={v4.probability:.4f}). " | |
| f"Classic model: '{cl_tier}' (prob={classic.probability:.4f})." | |
| ) | |
| avg_mpr = (v4.mutation_peak_ratio + classic.mutation_peak_ratio) / 2 | |
| parts.append( | |
| f"Average mutation-site activation ratio across protein models: " | |
| f"{avg_mpr:.2f}Γ." | |
| ) | |
| if float(classic.region_imp[0]) > 0.6: | |
| parts.append( | |
| "The classic model assigns high exon-region importance " | |
| f"({float(classic.region_imp[0]):.3f}), " | |
| "supporting a protein-coding disruption mechanism." | |
| ) | |
| elif float(classic.region_imp[1]) > 0.6: | |
| parts.append( | |
| "The classic model assigns high intron-region importance " | |
| f"({float(classic.region_imp[1]):.3f}), " | |
| "consistent with a regulatory or splicing mechanism." | |
| ) | |
| return " ".join(parts) | |
| def _agreement_text(splice: SpliceSignals, v4: V4Signals, | |
| classic: ClassicSignals, cross_locality: float, | |
| prob_std: float) -> str: | |
| probs = [splice.probability, v4.probability, classic.probability] | |
| p_max = max(probs) | |
| p_min = min(probs) | |
| lvl = "agree" if prob_std < 0.10 else ("partially agree" if prob_std < 0.18 else "disagree") | |
| lines = [ | |
| f"The three models {lvl} on pathogenicity " | |
| f"(probabilities: splice={splice.probability:.3f}, " | |
| f"v4={v4.probability:.3f}, classic={classic.probability:.3f}; " | |
| f"std={prob_std:.3f})." | |
| ] | |
| lines.append( | |
| f"Cross-model activation locality score = {cross_locality:.4f} " | |
| f"({'high' if cross_locality > 0.5 else 'moderate' if cross_locality > 0.0 else 'low'} " | |
| f"alignment of importance peaks across models)." | |
| ) | |
| if cross_locality > 0.5: | |
| lines.append( | |
| "All three models are attending to the same region of the 99-bp window, " | |
| "strongly supporting the identified mechanistic signal." | |
| ) | |
| else: | |
| lines.append( | |
| "The models are attending to different sequence regions, " | |
| "suggesting the signal may be mechanism-specific or context-dependent." | |
| ) | |
| return " ".join(lines) | |
| def _final_explanation(result: "DecisionResult") -> str: | |
| prob_label = ( | |
| "highly pathogenic" if result.unified_probability >= 0.85 else | |
| "likely pathogenic" if result.unified_probability >= 0.70 else | |
| "possibly pathogenic" if result.unified_probability >= 0.50 else | |
| "likely benign" if result.unified_probability >= 0.20 else | |
| "benign" | |
| ) | |
| mech = result.dominant_mechanism.lower() | |
| ess = result.explainability_strength | |
| lines = [ | |
| f"Variant {result.variant} is predicted {prob_label} " | |
| f"(unified probability = {result.unified_probability:.4f}; " | |
| f"risk tier: {result.risk_tier}) with {result.confidence.lower()} confidence.", | |
| f"The prediction is driven by a {mech} mechanism.", | |
| f"Explainability strength score = {result.explainability_strength:.4f} " | |
| f"({'high' if ess >= 0.65 else 'moderate' if ess >= 0.35 else 'low'} " | |
| f"overall evidence quality), based on a mutation peak ratio of " | |
| f"{result.mutation_peak_ratio:.2f}Γ, counterfactual magnitude of " | |
| f"{result.counterfactual_magnitude:.4f}, and cross-model locality of " | |
| f"{result.cross_model_locality:.4f}.", | |
| f"The conv3 activation pattern is classified as '{result.activation_pattern}', " | |
| f"and signal concentration at the mutation site is " | |
| f"{result.signal_concentration:.4f}.", | |
| f"Model agreement is {result.agreement_analysis.lower()}.", | |
| ] | |
| lines.append( | |
| "β This system is for research use only and should not substitute " | |
| "for clinical diagnostic testing." | |
| ) | |
| return "\n\n".join(lines) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_decision(chrom: str, pos: int, ref: str, alt: str, | |
| splice: SpliceSignals, | |
| v4: V4Signals, | |
| classic: ClassicSignals, | |
| cross: dict) -> DecisionResult: | |
| """ | |
| Build the unified DecisionResult. Called AFTER all XAI signals have been | |
| extracted β this ordering is mandatory. | |
| """ | |
| variant = f"chr{chrom}:{pos} {ref}>{alt}" | |
| # Weighted average probability β splice weighted 0.45 (most informative), | |
| # v4 0.30, classic 0.25 | |
| unified = (0.45 * splice.probability + | |
| 0.30 * v4.probability + | |
| 0.25 * classic.probability) | |
| unified = round(float(unified), 4) | |
| tier, tier_desc = _classify_risk_tier(unified) | |
| prob_std = cross["prob_std"] | |
| mech = _dominant_mechanism(splice, v4, classic, prob_std) | |
| ess = cross["explainability_strength_score"] | |
| cf_mag = cross["counterfactual_magnitude"] | |
| conf = _confidence(unified, prob_std, ess, cf_mag) | |
| splice_txt = _splice_analysis_text(splice, variant) | |
| protein_txt = _protein_analysis_text(v4, classic) | |
| agreement_txt= _agreement_text(splice, v4, classic, | |
| cross["cross_model_locality_score"], prob_std) | |
| # Build result skeleton so _final_explanation can reference it | |
| r = DecisionResult( | |
| variant=variant, | |
| unified_probability=unified, | |
| risk_tier=tier, tier_desc=tier_desc, | |
| dominant_mechanism=mech, confidence=conf, | |
| splice_prob=splice.probability, | |
| v4_prob=v4.probability, | |
| classic_prob=classic.probability, | |
| mutation_peak_ratio=cross["mutation_peak_ratio"], | |
| counterfactual_magnitude=cross["counterfactual_magnitude"], | |
| cross_model_locality=cross["cross_model_locality_score"], | |
| signal_concentration=cross["signal_concentration_index"], | |
| explainability_strength=ess, | |
| activation_pattern=cross["activation_pattern_type"], | |
| splice_analysis=splice_txt, | |
| protein_analysis=protein_txt, | |
| agreement_analysis=agreement_txt, | |
| final_explanation="", # filled below | |
| report_json="", # filled below | |
| ) | |
| r.final_explanation = _final_explanation(r) | |
| # ββ Build structured JSON βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cf = splice.counterfactual | |
| abl= splice.ablation | |
| report = { | |
| "variant": variant, | |
| "prediction": { | |
| "unified_probability": unified, | |
| "risk_tier": tier, | |
| "tier_desc": tier_desc, | |
| "dominant_mechanism": mech, | |
| "confidence": conf, | |
| }, | |
| "model_outputs": { | |
| "splice": { | |
| "probability": splice.probability, | |
| "risk_tier": splice.risk_tier, | |
| "conv3_peak_at_mutation": ( | |
| round(float(splice.conv3_norm[splice.mutation_pos]), 4) | |
| if 0 <= splice.mutation_pos < 99 else None | |
| ), | |
| "splice_importance": { | |
| "donor": round(float(splice.splice_imp[0]), 4), | |
| "acceptor":round(float(splice.splice_imp[1]), 4), | |
| "region": round(float(splice.splice_imp[2]), 4), | |
| }, | |
| "region_importance": { | |
| "exon": round(float(splice.region_imp[0]), 4), | |
| "intron": round(float(splice.region_imp[1]), 4), | |
| }, | |
| "splice_aura_score": splice.splice_aura_score, | |
| "dist_donor": splice.dist_donor, | |
| "dist_acceptor": splice.dist_acceptor, | |
| "splice_risk_donor": splice.splice_risk_donor, | |
| "splice_risk_acceptor": splice.splice_risk_acceptor, | |
| "counterfactual": {k: v for k, v in cf.items() | |
| if k != "table"}, | |
| "counterfactual_table": cf.get("table", []), | |
| "feature_ablation": abl, | |
| }, | |
| "v4": { | |
| "probability": v4.probability, | |
| "mutation_peak_ratio": round(v4.mutation_peak_ratio, 4), | |
| "signal_concentration": round(v4.signal_concentration, 4), | |
| }, | |
| "classic": { | |
| "probability": classic.probability, | |
| "importance_head": round(classic.importance_head, 4), | |
| "region_importance": { | |
| "exon": round(float(classic.region_imp[0]), 4), | |
| "intron": round(float(classic.region_imp[1]), 4), | |
| }, | |
| "mutation_peak_ratio": round(classic.mutation_peak_ratio, 4), | |
| }, | |
| }, | |
| "explainability_analysis": { | |
| "mutation_peak_ratio": cross["mutation_peak_ratio"], | |
| "counterfactual_magnitude": cross["counterfactual_magnitude"], | |
| "cross_model_locality_score": cross["cross_model_locality_score"], | |
| "signal_concentration_index": cross["signal_concentration_index"], | |
| "explainability_strength_score": ess, | |
| "activation_pattern_type": cross["activation_pattern_type"], | |
| "model_agreement": cross["model_agreement"], | |
| "prob_std": cross["prob_std"], | |
| }, | |
| "interpretation": { | |
| "splice_analysis": splice_txt, | |
| "protein_analysis": protein_txt, | |
| "agreement_analysis":agreement_txt, | |
| "final_explanation": r.final_explanation, | |
| }, | |
| } | |
| r.report_json = json.dumps(report, indent=2, default=str) | |
| return r | |