File size: 5,513 Bytes
6e2eb3e
 
 
 
 
 
 
75e8c0c
ab6fe11
96b173e
6e2eb3e
 
 
 
 
 
 
 
 
 
75e8c0c
 
 
 
d1a95be
 
 
6e2eb3e
 
75e8c0c
d1a95be
 
6e2eb3e
ab6fe11
6e2eb3e
33f5254
 
 
 
 
 
 
 
 
ab6fe11
6e2eb3e
 
75e8c0c
 
6e2eb3e
96b173e
 
6e2eb3e
96b173e
 
 
 
 
 
47e0ff5
96b173e
 
6e2eb3e
47e0ff5
6e2eb3e
96b173e
 
 
6e2eb3e
 
 
 
 
 
96b173e
 
6e2eb3e
96b173e
 
 
 
 
 
 
 
 
 
6e2eb3e
 
 
 
 
 
 
 
 
 
96b173e
 
 
6e2eb3e
 
 
 
 
 
 
47e0ff5
 
 
 
 
 
 
 
 
 
 
 
6e2eb3e
96b173e
6e2eb3e
47e0ff5
6e2eb3e
 
 
96b173e
 
 
 
 
 
 
 
 
 
 
 
 
 
47e0ff5
96b173e
 
 
 
47e0ff5
 
 
 
 
 
 
 
 
 
 
96b173e
 
47e0ff5
 
96b173e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import xgboost as xgb
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
import os
import sys
from typing import List, Union

app = FastAPI(title="Headache Predictor API")

# Load model at startup
model = None

@app.on_event("startup")
async def load_model():
    global model
    try:
        # Set cache directory to writable location
        cache_dir = "/tmp/hf_cache"
        os.makedirs(cache_dir, exist_ok=True)

        # Get HF token from environment (set as Space secret)
        hf_token = os.environ.get("HF_TOKEN")

        model_path = hf_hub_download(
            repo_id="emp-admin/headache-predictor-xgboost",
            filename="model.pkl",
            cache_dir=cache_dir,
            token=hf_token  # Use token for private repo access
        )

        with open(model_path, 'rb') as f:
            model_data = pickle.load(f)

        # Handle both dict format and raw model
        if isinstance(model_data, dict):
            model = model_data['model']
            print(f"✅ Model loaded successfully (threshold: {model_data.get('optimal_threshold', 0.5)})")
        else:
            model = model_data
            print("✅ Model loaded successfully")

    except Exception as e:
        print(f"❌ Error loading model: {e}")
        import traceback
        traceback.print_exc()

class SinglePredictionRequest(BaseModel):
    features: List[float]

class BatchPredictionRequest(BaseModel):
    instances: List[List[float]]

class DayPrediction(BaseModel):
    day: int
    prediction: int
    probability: float  # Probability of HEADACHE (class 1), regardless of prediction

class SinglePredictionResponse(BaseModel):
    prediction: int
    probability: float  # Probability of HEADACHE (class 1), regardless of prediction

class BatchPredictionResponse(BaseModel):
    predictions: List[DayPrediction]

@app.get("/")
def read_root():
    return {
        "message": "Headache Predictor API",
        "status": "running",
        "endpoints": {
            "predict": "/predict - Single day prediction",
            "predict_batch": "/predict/batch - 7-day forecast",
            "health": "/health"
        },
        "examples": {
            "single": {
                "url": "/predict",
                "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
            },
            "batch": {
                "url": "/predict/batch",
                "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"], "..."]}
            }
        }
    }

@app.get("/health")
def health_check():
    return {
        "status": "healthy",
        "model_loaded": model is not None
    }

@app.post("/predict", response_model=SinglePredictionResponse)
def predict(request: SinglePredictionRequest):
    """Predict headache risk for a single day"""
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        # Convert input to numpy array
        features = np.array(request.features).reshape(1, -1)

        # Get probability array for both classes
        prob_array = model.predict_proba(features)[0]

        # Always return probability of headache (class 1)
        headache_probability = float(prob_array[1])

        # Make prediction using threshold if available
        if isinstance(model, dict) and 'optimal_threshold' in model:
            threshold = model['optimal_threshold']
            prediction = 1 if headache_probability >= threshold else 0
        else:
            prediction = model.predict(features)[0]

        return SinglePredictionResponse(
            prediction=int(prediction),
            probability=headache_probability
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")

@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
    """Predict headache risk for multiple days (7-day forecast)"""
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        # Convert all instances to numpy array
        features = np.array(request.instances)

        if features.ndim != 2:
            raise ValueError(f"Expected 2D array, got shape {features.shape}")

        # Get probabilities for all days
        probabilities = model.predict_proba(features)

        # Format results
        results = []
        for i, prob_array in enumerate(probabilities, 1):
            # Always use probability of headache (class 1)
            headache_probability = float(prob_array[1])

            # Make prediction using threshold if available
            if isinstance(model, dict) and 'optimal_threshold' in model:
                threshold = model['optimal_threshold']
                prediction = 1 if headache_probability >= threshold else 0
            else:
                prediction = model.predict(features[i-1:i])[0]

            results.append(DayPrediction(
                day=i,
                prediction=int(prediction),
                probability=headache_probability
            ))

        return BatchPredictionResponse(predictions=results)

    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")