PERDIX / ADNI_Fusion_Engine.py
erick6655
Replace fallback dummies with real ViT ML model from Hugging Face Research
22055af
try:
import json
import os
import torch
from PIL import Image
from transformers import pipeline
except ImportError:
pass
class PerdixAdniEngine:
def __init__(self):
print("[Perdix] Descargando Inteligencia Artificial desde Research Hub...")
# Descargar谩 el modelo pre-entrenado la primera vez
self.classifier = pipeline("image-classification", model="prithivMLmods/Alzheimer-Stage-Classifier")
def diagnose_patient(self, image_path: str, age: float, mmse_score: float, edu_years: float):
"""
Analiza JPGs/PNGs del cerebro usando un Transformers preentrenado.
"""
assert os.path.exists(image_path), f"CR脥TICO: No se puede localizar MRI en {image_path}"
print(f"\n--- [Perdix Engine] Analizando Imagen ---")
try:
image = Image.open(image_path).convert("RGB")
results = self.classifier(image)
except Exception as e:
return {"status": "error", "message": f"Fall贸 leer la imagen con la IA: {e}"}
print(f"Predicciones Crudas del Modelo Transformer: {results}")
# El modelo 'prithivMLmods' suele retornar labels como:
# 'Mild Demented', 'Moderate Demented', 'Non Demented', 'Very Mild Demented'
# Mapeamos a nuestras 3 categor铆as:
# Non Demented -> CN
# Very Mild / Mild -> MCI
# Moderate -> AD
prob_ad = 0.0
prob_mci = 0.0
prob_cn = 0.0
for res in results:
label = res['label'].lower()
score = res['score']
if 'non' in label:
prob_cn += score
elif 'very mild' in label or 'mild' in label and 'moderate' not in label:
prob_mci += score
elif 'moderate' in label or 'severe' in label:
prob_ad += score
else:
# Fallback gen茅rico a AD si etiqueta es rara pero indica demencia
if 'demented' in label:
prob_mci += score
# Ajustamos los pesos finales basado levemente en la edad o MMR (Late Fusion manual simulado)
# Esto es solo si el doctor proporcion贸 un MMSE muy definitorio
if mmse_score < 20:
prob_ad += 0.15
elif mmse_score > 28:
prob_cn += 0.15
total = prob_ad + prob_mci + prob_cn
if total == 0:
total = 1.0 # fallback
prob_ad /= total
prob_mci /= total
prob_cn /= total
probs = [prob_mci, prob_ad, prob_cn]
classes = ["Deterioro Cognitivo Leve (MCI)", "Alzheimer Agudo (AD)", "Cognitivamente Sano (CN)"]
winner_index = int(np.argmax(probs)) if 'np' in globals() else probs.index(max(probs))
return {
"status": "success",
"prediction_class": classes[winner_index],
"confidence_score": round(max(probs) * 100, 2),
"breakdown": {
"AD": round(prob_ad * 100, 2),
"MCI": round(prob_mci * 100, 2),
"CN": round(prob_cn * 100, 2)
}
}