try: import json import os import torch from PIL import Image from transformers import pipeline except ImportError: pass class PerdixAdniEngine: def __init__(self): print("[Perdix] Descargando Inteligencia Artificial desde Research Hub...") # Descargará el modelo pre-entrenado la primera vez self.classifier = pipeline("image-classification", model="prithivMLmods/Alzheimer-Stage-Classifier") def diagnose_patient(self, image_path: str, age: float, mmse_score: float, edu_years: float): """ Analiza JPGs/PNGs del cerebro usando un Transformers preentrenado. """ assert os.path.exists(image_path), f"CRÍTICO: No se puede localizar MRI en {image_path}" print(f"\n--- [Perdix Engine] Analizando Imagen ---") try: image = Image.open(image_path).convert("RGB") results = self.classifier(image) except Exception as e: return {"status": "error", "message": f"Falló leer la imagen con la IA: {e}"} print(f"Predicciones Crudas del Modelo Transformer: {results}") # El modelo 'prithivMLmods' suele retornar labels como: # 'Mild Demented', 'Moderate Demented', 'Non Demented', 'Very Mild Demented' # Mapeamos a nuestras 3 categorías: # Non Demented -> CN # Very Mild / Mild -> MCI # Moderate -> AD prob_ad = 0.0 prob_mci = 0.0 prob_cn = 0.0 for res in results: label = res['label'].lower() score = res['score'] if 'non' in label: prob_cn += score elif 'very mild' in label or 'mild' in label and 'moderate' not in label: prob_mci += score elif 'moderate' in label or 'severe' in label: prob_ad += score else: # Fallback genérico a AD si etiqueta es rara pero indica demencia if 'demented' in label: prob_mci += score # Ajustamos los pesos finales basado levemente en la edad o MMR (Late Fusion manual simulado) # Esto es solo si el doctor proporcionó un MMSE muy definitorio if mmse_score < 20: prob_ad += 0.15 elif mmse_score > 28: prob_cn += 0.15 total = prob_ad + prob_mci + prob_cn if total == 0: total = 1.0 # fallback prob_ad /= total prob_mci /= total prob_cn /= total probs = [prob_mci, prob_ad, prob_cn] classes = ["Deterioro Cognitivo Leve (MCI)", "Alzheimer Agudo (AD)", "Cognitivamente Sano (CN)"] winner_index = int(np.argmax(probs)) if 'np' in globals() else probs.index(max(probs)) return { "status": "success", "prediction_class": classes[winner_index], "confidence_score": round(max(probs) * 100, 2), "breakdown": { "AD": round(prob_ad * 100, 2), "MCI": round(prob_mci * 100, 2), "CN": round(prob_cn * 100, 2) } }