Sriomdash's picture
Update main.py
b1b2e8c verified
import os
import numpy as np
import tensorflow as tf
import joblib
import shap
import pandas as pd
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from schemas import (
DiabetesOutput,
DiabetesInput,
SymptomsInput,
SymptomsOutput,
DetectionResult,
)
from utils import preprocess_and_split, predict_split
app = FastAPI(title="ML Healthcare API", lifespan=None)
# =========================
# CORS
# =========================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# =========================
# GLOBAL VARIABLES
# =========================
retinopathy_model_1 = None
retinopathy_model_2 = None
diabetes_model = None
symptoms_model = None
symptoms_explainer = None
# =========================
# PATH SETUP
# =========================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "models")
# =========================
# LIFESPAN
# =========================
@asynccontextmanager
async def lifespan(app: FastAPI):
global retinopathy_model_1, retinopathy_model_2
global diabetes_model, symptoms_model, symptoms_explainer
print("Application starting...")
# Load models
retinopathy_model_1 = tf.keras.models.load_model(
os.path.join(MODEL_DIR, "detection_model.keras")
)
retinopathy_model_2 = tf.keras.models.load_model(
os.path.join(MODEL_DIR, "new_model.keras")
)
diabetes_model = joblib.load(
os.path.join(MODEL_DIR, "diabetes_model_logistic.pkl")
)
symptoms_model = joblib.load(
os.path.join(MODEL_DIR, "symptoms_diabetes_model.pkl")
)
# Warm-up CNN
dummy = np.zeros((1, 224, 224, 3))
retinopathy_model_1.predict(dummy, verbose=0)
retinopathy_model_2.predict(dummy, verbose=0)
# SHAP only for Logistic symptoms model
try:
background = np.zeros((1, 16))
symptoms_explainer = shap.LinearExplainer(symptoms_model, background)
except Exception as e:
print("SHAP init error (symptoms):", e)
symptoms_explainer = None
print("Models loaded and warmed up")
yield
print("Application shutting down...")
app.router.lifespan_context = lifespan
# =========================
# ROOT
# =========================
@app.get("/")
async def root():
return {"message": "Welcome to Healthcare API!"}
@app.get("/health")
async def health():
return {"status": "healthy"}
# =========================
# RETINOPATHY
# =========================
@app.post("/detect", response_model=DetectionResult)
async def detect_retinopathy(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
imgs1, imgs2 = preprocess_and_split(image_bytes)
label, confidence = predict_split(
retinopathy_model_1,
retinopathy_model_2,
imgs1,
imgs2
)
return DetectionResult(
id=1,
retinopathy_level=label,
confidence=round(confidence, 4)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =========================
# DIABETES (NUMERIC - LOGISTIC)
# =========================
@app.post("/predict", response_model=DiabetesOutput)
def predict(data: DiabetesInput):
try:
features = pd.DataFrame([{
"Pregnancies": data.Pregnancies,
"Glucose": data.Glucose,
"BloodPressure": data.BloodPressure,
"SkinThickness": data.SkinThickness,
"Insulin": data.Insulin,
"BMI": data.BMI,
"DiabetesPedigreeFunction": data.DiabetesPedigreeFunction,
"Age": data.Age
}])
prediction = diabetes_model.predict(features)[0]
probability = diabetes_model.predict_proba(features)[0][1]
importance = {}
try:
# Use neutral background instead of same input
background = np.zeros((1, 8))
explainer = shap.LinearExplainer(diabetes_model, background)
shap_values = explainer(features)
values = shap_values.values[0]
feature_names = features.columns.tolist()
importance = dict(zip(feature_names, values.tolist()))
importance = dict(
sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True)
)
except Exception as e:
print("SHAP error (diabetes):", e)
return DiabetesOutput(
prediction=int(prediction),
probability=float(probability),
feature_importance=importance
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =========================
# DIABETES (SYMPTOMS)
# =========================
@app.post("/predict-symptoms", response_model=SymptomsOutput)
def predict_symptoms(data: SymptomsInput):
try:
features = pd.DataFrame([{
"Age": data.Age,
"Gender": data.Gender,
"Polyuria": data.Polyuria,
"Polydipsia": data.Polydipsia,
"sudden weight loss": data.sudden_weight_loss,
"weakness": data.weakness,
"Polyphagia": data.Polyphagia,
"Genital thrush": data.Genital_thrush,
"visual blurring": data.visual_blurring,
"Itching": data.Itching,
"Irritability": data.Irritability,
"delayed healing": data.delayed_healing,
"partial paresis": data.partial_paresis,
"muscle stiffness": data.muscle_stiffness,
"Alopecia": data.Alopecia,
"Obesity": data.Obesity
}])
prediction = symptoms_model.predict(features)[0]
probability = symptoms_model.predict_proba(features)[0][1]
importance = {}
try:
if symptoms_explainer:
shap_values = symptoms_explainer(features)
values = shap_values.values[0]
feature_names = features.columns.tolist()
importance = dict(zip(feature_names, values.tolist()))
importance = dict(
sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True)
)
except Exception as e:
print("SHAP error (symptoms):", e)
return SymptomsOutput(
prediction=int(prediction),
probability=float(probability),
feature_importance=importance
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))