""" ═══════════════════════════════════════════════════════════════════════ Phoebe Headache Predictor API v3.0 EmpedocLabs © 2025 ═══════════════════════════════════════════════════════════════════════ Endpoints: GET / → API info & usage examples GET /health → Health + model status POST /forecast → 7-day headache forecast (DailySnapshotDTO) POST /predict → Single-day legacy (raw feature vector) POST /predict/batch → Batch legacy (raw feature vectors) """ import logging import numpy as np import pickle import os from typing import List from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from huggingface_hub import hf_hub_download from models import ( DailySnapshotDTO, UserContextDTO, WeatherDataDTO, PredictionRequest, PredictionResponse, DayPrediction, SinglePredictionRequest, SinglePredictionResponse, ) from feature_engineering import ( extract_features_for_day, extract_forecast_features, get_risk_factors, FEATURE_NAMES, NUM_FEATURES, ) # ── Logging ────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") logger = logging.getLogger("phoebe") # ── App ────────────────────────────────────────────────────────────── app = FastAPI( title="Phoebe Headache Predictor API", version="3.0.0", description="ML-powered headache risk forecasting for the Phoebe iOS app by EmpedocLabs.", docs_url="/docs", redoc_url="/redoc", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ── Globals ────────────────────────────────────────────────────────── clf = None threshold = 0.5 model_version = "3.0.0" feature_importances = {} # ── Startup ────────────────────────────────────────────────────────── @app.on_event("startup") async def load_model(): global clf, threshold, model_version, feature_importances try: # Load from the Space's own files (uploaded alongside app.py) model_path = os.path.join(os.path.dirname(__file__), "model.pkl") if not os.path.exists(model_path): # Fallback: check common HF Space paths for p in ["/app/model.pkl", "model/model.pkl", "/app/model/model.pkl"]: if os.path.exists(p): model_path = p break logger.info(f"Loading model from {model_path}...") with open(model_path, "rb") as f: data = pickle.load(f) if isinstance(data, dict): clf = data["model"] threshold = float(data.get("optimal_threshold", 0.5)) model_version = data.get("model_version", "3.0.0") feature_importances = data.get("feature_importances", {}) metrics = data.get("test_metrics", {}) logger.info( f"✅ Model v{model_version} loaded | " f"threshold={threshold:.3f} | " f"AUC={metrics.get('roc_auc', '?')} | " f"F1={metrics.get('f1', '?')}" ) else: clf = data threshold = 0.5 logger.info("✅ Model loaded (legacy format)") except Exception as e: logger.error(f"❌ Model load failed: {e}") import traceback traceback.print_exc() # ── Helpers ────────────────────────────────────────────────────────── def _risk_level(prob: float) -> str: if prob < 0.20: return "low" if prob < 0.40: return "moderate" if prob < 0.65: return "high" return "very_high" # ── Root ───────────────────────────────────────────────────────────── @app.get("/") def root(): return { "service": "Phoebe Headache Predictor API", "version": model_version, "by": "EmpedocLabs", "status": "running" if clf is not None else "model_not_loaded", "endpoints": { "/health": "GET — model status & metrics", "/forecast": "POST — 7-day headache risk forecast", "/predict": "POST — single prediction (legacy)", "/predict/batch": "POST — batch prediction (legacy)", "/docs": "GET — Swagger UI", }, "example_forecast_body": { "user_context": {"age_range": "30-40", "location_region": "Balkan Peninsula, Europe"}, "daily_snapshots": [ { "headache_log": {"severity": 0, "duration_hours": 0, "input_date": "2025-06-01", "mood": "good"}, "health_kit_metrics": { "resting_heart_rate": 62, "sleep_analysis": {"total_duration_hours": 7.2, "deep_sleep_minutes": 85, "rem_sleep_minutes": 95}, "hrv_summary": {"average_ms": 42}, "workout_minutes": 30, "had_menstrual_flow": False, }, "weather_data": { "barometric_pressure_mb": 1015.2, "pressure_change_24h_mb": -2.1, "humidity_percent": 65, "temperature_celsius": 22.5, }, }, ], }, } @app.get("/health") def health(): return { "status": "healthy" if clf is not None else "degraded", "model_loaded": clf is not None, "model_version": model_version, "threshold": threshold, "num_features": NUM_FEATURES, "top_features": list(feature_importances.keys())[:5], } # ── /forecast — Main endpoint ─────────────────────────────────────── @app.post("/forecast", response_model=PredictionResponse) def forecast(request: PredictionRequest): """ 7-day headache risk forecast. Send daily_snapshots[0] = today (full HealthKit + diary + weather), daily_snapshots[1..6] = future days (weather forecast only). Returns probability, risk level, and top risk factors per day. """ if clf is None: raise HTTPException(status_code=503, detail="Model not loaded. Please retry shortly.") if not request.daily_snapshots: raise HTTPException(status_code=400, detail="daily_snapshots cannot be empty.") if len(request.daily_snapshots) > 14: raise HTTPException(status_code=400, detail="Maximum 14 days supported.") try: ctx = request.user_context snaps = request.daily_snapshots X = extract_forecast_features(snaps, ctx) predictions = [] for i in range(len(snaps)): prob_arr = clf.predict_proba(X[i:i + 1])[0] prob = float(prob_arr[1]) pred = 1 if prob >= threshold else 0 date_str = None if snaps[i].headache_log and snaps[i].headache_log.input_date: date_str = snaps[i].headache_log.input_date risks = get_risk_factors(X[i], feature_importances, top_k=3) predictions.append(DayPrediction( day=i + 1, date=date_str, prediction=pred, probability=round(prob, 4), risk_level=_risk_level(prob), top_risk_factors=risks, )) logger.info( f"Forecast: {len(snaps)} days | " f"probs={[p.probability for p in predictions]}" ) return PredictionResponse( predictions=predictions, model_version=model_version, threshold=round(threshold, 4), ) except HTTPException: raise except Exception as e: logger.error(f"Forecast error: {e}", exc_info=True) raise HTTPException(status_code=400, detail=f"Forecast error: {str(e)}") # ── Legacy endpoints ───────────────────────────────────────────────── class BatchRequest(BaseModel): instances: List[List[float]] class BatchDayPred(BaseModel): day: int prediction: int probability: float class BatchResponse(BaseModel): predictions: List[BatchDayPred] @app.post("/predict", response_model=SinglePredictionResponse) def predict_single(request: SinglePredictionRequest): """Legacy: raw feature vector → single prediction.""" if clf is None: raise HTTPException(status_code=503, detail="Model not loaded") try: X = np.array(request.features, dtype=np.float32).reshape(1, -1) if X.shape[1] != NUM_FEATURES: raise ValueError(f"Expected {NUM_FEATURES} features, got {X.shape[1]}") prob = float(clf.predict_proba(X)[0][1]) return SinglePredictionResponse(prediction=1 if prob >= threshold else 0, probability=round(prob, 4)) except HTTPException: raise except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/predict/batch", response_model=BatchResponse) def predict_batch(request: BatchRequest): """Legacy: batch raw feature vectors.""" if clf is None: raise HTTPException(status_code=503, detail="Model not loaded") try: X = np.array(request.instances, dtype=np.float32) if X.ndim != 2 or X.shape[1] != NUM_FEATURES: raise ValueError(f"Expected shape (n, {NUM_FEATURES}), got {X.shape}") probas = clf.predict_proba(X)[:, 1] preds = (probas >= threshold).astype(int) return BatchResponse(predictions=[ BatchDayPred(day=i + 1, prediction=int(preds[i]), probability=round(float(probas[i]), 4)) for i in range(len(probas)) ]) except HTTPException: raise except Exception as e: raise HTTPException(status_code=400, detail=str(e))