ML_MODEL_PLEASE / app.py
Harshilforworks's picture
Upload app.py
2a79da1 verified
"""
MediGuard Disease Prediction API
FastAPI application for Hugging Face Spaces deployment
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any
import numpy as np
import joblib
from pathlib import Path
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app with lifespan (will be defined below)
# We need to define lifespan first, then create app
# Model directory
MODEL_DIR = Path(".")
# Global variables for models
rf_model = None
nn_model = None
meta_model = None
scaler = None
label_encoder = None
feature_cols = None
# Backend's allowed diseases (for validation)
BACKEND_ALLOWED_DISEASES = {
"Anemia", "Prediabetes", "Diabetes", "Severe Inflammation",
"Thrombocytopenia", "Obesity", "IronDeficiencyAnemia",
"ThalassemiaMajorLike", "ThalassemiaTrait", "KidneyImpairment",
"Thromboc", "MetabolicSyndrome", "HyperthyroidismLike",
"CoronaryArteryDisease", "Hypertension", "ArrhythmiaRisk",
"Dyslipidemia", "Hepatitis", "NAFLD", "InfectionInflammation",
"Polycythemia", "ACS", "Healthy"
}
# Pydantic models
class PatientData(BaseModel):
"""Patient biomarker data for prediction"""
Glucose: float
Cholesterol: float
Hemoglobin: float
Platelets: float
White_Blood_Cells: float = Field(..., alias="White Blood Cells")
Red_Blood_Cells: float = Field(..., alias="Red Blood Cells")
Hematocrit: float
Mean_Corpuscular_Volume: float = Field(..., alias="Mean Corpuscular Volume")
Mean_Corpuscular_Hemoglobin: float = Field(..., alias="Mean Corpuscular Hemoglobin")
Mean_Corpuscular_Hemoglobin_Concentration: float = Field(..., alias="Mean Corpuscular Hemoglobin Concentration")
Insulin: float
BMI: float
Systolic_Blood_Pressure: float = Field(..., alias="Systolic Blood Pressure")
Diastolic_Blood_Pressure: float = Field(..., alias="Diastolic Blood Pressure")
Triglycerides: float
HbA1c: float
LDL_Cholesterol: float = Field(..., alias="LDL Cholesterol")
HDL_Cholesterol: float = Field(..., alias="HDL Cholesterol")
ALT: float
AST: float
Heart_Rate: float = Field(..., alias="Heart Rate")
Creatinine: float
Troponin: float
C_reactive_Protein: float = Field(..., alias="C-reactive Protein")
class Config:
populate_by_name = True
class PredictionResponse(BaseModel):
"""Response model for disease prediction"""
prediction: str
confidence: float
top_5_predictions: List[Dict[str, float]]
raw_values: Dict[str, float]
model_info: Dict[str, Any]
class HealthResponse(BaseModel):
"""Health check response"""
status: str
model_loaded: bool
feature_count: int
from contextlib import asynccontextmanager
from typing import AsyncGenerator
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Lifespan context manager for startup and shutdown events"""
global rf_model, nn_model, meta_model, scaler, label_encoder, feature_cols
# Startup
try:
logger.info("Loading models...")
# Load models
rf_model = joblib.load(MODEL_DIR / "rf_model.pkl")
nn_model = joblib.load(MODEL_DIR / "nn_model.pkl")
meta_model = joblib.load(MODEL_DIR / "meta_model.pkl")
scaler = joblib.load(MODEL_DIR / "scaler.pkl")
label_encoder = joblib.load(MODEL_DIR / "label_encoder.pkl")
feature_cols = joblib.load(MODEL_DIR / "feature_cols.pkl")
logger.info(f"✓ Models loaded successfully!")
logger.info(f"✓ Feature count: {len(feature_cols)}")
logger.info(f"✓ Classes: {list(label_encoder.classes_)}")
# Validate classes
invalid_classes = set(label_encoder.classes_) - BACKEND_ALLOWED_DISEASES
if invalid_classes:
logger.error(f"Invalid classes found: {invalid_classes}")
raise ValueError("Model contains invalid disease classes")
except Exception as e:
logger.error(f"❌ Error loading models: {e}")
raise
yield
# Shutdown (cleanup if needed)
logger.info("Shutting down...")
# Initialize FastAPI app with lifespan
app = FastAPI(
title="MediGuard Disease Prediction API",
description="AI-powered disease prediction using stacking ensemble",
version="1.0.0",
lifespan=lifespan
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def predict_disease(patient_features: np.ndarray):
"""
Predict disease using stacking ensemble
Args:
patient_features: Array of biomarker values
Returns:
Tuple of (disease, confidence, top_3_predictions)
"""
# Validate features
if len(patient_features) != len(feature_cols):
raise ValueError(
f"Expected {len(feature_cols)} features, got {len(patient_features)}"
)
# Scale features
X_scaled = scaler.transform([patient_features]).astype(np.float32)
# Get base learner predictions
rf_probs = rf_model.predict_proba(X_scaled)
nn_probs = nn_model.predict_proba(X_scaled)
# Create meta-features
X_meta = np.hstack([rf_probs, nn_probs])
# Get final prediction from meta-learner
y_pred = meta_model.predict(X_meta)[0]
y_proba = meta_model.predict_proba(X_meta)[0]
# Get disease name
disease = label_encoder.inverse_transform([y_pred])[0]
confidence = float(y_proba[y_pred])
# Get top 5 predictions
top_5_idx = np.argsort(y_proba)[-5:][::-1]
top_5 = [
{
"disease": label_encoder.inverse_transform([idx])[0],
"probability": float(y_proba[idx])
}
for idx in top_5_idx
]
return disease, confidence, top_5
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "MediGuard Disease Prediction API",
"version": "1.0.0",
"endpoints": {
"health": "/api/health",
"predict": "/api/predict (POST)",
"features": "/api/features",
"diseases": "/api/diseases",
"docs": "/docs"
}
}
@app.get("/api/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy" if rf_model is not None else "not_ready",
model_loaded=rf_model is not None,
feature_count=len(feature_cols) if feature_cols else 0
)
@app.get("/api/features")
async def get_features():
"""Get list of required features"""
if feature_cols is None:
raise HTTPException(status_code=503, detail="Models not loaded")
return {
"features": feature_cols,
"count": len(feature_cols),
"example": "Use /predict endpoint with biomarker values in this exact order"
}
@app.post("/api/predict", response_model=PredictionResponse)
async def predict(patient_data: PatientData):
"""
Predict disease from patient biomarker data
Args:
patient_data: PatientData object with biomarker key-value pairs
Returns:
PredictionResponse with predicted disease and confidence
"""
# Check if models are loaded
if rf_model is None:
raise HTTPException(
status_code=503,
detail="Models not loaded. Please wait for startup to complete."
)
try:
# Convert Pydantic model to dict and extract values in correct order
input_dict = patient_data.model_dump(by_alias=True)
# Build features array in the order expected by feature_cols
features = []
for feature_name in feature_cols:
if feature_name not in input_dict:
raise ValueError(f"Missing feature: {feature_name}")
features.append(float(input_dict[feature_name]))
features = np.array(features, dtype=np.float32)
# Predict
disease, confidence, top_5 = predict_disease(features)
logger.info(f"Prediction: {disease} ({confidence*100:.2f}%)")
return PredictionResponse(
prediction=disease,
confidence=confidence,
top_5_predictions=top_5,
raw_values=input_dict,
model_info={
"base_models": 2, # rf_model and nn_model
"features_used": len(feature_cols),
"meta_input_shape": [1, len(label_encoder.classes_) * 2],
"n_classes": len(label_encoder.classes_)
}
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/diseases", response_model=Dict[str, List[str]])
async def get_diseases():
"""Get list of all possible diseases the model can predict"""
if label_encoder is None:
raise HTTPException(status_code=503, detail="Models not loaded")
return {
"diseases": list(label_encoder.classes_),
"count": len(label_encoder.classes_)
}
# For local testing
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)