""" MediGuard Disease Prediction API FastAPI application for Hugging Face Spaces deployment """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Dict, Any import numpy as np import joblib from pathlib import Path import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app with lifespan (will be defined below) # We need to define lifespan first, then create app # Model directory MODEL_DIR = Path(".") # Global variables for models rf_model = None nn_model = None meta_model = None scaler = None label_encoder = None feature_cols = None # Backend's allowed diseases (for validation) BACKEND_ALLOWED_DISEASES = { "Anemia", "Prediabetes", "Diabetes", "Severe Inflammation", "Thrombocytopenia", "Obesity", "IronDeficiencyAnemia", "ThalassemiaMajorLike", "ThalassemiaTrait", "KidneyImpairment", "Thromboc", "MetabolicSyndrome", "HyperthyroidismLike", "CoronaryArteryDisease", "Hypertension", "ArrhythmiaRisk", "Dyslipidemia", "Hepatitis", "NAFLD", "InfectionInflammation", "Polycythemia", "ACS", "Healthy" } # Pydantic models class PatientData(BaseModel): """Patient biomarker data for prediction""" Glucose: float Cholesterol: float Hemoglobin: float Platelets: float White_Blood_Cells: float = Field(..., alias="White Blood Cells") Red_Blood_Cells: float = Field(..., alias="Red Blood Cells") Hematocrit: float Mean_Corpuscular_Volume: float = Field(..., alias="Mean Corpuscular Volume") Mean_Corpuscular_Hemoglobin: float = Field(..., alias="Mean Corpuscular Hemoglobin") Mean_Corpuscular_Hemoglobin_Concentration: float = Field(..., alias="Mean Corpuscular Hemoglobin Concentration") Insulin: float BMI: float Systolic_Blood_Pressure: float = Field(..., alias="Systolic Blood Pressure") Diastolic_Blood_Pressure: float = Field(..., alias="Diastolic Blood Pressure") Triglycerides: float HbA1c: float LDL_Cholesterol: float = Field(..., alias="LDL Cholesterol") HDL_Cholesterol: float = Field(..., alias="HDL Cholesterol") ALT: float AST: float Heart_Rate: float = Field(..., alias="Heart Rate") Creatinine: float Troponin: float C_reactive_Protein: float = Field(..., alias="C-reactive Protein") class Config: populate_by_name = True class PredictionResponse(BaseModel): """Response model for disease prediction""" prediction: str confidence: float top_5_predictions: List[Dict[str, float]] raw_values: Dict[str, float] model_info: Dict[str, Any] class HealthResponse(BaseModel): """Health check response""" status: str model_loaded: bool feature_count: int from contextlib import asynccontextmanager from typing import AsyncGenerator @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Lifespan context manager for startup and shutdown events""" global rf_model, nn_model, meta_model, scaler, label_encoder, feature_cols # Startup try: logger.info("Loading models...") # Load models rf_model = joblib.load(MODEL_DIR / "rf_model.pkl") nn_model = joblib.load(MODEL_DIR / "nn_model.pkl") meta_model = joblib.load(MODEL_DIR / "meta_model.pkl") scaler = joblib.load(MODEL_DIR / "scaler.pkl") label_encoder = joblib.load(MODEL_DIR / "label_encoder.pkl") feature_cols = joblib.load(MODEL_DIR / "feature_cols.pkl") logger.info(f"✓ Models loaded successfully!") logger.info(f"✓ Feature count: {len(feature_cols)}") logger.info(f"✓ Classes: {list(label_encoder.classes_)}") # Validate classes invalid_classes = set(label_encoder.classes_) - BACKEND_ALLOWED_DISEASES if invalid_classes: logger.error(f"Invalid classes found: {invalid_classes}") raise ValueError("Model contains invalid disease classes") except Exception as e: logger.error(f"❌ Error loading models: {e}") raise yield # Shutdown (cleanup if needed) logger.info("Shutting down...") # Initialize FastAPI app with lifespan app = FastAPI( title="MediGuard Disease Prediction API", description="AI-powered disease prediction using stacking ensemble", version="1.0.0", lifespan=lifespan ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def predict_disease(patient_features: np.ndarray): """ Predict disease using stacking ensemble Args: patient_features: Array of biomarker values Returns: Tuple of (disease, confidence, top_3_predictions) """ # Validate features if len(patient_features) != len(feature_cols): raise ValueError( f"Expected {len(feature_cols)} features, got {len(patient_features)}" ) # Scale features X_scaled = scaler.transform([patient_features]).astype(np.float32) # Get base learner predictions rf_probs = rf_model.predict_proba(X_scaled) nn_probs = nn_model.predict_proba(X_scaled) # Create meta-features X_meta = np.hstack([rf_probs, nn_probs]) # Get final prediction from meta-learner y_pred = meta_model.predict(X_meta)[0] y_proba = meta_model.predict_proba(X_meta)[0] # Get disease name disease = label_encoder.inverse_transform([y_pred])[0] confidence = float(y_proba[y_pred]) # Get top 5 predictions top_5_idx = np.argsort(y_proba)[-5:][::-1] top_5 = [ { "disease": label_encoder.inverse_transform([idx])[0], "probability": float(y_proba[idx]) } for idx in top_5_idx ] return disease, confidence, top_5 @app.get("/") async def root(): """Root endpoint""" return { "message": "MediGuard Disease Prediction API", "version": "1.0.0", "endpoints": { "health": "/api/health", "predict": "/api/predict (POST)", "features": "/api/features", "diseases": "/api/diseases", "docs": "/docs" } } @app.get("/api/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" return HealthResponse( status="healthy" if rf_model is not None else "not_ready", model_loaded=rf_model is not None, feature_count=len(feature_cols) if feature_cols else 0 ) @app.get("/api/features") async def get_features(): """Get list of required features""" if feature_cols is None: raise HTTPException(status_code=503, detail="Models not loaded") return { "features": feature_cols, "count": len(feature_cols), "example": "Use /predict endpoint with biomarker values in this exact order" } @app.post("/api/predict", response_model=PredictionResponse) async def predict(patient_data: PatientData): """ Predict disease from patient biomarker data Args: patient_data: PatientData object with biomarker key-value pairs Returns: PredictionResponse with predicted disease and confidence """ # Check if models are loaded if rf_model is None: raise HTTPException( status_code=503, detail="Models not loaded. Please wait for startup to complete." ) try: # Convert Pydantic model to dict and extract values in correct order input_dict = patient_data.model_dump(by_alias=True) # Build features array in the order expected by feature_cols features = [] for feature_name in feature_cols: if feature_name not in input_dict: raise ValueError(f"Missing feature: {feature_name}") features.append(float(input_dict[feature_name])) features = np.array(features, dtype=np.float32) # Predict disease, confidence, top_5 = predict_disease(features) logger.info(f"Prediction: {disease} ({confidence*100:.2f}%)") return PredictionResponse( prediction=disease, confidence=confidence, top_5_predictions=top_5, raw_values=input_dict, model_info={ "base_models": 2, # rf_model and nn_model "features_used": len(feature_cols), "meta_input_shape": [1, len(label_encoder.classes_) * 2], "n_classes": len(label_encoder.classes_) } ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @app.get("/diseases", response_model=Dict[str, List[str]]) async def get_diseases(): """Get list of all possible diseases the model can predict""" if label_encoder is None: raise HTTPException(status_code=503, detail="Models not loaded") return { "diseases": list(label_encoder.classes_), "count": len(label_encoder.classes_) } # For local testing if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)