Spaces:
Sleeping
Sleeping
File size: 9,575 Bytes
c375045 2a79da1 c375045 fc7a995 c375045 bae8d6e c375045 ccea93f c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 bae8d6e c375045 bae8d6e c375045 bae8d6e c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 bae8d6e c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 2a79da1 c375045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
"""
MediGuard Disease Prediction API
FastAPI application for Hugging Face Spaces deployment
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import numpy as np
import joblib
from pathlib import Path
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app with lifespan (will be defined below)
# We need to define lifespan first, then create app
# Model directory
MODEL_DIR = Path(".")
# Global variables for models
rf_model = None
nn_model = None
meta_model = None
scaler = None
label_encoder = None
feature_cols = None
# Backend's allowed diseases (for validation)
BACKEND_ALLOWED_DISEASES = {
"Anemia", "Prediabetes", "Diabetes", "Severe Inflammation",
"Thrombocytopenia", "Obesity", "IronDeficiencyAnemia",
"ThalassemiaMajorLike", "ThalassemiaTrait", "KidneyImpairment",
"Thromboc", "MetabolicSyndrome", "HyperthyroidismLike",
"CoronaryArteryDisease", "Hypertension", "ArrhythmiaRisk",
"Dyslipidemia", "Hepatitis", "NAFLD", "InfectionInflammation",
"Polycythemia", "ACS", "Healthy"
}
# Pydantic models
class PatientData(BaseModel):
"""Patient biomarker data for prediction"""
Glucose: float
Cholesterol: float
Hemoglobin: float
Platelets: float
White_Blood_Cells: float = Field(..., alias="White Blood Cells")
Red_Blood_Cells: float = Field(..., alias="Red Blood Cells")
Hematocrit: float
Mean_Corpuscular_Volume: float = Field(..., alias="Mean Corpuscular Volume")
Mean_Corpuscular_Hemoglobin: float = Field(..., alias="Mean Corpuscular Hemoglobin")
Mean_Corpuscular_Hemoglobin_Concentration: float = Field(..., alias="Mean Corpuscular Hemoglobin Concentration")
Insulin: float
BMI: float
Systolic_Blood_Pressure: float = Field(..., alias="Systolic Blood Pressure")
Diastolic_Blood_Pressure: float = Field(..., alias="Diastolic Blood Pressure")
Triglycerides: float
HbA1c: float
LDL_Cholesterol: float = Field(..., alias="LDL Cholesterol")
HDL_Cholesterol: float = Field(..., alias="HDL Cholesterol")
ALT: float
AST: float
Heart_Rate: float = Field(..., alias="Heart Rate")
Creatinine: float
Troponin: float
C_reactive_Protein: float = Field(..., alias="C-reactive Protein")
class Config:
populate_by_name = True
class PredictionResponse(BaseModel):
"""Response model for disease prediction"""
prediction: str
confidence: float
top_5_predictions: List[Dict[str, float]]
raw_values: Dict[str, float]
model_info: Dict[str, Any]
class HealthResponse(BaseModel):
"""Health check response"""
status: str
model_loaded: bool
feature_count: int
from contextlib import asynccontextmanager
from typing import AsyncGenerator
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Lifespan context manager for startup and shutdown events"""
global rf_model, nn_model, meta_model, scaler, label_encoder, feature_cols
# Startup
try:
logger.info("Loading models...")
# Load models
rf_model = joblib.load(MODEL_DIR / "rf_model.pkl")
nn_model = joblib.load(MODEL_DIR / "nn_model.pkl")
meta_model = joblib.load(MODEL_DIR / "meta_model.pkl")
scaler = joblib.load(MODEL_DIR / "scaler.pkl")
label_encoder = joblib.load(MODEL_DIR / "label_encoder.pkl")
feature_cols = joblib.load(MODEL_DIR / "feature_cols.pkl")
logger.info(f"✓ Models loaded successfully!")
logger.info(f"✓ Feature count: {len(feature_cols)}")
logger.info(f"✓ Classes: {list(label_encoder.classes_)}")
# Validate classes
invalid_classes = set(label_encoder.classes_) - BACKEND_ALLOWED_DISEASES
if invalid_classes:
logger.error(f"Invalid classes found: {invalid_classes}")
raise ValueError("Model contains invalid disease classes")
except Exception as e:
logger.error(f"❌ Error loading models: {e}")
raise
yield
# Shutdown (cleanup if needed)
logger.info("Shutting down...")
# Initialize FastAPI app with lifespan
app = FastAPI(
title="MediGuard Disease Prediction API",
description="AI-powered disease prediction using stacking ensemble",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def predict_disease(patient_features: np.ndarray):
"""
Predict disease using stacking ensemble
Args:
patient_features: Array of biomarker values
Returns:
Tuple of (disease, confidence, top_3_predictions)
"""
# Validate features
if len(patient_features) != len(feature_cols):
raise ValueError(
f"Expected {len(feature_cols)} features, got {len(patient_features)}"
)
# Scale features
X_scaled = scaler.transform([patient_features]).astype(np.float32)
# Get base learner predictions
rf_probs = rf_model.predict_proba(X_scaled)
nn_probs = nn_model.predict_proba(X_scaled)
# Create meta-features
X_meta = np.hstack([rf_probs, nn_probs])
# Get final prediction from meta-learner
y_pred = meta_model.predict(X_meta)[0]
y_proba = meta_model.predict_proba(X_meta)[0]
# Get disease name
disease = label_encoder.inverse_transform([y_pred])[0]
confidence = float(y_proba[y_pred])
# Get top 5 predictions
top_5_idx = np.argsort(y_proba)[-5:][::-1]
top_5 = [
{
"disease": label_encoder.inverse_transform([idx])[0],
"probability": float(y_proba[idx])
}
for idx in top_5_idx
]
return disease, confidence, top_5
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "MediGuard Disease Prediction API",
"version": "1.0.0",
"endpoints": {
"health": "/api/health",
"predict": "/api/predict (POST)",
"features": "/api/features",
"diseases": "/api/diseases",
"docs": "/docs"
}
}
@app.get("/api/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy" if rf_model is not None else "not_ready",
model_loaded=rf_model is not None,
feature_count=len(feature_cols) if feature_cols else 0
)
@app.get("/api/features")
async def get_features():
"""Get list of required features"""
if feature_cols is None:
raise HTTPException(status_code=503, detail="Models not loaded")
return {
"features": feature_cols,
"count": len(feature_cols),
"example": "Use /predict endpoint with biomarker values in this exact order"
}
@app.post("/api/predict", response_model=PredictionResponse)
async def predict(patient_data: PatientData):
"""
Predict disease from patient biomarker data
Args:
patient_data: PatientData object with biomarker key-value pairs
Returns:
PredictionResponse with predicted disease and confidence
"""
# Check if models are loaded
if rf_model is None:
raise HTTPException(
status_code=503,
detail="Models not loaded. Please wait for startup to complete."
)
try:
# Convert Pydantic model to dict and extract values in correct order
input_dict = patient_data.model_dump(by_alias=True)
# Build features array in the order expected by feature_cols
features = []
for feature_name in feature_cols:
if feature_name not in input_dict:
raise ValueError(f"Missing feature: {feature_name}")
features.append(float(input_dict[feature_name]))
features = np.array(features, dtype=np.float32)
# Predict
disease, confidence, top_5 = predict_disease(features)
logger.info(f"Prediction: {disease} ({confidence*100:.2f}%)")
return PredictionResponse(
prediction=disease,
confidence=confidence,
top_5_predictions=top_5,
raw_values=input_dict,
model_info={
"base_models": 2, # rf_model and nn_model
"features_used": len(feature_cols),
"meta_input_shape": [1, len(label_encoder.classes_) * 2],
"n_classes": len(label_encoder.classes_)
}
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/diseases", response_model=Dict[str, List[str]])
async def get_diseases():
"""Get list of all possible diseases the model can predict"""
if label_encoder is None:
raise HTTPException(status_code=503, detail="Models not loaded")
return {
"diseases": list(label_encoder.classes_),
"count": len(label_encoder.classes_)
}
# For local testing
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |