alzheimer-api / app /utils.py
nessim9898's picture
add info
6fa307d verified
import pandas as pd
import numpy as np
import os
import joblib
import tensorflow as tf
from PIL import Image
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from collections import Counter
# Set up paths relative to this file
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODELS_DIR = os.path.join(BASE_DIR, "models")
# MRI Settings
MRI_IMG_SIZE = 224
DL_IMG_SIZE = 224
# Classes
MRI_CLASSES = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
MRI_CLASSES_FR = ['Démence Légère', 'Démence Modérée', 'Sain', 'Démence Très Légère']
# --- Tabular Model Utilities (Risk Prediction) ---
def load_tabular_artifacts():
try:
model = joblib.load(os.path.join(MODELS_DIR, "tabular_model.joblib"))
scaler = joblib.load(os.path.join(MODELS_DIR, "tabular_scaler.joblib"))
pca = joblib.load(os.path.join(MODELS_DIR, "tabular_pca.joblib"))
features = joblib.load(os.path.join(MODELS_DIR, "tabular_features.joblib"))
return model, scaler, pca, features
except Exception as e:
print(f"Error loading tabular artifacts: {e}")
return None, None, None, None
def predict_risk(input_data):
model, scaler, pca, features = load_tabular_artifacts()
if model is None: return None, None
try:
# Ensure input data has the required features
X = input_data[features]
X_scaled = scaler.transform(X)
X_pca = pca.transform(X_scaled)
prediction = model.predict(X_pca)
probability = model.predict_proba(X_pca)
return prediction[0], probability[0][1]
except Exception as e:
print(f"Error during tabular prediction: {e}")
return None, None
# --- Recommendation System Utilities ---
def load_recommendation_artifacts():
try:
clusterer = joblib.load(os.path.join(MODELS_DIR, "care_clusterer.joblib"))
recommender = joblib.load(os.path.join(MODELS_DIR, "care_recommender.joblib"))
scaler = joblib.load(os.path.join(MODELS_DIR, "care_scaler.joblib"))
ref_df = pd.read_csv(os.path.join(MODELS_DIR, "patient_care_reference.csv"))
# Features used in recommendation are the same 16 Lasso features
_, _, _, features = load_tabular_artifacts()
return clusterer, recommender, scaler, ref_df, features
except Exception as e:
print(f"Error loading recommendation artifacts: {e}")
return None, None, None, None, None
def get_care_recommendation(patient_input_df):
clusterer, recommender, scaler, ref_df, features = load_recommendation_artifacts()
if clusterer is None: return None
try:
# 1. Scale input
X = patient_input_df[features]
X_scaled = scaler.transform(X)
# 2. Clustering (Find the patient's clinical group)
cluster_id = int(clusterer.predict(X_scaled)[0])
# 3. Similarity (Find 10 closest clinical peers)
distances, indices = recommender.kneighbors(X_scaled)
peer_indices = indices[0]
peers = ref_df.iloc[peer_indices]
# 4. Consensus-based Recommendations
reco_focus = peers['PrimaryFocus'].mode()[0]
reco_type = peers['CarePlanType'].mode()[0]
reco_actions = peers['RecommendedActions'].mode()[0]
# Calculate a consensus score based on the peers
consensus_count = len(peers[peers['PrimaryFocus'] == reco_focus])
return {
"cluster_id": cluster_id,
"recommended_focus": reco_focus,
"care_plan_type": reco_type,
"actions": reco_actions,
"confidence_score": f"{consensus_count}/10 peers agree"
}
except Exception as e:
print(f"Error during recommendation: {e}")
return None
# --- MRI Model Utilities (Load once on startup for speed) ---
def _load_model_safe(filename):
path = os.path.join(MODELS_DIR, filename)
if not os.path.exists(path):
return f"MISSING: {filename}"
try:
return tf.keras.models.load_model(path)
except Exception as e:
return f"LOAD_ERROR: {str(e)}"
# Global model cache
MRI_MODELS = {
'cnn': _load_model_safe("alzheimer_cnn.keras"),
'resnet': _load_model_safe("alzheimer_resnet.keras")
}
def predict_mri(image):
# Check if models loaded correctly or have error strings
cnn_model = MRI_MODELS.get('cnn')
resnet_model = MRI_MODELS.get('resnet')
# If it's a string, it means an error occurred during loading
if isinstance(cnn_model, str):
return None, None, {"error": f"CNN Model Error: {cnn_model}"}
if isinstance(resnet_model, str):
return None, None, {"error": f"ResNet Model Error: {resnet_model}"}
if cnn_model is None or resnet_model is None:
return None, None, {"error": "Models not initialized"}
try:
img_rgb = image.convert('RGB')
img_resized = img_rgb.resize((DL_IMG_SIZE, DL_IMG_SIZE))
img_array = np.array(img_resized) / 255.0
img_dl_input = np.expand_dims(img_array, axis=0)
# DL Predictions
cnn_probs = cnn_model.predict(img_dl_input, verbose=0)[0]
cnn_pred_idx = np.argmax(cnn_probs)
resnet_probs = resnet_model.predict(img_dl_input, verbose=0)[0]
resnet_pred_idx = np.argmax(resnet_probs)
# Primary Diagnosis: Use ResNet50
final_class_idx = resnet_pred_idx
confidence = resnet_probs[resnet_pred_idx]
# Details for the UI
details = {
"Primary (ResNet50)": MRI_CLASSES_FR[resnet_pred_idx],
"Primary Confidence": round(float(resnet_probs[resnet_pred_idx]), 4),
"Secondary (Custom CNN)": MRI_CLASSES_FR[cnn_pred_idx],
"Secondary Confidence": round(float(cnn_probs[cnn_pred_idx]), 4),
"Status": "Agreement" if resnet_pred_idx == cnn_pred_idx else "Minority Disagreement"
}
return MRI_CLASSES_FR[final_class_idx], confidence, details
except Exception as e:
print(f"Error during MRI prediction: {e}")
return None, None, None