Spaces:
Running
Running
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import joblib | |
| import tensorflow as tf | |
| from PIL import Image | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.decomposition import PCA | |
| from sklearn.svm import SVC | |
| from collections import Counter | |
| # Set up paths relative to this file | |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| MODELS_DIR = os.path.join(BASE_DIR, "models") | |
| # MRI Settings | |
| MRI_IMG_SIZE = 224 | |
| DL_IMG_SIZE = 224 | |
| # Classes | |
| MRI_CLASSES = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented'] | |
| MRI_CLASSES_FR = ['Démence Légère', 'Démence Modérée', 'Sain', 'Démence Très Légère'] | |
| # --- Tabular Model Utilities (Risk Prediction) --- | |
| def load_tabular_artifacts(): | |
| try: | |
| model = joblib.load(os.path.join(MODELS_DIR, "tabular_model.joblib")) | |
| scaler = joblib.load(os.path.join(MODELS_DIR, "tabular_scaler.joblib")) | |
| pca = joblib.load(os.path.join(MODELS_DIR, "tabular_pca.joblib")) | |
| features = joblib.load(os.path.join(MODELS_DIR, "tabular_features.joblib")) | |
| return model, scaler, pca, features | |
| except Exception as e: | |
| print(f"Error loading tabular artifacts: {e}") | |
| return None, None, None, None | |
| def predict_risk(input_data): | |
| model, scaler, pca, features = load_tabular_artifacts() | |
| if model is None: return None, None | |
| try: | |
| # Ensure input data has the required features | |
| X = input_data[features] | |
| X_scaled = scaler.transform(X) | |
| X_pca = pca.transform(X_scaled) | |
| prediction = model.predict(X_pca) | |
| probability = model.predict_proba(X_pca) | |
| return prediction[0], probability[0][1] | |
| except Exception as e: | |
| print(f"Error during tabular prediction: {e}") | |
| return None, None | |
| # --- Recommendation System Utilities --- | |
| def load_recommendation_artifacts(): | |
| try: | |
| clusterer = joblib.load(os.path.join(MODELS_DIR, "care_clusterer.joblib")) | |
| recommender = joblib.load(os.path.join(MODELS_DIR, "care_recommender.joblib")) | |
| scaler = joblib.load(os.path.join(MODELS_DIR, "care_scaler.joblib")) | |
| ref_df = pd.read_csv(os.path.join(MODELS_DIR, "patient_care_reference.csv")) | |
| # Features used in recommendation are the same 16 Lasso features | |
| _, _, _, features = load_tabular_artifacts() | |
| return clusterer, recommender, scaler, ref_df, features | |
| except Exception as e: | |
| print(f"Error loading recommendation artifacts: {e}") | |
| return None, None, None, None, None | |
| def get_care_recommendation(patient_input_df): | |
| clusterer, recommender, scaler, ref_df, features = load_recommendation_artifacts() | |
| if clusterer is None: return None | |
| try: | |
| # 1. Scale input | |
| X = patient_input_df[features] | |
| X_scaled = scaler.transform(X) | |
| # 2. Clustering (Find the patient's clinical group) | |
| cluster_id = int(clusterer.predict(X_scaled)[0]) | |
| # 3. Similarity (Find 10 closest clinical peers) | |
| distances, indices = recommender.kneighbors(X_scaled) | |
| peer_indices = indices[0] | |
| peers = ref_df.iloc[peer_indices] | |
| # 4. Consensus-based Recommendations | |
| reco_focus = peers['PrimaryFocus'].mode()[0] | |
| reco_type = peers['CarePlanType'].mode()[0] | |
| reco_actions = peers['RecommendedActions'].mode()[0] | |
| # Calculate a consensus score based on the peers | |
| consensus_count = len(peers[peers['PrimaryFocus'] == reco_focus]) | |
| return { | |
| "cluster_id": cluster_id, | |
| "recommended_focus": reco_focus, | |
| "care_plan_type": reco_type, | |
| "actions": reco_actions, | |
| "confidence_score": f"{consensus_count}/10 peers agree" | |
| } | |
| except Exception as e: | |
| print(f"Error during recommendation: {e}") | |
| return None | |
| # --- MRI Model Utilities (Load once on startup for speed) --- | |
| def _load_model_safe(filename): | |
| path = os.path.join(MODELS_DIR, filename) | |
| if not os.path.exists(path): | |
| return f"MISSING: {filename}" | |
| try: | |
| return tf.keras.models.load_model(path) | |
| except Exception as e: | |
| return f"LOAD_ERROR: {str(e)}" | |
| # Global model cache | |
| MRI_MODELS = { | |
| 'cnn': _load_model_safe("alzheimer_cnn.keras"), | |
| 'resnet': _load_model_safe("alzheimer_resnet.keras") | |
| } | |
| def predict_mri(image): | |
| # Check if models loaded correctly or have error strings | |
| cnn_model = MRI_MODELS.get('cnn') | |
| resnet_model = MRI_MODELS.get('resnet') | |
| # If it's a string, it means an error occurred during loading | |
| if isinstance(cnn_model, str): | |
| return None, None, {"error": f"CNN Model Error: {cnn_model}"} | |
| if isinstance(resnet_model, str): | |
| return None, None, {"error": f"ResNet Model Error: {resnet_model}"} | |
| if cnn_model is None or resnet_model is None: | |
| return None, None, {"error": "Models not initialized"} | |
| try: | |
| img_rgb = image.convert('RGB') | |
| img_resized = img_rgb.resize((DL_IMG_SIZE, DL_IMG_SIZE)) | |
| img_array = np.array(img_resized) / 255.0 | |
| img_dl_input = np.expand_dims(img_array, axis=0) | |
| # DL Predictions | |
| cnn_probs = cnn_model.predict(img_dl_input, verbose=0)[0] | |
| cnn_pred_idx = np.argmax(cnn_probs) | |
| resnet_probs = resnet_model.predict(img_dl_input, verbose=0)[0] | |
| resnet_pred_idx = np.argmax(resnet_probs) | |
| # Primary Diagnosis: Use ResNet50 | |
| final_class_idx = resnet_pred_idx | |
| confidence = resnet_probs[resnet_pred_idx] | |
| # Details for the UI | |
| details = { | |
| "Primary (ResNet50)": MRI_CLASSES_FR[resnet_pred_idx], | |
| "Primary Confidence": round(float(resnet_probs[resnet_pred_idx]), 4), | |
| "Secondary (Custom CNN)": MRI_CLASSES_FR[cnn_pred_idx], | |
| "Secondary Confidence": round(float(cnn_probs[cnn_pred_idx]), 4), | |
| "Status": "Agreement" if resnet_pred_idx == cnn_pred_idx else "Minority Disagreement" | |
| } | |
| return MRI_CLASSES_FR[final_class_idx], confidence, details | |
| except Exception as e: | |
| print(f"Error during MRI prediction: {e}") | |
| return None, None, None | |