emp-admin's picture
Update app.py
1b10340 verified
"""
═══════════════════════════════════════════════════════════════════════
Phoebe Headache Predictor API v3.0
EmpedocLabs Β© 2025
═══════════════════════════════════════════════════════════════════════
Endpoints:
GET / β†’ API info & usage examples
GET /health β†’ Health + model status
POST /forecast β†’ 7-day headache forecast (DailySnapshotDTO)
POST /predict β†’ Single-day legacy (raw feature vector)
POST /predict/batch β†’ Batch legacy (raw feature vectors)
"""
import logging
import numpy as np
import pickle
import os
from typing import List
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from models import (
DailySnapshotDTO, UserContextDTO, WeatherDataDTO,
PredictionRequest, PredictionResponse, DayPrediction,
SinglePredictionRequest, SinglePredictionResponse,
)
from feature_engineering import (
extract_features_for_day, extract_forecast_features,
get_risk_factors, FEATURE_NAMES, NUM_FEATURES,
)
# ── Logging ──────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger("phoebe")
# ── App ──────────────────────────────────────────────────────────────
app = FastAPI(
title="Phoebe Headache Predictor API",
version="3.0.0",
description="ML-powered headache risk forecasting for the Phoebe iOS app by EmpedocLabs.",
docs_url="/docs",
redoc_url="/redoc",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── Globals ──────────────────────────────────────────────────────────
clf = None
threshold = 0.5
model_version = "3.0.0"
feature_importances = {}
# ── Startup ──────────────────────────────────────────────────────────
@app.on_event("startup")
async def load_model():
global clf, threshold, model_version, feature_importances
try:
# Load from the Space's own files (uploaded alongside app.py)
model_path = os.path.join(os.path.dirname(__file__), "model.pkl")
if not os.path.exists(model_path):
# Fallback: check common HF Space paths
for p in ["/app/model.pkl", "model/model.pkl", "/app/model/model.pkl"]:
if os.path.exists(p):
model_path = p
break
logger.info(f"Loading model from {model_path}...")
with open(model_path, "rb") as f:
data = pickle.load(f)
if isinstance(data, dict):
clf = data["model"]
threshold = float(data.get("optimal_threshold", 0.5))
model_version = data.get("model_version", "3.0.0")
feature_importances = data.get("feature_importances", {})
metrics = data.get("test_metrics", {})
logger.info(
f"βœ… Model v{model_version} loaded | "
f"threshold={threshold:.3f} | "
f"AUC={metrics.get('roc_auc', '?')} | "
f"F1={metrics.get('f1', '?')}"
)
else:
clf = data
threshold = 0.5
logger.info("βœ… Model loaded (legacy format)")
except Exception as e:
logger.error(f"❌ Model load failed: {e}")
import traceback
traceback.print_exc()
# ── Helpers ──────────────────────────────────────────────────────────
def _risk_level(prob: float) -> str:
if prob < 0.20: return "low"
if prob < 0.40: return "moderate"
if prob < 0.65: return "high"
return "very_high"
# ── Root ─────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"service": "Phoebe Headache Predictor API",
"version": model_version,
"by": "EmpedocLabs",
"status": "running" if clf is not None else "model_not_loaded",
"endpoints": {
"/health": "GET β€” model status & metrics",
"/forecast": "POST β€” 7-day headache risk forecast",
"/predict": "POST β€” single prediction (legacy)",
"/predict/batch": "POST β€” batch prediction (legacy)",
"/docs": "GET β€” Swagger UI",
},
"example_forecast_body": {
"user_context": {"age_range": "30-40", "location_region": "Balkan Peninsula, Europe"},
"daily_snapshots": [
{
"headache_log": {"severity": 0, "duration_hours": 0, "input_date": "2025-06-01", "mood": "good"},
"health_kit_metrics": {
"resting_heart_rate": 62,
"sleep_analysis": {"total_duration_hours": 7.2, "deep_sleep_minutes": 85, "rem_sleep_minutes": 95},
"hrv_summary": {"average_ms": 42},
"workout_minutes": 30,
"had_menstrual_flow": False,
},
"weather_data": {
"barometric_pressure_mb": 1015.2, "pressure_change_24h_mb": -2.1,
"humidity_percent": 65, "temperature_celsius": 22.5,
},
},
],
},
}
@app.get("/health")
def health():
return {
"status": "healthy" if clf is not None else "degraded",
"model_loaded": clf is not None,
"model_version": model_version,
"threshold": threshold,
"num_features": NUM_FEATURES,
"top_features": list(feature_importances.keys())[:5],
}
# ── /forecast β€” Main endpoint ───────────────────────────────────────
@app.post("/forecast", response_model=PredictionResponse)
def forecast(request: PredictionRequest):
"""
7-day headache risk forecast.
Send daily_snapshots[0] = today (full HealthKit + diary + weather),
daily_snapshots[1..6] = future days (weather forecast only).
Returns probability, risk level, and top risk factors per day.
"""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded. Please retry shortly.")
if not request.daily_snapshots:
raise HTTPException(status_code=400, detail="daily_snapshots cannot be empty.")
if len(request.daily_snapshots) > 14:
raise HTTPException(status_code=400, detail="Maximum 14 days supported.")
try:
ctx = request.user_context
snaps = request.daily_snapshots
X = extract_forecast_features(snaps, ctx)
predictions = []
for i in range(len(snaps)):
prob_arr = clf.predict_proba(X[i:i + 1])[0]
prob = float(prob_arr[1])
pred = 1 if prob >= threshold else 0
date_str = None
if snaps[i].headache_log and snaps[i].headache_log.input_date:
date_str = snaps[i].headache_log.input_date
risks = get_risk_factors(X[i], feature_importances, top_k=3)
predictions.append(DayPrediction(
day=i + 1,
date=date_str,
prediction=pred,
probability=round(prob, 4),
risk_level=_risk_level(prob),
top_risk_factors=risks,
))
logger.info(
f"Forecast: {len(snaps)} days | "
f"probs={[p.probability for p in predictions]}"
)
return PredictionResponse(
predictions=predictions,
model_version=model_version,
threshold=round(threshold, 4),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Forecast error: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Forecast error: {str(e)}")
# ── Legacy endpoints ─────────────────────────────────────────────────
class BatchRequest(BaseModel):
instances: List[List[float]]
class BatchDayPred(BaseModel):
day: int
prediction: int
probability: float
class BatchResponse(BaseModel):
predictions: List[BatchDayPred]
@app.post("/predict", response_model=SinglePredictionResponse)
def predict_single(request: SinglePredictionRequest):
"""Legacy: raw feature vector β†’ single prediction."""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
X = np.array(request.features, dtype=np.float32).reshape(1, -1)
if X.shape[1] != NUM_FEATURES:
raise ValueError(f"Expected {NUM_FEATURES} features, got {X.shape[1]}")
prob = float(clf.predict_proba(X)[0][1])
return SinglePredictionResponse(prediction=1 if prob >= threshold else 0, probability=round(prob, 4))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/predict/batch", response_model=BatchResponse)
def predict_batch(request: BatchRequest):
"""Legacy: batch raw feature vectors."""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
X = np.array(request.instances, dtype=np.float32)
if X.ndim != 2 or X.shape[1] != NUM_FEATURES:
raise ValueError(f"Expected shape (n, {NUM_FEATURES}), got {X.shape}")
probas = clf.predict_proba(X)[:, 1]
preds = (probas >= threshold).astype(int)
return BatchResponse(predictions=[
BatchDayPred(day=i + 1, prediction=int(preds[i]), probability=round(float(probas[i]), 4))
for i in range(len(probas))
])
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))