Spaces:
Sleeping
Sleeping
| """ | |
| MediGuard Disease Prediction API | |
| FastAPI application for Hugging Face Spaces deployment | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| import joblib | |
| from pathlib import Path | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app with lifespan (will be defined below) | |
| # We need to define lifespan first, then create app | |
| # Model directory | |
| MODEL_DIR = Path(".") | |
| # Global variables for models | |
| rf_model = None | |
| nn_model = None | |
| meta_model = None | |
| scaler = None | |
| label_encoder = None | |
| feature_cols = None | |
| # Backend's allowed diseases (for validation) | |
| BACKEND_ALLOWED_DISEASES = { | |
| "Anemia", "Prediabetes", "Diabetes", "Severe Inflammation", | |
| "Thrombocytopenia", "Obesity", "IronDeficiencyAnemia", | |
| "ThalassemiaMajorLike", "ThalassemiaTrait", "KidneyImpairment", | |
| "Thromboc", "MetabolicSyndrome", "HyperthyroidismLike", | |
| "CoronaryArteryDisease", "Hypertension", "ArrhythmiaRisk", | |
| "Dyslipidemia", "Hepatitis", "NAFLD", "InfectionInflammation", | |
| "Polycythemia", "ACS", "Healthy" | |
| } | |
| # Pydantic models | |
| class PatientData(BaseModel): | |
| """Patient biomarker data for prediction""" | |
| Glucose: float | |
| Cholesterol: float | |
| Hemoglobin: float | |
| Platelets: float | |
| White_Blood_Cells: float = Field(..., alias="White Blood Cells") | |
| Red_Blood_Cells: float = Field(..., alias="Red Blood Cells") | |
| Hematocrit: float | |
| Mean_Corpuscular_Volume: float = Field(..., alias="Mean Corpuscular Volume") | |
| Mean_Corpuscular_Hemoglobin: float = Field(..., alias="Mean Corpuscular Hemoglobin") | |
| Mean_Corpuscular_Hemoglobin_Concentration: float = Field(..., alias="Mean Corpuscular Hemoglobin Concentration") | |
| Insulin: float | |
| BMI: float | |
| Systolic_Blood_Pressure: float = Field(..., alias="Systolic Blood Pressure") | |
| Diastolic_Blood_Pressure: float = Field(..., alias="Diastolic Blood Pressure") | |
| Triglycerides: float | |
| HbA1c: float | |
| LDL_Cholesterol: float = Field(..., alias="LDL Cholesterol") | |
| HDL_Cholesterol: float = Field(..., alias="HDL Cholesterol") | |
| ALT: float | |
| AST: float | |
| Heart_Rate: float = Field(..., alias="Heart Rate") | |
| Creatinine: float | |
| Troponin: float | |
| C_reactive_Protein: float = Field(..., alias="C-reactive Protein") | |
| class Config: | |
| populate_by_name = True | |
| class PredictionResponse(BaseModel): | |
| """Response model for disease prediction""" | |
| prediction: str | |
| confidence: float | |
| top_5_predictions: List[Dict[str, float]] | |
| raw_values: Dict[str, float] | |
| model_info: Dict[str, Any] | |
| class HealthResponse(BaseModel): | |
| """Health check response""" | |
| status: str | |
| model_loaded: bool | |
| feature_count: int | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator | |
| async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | |
| """Lifespan context manager for startup and shutdown events""" | |
| global rf_model, nn_model, meta_model, scaler, label_encoder, feature_cols | |
| # Startup | |
| try: | |
| logger.info("Loading models...") | |
| # Load models | |
| rf_model = joblib.load(MODEL_DIR / "rf_model.pkl") | |
| nn_model = joblib.load(MODEL_DIR / "nn_model.pkl") | |
| meta_model = joblib.load(MODEL_DIR / "meta_model.pkl") | |
| scaler = joblib.load(MODEL_DIR / "scaler.pkl") | |
| label_encoder = joblib.load(MODEL_DIR / "label_encoder.pkl") | |
| feature_cols = joblib.load(MODEL_DIR / "feature_cols.pkl") | |
| logger.info(f"✓ Models loaded successfully!") | |
| logger.info(f"✓ Feature count: {len(feature_cols)}") | |
| logger.info(f"✓ Classes: {list(label_encoder.classes_)}") | |
| # Validate classes | |
| invalid_classes = set(label_encoder.classes_) - BACKEND_ALLOWED_DISEASES | |
| if invalid_classes: | |
| logger.error(f"Invalid classes found: {invalid_classes}") | |
| raise ValueError("Model contains invalid disease classes") | |
| except Exception as e: | |
| logger.error(f"❌ Error loading models: {e}") | |
| raise | |
| yield | |
| # Shutdown (cleanup if needed) | |
| logger.info("Shutting down...") | |
| # Initialize FastAPI app with lifespan | |
| app = FastAPI( | |
| title="MediGuard Disease Prediction API", | |
| description="AI-powered disease prediction using stacking ensemble", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def predict_disease(patient_features: np.ndarray): | |
| """ | |
| Predict disease using stacking ensemble | |
| Args: | |
| patient_features: Array of biomarker values | |
| Returns: | |
| Tuple of (disease, confidence, top_3_predictions) | |
| """ | |
| # Validate features | |
| if len(patient_features) != len(feature_cols): | |
| raise ValueError( | |
| f"Expected {len(feature_cols)} features, got {len(patient_features)}" | |
| ) | |
| # Scale features | |
| X_scaled = scaler.transform([patient_features]).astype(np.float32) | |
| # Get base learner predictions | |
| rf_probs = rf_model.predict_proba(X_scaled) | |
| nn_probs = nn_model.predict_proba(X_scaled) | |
| # Create meta-features | |
| X_meta = np.hstack([rf_probs, nn_probs]) | |
| # Get final prediction from meta-learner | |
| y_pred = meta_model.predict(X_meta)[0] | |
| y_proba = meta_model.predict_proba(X_meta)[0] | |
| # Get disease name | |
| disease = label_encoder.inverse_transform([y_pred])[0] | |
| confidence = float(y_proba[y_pred]) | |
| # Get top 5 predictions | |
| top_5_idx = np.argsort(y_proba)[-5:][::-1] | |
| top_5 = [ | |
| { | |
| "disease": label_encoder.inverse_transform([idx])[0], | |
| "probability": float(y_proba[idx]) | |
| } | |
| for idx in top_5_idx | |
| ] | |
| return disease, confidence, top_5 | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "MediGuard Disease Prediction API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/api/health", | |
| "predict": "/api/predict (POST)", | |
| "features": "/api/features", | |
| "diseases": "/api/diseases", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy" if rf_model is not None else "not_ready", | |
| model_loaded=rf_model is not None, | |
| feature_count=len(feature_cols) if feature_cols else 0 | |
| ) | |
| async def get_features(): | |
| """Get list of required features""" | |
| if feature_cols is None: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| return { | |
| "features": feature_cols, | |
| "count": len(feature_cols), | |
| "example": "Use /predict endpoint with biomarker values in this exact order" | |
| } | |
| async def predict(patient_data: PatientData): | |
| """ | |
| Predict disease from patient biomarker data | |
| Args: | |
| patient_data: PatientData object with biomarker key-value pairs | |
| Returns: | |
| PredictionResponse with predicted disease and confidence | |
| """ | |
| # Check if models are loaded | |
| if rf_model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Models not loaded. Please wait for startup to complete." | |
| ) | |
| try: | |
| # Convert Pydantic model to dict and extract values in correct order | |
| input_dict = patient_data.model_dump(by_alias=True) | |
| # Build features array in the order expected by feature_cols | |
| features = [] | |
| for feature_name in feature_cols: | |
| if feature_name not in input_dict: | |
| raise ValueError(f"Missing feature: {feature_name}") | |
| features.append(float(input_dict[feature_name])) | |
| features = np.array(features, dtype=np.float32) | |
| # Predict | |
| disease, confidence, top_5 = predict_disease(features) | |
| logger.info(f"Prediction: {disease} ({confidence*100:.2f}%)") | |
| return PredictionResponse( | |
| prediction=disease, | |
| confidence=confidence, | |
| top_5_predictions=top_5, | |
| raw_values=input_dict, | |
| model_info={ | |
| "base_models": 2, # rf_model and nn_model | |
| "features_used": len(feature_cols), | |
| "meta_input_shape": [1, len(label_encoder.classes_) * 2], | |
| "n_classes": len(label_encoder.classes_) | |
| } | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def get_diseases(): | |
| """Get list of all possible diseases the model can predict""" | |
| if label_encoder is None: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| return { | |
| "diseases": list(label_encoder.classes_), | |
| "count": len(label_encoder.classes_) | |
| } | |
| # For local testing | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |