File size: 7,914 Bytes
56f192b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
Feature engineering v3.0 — Leak-free extraction from DailySnapshotDTO.

Predicts day T headache using:
  - Day T weather forecast (WeatherKit)
  - Day T-1 HealthKit + diary (lag)
  - Day T-2 headache history
  - Temporal + user context + interactions

Total: 38 features.
"""

from __future__ import annotations
import math
import numpy as np
from datetime import datetime
from typing import List, Optional

from models import (
    DailySnapshotDTO, UserContextDTO,
    HeadacheLogSnapshotDTO, HealthKitMetricsDTO, WeatherDataDTO,
    SleepAnalysisDTO, HRVSummaryDTO,
)

MOOD_MAP = {"great": 5, "good": 4, "okay": 3, "bad": 2, "terrible": 1}

FEATURE_NAMES = [
    "pressure_mb", "pressure_change_24h", "pressure_volatility",
    "humidity_pct", "temperature_c", "is_pressure_drop",
    "sleep_total_hours", "deep_sleep_min", "rem_sleep_min",
    "resting_hr", "hrv_avg_ms", "workout_min", "menstrual_flow_flag",
    "had_headache_1d", "severity_1d", "duration_1d",
    "mood_1d", "symptom_count_1d", "trigger_count_1d",
    "had_headache_2d", "severity_2d", "duration_2d",
    "dow_sin", "dow_cos", "month_sin", "month_cos",
    "doy_sin", "doy_cos", "is_weekend",
    "age_midpoint", "is_europe", "is_tropical",
    "sleep_x_pressure", "low_hrv_flag", "sleep_deficit",
    "high_humidity_flag", "headache_streak_2d", "consecutive_headache_days",
]
NUM_FEATURES = len(FEATURE_NAMES)  # 38

# Human-readable risk factor labels for the API response
RISK_LABELS = {
    "had_headache_1d": "recent_headache",
    "pressure_change_24h": "barometric_pressure_drop",
    "consecutive_headache_days": "headache_streak",
    "hrv_avg_ms": "low_hrv_stress",
    "headache_streak_2d": "multi_day_pattern",
    "humidity_pct": "high_humidity",
    "menstrual_flow_flag": "menstrual_phase",
    "temperature_c": "temperature_extreme",
    "sleep_total_hours": "poor_sleep",
    "is_weekend": "weekend_pattern",
    "sleep_deficit": "sleep_deficit",
    "low_hrv_flag": "stress_indicator",
    "is_pressure_drop": "pressure_front",
}


def _safe(val, default=0.0) -> float:
    return float(val) if val is not None else default

def _cyclic(value: float, period: float):
    a = 2 * math.pi * value / period
    return math.sin(a), math.cos(a)

def _parse_age_range(age_range: Optional[str]) -> float:
    if not age_range:
        return 35.0
    try:
        parts = age_range.replace(" ", "").split("-")
        return (float(parts[0]) + float(parts[1])) / 2.0
    except Exception:
        return 35.0


def extract_features_for_day(
    target_weather: WeatherDataDTO,
    target_date: str,
    yesterday_snapshot: Optional[DailySnapshotDTO],
    two_days_ago_snapshot: Optional[DailySnapshotDTO],
    user_ctx: Optional[UserContextDTO] = None,
    consecutive_headache_days: int = 0,
) -> np.ndarray:
    """Build 38-feature vector for predicting headache on target_date."""
    f: List[float] = []

    w = target_weather or WeatherDataDTO()
    yest = yesterday_snapshot or DailySnapshotDTO()
    twod = two_days_ago_snapshot or DailySnapshotDTO()
    ctx = user_ctx or UserContextDTO()

    yest_hk = yest.health_kit_metrics or HealthKitMetricsDTO()
    yest_sl = yest_hk.sleep_analysis or SleepAnalysisDTO()
    yest_hrv = yest_hk.hrv_summary or HRVSummaryDTO()
    yest_log = yest.headache_log or HeadacheLogSnapshotDTO()
    twod_log = twod.headache_log or HeadacheLogSnapshotDTO()

    # Weather target (6)
    pc = _safe(w.pressure_change_24h_mb, 0.0)
    hum = _safe(w.humidity_percent, 50.0)
    f.append(_safe(w.barometric_pressure_mb, 1013.25))
    f.append(pc)
    f.append(abs(pc))
    f.append(hum)
    f.append(_safe(w.temperature_celsius, 15.0))
    f.append(1.0 if pc < -5 else 0.0)

    # HealthKit yesterday (7)
    slp = _safe(yest_sl.total_duration_hours, 7.0)
    hrv = _safe(yest_hrv.average_ms, 40.0)
    f.append(slp)
    f.append(_safe(yest_sl.deep_sleep_minutes, 80.0))
    f.append(_safe(yest_sl.rem_sleep_minutes, 90.0))
    f.append(_safe(yest_hk.resting_heart_rate, 65.0))
    f.append(hrv)
    f.append(_safe(yest_hk.workout_minutes, 0))
    f.append(1.0 if yest_hk.had_menstrual_flow else 0.0)

    # Headache yesterday (6)
    yh = 1.0 if yest_log.severity > 0 else 0.0
    f.append(yh)
    f.append(float(yest_log.severity))
    f.append(float(yest_log.duration_hours))
    f.append(float(MOOD_MAP.get(str(yest_log.mood).lower(), 3)))
    f.append(float(len(yest_log.symptoms.symptoms)))
    f.append(float(len(yest_log.triggers.triggers)))

    # Headache 2d ago (3)
    th = 1.0 if twod_log.severity > 0 else 0.0
    f.append(th)
    f.append(float(twod_log.severity))
    f.append(float(twod_log.duration_hours))

    # Temporal (7)
    try:
        dt = datetime.strptime(target_date, "%Y-%m-%d")
    except (ValueError, TypeError):
        dt = datetime.now()
    dw_s, dw_c = _cyclic(dt.weekday(), 7)
    mn_s, mn_c = _cyclic(dt.month - 1, 12)
    dy_s, dy_c = _cyclic(dt.timetuple().tm_yday, 365)
    f.extend([dw_s, dw_c, mn_s, mn_c, dy_s, dy_c])
    f.append(1.0 if dt.weekday() >= 5 else 0.0)

    # User context (3)
    f.append(_parse_age_range(ctx.age_range))
    reg = str(ctx.location_region or "").lower()
    f.append(1.0 if "europe" in reg else 0.0)
    f.append(1.0 if "tropic" in reg else 0.0)

    # Interactions (6)
    f.append(slp * abs(pc))
    f.append(1.0 if hrv < 25 else 0.0)
    f.append(max(0.0, 6.0 - slp))
    f.append(1.0 if hum > 80 else 0.0)
    f.append(yh + th)
    f.append(float(min(consecutive_headache_days, 7)))

    return np.array(f, dtype=np.float32)


def extract_forecast_features(
    snapshots: List[DailySnapshotDTO],
    user_ctx: Optional[UserContextDTO] = None,
) -> np.ndarray:
    """
    Build feature matrix for 7-day forecast.
    snapshots[0] = today (full data), [1..6] = future (weather only).
    """
    rows = []
    for i in range(len(snapshots)):
        snap = snapshots[i]
        tw = snap.weather_data or WeatherDataDTO()
        td = ""
        if snap.headache_log and snap.headache_log.input_date:
            td = snap.headache_log.input_date

        yest = snapshots[i - 1] if i > 0 else None
        twod = snapshots[i - 2] if i > 1 else None

        consec = 0
        for j in range(i - 1, -1, -1):
            lj = snapshots[j].headache_log
            if lj and lj.severity > 0:
                consec += 1
            else:
                break

        rows.append(extract_features_for_day(tw, td, yest, twod, user_ctx, consec))
    return np.vstack(rows)


def get_risk_factors(
    features: np.ndarray,
    feature_importances: dict,
    top_k: int = 3,
) -> List[str]:
    """Identify top risk factors from feature values + learned importances."""
    risks = []

    # Check each important feature for concerning values
    checks = [
        ("had_headache_1d", lambda v: v > 0),
        ("pressure_change_24h", lambda v: v < -3),
        ("consecutive_headache_days", lambda v: v >= 2),
        ("hrv_avg_ms", lambda v: v < 30),
        ("headache_streak_2d", lambda v: v >= 1),
        ("humidity_pct", lambda v: v > 75),
        ("menstrual_flow_flag", lambda v: v > 0),
        ("temperature_c", lambda v: v > 30 or v < -5),
        ("sleep_total_hours", lambda v: v < 6),
        ("sleep_deficit", lambda v: v > 0),
        ("low_hrv_flag", lambda v: v > 0),
        ("is_pressure_drop", lambda v: v > 0),
        ("is_weekend", lambda v: v > 0),
    ]

    # Sort by feature importance
    sorted_checks = sorted(
        checks,
        key=lambda x: feature_importances.get(x[0], 0),
        reverse=True,
    )

    for fname, condition in sorted_checks:
        if fname in FEATURE_NAMES:
            idx = FEATURE_NAMES.index(fname)
            if condition(features[idx]):
                label = RISK_LABELS.get(fname, fname)
                if label not in risks:
                    risks.append(label)
        if len(risks) >= top_k:
            break

    return risks