Mutation_XAI / decision_engine.py
nileshhanotia's picture
decision_engine.py
44fb3c3 verified
"""decision_engine.py — PeVe v1.1 Deterministic Synthesis Engine"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from config import (
PEVE_VERSION, THRESHOLD_VERSION,
SPLICE_PROB_HIGH, SPLICE_PROB_MODERATE, SPLICE_PROB_WEAK,
SPLICE_SIGNAL_MIN, SPLICE_DOMINANT_MIN,
ACTIVATION_NORM_HIGH, ACTIVATION_NORM_MODERATE, ACTIVATION_NORM_WEAK,
CONTEXT_ACTIVE_MIN, BIOCHEMICAL_RISK_ACTIVE,
AF_RARITY_THRESHOLD, AF_HIGH_CONFLICT,
BOUNDARY_TOLERANCE, WINDOW_BP, PEAK_OFF_CENTER_FRAC,
)
from prefilter import VariantClass
from af_handler import AFResult, AF_NUMERIC, AF_ZERO, AF_UNKNOWN, AF_UNCERTAIN
# ── Raw layer outputs ─────────────────────────────────────
@dataclass
class SpliceLayerOutput:
splice_prob: float
splice_signal_strength: float
counterfactual_delta: float
saliency_map: Optional[object]
model_available: bool = True
@dataclass
class ContextLayerOutput:
context_pathogenic_prob: float
activation_norm: float
activation_peak_position: int
importance_score: float
model_available: bool = True
@dataclass
class ProteinLayerOutput:
biochemical_risk_score: float
feature_pathogenic_prob: float
shap_feature_contributions: dict
l3_substitution_valid: bool
model_available: bool = True
# ── Band classifiers ──────────────────────────────────────
def _splice_band(p):
if p >= SPLICE_PROB_HIGH: return "High"
if p >= SPLICE_PROB_MODERATE: return "Moderate"
if p >= SPLICE_PROB_WEAK: return "Weak"
return "Inactive"
def _context_band(n):
if n >= ACTIVATION_NORM_HIGH: return "High"
if n >= ACTIVATION_NORM_MODERATE: return "Moderate"
if n >= ACTIVATION_NORM_WEAK: return "Weak"
return "Inactive"
def _near(val, thresh): return abs(val - thresh) <= BOUNDARY_TOLERANCE
def _off_center(pos): return abs(pos - WINDOW_BP//2) > int(WINDOW_BP * PEAK_OFF_CENTER_FRAC)
# ── Activation levels ─────────────────────────────────────
@dataclass
class ActivationLevels:
splice_band: str; rna_active: bool; rna_dominant: bool
context_band: str; context_active: bool
protein_active: bool; l3_valid: bool
rna_boundary: bool; context_boundary: bool; protein_boundary: bool
def compute_activation_levels(splice, context, protein, af_result):
s_band = _splice_band(splice.splice_prob)
rna_active = splice.splice_prob >= SPLICE_PROB_MODERATE and splice.splice_signal_strength >= SPLICE_SIGNAL_MIN
rna_dominant = splice.splice_prob >= SPLICE_DOMINANT_MIN
c_band = _context_band(context.activation_norm)
ctx_active = context.activation_norm >= CONTEXT_ACTIVE_MIN
prot_active = (protein.l3_substitution_valid and
protein.biochemical_risk_score >= BIOCHEMICAL_RISK_ACTIVE and
af_result.satisfies_rarity())
rna_b = _near(splice.splice_prob, SPLICE_PROB_MODERATE) or _near(splice.splice_prob, SPLICE_DOMINANT_MIN) or _near(splice.splice_signal_strength, SPLICE_SIGNAL_MIN)
ctx_b = _near(context.activation_norm, CONTEXT_ACTIVE_MIN) or _near(context.activation_norm, ACTIVATION_NORM_HIGH)
pro_b = _near(protein.biochemical_risk_score, BIOCHEMICAL_RISK_ACTIVE)
return ActivationLevels(s_band, rna_active, rna_dominant, c_band, ctx_active, prot_active,
protein.l3_substitution_valid, rna_b, ctx_b, pro_b)
# ── Conflict detection ────────────────────────────────────
@dataclass
class ConflictReport:
major_conflicts: list = field(default_factory=list)
minor_conflicts: list = field(default_factory=list)
requires_manual_review: bool = False
conflict_score_major: int = 0
conflict_score_minor: int = 0
def compute_review_flag(self):
self.conflict_score_major = len(self.major_conflicts)
self.conflict_score_minor = len(self.minor_conflicts)
self.requires_manual_review = self.conflict_score_major >= 1 or self.conflict_score_minor >= 2
def detect_conflicts(splice, context, protein, af_result, activation, variant_class):
r = ConflictReport()
if splice.splice_prob >= SPLICE_PROB_HIGH and af_result.triggers_high_af_conflict():
r.major_conflicts.append(
f"MAJOR: High splice_prob ({splice.splice_prob:.3f}) + common variant (AF={af_result.global_af:.5f}). "
"Splice-disrupting variant unlikely at this population frequency.")
if (protein.l3_substitution_valid and protein.biochemical_risk_score >= BIOCHEMICAL_RISK_ACTIVE
and af_result.triggers_high_af_conflict()):
r.major_conflicts.append(
f"MAJOR: High biochemical risk ({protein.biochemical_risk_score:.3f}) + common variant "
f"(AF={af_result.global_af:.5f}). Common biochemically disruptive variants are typically tolerated.")
if variant_class.variant_class == "canonical_splice" and not activation.rna_active:
r.major_conflicts.append(
f"MAJOR: Canonical splice site ({variant_class.raw_consequence}) but RNA model inactive "
f"(splice_prob={splice.splice_prob:.3f}). Model/annotation disagreement.")
bnd = []
if activation.rna_boundary: bnd.append(f"splice_prob({splice.splice_prob:.3f})/signal({splice.splice_signal_strength:.3f})")
if activation.context_boundary: bnd.append(f"activation_norm({context.activation_norm:.3f})")
if activation.protein_boundary: bnd.append(f"biochemical_risk({protein.biochemical_risk_score:.3f})")
if bnd: r.minor_conflicts.append(f"MINOR: Boundary proximity — {'; '.join(bnd)} within ±{BOUNDARY_TOLERANCE}.")
if _off_center(context.activation_peak_position):
offset = abs(context.activation_peak_position - WINDOW_BP//2)
r.minor_conflicts.append(f"MINOR: Activation peak {offset}bp from mutation centre (pos={context.activation_peak_position}).")
if activation.context_active and variant_class.raw_consequence in {
"synonymous_variant","intron_variant","upstream_gene_variant","downstream_gene_variant"}:
r.minor_conflicts.append(
f"MINOR: Context active (norm={context.activation_norm:.3f}) but VEP='{variant_class.raw_consequence}' (low impact).")
if af_result.state in {AF_UNKNOWN, AF_UNCERTAIN}:
r.minor_conflicts.append(f"MINOR: AF state={af_result.state} — rarity unconfirmed.")
r.compute_review_flag()
return r
# ── Mechanism constants ───────────────────────────────────
DOMINANT_RNA = "RNA_Splicing"
DOMINANT_PROTEIN = "Protein_Biochemical"
DOMINANT_CONTEXT = "Sequence_Context"
DOMINANT_AMBIGUITY = "Mechanism_Ambiguity"
DOMINANT_TRUNCATION = "Protein_Truncation"
DOMINANT_INSUFFICIENT = "Insufficient_Evidence"
DOMINANT_OOS = "Out_Of_Scope"
DOMINANT_CONFLICT_REVIEW = "Conflict_Manual_Review"
# ── Synthesis result ──────────────────────────────────────
@dataclass
class SynthesisResult:
dominant_mechanism: str
final_classification: str
supporting_mechanisms: list
activation_levels: ActivationLevels
conflict_report: ConflictReport
reasoning_steps: list
transcript_ambiguity: bool
af_uncertainty: bool
version: str = PEVE_VERSION
threshold_version: str = THRESHOLD_VERSION
def _mkr(dom, cls, sup, act, conf, steps, vc, af):
return SynthesisResult(dom, cls, sup, act, conf, steps,
vc.transcript_conflict, af.state in {AF_UNKNOWN, AF_UNCERTAIN})
# ── Main synthesis ────────────────────────────────────────
def synthesize(splice, context, protein, af_result, variant_class):
act = compute_activation_levels(splice, context, protein, af_result)
conf = detect_conflicts(splice, context, protein, af_result, act, variant_class)
steps = []
sup = []
# Conflict override
if conf.requires_manual_review and conf.conflict_score_major >= 1:
steps.append(f"CONFLICT OVERRIDE: {conf.conflict_score_major} major conflict(s). Classification suppressed.")
return _mkr(DOMINANT_CONFLICT_REVIEW, "Conflict — Manual Review Required", [], act, conf, steps, variant_class, af_result)
# Out of scope
if variant_class.out_of_scope:
steps.append(f"Variant class '{variant_class.variant_class}' is outside PeVe v1.1 scope.")
return _mkr(DOMINANT_OOS, "Out of Scope — See Flags", [], act, conf, steps, variant_class, af_result)
# Truncation gate
if variant_class.variant_class in {"frameshift","stop_gained","start_lost"}:
steps.append(f"Variant class '{variant_class.variant_class}' — protein truncation. L3 substitution metrics excluded.")
if act.rna_active:
steps.append(f"RNA also active (splice_prob={splice.splice_prob:.3f}) — possible NMD-relevant splice signal.")
sup.append(DOMINANT_RNA)
return _mkr(DOMINANT_TRUNCATION, "Protein Truncation", sup, act, conf, steps, variant_class, af_result)
if variant_class.transcript_conflict:
steps.append("Transcript conflict: consequence differs across transcripts. Both mechanisms elevated.")
# Rule 1: RNA High → dominant
if act.rna_dominant:
steps.append(f"RULE 1: RNA HIGH (splice_prob={splice.splice_prob:.3f}{SPLICE_DOMINANT_MIN}, signal={splice.splice_signal_strength:.3f}). RNA dominant.")
if act.protein_active: sup.append(DOMINANT_PROTEIN); steps.append(f" Supporting: Protein active (risk={protein.biochemical_risk_score:.3f}).")
if act.context_active: sup.append(DOMINANT_CONTEXT); steps.append(f" Supporting: Context active (norm={context.activation_norm:.3f}).")
return _mkr(DOMINANT_RNA, "Pathogenic — RNA Splice Mechanism", sup, act, conf, steps, variant_class, af_result)
# Rule 1b: RNA Moderate + Protein Active → ambiguity
if act.rna_active and act.protein_active:
steps.append(f"RULE 1b: RNA MODERATE (splice_prob={splice.splice_prob:.3f}) + Protein ACTIVE (risk={protein.biochemical_risk_score:.3f}). Mechanism Ambiguity.")
return _mkr(DOMINANT_AMBIGUITY, "Mechanism Ambiguity — Manual Review Recommended",
[DOMINANT_RNA, DOMINANT_PROTEIN], act, conf, steps, variant_class, af_result)
# Rule 2: Protein dominant
if act.protein_active:
steps.append(f"RULE 2: RNA inactive. Protein ACTIVE (risk={protein.biochemical_risk_score:.3f}, AF={af_result.global_af}).")
if act.context_active: sup.append(DOMINANT_CONTEXT); steps.append(f" Supporting: Context active (norm={context.activation_norm:.3f}).")
if act.rna_active:
sup.append(DOMINANT_RNA)
steps.append(f" Note: Moderate RNA signal present (splice_prob={splice.splice_prob:.3f}). mechanism_ambiguity_flag added.")
conf.minor_conflicts.append("MINOR: Moderate RNA signal alongside Protein-dominant call.")
conf.compute_review_flag()
return _mkr(DOMINANT_PROTEIN, "Pathogenic — Protein Biochemical Mechanism", sup, act, conf, steps, variant_class, af_result)
# Rule 3: Context dominant
if act.context_active:
if variant_class.variant_class == "substitution_synonymous":
steps.append(f"RULE 3 BLOCKED: Context active but synonymous variant — context alone cannot classify pathogenic.")
else:
steps.append(f"RULE 3: RNA+Protein inactive. Context ACTIVE (norm={context.activation_norm:.3f}).")
return _mkr(DOMINANT_CONTEXT, "Uncertain — Sequence Context Signal Only", [], act, conf, steps, variant_class, af_result)
# Rule 4: Insufficient evidence
steps.append(
f"RULE 4: No mechanism active. RNA={act.splice_band} Context={act.context_band} "
f"Protein active={act.protein_active} (L3 valid={act.l3_valid}, rare={af_result.satisfies_rarity()})."
)
if conf.requires_manual_review:
steps.append(f"Minor conflict threshold reached ({conf.conflict_score_minor} minor). Upgrading to Review.")
return _mkr(DOMINANT_CONFLICT_REVIEW, "Conflict — Manual Review Required", [], act, conf, steps, variant_class, af_result)
return _mkr(DOMINANT_INSUFFICIENT, "Likely Benign or Insufficient Evidence", [], act, conf, steps, variant_class, af_result)
# ── Narrative builder ─────────────────────────────────────
def build_narrative(result, splice, context, protein, af_result, variant_class):
lines = [f"PeVe v{PEVE_VERSION} Structured Reasoning Narrative", "="*60]
lines.append(f"Variant class: {variant_class.variant_class.replace('_',' ').title()}")
lines.append(f"RNA: splice_prob={splice.splice_prob:.3f} (band={result.activation_levels.splice_band}), "
f"signal={splice.splice_signal_strength:.3f}. "
+ ("ACTIVE." if result.activation_levels.rna_active else "INACTIVE."))
lines.append(f"Context: activation_norm={context.activation_norm:.3f} (band={result.activation_levels.context_band}). "
+ ("ACTIVE." if result.activation_levels.context_active else "INACTIVE."))
if result.activation_levels.l3_valid:
af_str = f"AF={af_result.global_af:.6f}" if af_result.global_af is not None else f"AF_state={af_result.state}"
lines.append(f"Protein: biochemical_risk={protein.biochemical_risk_score:.3f}, {af_str}. "
+ ("ACTIVE." if result.activation_levels.protein_active else "INACTIVE."))
else:
lines.append("Protein substitution metrics: NOT APPLICABLE for this variant class.")
lines.append("")
lines.append(f"Dominant mechanism: {result.dominant_mechanism.replace('_',' ')}")
lines.append(f"Final classification: {result.final_classification}")
if result.supporting_mechanisms:
lines.append(f"Supporting: {', '.join(m.replace('_',' ') for m in result.supporting_mechanisms)}")
if result.conflict_report.major_conflicts:
lines.append("\nMAJOR CONFLICTS:")
lines.extend(f" • {c}" for c in result.conflict_report.major_conflicts)
if result.conflict_report.minor_conflicts:
lines.append("MINOR CONFLICTS / BOUNDARY FLAGS:")
lines.extend(f" • {c}" for c in result.conflict_report.minor_conflicts)
if result.transcript_ambiguity:
lines.append("⚠ Transcript conflict: consequence differs across transcripts.")
if variant_class.flags:
lines.append("\nPre-filter flags:")
lines.extend(f" • {f}" for f in variant_class.flags)
if result.conflict_report.requires_manual_review:
lines.append("\n⛔ MANUAL REVIEW REQUIRED.")
lines.append("="*60)
lines.append(f"PeVe v{PEVE_VERSION} | Thresholds {THRESHOLD_VERSION} | No probability averaging.")
return "\n".join(lines)