|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLASS_NAMES = [ |
|
|
"Normal", |
|
|
"Infarto Agudo", |
|
|
"Infarto Antiguo", |
|
|
"Fibrilación Auricular", |
|
|
"Bloqueo de Rama Izquierda", |
|
|
"Bloqueo de Rama Derecha", |
|
|
"Extrasístole Auricular", |
|
|
"Extrasístole Ventricular" |
|
|
] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model(model_path: str = "densenet_model.pth"): |
|
|
"""Carga los pesos del modelo DenseNet201 entrenado.""" |
|
|
|
|
|
|
|
|
|
|
|
model = models.densenet201(weights=None) |
|
|
|
|
|
|
|
|
num_ftrs = model.classifier.in_features |
|
|
model.classifier = nn.Linear(num_ftrs, len(CLASS_NAMES)) |
|
|
|
|
|
|
|
|
try: |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
print(f"Modelo cargado con éxito desde {model_path} en {device}.") |
|
|
return model |
|
|
except FileNotFoundError: |
|
|
print(f"Error: El archivo del modelo '{model_path}' no se encontró. ¡La aplicación no funcionará correctamente sin él!") |
|
|
return None |
|
|
|
|
|
|
|
|
ecg_model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def predict(image: Image.Image): |
|
|
"""Realiza la inferencia en la imagen del ECG.""" |
|
|
if ecg_model is None: |
|
|
|
|
|
return {"ERROR: Modelo no cargado (Falta .pth)": 1.0} |
|
|
|
|
|
|
|
|
img_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = ecg_model(img_tensor) |
|
|
|
|
|
|
|
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
|
|
|
results = { |
|
|
CLASS_NAMES[i]: float(probabilities[i]) |
|
|
for i in range(len(CLASS_NAMES)) |
|
|
} |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil", label="Sube una imagen de ECG"), |
|
|
|
|
|
outputs=gr.Label(num_top_classes=5, label="Clasificación del Modelo"), |
|
|
title="Análisis de ECG con IA: 8 Clases de Diagnóstico", |
|
|
description="Sube una imagen de tu electrocardiograma. El modelo clasifica en 8 condiciones cardíacas: Normal, Infarto Agudo, Infarto Antiguo, Fibrilación Auricular, Bloqueo de Rama Izquierda, Bloqueo de Rama Derecha, Extrasístole Auricular y Extrasístole Ventricular.", |
|
|
allow_flagging="auto" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
demo.launch(share=True) |