File size: 9,002 Bytes
14a5ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
intervention_engine.py — CognitivePulse

Given a patient's SHAP-based risk contribution profile, ranks their modifiable
risk factors by combined impact and practical actionability, and maps each to
the relevant literature domain for downstream RAG retrieval.

The core logic:
  priority_score = |SHAP contribution| × actionability_weight

where actionability_weight reflects both medical tractability (e.g. hypertension
is very treatable) and evidence quality for brain-health outcomes.
"""

from __future__ import annotations

from data_loader import FEATURE_META

# Maps each modifiable feature to: (literature_domain, actionability_weight, norm direction)
# norm_direction: "lower_better" or "higher_better" — used to determine if a value
# is adverse vs protective relative to population norms.
MODIFIABLE_FEATURE_MAP = {
    "BMI":                   ("diet_exercise",        0.8, "lower_better"),
    "Smoking":               ("smoking_cessation",    1.0, "lower_better"),
    "AlcoholConsumption":    ("alcohol_moderation",   0.7, "lower_better"),
    "PhysicalActivity":      ("exercise",             1.0, "higher_better"),
    "DietQuality":           ("nutrition",            0.9, "higher_better"),
    "SleepQuality":          ("sleep",                0.9, "higher_better"),
    "CardiovascularDisease": ("cardiovascular",       0.8, "lower_better"),
    "Diabetes":              ("metabolic_health",     0.8, "lower_better"),
    "Depression":            ("mental_health",        0.9, "lower_better"),
    "Hypertension":          ("cardiovascular",       1.0, "lower_better"),
    "SystolicBP":            ("cardiovascular",       1.0, "lower_better"),
    "DiastolicBP":           ("cardiovascular",       0.9, "lower_better"),
    "CholesterolTotal":      ("cardiovascular",       0.9, "lower_better"),
    "CholesterolLDL":        ("cardiovascular",       1.0, "lower_better"),
    "CholesterolHDL":        ("cardiovascular",       0.8, "higher_better"),
    "CholesterolTriglycerides": ("cardiovascular",    0.8, "lower_better"),
}

# Domain → literature tags (must match domains used in rag_engine.py corpus)
DOMAIN_TO_LITERATURE = {
    "exercise":           ["exercise_cognitive_reserve"],
    "nutrition":          ["diet_nutrition"],
    "sleep":              ["sleep_glymphatic"],
    "cardiovascular":     ["cardiovascular_risk"],
    "metabolic_health":   ["metabolic_health"],
    "mental_health":      ["mental_health_social"],
    "diet_exercise":      ["diet_nutrition", "exercise_cognitive_reserve"],
    "smoking_cessation":  ["cardiovascular_risk"],
    "alcohol_moderation": ["lifestyle_factors"],
}

# Human-readable intervention summaries (shown before RAG coaching text)
INTERVENTION_SUMMARY = {
    "exercise":           "Increasing structured physical activity",
    "nutrition":          "Improving diet quality (Mediterranean / MIND dietary patterns)",
    "sleep":              "Improving sleep quality and duration",
    "cardiovascular":     "Managing cardiovascular risk factors (BP / cholesterol)",
    "metabolic_health":   "Managing metabolic health (blood glucose / insulin resistance)",
    "mental_health":      "Addressing depression and social engagement",
    "diet_exercise":      "Combined diet and exercise program",
    "smoking_cessation":  "Smoking cessation",
    "alcohol_moderation": "Moderating alcohol consumption",
}


def _is_adverse(feature: str, value, norm_direction: str) -> bool:
    """
    Returns True if the feature value represents an adverse (risk-elevating) level
    relative to the norm direction. Used to filter out features that are already
    at protective levels.
    """
    from data_loader import REFERENCE_RANGES
    if feature not in REFERENCE_RANGES:
        # Binary features: adverse if positive and lower_better, or zero and higher_better
        if norm_direction == "lower_better":
            return float(value) > 0.5
        else:
            return float(value) < 0.5
    ranges = REFERENCE_RANGES[feature]
    v = float(value)
    if norm_direction == "lower_better":
        return v > ranges["optimal"][1]
    else:
        return v < ranges["optimal"][0]


def rank_interventions(shap_contributions: dict, patient: dict, n: int = 4) -> list:
    """
    Returns the top n prioritized, modifiable interventions for a patient.

    Each entry contains:
      - feature: raw feature name
      - label: human-readable label
      - domain: literature domain for RAG retrieval
      - literature_tags: list of corpus tags
      - intervention_summary: one-line description
      - priority_score: combined impact × actionability
      - shap_value: raw SHAP contribution
      - patient_value: the patient's actual value for context
    """
    candidates = []
    seen_domains = set()

    for feature, (domain, actionability, norm_dir) in MODIFIABLE_FEATURE_MAP.items():
        if feature not in shap_contributions:
            continue

        shap_val = shap_contributions[feature]
        patient_val = patient.get(feature, None)

        # Only flag features that are both risk-elevating (positive SHAP) AND
        # at an adverse level — no point flagging e.g. "eat better" when diet is
        # already excellent.
        if shap_val <= 0:
            continue
        if patient_val is not None and not _is_adverse(feature, patient_val, norm_dir):
            continue

        priority = abs(shap_val) * actionability

        # De-duplicate domains (no benefit listing SystolicBP + DiastolicBP separately)
        domain_key = domain
        if domain_key in seen_domains:
            # Keep whichever has higher priority score
            existing = next((c for c in candidates if c["domain"] == domain_key), None)
            if existing and priority > existing["priority_score"]:
                candidates.remove(existing)
                seen_domains.discard(domain_key)
            else:
                continue

        candidates.append({
            "feature": feature,
            "label": FEATURE_META.get(feature, {}).get("label", feature),
            "domain": domain_key,
            "literature_tags": DOMAIN_TO_LITERATURE.get(domain_key, [domain_key]),
            "intervention_summary": INTERVENTION_SUMMARY.get(domain_key, domain_key),
            "priority_score": round(priority, 4),
            "shap_value": round(shap_val, 4),
            "patient_value": patient_val,
        })
        seen_domains.add(domain_key)

    # Sort by priority descending
    candidates.sort(key=lambda x: x["priority_score"], reverse=True)
    return candidates[:n]


def build_coach_brief(patient: dict, risk_result: dict, interventions: list) -> str:
    """
    Builds a structured pre-session brief for a BetterBrain-style health coach,
    summarising the patient's risk profile and the top intervention priorities.
    This is passed as context to the RAG coaching generation step.
    """
    lines = [
        f"PATIENT RISK SCORE: {risk_result['risk_score']}/100 ({risk_result['risk_band'].upper()} risk band)",
        f"Risk probability: {risk_result['risk_probability']:.1%}",
        "",
        "TOP RISK DRIVERS (SHAP-identified):",
    ]
    for d in risk_result.get("top_drivers", [])[:5]:
        mod = "modifiable" if d["modifiable"] else "non-modifiable"
        lines.append(f"  • {d['label']}: SHAP={d['shap_value']:+.3f}{d['direction']} ({mod})")

    lines += ["", "PRIORITIZED INTERVENTION AREAS:"]
    for i, iv in enumerate(interventions, 1):
        lines.append(f"  {i}. {iv['intervention_summary']} (priority score: {iv['priority_score']:.3f})")
        if iv["patient_value"] is not None:
            lines.append(f"     Patient value: {iv['patient_value']} | Feature: {iv['label']}")

    lines += [
        "",
        "COACHING SESSION FOCUS: Ground recommendations in the intervention areas above.",
        "All claims must cite retrieved research evidence. Do not make unsupported assertions.",
    ]
    return "\n".join(lines)


if __name__ == "__main__":
    # Smoke test
    sample_shap = {
        "SystolicBP": 0.845, "DietQuality": 0.626, "SleepQuality": 0.446,
        "CholesterolLDL": 0.460, "PhysicalActivity": -0.279, "MMSE": -0.940,
        "FamilyHistoryAlzheimers": 0.313, "Forgetfulness": 0.555,
        "Depression": 0.0, "Smoking": -0.025,
    }
    sample_patient = {
        "SystolicBP": 148, "DietQuality": 5.0, "SleepQuality": 6.0,
        "CholesterolLDL": 158, "PhysicalActivity": 2.5, "Depression": 0,
        "Smoking": 0,
    }
    sample_risk = {"risk_score": 85.1, "risk_band": "high", "risk_probability": 0.851,
                   "top_drivers": [{"label": "MMSE Score", "shap_value": -0.94,
                                     "direction": "decreases risk", "modifiable": False}]}
    ivs = rank_interventions(sample_shap, sample_patient)
    import json
    print(json.dumps(ivs, indent=2))
    print("\n--- COACH BRIEF ---")
    print(build_coach_brief(sample_patient, sample_risk, ivs))