import os import numpy as np import tensorflow as tf import joblib import shap import pandas as pd from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from schemas import ( DiabetesOutput, DiabetesInput, SymptomsInput, SymptomsOutput, DetectionResult, ) from utils import preprocess_and_split, predict_split app = FastAPI(title="ML Healthcare API", lifespan=None) # ========================= # CORS # ========================= app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ========================= # GLOBAL VARIABLES # ========================= retinopathy_model_1 = None retinopathy_model_2 = None diabetes_model = None symptoms_model = None symptoms_explainer = None # ========================= # PATH SETUP # ========================= BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_DIR = os.path.join(BASE_DIR, "models") # ========================= # LIFESPAN # ========================= @asynccontextmanager async def lifespan(app: FastAPI): global retinopathy_model_1, retinopathy_model_2 global diabetes_model, symptoms_model, symptoms_explainer print("Application starting...") # Load models retinopathy_model_1 = tf.keras.models.load_model( os.path.join(MODEL_DIR, "detection_model.keras") ) retinopathy_model_2 = tf.keras.models.load_model( os.path.join(MODEL_DIR, "new_model.keras") ) diabetes_model = joblib.load( os.path.join(MODEL_DIR, "diabetes_model_logistic.pkl") ) symptoms_model = joblib.load( os.path.join(MODEL_DIR, "symptoms_diabetes_model.pkl") ) # Warm-up CNN dummy = np.zeros((1, 224, 224, 3)) retinopathy_model_1.predict(dummy, verbose=0) retinopathy_model_2.predict(dummy, verbose=0) # SHAP only for Logistic symptoms model try: background = np.zeros((1, 16)) symptoms_explainer = shap.LinearExplainer(symptoms_model, background) except Exception as e: print("SHAP init error (symptoms):", e) symptoms_explainer = None print("Models loaded and warmed up") yield print("Application shutting down...") app.router.lifespan_context = lifespan # ========================= # ROOT # ========================= @app.get("/") async def root(): return {"message": "Welcome to Healthcare API!"} @app.get("/health") async def health(): return {"status": "healthy"} # ========================= # RETINOPATHY # ========================= @app.post("/detect", response_model=DetectionResult) async def detect_retinopathy(file: UploadFile = File(...)): try: image_bytes = await file.read() imgs1, imgs2 = preprocess_and_split(image_bytes) label, confidence = predict_split( retinopathy_model_1, retinopathy_model_2, imgs1, imgs2 ) return DetectionResult( id=1, retinopathy_level=label, confidence=round(confidence, 4) ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ========================= # DIABETES (NUMERIC - LOGISTIC) # ========================= @app.post("/predict", response_model=DiabetesOutput) def predict(data: DiabetesInput): try: features = pd.DataFrame([{ "Pregnancies": data.Pregnancies, "Glucose": data.Glucose, "BloodPressure": data.BloodPressure, "SkinThickness": data.SkinThickness, "Insulin": data.Insulin, "BMI": data.BMI, "DiabetesPedigreeFunction": data.DiabetesPedigreeFunction, "Age": data.Age }]) prediction = diabetes_model.predict(features)[0] probability = diabetes_model.predict_proba(features)[0][1] importance = {} try: # Use neutral background instead of same input background = np.zeros((1, 8)) explainer = shap.LinearExplainer(diabetes_model, background) shap_values = explainer(features) values = shap_values.values[0] feature_names = features.columns.tolist() importance = dict(zip(feature_names, values.tolist())) importance = dict( sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True) ) except Exception as e: print("SHAP error (diabetes):", e) return DiabetesOutput( prediction=int(prediction), probability=float(probability), feature_importance=importance ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ========================= # DIABETES (SYMPTOMS) # ========================= @app.post("/predict-symptoms", response_model=SymptomsOutput) def predict_symptoms(data: SymptomsInput): try: features = pd.DataFrame([{ "Age": data.Age, "Gender": data.Gender, "Polyuria": data.Polyuria, "Polydipsia": data.Polydipsia, "sudden weight loss": data.sudden_weight_loss, "weakness": data.weakness, "Polyphagia": data.Polyphagia, "Genital thrush": data.Genital_thrush, "visual blurring": data.visual_blurring, "Itching": data.Itching, "Irritability": data.Irritability, "delayed healing": data.delayed_healing, "partial paresis": data.partial_paresis, "muscle stiffness": data.muscle_stiffness, "Alopecia": data.Alopecia, "Obesity": data.Obesity }]) prediction = symptoms_model.predict(features)[0] probability = symptoms_model.predict_proba(features)[0][1] importance = {} try: if symptoms_explainer: shap_values = symptoms_explainer(features) values = shap_values.values[0] feature_names = features.columns.tolist() importance = dict(zip(feature_names, values.tolist())) importance = dict( sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True) ) except Exception as e: print("SHAP error (symptoms):", e) return SymptomsOutput( prediction=int(prediction), probability=float(probability), feature_importance=importance ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))