CognitivePulse / intervention_engine.py
Kshamaa S
Initial deployment: CognitivePulse biomarker intelligence and coaching assistant
14a5ab4
Raw
History Blame Contribute Delete
9 kB
"""
intervention_engine.py — CognitivePulse
Given a patient's SHAP-based risk contribution profile, ranks their modifiable
risk factors by combined impact and practical actionability, and maps each to
the relevant literature domain for downstream RAG retrieval.
The core logic:
priority_score = |SHAP contribution| × actionability_weight
where actionability_weight reflects both medical tractability (e.g. hypertension
is very treatable) and evidence quality for brain-health outcomes.
"""
from __future__ import annotations
from data_loader import FEATURE_META
# Maps each modifiable feature to: (literature_domain, actionability_weight, norm direction)
# norm_direction: "lower_better" or "higher_better" — used to determine if a value
# is adverse vs protective relative to population norms.
MODIFIABLE_FEATURE_MAP = {
"BMI": ("diet_exercise", 0.8, "lower_better"),
"Smoking": ("smoking_cessation", 1.0, "lower_better"),
"AlcoholConsumption": ("alcohol_moderation", 0.7, "lower_better"),
"PhysicalActivity": ("exercise", 1.0, "higher_better"),
"DietQuality": ("nutrition", 0.9, "higher_better"),
"SleepQuality": ("sleep", 0.9, "higher_better"),
"CardiovascularDisease": ("cardiovascular", 0.8, "lower_better"),
"Diabetes": ("metabolic_health", 0.8, "lower_better"),
"Depression": ("mental_health", 0.9, "lower_better"),
"Hypertension": ("cardiovascular", 1.0, "lower_better"),
"SystolicBP": ("cardiovascular", 1.0, "lower_better"),
"DiastolicBP": ("cardiovascular", 0.9, "lower_better"),
"CholesterolTotal": ("cardiovascular", 0.9, "lower_better"),
"CholesterolLDL": ("cardiovascular", 1.0, "lower_better"),
"CholesterolHDL": ("cardiovascular", 0.8, "higher_better"),
"CholesterolTriglycerides": ("cardiovascular", 0.8, "lower_better"),
}
# Domain → literature tags (must match domains used in rag_engine.py corpus)
DOMAIN_TO_LITERATURE = {
"exercise": ["exercise_cognitive_reserve"],
"nutrition": ["diet_nutrition"],
"sleep": ["sleep_glymphatic"],
"cardiovascular": ["cardiovascular_risk"],
"metabolic_health": ["metabolic_health"],
"mental_health": ["mental_health_social"],
"diet_exercise": ["diet_nutrition", "exercise_cognitive_reserve"],
"smoking_cessation": ["cardiovascular_risk"],
"alcohol_moderation": ["lifestyle_factors"],
}
# Human-readable intervention summaries (shown before RAG coaching text)
INTERVENTION_SUMMARY = {
"exercise": "Increasing structured physical activity",
"nutrition": "Improving diet quality (Mediterranean / MIND dietary patterns)",
"sleep": "Improving sleep quality and duration",
"cardiovascular": "Managing cardiovascular risk factors (BP / cholesterol)",
"metabolic_health": "Managing metabolic health (blood glucose / insulin resistance)",
"mental_health": "Addressing depression and social engagement",
"diet_exercise": "Combined diet and exercise program",
"smoking_cessation": "Smoking cessation",
"alcohol_moderation": "Moderating alcohol consumption",
}
def _is_adverse(feature: str, value, norm_direction: str) -> bool:
"""
Returns True if the feature value represents an adverse (risk-elevating) level
relative to the norm direction. Used to filter out features that are already
at protective levels.
"""
from data_loader import REFERENCE_RANGES
if feature not in REFERENCE_RANGES:
# Binary features: adverse if positive and lower_better, or zero and higher_better
if norm_direction == "lower_better":
return float(value) > 0.5
else:
return float(value) < 0.5
ranges = REFERENCE_RANGES[feature]
v = float(value)
if norm_direction == "lower_better":
return v > ranges["optimal"][1]
else:
return v < ranges["optimal"][0]
def rank_interventions(shap_contributions: dict, patient: dict, n: int = 4) -> list:
"""
Returns the top n prioritized, modifiable interventions for a patient.
Each entry contains:
- feature: raw feature name
- label: human-readable label
- domain: literature domain for RAG retrieval
- literature_tags: list of corpus tags
- intervention_summary: one-line description
- priority_score: combined impact × actionability
- shap_value: raw SHAP contribution
- patient_value: the patient's actual value for context
"""
candidates = []
seen_domains = set()
for feature, (domain, actionability, norm_dir) in MODIFIABLE_FEATURE_MAP.items():
if feature not in shap_contributions:
continue
shap_val = shap_contributions[feature]
patient_val = patient.get(feature, None)
# Only flag features that are both risk-elevating (positive SHAP) AND
# at an adverse level — no point flagging e.g. "eat better" when diet is
# already excellent.
if shap_val <= 0:
continue
if patient_val is not None and not _is_adverse(feature, patient_val, norm_dir):
continue
priority = abs(shap_val) * actionability
# De-duplicate domains (no benefit listing SystolicBP + DiastolicBP separately)
domain_key = domain
if domain_key in seen_domains:
# Keep whichever has higher priority score
existing = next((c for c in candidates if c["domain"] == domain_key), None)
if existing and priority > existing["priority_score"]:
candidates.remove(existing)
seen_domains.discard(domain_key)
else:
continue
candidates.append({
"feature": feature,
"label": FEATURE_META.get(feature, {}).get("label", feature),
"domain": domain_key,
"literature_tags": DOMAIN_TO_LITERATURE.get(domain_key, [domain_key]),
"intervention_summary": INTERVENTION_SUMMARY.get(domain_key, domain_key),
"priority_score": round(priority, 4),
"shap_value": round(shap_val, 4),
"patient_value": patient_val,
})
seen_domains.add(domain_key)
# Sort by priority descending
candidates.sort(key=lambda x: x["priority_score"], reverse=True)
return candidates[:n]
def build_coach_brief(patient: dict, risk_result: dict, interventions: list) -> str:
"""
Builds a structured pre-session brief for a BetterBrain-style health coach,
summarising the patient's risk profile and the top intervention priorities.
This is passed as context to the RAG coaching generation step.
"""
lines = [
f"PATIENT RISK SCORE: {risk_result['risk_score']}/100 ({risk_result['risk_band'].upper()} risk band)",
f"Risk probability: {risk_result['risk_probability']:.1%}",
"",
"TOP RISK DRIVERS (SHAP-identified):",
]
for d in risk_result.get("top_drivers", [])[:5]:
mod = "modifiable" if d["modifiable"] else "non-modifiable"
lines.append(f" • {d['label']}: SHAP={d['shap_value']:+.3f}{d['direction']} ({mod})")
lines += ["", "PRIORITIZED INTERVENTION AREAS:"]
for i, iv in enumerate(interventions, 1):
lines.append(f" {i}. {iv['intervention_summary']} (priority score: {iv['priority_score']:.3f})")
if iv["patient_value"] is not None:
lines.append(f" Patient value: {iv['patient_value']} | Feature: {iv['label']}")
lines += [
"",
"COACHING SESSION FOCUS: Ground recommendations in the intervention areas above.",
"All claims must cite retrieved research evidence. Do not make unsupported assertions.",
]
return "\n".join(lines)
if __name__ == "__main__":
# Smoke test
sample_shap = {
"SystolicBP": 0.845, "DietQuality": 0.626, "SleepQuality": 0.446,
"CholesterolLDL": 0.460, "PhysicalActivity": -0.279, "MMSE": -0.940,
"FamilyHistoryAlzheimers": 0.313, "Forgetfulness": 0.555,
"Depression": 0.0, "Smoking": -0.025,
}
sample_patient = {
"SystolicBP": 148, "DietQuality": 5.0, "SleepQuality": 6.0,
"CholesterolLDL": 158, "PhysicalActivity": 2.5, "Depression": 0,
"Smoking": 0,
}
sample_risk = {"risk_score": 85.1, "risk_band": "high", "risk_probability": 0.851,
"top_drivers": [{"label": "MMSE Score", "shap_value": -0.94,
"direction": "decreases risk", "modifiable": False}]}
ivs = rank_interventions(sample_shap, sample_patient)
import json
print(json.dumps(ivs, indent=2))
print("\n--- COACH BRIEF ---")
print(build_coach_brief(sample_patient, sample_risk, ivs))