headache-predictor-xgboost / feature_engineering.py
emp-admin's picture
Upload 9 files
56f192b verified
"""
Feature engineering v3.0 — Leak-free extraction from DailySnapshotDTO.
Predicts day T headache using:
- Day T weather forecast (WeatherKit)
- Day T-1 HealthKit + diary (lag)
- Day T-2 headache history
- Temporal + user context + interactions
Total: 38 features.
"""
from __future__ import annotations
import math
import numpy as np
from datetime import datetime
from typing import List, Optional
from models import (
DailySnapshotDTO, UserContextDTO,
HeadacheLogSnapshotDTO, HealthKitMetricsDTO, WeatherDataDTO,
SleepAnalysisDTO, HRVSummaryDTO,
)
MOOD_MAP = {"great": 5, "good": 4, "okay": 3, "bad": 2, "terrible": 1}
FEATURE_NAMES = [
"pressure_mb", "pressure_change_24h", "pressure_volatility",
"humidity_pct", "temperature_c", "is_pressure_drop",
"sleep_total_hours", "deep_sleep_min", "rem_sleep_min",
"resting_hr", "hrv_avg_ms", "workout_min", "menstrual_flow_flag",
"had_headache_1d", "severity_1d", "duration_1d",
"mood_1d", "symptom_count_1d", "trigger_count_1d",
"had_headache_2d", "severity_2d", "duration_2d",
"dow_sin", "dow_cos", "month_sin", "month_cos",
"doy_sin", "doy_cos", "is_weekend",
"age_midpoint", "is_europe", "is_tropical",
"sleep_x_pressure", "low_hrv_flag", "sleep_deficit",
"high_humidity_flag", "headache_streak_2d", "consecutive_headache_days",
]
NUM_FEATURES = len(FEATURE_NAMES) # 38
# Human-readable risk factor labels for the API response
RISK_LABELS = {
"had_headache_1d": "recent_headache",
"pressure_change_24h": "barometric_pressure_drop",
"consecutive_headache_days": "headache_streak",
"hrv_avg_ms": "low_hrv_stress",
"headache_streak_2d": "multi_day_pattern",
"humidity_pct": "high_humidity",
"menstrual_flow_flag": "menstrual_phase",
"temperature_c": "temperature_extreme",
"sleep_total_hours": "poor_sleep",
"is_weekend": "weekend_pattern",
"sleep_deficit": "sleep_deficit",
"low_hrv_flag": "stress_indicator",
"is_pressure_drop": "pressure_front",
}
def _safe(val, default=0.0) -> float:
return float(val) if val is not None else default
def _cyclic(value: float, period: float):
a = 2 * math.pi * value / period
return math.sin(a), math.cos(a)
def _parse_age_range(age_range: Optional[str]) -> float:
if not age_range:
return 35.0
try:
parts = age_range.replace(" ", "").split("-")
return (float(parts[0]) + float(parts[1])) / 2.0
except Exception:
return 35.0
def extract_features_for_day(
target_weather: WeatherDataDTO,
target_date: str,
yesterday_snapshot: Optional[DailySnapshotDTO],
two_days_ago_snapshot: Optional[DailySnapshotDTO],
user_ctx: Optional[UserContextDTO] = None,
consecutive_headache_days: int = 0,
) -> np.ndarray:
"""Build 38-feature vector for predicting headache on target_date."""
f: List[float] = []
w = target_weather or WeatherDataDTO()
yest = yesterday_snapshot or DailySnapshotDTO()
twod = two_days_ago_snapshot or DailySnapshotDTO()
ctx = user_ctx or UserContextDTO()
yest_hk = yest.health_kit_metrics or HealthKitMetricsDTO()
yest_sl = yest_hk.sleep_analysis or SleepAnalysisDTO()
yest_hrv = yest_hk.hrv_summary or HRVSummaryDTO()
yest_log = yest.headache_log or HeadacheLogSnapshotDTO()
twod_log = twod.headache_log or HeadacheLogSnapshotDTO()
# Weather target (6)
pc = _safe(w.pressure_change_24h_mb, 0.0)
hum = _safe(w.humidity_percent, 50.0)
f.append(_safe(w.barometric_pressure_mb, 1013.25))
f.append(pc)
f.append(abs(pc))
f.append(hum)
f.append(_safe(w.temperature_celsius, 15.0))
f.append(1.0 if pc < -5 else 0.0)
# HealthKit yesterday (7)
slp = _safe(yest_sl.total_duration_hours, 7.0)
hrv = _safe(yest_hrv.average_ms, 40.0)
f.append(slp)
f.append(_safe(yest_sl.deep_sleep_minutes, 80.0))
f.append(_safe(yest_sl.rem_sleep_minutes, 90.0))
f.append(_safe(yest_hk.resting_heart_rate, 65.0))
f.append(hrv)
f.append(_safe(yest_hk.workout_minutes, 0))
f.append(1.0 if yest_hk.had_menstrual_flow else 0.0)
# Headache yesterday (6)
yh = 1.0 if yest_log.severity > 0 else 0.0
f.append(yh)
f.append(float(yest_log.severity))
f.append(float(yest_log.duration_hours))
f.append(float(MOOD_MAP.get(str(yest_log.mood).lower(), 3)))
f.append(float(len(yest_log.symptoms.symptoms)))
f.append(float(len(yest_log.triggers.triggers)))
# Headache 2d ago (3)
th = 1.0 if twod_log.severity > 0 else 0.0
f.append(th)
f.append(float(twod_log.severity))
f.append(float(twod_log.duration_hours))
# Temporal (7)
try:
dt = datetime.strptime(target_date, "%Y-%m-%d")
except (ValueError, TypeError):
dt = datetime.now()
dw_s, dw_c = _cyclic(dt.weekday(), 7)
mn_s, mn_c = _cyclic(dt.month - 1, 12)
dy_s, dy_c = _cyclic(dt.timetuple().tm_yday, 365)
f.extend([dw_s, dw_c, mn_s, mn_c, dy_s, dy_c])
f.append(1.0 if dt.weekday() >= 5 else 0.0)
# User context (3)
f.append(_parse_age_range(ctx.age_range))
reg = str(ctx.location_region or "").lower()
f.append(1.0 if "europe" in reg else 0.0)
f.append(1.0 if "tropic" in reg else 0.0)
# Interactions (6)
f.append(slp * abs(pc))
f.append(1.0 if hrv < 25 else 0.0)
f.append(max(0.0, 6.0 - slp))
f.append(1.0 if hum > 80 else 0.0)
f.append(yh + th)
f.append(float(min(consecutive_headache_days, 7)))
return np.array(f, dtype=np.float32)
def extract_forecast_features(
snapshots: List[DailySnapshotDTO],
user_ctx: Optional[UserContextDTO] = None,
) -> np.ndarray:
"""
Build feature matrix for 7-day forecast.
snapshots[0] = today (full data), [1..6] = future (weather only).
"""
rows = []
for i in range(len(snapshots)):
snap = snapshots[i]
tw = snap.weather_data or WeatherDataDTO()
td = ""
if snap.headache_log and snap.headache_log.input_date:
td = snap.headache_log.input_date
yest = snapshots[i - 1] if i > 0 else None
twod = snapshots[i - 2] if i > 1 else None
consec = 0
for j in range(i - 1, -1, -1):
lj = snapshots[j].headache_log
if lj and lj.severity > 0:
consec += 1
else:
break
rows.append(extract_features_for_day(tw, td, yest, twod, user_ctx, consec))
return np.vstack(rows)
def get_risk_factors(
features: np.ndarray,
feature_importances: dict,
top_k: int = 3,
) -> List[str]:
"""Identify top risk factors from feature values + learned importances."""
risks = []
# Check each important feature for concerning values
checks = [
("had_headache_1d", lambda v: v > 0),
("pressure_change_24h", lambda v: v < -3),
("consecutive_headache_days", lambda v: v >= 2),
("hrv_avg_ms", lambda v: v < 30),
("headache_streak_2d", lambda v: v >= 1),
("humidity_pct", lambda v: v > 75),
("menstrual_flow_flag", lambda v: v > 0),
("temperature_c", lambda v: v > 30 or v < -5),
("sleep_total_hours", lambda v: v < 6),
("sleep_deficit", lambda v: v > 0),
("low_hrv_flag", lambda v: v > 0),
("is_pressure_drop", lambda v: v > 0),
("is_weekend", lambda v: v > 0),
]
# Sort by feature importance
sorted_checks = sorted(
checks,
key=lambda x: feature_importances.get(x[0], 0),
reverse=True,
)
for fname, condition in sorted_checks:
if fname in FEATURE_NAMES:
idx = FEATURE_NAMES.index(fname)
if condition(features[idx]):
label = RISK_LABELS.get(fname, fname)
if label not in risks:
risks.append(label)
if len(risks) >= top_k:
break
return risks