Mutation-XAI / decision_engine.py
nileshhanotia's picture
Rename decision_engine (2).py to decision_engine.py
23a5204 verified
"""
decision_engine.py
==================
Unified decision logic β€” probability, dominant mechanism, confidence level,
and human-readable interpretation grounded in internal signals.
Prediction is NEVER computed without the explainability analysis being
present first. This module enforces that ordering.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from explainability_engine import (
SpliceSignals,
V4Signals,
ClassicSignals,
_classify_risk_tier,
)
# ═══════════════════════════════════════════════════════════════════════════════
# Result dataclass
# ═══════════════════════════════════════════════════════════════════════════════
@dataclass
class DecisionResult:
variant: str
unified_probability: float
risk_tier: str
tier_desc: str
dominant_mechanism: str
confidence: str
# per-model
splice_prob: float
v4_prob: float
classic_prob: float
# XAI engine metrics
mutation_peak_ratio: float
counterfactual_magnitude:float
cross_model_locality: float
signal_concentration: float
explainability_strength: float
activation_pattern: str
# interpretation
splice_analysis: str
protein_analysis: str
agreement_analysis: str
final_explanation: str
# full structured JSON (as string)
report_json: str = field(repr=False)
# ═══════════════════════════════════════════════════════════════════════════════
# Dominant mechanism logic
# ═══════════════════════════════════════════════════════════════════════════════
def _dominant_mechanism(splice: SpliceSignals, v4: V4Signals,
classic: ClassicSignals,
prob_std: float) -> str:
"""
Determine which mechanism drives the prediction.
Splice-driven: splice model dominates AND aura / splice importance elevated
Protein-driven: v4 + classic > splice probability by meaningful margin
Consensus: all three models agree within 0.10
Ambiguous: high disagreement (prob_std > 0.12)
"""
probs = np.array([splice.probability, v4.probability, classic.probability])
p_max = float(probs.max())
p_min = float(probs.min())
if prob_std > 0.14:
return "Ambiguous"
# Splice dominance
splice_leads = splice.probability > v4.probability + 0.05
high_aura = splice.splice_aura_score > 0.35
high_splice_i = float(splice.splice_imp.max()) > 0.50
if splice_leads and (high_aura or high_splice_i):
return "Splice-driven"
# Protein dominance
protein_avg = (v4.probability + classic.probability) / 2
if protein_avg > splice.probability + 0.05 and float(splice.region_imp[0]) > 0.5:
return "Protein-driven"
# Consensus
if p_max - p_min <= 0.10:
return "Consensus"
return "Ambiguous"
# ═══════════════════════════════════════════════════════════════════════════════
# Confidence level
# ═══════════════════════════════════════════════════════════════════════════════
def _confidence(unified_prob: float, prob_std: float,
ess: float, cf_mag: float) -> str:
"""
High: strong signal from all three axes
Moderate: partial support
Low: conflicting or weak signals
"""
score = 0
# Model agreement
if prob_std < 0.05: score += 2
elif prob_std < 0.12: score += 1
# Explainability strength
if ess >= 0.65: score += 2
elif ess >= 0.35: score += 1
# Counterfactual magnitude
if cf_mag >= 0.25: score += 2
elif cf_mag >= 0.10: score += 1
# Probability extremity
if unified_prob >= 0.85 or unified_prob <= 0.15: score += 1
if score >= 5: return "High"
if score >= 3: return "Moderate"
return "Low"
# ═══════════════════════════════════════════════════════════════════════════════
# Human-readable interpretation builders
# ═══════════════════════════════════════════════════════════════════════════════
def _splice_analysis_text(s: SpliceSignals, variant: str) -> str:
parts = []
tier, _ = _classify_risk_tier(s.probability)
parts.append(
f"The splice model classifies variant {variant} as "
f"'{tier}' with probability {s.probability:.4f}."
)
if s.mutation_peak_ratio >= 2.0:
parts.append(
f"The conv3 activation peak at the mutation site "
f"(position {s.mutation_pos}) is {s.mutation_peak_ratio:.2f}Γ— "
f"the mean window activation, indicating the model is strongly "
f"attending to the mutation location."
)
elif s.mutation_peak_ratio >= 1.0:
parts.append(
f"The mutation position ({s.mutation_pos}) has above-average "
f"conv3 activation (MPR={s.mutation_peak_ratio:.2f}Γ—)."
)
else:
parts.append(
f"The mutation position ({s.mutation_pos}) does not carry an "
f"elevated conv3 activation peak (MPR={s.mutation_peak_ratio:.2f}Γ—), "
f"suggesting the pathogenic signal may arise from broader context."
)
if s.splice_risk_donor in ("CRITICAL SPLICE SITE", "SPLICE REGION"):
parts.append(
f"The variant lies {s.dist_donor} bp from the nearest GT donor "
f"dinucleotide (risk: {s.splice_risk_donor}). "
f"Splice donor importance score = {float(s.splice_imp[0]):.3f}."
)
if s.splice_risk_acceptor in ("CRITICAL SPLICE SITE", "SPLICE REGION"):
parts.append(
f"The variant lies {s.dist_acceptor} bp from the nearest AG acceptor "
f"dinucleotide (risk: {s.splice_risk_acceptor}). "
f"Splice acceptor importance score = {float(s.splice_imp[1]):.3f}."
)
cf = s.counterfactual
if cf.get("probability_range", 0) > 0.15:
parts.append(
f"Counterfactual analysis: swapping the alternate base changes "
f"the pathogenicity probability by up to "
f"{cf['probability_range']:.3f} "
f"(range {cf['min_probability']:.3f}–{cf['max_probability']:.3f}), "
f"confirming strong position-level causality."
)
abl = s.ablation
dom = abl.get("dominant_feature", "unknown")
parts.append(
f"Feature ablation: '{dom}' contributes the largest share "
f"of the pathogenic signal "
f"({abl.get(dom.lower().split()[0]+'_pct', '?')}%)."
)
return " ".join(parts)
def _protein_analysis_text(v4: V4Signals, classic: ClassicSignals) -> str:
parts = []
v4_tier, _ = _classify_risk_tier(v4.probability)
cl_tier, _ = _classify_risk_tier(classic.probability)
parts.append(
f"V4 model: '{v4_tier}' (prob={v4.probability:.4f}). "
f"Classic model: '{cl_tier}' (prob={classic.probability:.4f})."
)
avg_mpr = (v4.mutation_peak_ratio + classic.mutation_peak_ratio) / 2
parts.append(
f"Average mutation-site activation ratio across protein models: "
f"{avg_mpr:.2f}Γ—."
)
if float(classic.region_imp[0]) > 0.6:
parts.append(
"The classic model assigns high exon-region importance "
f"({float(classic.region_imp[0]):.3f}), "
"supporting a protein-coding disruption mechanism."
)
elif float(classic.region_imp[1]) > 0.6:
parts.append(
"The classic model assigns high intron-region importance "
f"({float(classic.region_imp[1]):.3f}), "
"consistent with a regulatory or splicing mechanism."
)
return " ".join(parts)
def _agreement_text(splice: SpliceSignals, v4: V4Signals,
classic: ClassicSignals, cross_locality: float,
prob_std: float) -> str:
probs = [splice.probability, v4.probability, classic.probability]
p_max = max(probs)
p_min = min(probs)
lvl = "agree" if prob_std < 0.10 else ("partially agree" if prob_std < 0.18 else "disagree")
lines = [
f"The three models {lvl} on pathogenicity "
f"(probabilities: splice={splice.probability:.3f}, "
f"v4={v4.probability:.3f}, classic={classic.probability:.3f}; "
f"std={prob_std:.3f})."
]
lines.append(
f"Cross-model activation locality score = {cross_locality:.4f} "
f"({'high' if cross_locality > 0.5 else 'moderate' if cross_locality > 0.0 else 'low'} "
f"alignment of importance peaks across models)."
)
if cross_locality > 0.5:
lines.append(
"All three models are attending to the same region of the 99-bp window, "
"strongly supporting the identified mechanistic signal."
)
else:
lines.append(
"The models are attending to different sequence regions, "
"suggesting the signal may be mechanism-specific or context-dependent."
)
return " ".join(lines)
def _final_explanation(result: "DecisionResult") -> str:
prob_label = (
"highly pathogenic" if result.unified_probability >= 0.85 else
"likely pathogenic" if result.unified_probability >= 0.70 else
"possibly pathogenic" if result.unified_probability >= 0.50 else
"likely benign" if result.unified_probability >= 0.20 else
"benign"
)
mech = result.dominant_mechanism.lower()
ess = result.explainability_strength
lines = [
f"Variant {result.variant} is predicted {prob_label} "
f"(unified probability = {result.unified_probability:.4f}; "
f"risk tier: {result.risk_tier}) with {result.confidence.lower()} confidence.",
f"The prediction is driven by a {mech} mechanism.",
f"Explainability strength score = {result.explainability_strength:.4f} "
f"({'high' if ess >= 0.65 else 'moderate' if ess >= 0.35 else 'low'} "
f"overall evidence quality), based on a mutation peak ratio of "
f"{result.mutation_peak_ratio:.2f}Γ—, counterfactual magnitude of "
f"{result.counterfactual_magnitude:.4f}, and cross-model locality of "
f"{result.cross_model_locality:.4f}.",
f"The conv3 activation pattern is classified as '{result.activation_pattern}', "
f"and signal concentration at the mutation site is "
f"{result.signal_concentration:.4f}.",
f"Model agreement is {result.agreement_analysis.lower()}.",
]
lines.append(
"⚠ This system is for research use only and should not substitute "
"for clinical diagnostic testing."
)
return "\n\n".join(lines)
# ═══════════════════════════════════════════════════════════════════════════════
# Main entry point
# ═══════════════════════════════════════════════════════════════════════════════
def build_decision(chrom: str, pos: int, ref: str, alt: str,
splice: SpliceSignals,
v4: V4Signals,
classic: ClassicSignals,
cross: dict) -> DecisionResult:
"""
Build the unified DecisionResult. Called AFTER all XAI signals have been
extracted β€” this ordering is mandatory.
"""
variant = f"chr{chrom}:{pos} {ref}>{alt}"
# Weighted average probability β€” splice weighted 0.45 (most informative),
# v4 0.30, classic 0.25
unified = (0.45 * splice.probability +
0.30 * v4.probability +
0.25 * classic.probability)
unified = round(float(unified), 4)
tier, tier_desc = _classify_risk_tier(unified)
prob_std = cross["prob_std"]
mech = _dominant_mechanism(splice, v4, classic, prob_std)
ess = cross["explainability_strength_score"]
cf_mag = cross["counterfactual_magnitude"]
conf = _confidence(unified, prob_std, ess, cf_mag)
splice_txt = _splice_analysis_text(splice, variant)
protein_txt = _protein_analysis_text(v4, classic)
agreement_txt= _agreement_text(splice, v4, classic,
cross["cross_model_locality_score"], prob_std)
# Build result skeleton so _final_explanation can reference it
r = DecisionResult(
variant=variant,
unified_probability=unified,
risk_tier=tier, tier_desc=tier_desc,
dominant_mechanism=mech, confidence=conf,
splice_prob=splice.probability,
v4_prob=v4.probability,
classic_prob=classic.probability,
mutation_peak_ratio=cross["mutation_peak_ratio"],
counterfactual_magnitude=cross["counterfactual_magnitude"],
cross_model_locality=cross["cross_model_locality_score"],
signal_concentration=cross["signal_concentration_index"],
explainability_strength=ess,
activation_pattern=cross["activation_pattern_type"],
splice_analysis=splice_txt,
protein_analysis=protein_txt,
agreement_analysis=agreement_txt,
final_explanation="", # filled below
report_json="", # filled below
)
r.final_explanation = _final_explanation(r)
# ── Build structured JSON ─────────────────────────────────────────────────
cf = splice.counterfactual
abl= splice.ablation
report = {
"variant": variant,
"prediction": {
"unified_probability": unified,
"risk_tier": tier,
"tier_desc": tier_desc,
"dominant_mechanism": mech,
"confidence": conf,
},
"model_outputs": {
"splice": {
"probability": splice.probability,
"risk_tier": splice.risk_tier,
"conv3_peak_at_mutation": (
round(float(splice.conv3_norm[splice.mutation_pos]), 4)
if 0 <= splice.mutation_pos < 99 else None
),
"splice_importance": {
"donor": round(float(splice.splice_imp[0]), 4),
"acceptor":round(float(splice.splice_imp[1]), 4),
"region": round(float(splice.splice_imp[2]), 4),
},
"region_importance": {
"exon": round(float(splice.region_imp[0]), 4),
"intron": round(float(splice.region_imp[1]), 4),
},
"splice_aura_score": splice.splice_aura_score,
"dist_donor": splice.dist_donor,
"dist_acceptor": splice.dist_acceptor,
"splice_risk_donor": splice.splice_risk_donor,
"splice_risk_acceptor": splice.splice_risk_acceptor,
"counterfactual": {k: v for k, v in cf.items()
if k != "table"},
"counterfactual_table": cf.get("table", []),
"feature_ablation": abl,
},
"v4": {
"probability": v4.probability,
"mutation_peak_ratio": round(v4.mutation_peak_ratio, 4),
"signal_concentration": round(v4.signal_concentration, 4),
},
"classic": {
"probability": classic.probability,
"importance_head": round(classic.importance_head, 4),
"region_importance": {
"exon": round(float(classic.region_imp[0]), 4),
"intron": round(float(classic.region_imp[1]), 4),
},
"mutation_peak_ratio": round(classic.mutation_peak_ratio, 4),
},
},
"explainability_analysis": {
"mutation_peak_ratio": cross["mutation_peak_ratio"],
"counterfactual_magnitude": cross["counterfactual_magnitude"],
"cross_model_locality_score": cross["cross_model_locality_score"],
"signal_concentration_index": cross["signal_concentration_index"],
"explainability_strength_score": ess,
"activation_pattern_type": cross["activation_pattern_type"],
"model_agreement": cross["model_agreement"],
"prob_std": cross["prob_std"],
},
"interpretation": {
"splice_analysis": splice_txt,
"protein_analysis": protein_txt,
"agreement_analysis":agreement_txt,
"final_explanation": r.final_explanation,
},
}
r.report_json = json.dumps(report, indent=2, default=str)
return r