""" decision_engine.py ================== Unified decision logic — probability, dominant mechanism, confidence level, and human-readable interpretation grounded in internal signals. Prediction is NEVER computed without the explainability analysis being present first. This module enforces that ordering. """ from __future__ import annotations import json from dataclasses import dataclass, field from typing import Optional import numpy as np from explainability_engine import ( SpliceSignals, V4Signals, ClassicSignals, _classify_risk_tier, ) # ═══════════════════════════════════════════════════════════════════════════════ # Result dataclass # ═══════════════════════════════════════════════════════════════════════════════ @dataclass class DecisionResult: variant: str unified_probability: float risk_tier: str tier_desc: str dominant_mechanism: str confidence: str # per-model splice_prob: float v4_prob: float classic_prob: float # XAI engine metrics mutation_peak_ratio: float counterfactual_magnitude:float cross_model_locality: float signal_concentration: float explainability_strength: float activation_pattern: str # interpretation splice_analysis: str protein_analysis: str agreement_analysis: str final_explanation: str # full structured JSON (as string) report_json: str = field(repr=False) # ═══════════════════════════════════════════════════════════════════════════════ # Dominant mechanism logic # ═══════════════════════════════════════════════════════════════════════════════ def _dominant_mechanism(splice: SpliceSignals, v4: V4Signals, classic: ClassicSignals, prob_std: float) -> str: """ Determine which mechanism drives the prediction. Splice-driven: splice model dominates AND aura / splice importance elevated Protein-driven: v4 + classic > splice probability by meaningful margin Consensus: all three models agree within 0.10 Ambiguous: high disagreement (prob_std > 0.12) """ probs = np.array([splice.probability, v4.probability, classic.probability]) p_max = float(probs.max()) p_min = float(probs.min()) if prob_std > 0.14: return "Ambiguous" # Splice dominance splice_leads = splice.probability > v4.probability + 0.05 high_aura = splice.splice_aura_score > 0.35 high_splice_i = float(splice.splice_imp.max()) > 0.50 if splice_leads and (high_aura or high_splice_i): return "Splice-driven" # Protein dominance protein_avg = (v4.probability + classic.probability) / 2 if protein_avg > splice.probability + 0.05 and float(splice.region_imp[0]) > 0.5: return "Protein-driven" # Consensus if p_max - p_min <= 0.10: return "Consensus" return "Ambiguous" # ═══════════════════════════════════════════════════════════════════════════════ # Confidence level # ═══════════════════════════════════════════════════════════════════════════════ def _confidence(unified_prob: float, prob_std: float, ess: float, cf_mag: float) -> str: """ High: strong signal from all three axes Moderate: partial support Low: conflicting or weak signals """ score = 0 # Model agreement if prob_std < 0.05: score += 2 elif prob_std < 0.12: score += 1 # Explainability strength if ess >= 0.65: score += 2 elif ess >= 0.35: score += 1 # Counterfactual magnitude if cf_mag >= 0.25: score += 2 elif cf_mag >= 0.10: score += 1 # Probability extremity if unified_prob >= 0.85 or unified_prob <= 0.15: score += 1 if score >= 5: return "High" if score >= 3: return "Moderate" return "Low" # ═══════════════════════════════════════════════════════════════════════════════ # Human-readable interpretation builders # ═══════════════════════════════════════════════════════════════════════════════ def _splice_analysis_text(s: SpliceSignals, variant: str) -> str: parts = [] tier, _ = _classify_risk_tier(s.probability) parts.append( f"The splice model classifies variant {variant} as " f"'{tier}' with probability {s.probability:.4f}." ) if s.mutation_peak_ratio >= 2.0: parts.append( f"The conv3 activation peak at the mutation site " f"(position {s.mutation_pos}) is {s.mutation_peak_ratio:.2f}× " f"the mean window activation, indicating the model is strongly " f"attending to the mutation location." ) elif s.mutation_peak_ratio >= 1.0: parts.append( f"The mutation position ({s.mutation_pos}) has above-average " f"conv3 activation (MPR={s.mutation_peak_ratio:.2f}×)." ) else: parts.append( f"The mutation position ({s.mutation_pos}) does not carry an " f"elevated conv3 activation peak (MPR={s.mutation_peak_ratio:.2f}×), " f"suggesting the pathogenic signal may arise from broader context." ) if s.splice_risk_donor in ("CRITICAL SPLICE SITE", "SPLICE REGION"): parts.append( f"The variant lies {s.dist_donor} bp from the nearest GT donor " f"dinucleotide (risk: {s.splice_risk_donor}). " f"Splice donor importance score = {float(s.splice_imp[0]):.3f}." ) if s.splice_risk_acceptor in ("CRITICAL SPLICE SITE", "SPLICE REGION"): parts.append( f"The variant lies {s.dist_acceptor} bp from the nearest AG acceptor " f"dinucleotide (risk: {s.splice_risk_acceptor}). " f"Splice acceptor importance score = {float(s.splice_imp[1]):.3f}." ) cf = s.counterfactual if cf.get("probability_range", 0) > 0.15: parts.append( f"Counterfactual analysis: swapping the alternate base changes " f"the pathogenicity probability by up to " f"{cf['probability_range']:.3f} " f"(range {cf['min_probability']:.3f}–{cf['max_probability']:.3f}), " f"confirming strong position-level causality." ) abl = s.ablation dom = abl.get("dominant_feature", "unknown") parts.append( f"Feature ablation: '{dom}' contributes the largest share " f"of the pathogenic signal " f"({abl.get(dom.lower().split()[0]+'_pct', '?')}%)." ) return " ".join(parts) def _protein_analysis_text(v4: V4Signals, classic: ClassicSignals) -> str: parts = [] v4_tier, _ = _classify_risk_tier(v4.probability) cl_tier, _ = _classify_risk_tier(classic.probability) parts.append( f"V4 model: '{v4_tier}' (prob={v4.probability:.4f}). " f"Classic model: '{cl_tier}' (prob={classic.probability:.4f})." ) avg_mpr = (v4.mutation_peak_ratio + classic.mutation_peak_ratio) / 2 parts.append( f"Average mutation-site activation ratio across protein models: " f"{avg_mpr:.2f}×." ) if float(classic.region_imp[0]) > 0.6: parts.append( "The classic model assigns high exon-region importance " f"({float(classic.region_imp[0]):.3f}), " "supporting a protein-coding disruption mechanism." ) elif float(classic.region_imp[1]) > 0.6: parts.append( "The classic model assigns high intron-region importance " f"({float(classic.region_imp[1]):.3f}), " "consistent with a regulatory or splicing mechanism." ) return " ".join(parts) def _agreement_text(splice: SpliceSignals, v4: V4Signals, classic: ClassicSignals, cross_locality: float, prob_std: float) -> str: probs = [splice.probability, v4.probability, classic.probability] p_max = max(probs) p_min = min(probs) lvl = "agree" if prob_std < 0.10 else ("partially agree" if prob_std < 0.18 else "disagree") lines = [ f"The three models {lvl} on pathogenicity " f"(probabilities: splice={splice.probability:.3f}, " f"v4={v4.probability:.3f}, classic={classic.probability:.3f}; " f"std={prob_std:.3f})." ] lines.append( f"Cross-model activation locality score = {cross_locality:.4f} " f"({'high' if cross_locality > 0.5 else 'moderate' if cross_locality > 0.0 else 'low'} " f"alignment of importance peaks across models)." ) if cross_locality > 0.5: lines.append( "All three models are attending to the same region of the 99-bp window, " "strongly supporting the identified mechanistic signal." ) else: lines.append( "The models are attending to different sequence regions, " "suggesting the signal may be mechanism-specific or context-dependent." ) return " ".join(lines) def _final_explanation(result: "DecisionResult") -> str: prob_label = ( "highly pathogenic" if result.unified_probability >= 0.85 else "likely pathogenic" if result.unified_probability >= 0.70 else "possibly pathogenic" if result.unified_probability >= 0.50 else "likely benign" if result.unified_probability >= 0.20 else "benign" ) mech = result.dominant_mechanism.lower() ess = result.explainability_strength lines = [ f"Variant {result.variant} is predicted {prob_label} " f"(unified probability = {result.unified_probability:.4f}; " f"risk tier: {result.risk_tier}) with {result.confidence.lower()} confidence.", f"The prediction is driven by a {mech} mechanism.", f"Explainability strength score = {result.explainability_strength:.4f} " f"({'high' if ess >= 0.65 else 'moderate' if ess >= 0.35 else 'low'} " f"overall evidence quality), based on a mutation peak ratio of " f"{result.mutation_peak_ratio:.2f}×, counterfactual magnitude of " f"{result.counterfactual_magnitude:.4f}, and cross-model locality of " f"{result.cross_model_locality:.4f}.", f"The conv3 activation pattern is classified as '{result.activation_pattern}', " f"and signal concentration at the mutation site is " f"{result.signal_concentration:.4f}.", f"Model agreement is {result.agreement_analysis.lower()}.", ] lines.append( "⚠ This system is for research use only and should not substitute " "for clinical diagnostic testing." ) return "\n\n".join(lines) # ═══════════════════════════════════════════════════════════════════════════════ # Main entry point # ═══════════════════════════════════════════════════════════════════════════════ def build_decision(chrom: str, pos: int, ref: str, alt: str, splice: SpliceSignals, v4: V4Signals, classic: ClassicSignals, cross: dict) -> DecisionResult: """ Build the unified DecisionResult. Called AFTER all XAI signals have been extracted — this ordering is mandatory. """ variant = f"chr{chrom}:{pos} {ref}>{alt}" # Weighted average probability — splice weighted 0.45 (most informative), # v4 0.30, classic 0.25 unified = (0.45 * splice.probability + 0.30 * v4.probability + 0.25 * classic.probability) unified = round(float(unified), 4) tier, tier_desc = _classify_risk_tier(unified) prob_std = cross["prob_std"] mech = _dominant_mechanism(splice, v4, classic, prob_std) ess = cross["explainability_strength_score"] cf_mag = cross["counterfactual_magnitude"] conf = _confidence(unified, prob_std, ess, cf_mag) splice_txt = _splice_analysis_text(splice, variant) protein_txt = _protein_analysis_text(v4, classic) agreement_txt= _agreement_text(splice, v4, classic, cross["cross_model_locality_score"], prob_std) # Build result skeleton so _final_explanation can reference it r = DecisionResult( variant=variant, unified_probability=unified, risk_tier=tier, tier_desc=tier_desc, dominant_mechanism=mech, confidence=conf, splice_prob=splice.probability, v4_prob=v4.probability, classic_prob=classic.probability, mutation_peak_ratio=cross["mutation_peak_ratio"], counterfactual_magnitude=cross["counterfactual_magnitude"], cross_model_locality=cross["cross_model_locality_score"], signal_concentration=cross["signal_concentration_index"], explainability_strength=ess, activation_pattern=cross["activation_pattern_type"], splice_analysis=splice_txt, protein_analysis=protein_txt, agreement_analysis=agreement_txt, final_explanation="", # filled below report_json="", # filled below ) r.final_explanation = _final_explanation(r) # ── Build structured JSON ───────────────────────────────────────────────── cf = splice.counterfactual abl= splice.ablation report = { "variant": variant, "prediction": { "unified_probability": unified, "risk_tier": tier, "tier_desc": tier_desc, "dominant_mechanism": mech, "confidence": conf, }, "model_outputs": { "splice": { "probability": splice.probability, "risk_tier": splice.risk_tier, "conv3_peak_at_mutation": ( round(float(splice.conv3_norm[splice.mutation_pos]), 4) if 0 <= splice.mutation_pos < 99 else None ), "splice_importance": { "donor": round(float(splice.splice_imp[0]), 4), "acceptor":round(float(splice.splice_imp[1]), 4), "region": round(float(splice.splice_imp[2]), 4), }, "region_importance": { "exon": round(float(splice.region_imp[0]), 4), "intron": round(float(splice.region_imp[1]), 4), }, "splice_aura_score": splice.splice_aura_score, "dist_donor": splice.dist_donor, "dist_acceptor": splice.dist_acceptor, "splice_risk_donor": splice.splice_risk_donor, "splice_risk_acceptor": splice.splice_risk_acceptor, "counterfactual": {k: v for k, v in cf.items() if k != "table"}, "counterfactual_table": cf.get("table", []), "feature_ablation": abl, }, "v4": { "probability": v4.probability, "mutation_peak_ratio": round(v4.mutation_peak_ratio, 4), "signal_concentration": round(v4.signal_concentration, 4), }, "classic": { "probability": classic.probability, "importance_head": round(classic.importance_head, 4), "region_importance": { "exon": round(float(classic.region_imp[0]), 4), "intron": round(float(classic.region_imp[1]), 4), }, "mutation_peak_ratio": round(classic.mutation_peak_ratio, 4), }, }, "explainability_analysis": { "mutation_peak_ratio": cross["mutation_peak_ratio"], "counterfactual_magnitude": cross["counterfactual_magnitude"], "cross_model_locality_score": cross["cross_model_locality_score"], "signal_concentration_index": cross["signal_concentration_index"], "explainability_strength_score": ess, "activation_pattern_type": cross["activation_pattern_type"], "model_agreement": cross["model_agreement"], "prob_std": cross["prob_std"], }, "interpretation": { "splice_analysis": splice_txt, "protein_analysis": protein_txt, "agreement_analysis":agreement_txt, "final_explanation": r.final_explanation, }, } r.report_json = json.dumps(report, indent=2, default=str) return r