Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import tensorflow as tf | |
| import joblib | |
| import shap | |
| import pandas as pd | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| from schemas import ( | |
| DiabetesOutput, | |
| DiabetesInput, | |
| SymptomsInput, | |
| SymptomsOutput, | |
| DetectionResult, | |
| ) | |
| from utils import preprocess_and_split, predict_split | |
| app = FastAPI(title="ML Healthcare API", lifespan=None) | |
| # ========================= | |
| # CORS | |
| # ========================= | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ========================= | |
| # GLOBAL VARIABLES | |
| # ========================= | |
| retinopathy_model_1 = None | |
| retinopathy_model_2 = None | |
| diabetes_model = None | |
| symptoms_model = None | |
| symptoms_explainer = None | |
| # ========================= | |
| # PATH SETUP | |
| # ========================= | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "models") | |
| # ========================= | |
| # LIFESPAN | |
| # ========================= | |
| async def lifespan(app: FastAPI): | |
| global retinopathy_model_1, retinopathy_model_2 | |
| global diabetes_model, symptoms_model, symptoms_explainer | |
| print("Application starting...") | |
| # Load models | |
| retinopathy_model_1 = tf.keras.models.load_model( | |
| os.path.join(MODEL_DIR, "detection_model.keras") | |
| ) | |
| retinopathy_model_2 = tf.keras.models.load_model( | |
| os.path.join(MODEL_DIR, "new_model.keras") | |
| ) | |
| diabetes_model = joblib.load( | |
| os.path.join(MODEL_DIR, "diabetes_model_logistic.pkl") | |
| ) | |
| symptoms_model = joblib.load( | |
| os.path.join(MODEL_DIR, "symptoms_diabetes_model.pkl") | |
| ) | |
| # Warm-up CNN | |
| dummy = np.zeros((1, 224, 224, 3)) | |
| retinopathy_model_1.predict(dummy, verbose=0) | |
| retinopathy_model_2.predict(dummy, verbose=0) | |
| # SHAP only for Logistic symptoms model | |
| try: | |
| background = np.zeros((1, 16)) | |
| symptoms_explainer = shap.LinearExplainer(symptoms_model, background) | |
| except Exception as e: | |
| print("SHAP init error (symptoms):", e) | |
| symptoms_explainer = None | |
| print("Models loaded and warmed up") | |
| yield | |
| print("Application shutting down...") | |
| app.router.lifespan_context = lifespan | |
| # ========================= | |
| # ROOT | |
| # ========================= | |
| async def root(): | |
| return {"message": "Welcome to Healthcare API!"} | |
| async def health(): | |
| return {"status": "healthy"} | |
| # ========================= | |
| # RETINOPATHY | |
| # ========================= | |
| async def detect_retinopathy(file: UploadFile = File(...)): | |
| try: | |
| image_bytes = await file.read() | |
| imgs1, imgs2 = preprocess_and_split(image_bytes) | |
| label, confidence = predict_split( | |
| retinopathy_model_1, | |
| retinopathy_model_2, | |
| imgs1, | |
| imgs2 | |
| ) | |
| return DetectionResult( | |
| id=1, | |
| retinopathy_level=label, | |
| confidence=round(confidence, 4) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ========================= | |
| # DIABETES (NUMERIC - LOGISTIC) | |
| # ========================= | |
| def predict(data: DiabetesInput): | |
| try: | |
| features = pd.DataFrame([{ | |
| "Pregnancies": data.Pregnancies, | |
| "Glucose": data.Glucose, | |
| "BloodPressure": data.BloodPressure, | |
| "SkinThickness": data.SkinThickness, | |
| "Insulin": data.Insulin, | |
| "BMI": data.BMI, | |
| "DiabetesPedigreeFunction": data.DiabetesPedigreeFunction, | |
| "Age": data.Age | |
| }]) | |
| prediction = diabetes_model.predict(features)[0] | |
| probability = diabetes_model.predict_proba(features)[0][1] | |
| importance = {} | |
| try: | |
| # Use neutral background instead of same input | |
| background = np.zeros((1, 8)) | |
| explainer = shap.LinearExplainer(diabetes_model, background) | |
| shap_values = explainer(features) | |
| values = shap_values.values[0] | |
| feature_names = features.columns.tolist() | |
| importance = dict(zip(feature_names, values.tolist())) | |
| importance = dict( | |
| sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True) | |
| ) | |
| except Exception as e: | |
| print("SHAP error (diabetes):", e) | |
| return DiabetesOutput( | |
| prediction=int(prediction), | |
| probability=float(probability), | |
| feature_importance=importance | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ========================= | |
| # DIABETES (SYMPTOMS) | |
| # ========================= | |
| def predict_symptoms(data: SymptomsInput): | |
| try: | |
| features = pd.DataFrame([{ | |
| "Age": data.Age, | |
| "Gender": data.Gender, | |
| "Polyuria": data.Polyuria, | |
| "Polydipsia": data.Polydipsia, | |
| "sudden weight loss": data.sudden_weight_loss, | |
| "weakness": data.weakness, | |
| "Polyphagia": data.Polyphagia, | |
| "Genital thrush": data.Genital_thrush, | |
| "visual blurring": data.visual_blurring, | |
| "Itching": data.Itching, | |
| "Irritability": data.Irritability, | |
| "delayed healing": data.delayed_healing, | |
| "partial paresis": data.partial_paresis, | |
| "muscle stiffness": data.muscle_stiffness, | |
| "Alopecia": data.Alopecia, | |
| "Obesity": data.Obesity | |
| }]) | |
| prediction = symptoms_model.predict(features)[0] | |
| probability = symptoms_model.predict_proba(features)[0][1] | |
| importance = {} | |
| try: | |
| if symptoms_explainer: | |
| shap_values = symptoms_explainer(features) | |
| values = shap_values.values[0] | |
| feature_names = features.columns.tolist() | |
| importance = dict(zip(feature_names, values.tolist())) | |
| importance = dict( | |
| sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True) | |
| ) | |
| except Exception as e: | |
| print("SHAP error (symptoms):", e) | |
| return SymptomsOutput( | |
| prediction=int(prediction), | |
| probability=float(probability), | |
| feature_importance=importance | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |