Spaces:
Sleeping
Sleeping
File size: 5,513 Bytes
6e2eb3e 75e8c0c ab6fe11 96b173e 6e2eb3e 75e8c0c d1a95be 6e2eb3e 75e8c0c d1a95be 6e2eb3e ab6fe11 6e2eb3e 33f5254 ab6fe11 6e2eb3e 75e8c0c 6e2eb3e 96b173e 6e2eb3e 96b173e 47e0ff5 96b173e 6e2eb3e 47e0ff5 6e2eb3e 96b173e 6e2eb3e 96b173e 6e2eb3e 96b173e 6e2eb3e 96b173e 6e2eb3e 47e0ff5 6e2eb3e 96b173e 6e2eb3e 47e0ff5 6e2eb3e 96b173e 47e0ff5 96b173e 47e0ff5 96b173e 47e0ff5 96b173e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import xgboost as xgb
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
import os
import sys
from typing import List, Union
app = FastAPI(title="Headache Predictor API")
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
try:
# Set cache directory to writable location
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
# Get HF token from environment (set as Space secret)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
repo_id="emp-admin/headache-predictor-xgboost",
filename="model.pkl",
cache_dir=cache_dir,
token=hf_token # Use token for private repo access
)
with open(model_path, 'rb') as f:
model_data = pickle.load(f)
# Handle both dict format and raw model
if isinstance(model_data, dict):
model = model_data['model']
print(f"✅ Model loaded successfully (threshold: {model_data.get('optimal_threshold', 0.5)})")
else:
model = model_data
print("✅ Model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
class SinglePredictionRequest(BaseModel):
features: List[float]
class BatchPredictionRequest(BaseModel):
instances: List[List[float]]
class DayPrediction(BaseModel):
day: int
prediction: int
probability: float # Probability of HEADACHE (class 1), regardless of prediction
class SinglePredictionResponse(BaseModel):
prediction: int
probability: float # Probability of HEADACHE (class 1), regardless of prediction
class BatchPredictionResponse(BaseModel):
predictions: List[DayPrediction]
@app.get("/")
def read_root():
return {
"message": "Headache Predictor API",
"status": "running",
"endpoints": {
"predict": "/predict - Single day prediction",
"predict_batch": "/predict/batch - 7-day forecast",
"health": "/health"
},
"examples": {
"single": {
"url": "/predict",
"body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
},
"batch": {
"url": "/predict/batch",
"body": {"instances": [["array of 37 features for day 1"], ["array for day 2"], "..."]}
}
}
}
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": model is not None
}
@app.post("/predict", response_model=SinglePredictionResponse)
def predict(request: SinglePredictionRequest):
"""Predict headache risk for a single day"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Convert input to numpy array
features = np.array(request.features).reshape(1, -1)
# Get probability array for both classes
prob_array = model.predict_proba(features)[0]
# Always return probability of headache (class 1)
headache_probability = float(prob_array[1])
# Make prediction using threshold if available
if isinstance(model, dict) and 'optimal_threshold' in model:
threshold = model['optimal_threshold']
prediction = 1 if headache_probability >= threshold else 0
else:
prediction = model.predict(features)[0]
return SinglePredictionResponse(
prediction=int(prediction),
probability=headache_probability
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")
@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
"""Predict headache risk for multiple days (7-day forecast)"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Convert all instances to numpy array
features = np.array(request.instances)
if features.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {features.shape}")
# Get probabilities for all days
probabilities = model.predict_proba(features)
# Format results
results = []
for i, prob_array in enumerate(probabilities, 1):
# Always use probability of headache (class 1)
headache_probability = float(prob_array[1])
# Make prediction using threshold if available
if isinstance(model, dict) and 'optimal_threshold' in model:
threshold = model['optimal_threshold']
prediction = 1 if headache_probability >= threshold else 0
else:
prediction = model.predict(features[i-1:i])[0]
results.append(DayPrediction(
day=i,
prediction=int(prediction),
probability=headache_probability
))
return BatchPredictionResponse(predictions=results)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")
|