Spaces:
Running
Running
File size: 9,002 Bytes
14a5ab4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """
intervention_engine.py — CognitivePulse
Given a patient's SHAP-based risk contribution profile, ranks their modifiable
risk factors by combined impact and practical actionability, and maps each to
the relevant literature domain for downstream RAG retrieval.
The core logic:
priority_score = |SHAP contribution| × actionability_weight
where actionability_weight reflects both medical tractability (e.g. hypertension
is very treatable) and evidence quality for brain-health outcomes.
"""
from __future__ import annotations
from data_loader import FEATURE_META
# Maps each modifiable feature to: (literature_domain, actionability_weight, norm direction)
# norm_direction: "lower_better" or "higher_better" — used to determine if a value
# is adverse vs protective relative to population norms.
MODIFIABLE_FEATURE_MAP = {
"BMI": ("diet_exercise", 0.8, "lower_better"),
"Smoking": ("smoking_cessation", 1.0, "lower_better"),
"AlcoholConsumption": ("alcohol_moderation", 0.7, "lower_better"),
"PhysicalActivity": ("exercise", 1.0, "higher_better"),
"DietQuality": ("nutrition", 0.9, "higher_better"),
"SleepQuality": ("sleep", 0.9, "higher_better"),
"CardiovascularDisease": ("cardiovascular", 0.8, "lower_better"),
"Diabetes": ("metabolic_health", 0.8, "lower_better"),
"Depression": ("mental_health", 0.9, "lower_better"),
"Hypertension": ("cardiovascular", 1.0, "lower_better"),
"SystolicBP": ("cardiovascular", 1.0, "lower_better"),
"DiastolicBP": ("cardiovascular", 0.9, "lower_better"),
"CholesterolTotal": ("cardiovascular", 0.9, "lower_better"),
"CholesterolLDL": ("cardiovascular", 1.0, "lower_better"),
"CholesterolHDL": ("cardiovascular", 0.8, "higher_better"),
"CholesterolTriglycerides": ("cardiovascular", 0.8, "lower_better"),
}
# Domain → literature tags (must match domains used in rag_engine.py corpus)
DOMAIN_TO_LITERATURE = {
"exercise": ["exercise_cognitive_reserve"],
"nutrition": ["diet_nutrition"],
"sleep": ["sleep_glymphatic"],
"cardiovascular": ["cardiovascular_risk"],
"metabolic_health": ["metabolic_health"],
"mental_health": ["mental_health_social"],
"diet_exercise": ["diet_nutrition", "exercise_cognitive_reserve"],
"smoking_cessation": ["cardiovascular_risk"],
"alcohol_moderation": ["lifestyle_factors"],
}
# Human-readable intervention summaries (shown before RAG coaching text)
INTERVENTION_SUMMARY = {
"exercise": "Increasing structured physical activity",
"nutrition": "Improving diet quality (Mediterranean / MIND dietary patterns)",
"sleep": "Improving sleep quality and duration",
"cardiovascular": "Managing cardiovascular risk factors (BP / cholesterol)",
"metabolic_health": "Managing metabolic health (blood glucose / insulin resistance)",
"mental_health": "Addressing depression and social engagement",
"diet_exercise": "Combined diet and exercise program",
"smoking_cessation": "Smoking cessation",
"alcohol_moderation": "Moderating alcohol consumption",
}
def _is_adverse(feature: str, value, norm_direction: str) -> bool:
"""
Returns True if the feature value represents an adverse (risk-elevating) level
relative to the norm direction. Used to filter out features that are already
at protective levels.
"""
from data_loader import REFERENCE_RANGES
if feature not in REFERENCE_RANGES:
# Binary features: adverse if positive and lower_better, or zero and higher_better
if norm_direction == "lower_better":
return float(value) > 0.5
else:
return float(value) < 0.5
ranges = REFERENCE_RANGES[feature]
v = float(value)
if norm_direction == "lower_better":
return v > ranges["optimal"][1]
else:
return v < ranges["optimal"][0]
def rank_interventions(shap_contributions: dict, patient: dict, n: int = 4) -> list:
"""
Returns the top n prioritized, modifiable interventions for a patient.
Each entry contains:
- feature: raw feature name
- label: human-readable label
- domain: literature domain for RAG retrieval
- literature_tags: list of corpus tags
- intervention_summary: one-line description
- priority_score: combined impact × actionability
- shap_value: raw SHAP contribution
- patient_value: the patient's actual value for context
"""
candidates = []
seen_domains = set()
for feature, (domain, actionability, norm_dir) in MODIFIABLE_FEATURE_MAP.items():
if feature not in shap_contributions:
continue
shap_val = shap_contributions[feature]
patient_val = patient.get(feature, None)
# Only flag features that are both risk-elevating (positive SHAP) AND
# at an adverse level — no point flagging e.g. "eat better" when diet is
# already excellent.
if shap_val <= 0:
continue
if patient_val is not None and not _is_adverse(feature, patient_val, norm_dir):
continue
priority = abs(shap_val) * actionability
# De-duplicate domains (no benefit listing SystolicBP + DiastolicBP separately)
domain_key = domain
if domain_key in seen_domains:
# Keep whichever has higher priority score
existing = next((c for c in candidates if c["domain"] == domain_key), None)
if existing and priority > existing["priority_score"]:
candidates.remove(existing)
seen_domains.discard(domain_key)
else:
continue
candidates.append({
"feature": feature,
"label": FEATURE_META.get(feature, {}).get("label", feature),
"domain": domain_key,
"literature_tags": DOMAIN_TO_LITERATURE.get(domain_key, [domain_key]),
"intervention_summary": INTERVENTION_SUMMARY.get(domain_key, domain_key),
"priority_score": round(priority, 4),
"shap_value": round(shap_val, 4),
"patient_value": patient_val,
})
seen_domains.add(domain_key)
# Sort by priority descending
candidates.sort(key=lambda x: x["priority_score"], reverse=True)
return candidates[:n]
def build_coach_brief(patient: dict, risk_result: dict, interventions: list) -> str:
"""
Builds a structured pre-session brief for a BetterBrain-style health coach,
summarising the patient's risk profile and the top intervention priorities.
This is passed as context to the RAG coaching generation step.
"""
lines = [
f"PATIENT RISK SCORE: {risk_result['risk_score']}/100 ({risk_result['risk_band'].upper()} risk band)",
f"Risk probability: {risk_result['risk_probability']:.1%}",
"",
"TOP RISK DRIVERS (SHAP-identified):",
]
for d in risk_result.get("top_drivers", [])[:5]:
mod = "modifiable" if d["modifiable"] else "non-modifiable"
lines.append(f" • {d['label']}: SHAP={d['shap_value']:+.3f} — {d['direction']} ({mod})")
lines += ["", "PRIORITIZED INTERVENTION AREAS:"]
for i, iv in enumerate(interventions, 1):
lines.append(f" {i}. {iv['intervention_summary']} (priority score: {iv['priority_score']:.3f})")
if iv["patient_value"] is not None:
lines.append(f" Patient value: {iv['patient_value']} | Feature: {iv['label']}")
lines += [
"",
"COACHING SESSION FOCUS: Ground recommendations in the intervention areas above.",
"All claims must cite retrieved research evidence. Do not make unsupported assertions.",
]
return "\n".join(lines)
if __name__ == "__main__":
# Smoke test
sample_shap = {
"SystolicBP": 0.845, "DietQuality": 0.626, "SleepQuality": 0.446,
"CholesterolLDL": 0.460, "PhysicalActivity": -0.279, "MMSE": -0.940,
"FamilyHistoryAlzheimers": 0.313, "Forgetfulness": 0.555,
"Depression": 0.0, "Smoking": -0.025,
}
sample_patient = {
"SystolicBP": 148, "DietQuality": 5.0, "SleepQuality": 6.0,
"CholesterolLDL": 158, "PhysicalActivity": 2.5, "Depression": 0,
"Smoking": 0,
}
sample_risk = {"risk_score": 85.1, "risk_band": "high", "risk_probability": 0.851,
"top_drivers": [{"label": "MMSE Score", "shap_value": -0.94,
"direction": "decreases risk", "modifiable": False}]}
ivs = rank_interventions(sample_shap, sample_patient)
import json
print(json.dumps(ivs, indent=2))
print("\n--- COACH BRIEF ---")
print(build_coach_brief(sample_patient, sample_risk, ivs))
|