ECG_App / app.py
maxxxi100's picture
Create app.py
4885409 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# ===============================================
# 1. CONFIGURACIÓN Y CARGA DEL MODELO
# ===============================================
# Definición de las 8 clases de salida
# NOTA: El modelo DENSENET201 DEBE haber sido entrenado con estas 8 clases.
CLASS_NAMES = [
"Normal",
"Infarto Agudo",
"Infarto Antiguo",
"Fibrilación Auricular",
"Bloqueo de Rama Izquierda",
"Bloqueo de Rama Derecha",
"Extrasístole Auricular",
"Extrasístole Ventricular"
]
# Configuración del dispositivo (GPU si está disponible, sino CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(model_path: str = "densenet_model.pth"):
"""Carga los pesos del modelo DenseNet201 entrenado."""
# 1. Definir la arquitectura base (DenseNet201)
# Se usa weights=None porque los pesos serán cargados de nuestro .pth
model = models.densenet201(weights=None)
# 2. Reemplazar la capa final para que coincida con el número de clases (8)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, len(CLASS_NAMES))
# 3. Cargar los pesos entrenados
try:
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval() # Poner el modelo en modo de evaluación
print(f"Modelo cargado con éxito desde {model_path} en {device}.")
return model
except FileNotFoundError:
print(f"Error: El archivo del modelo '{model_path}' no se encontró. ¡La aplicación no funcionará correctamente sin él!")
return None
# Cargar el modelo globalmente (¡solo una vez!)
ecg_model = load_model()
# ===============================================
# 2. FUNCIÓN DE PREDICCIÓN
# ===============================================
# Definición de las transformaciones necesarias para el pre-procesamiento
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
# Valores de normalización estándar para modelos pre-entrenados
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image: Image.Image):
"""Realiza la inferencia en la imagen del ECG."""
if ecg_model is None:
# Devuelve un mensaje de error legible en la UI de Gradio
return {"ERROR: Modelo no cargado (Falta .pth)": 1.0}
# 1. Preprocesar la imagen
img_tensor = preprocess(image).unsqueeze(0).to(device)
# 2. Inferir sin calcular gradientes (más rápido)
with torch.no_grad():
output = ecg_model(img_tensor)
# 3. Post-procesar (Convertir a probabilidades)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Crear el diccionario de resultados para Gradio
results = {
CLASS_NAMES[i]: float(probabilities[i])
for i in range(len(CLASS_NAMES))
}
return results
# ===============================================
# 3. INTERFAZ GRADIO
# ===============================================
# Asignamos la interfaz a la variable 'demo'
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Sube una imagen de ECG"),
# Muestra las 5 clases con mayor probabilidad
outputs=gr.Label(num_top_classes=5, label="Clasificación del Modelo"),
title="Análisis de ECG con IA: 8 Clases de Diagnóstico",
description="Sube una imagen de tu electrocardiograma. El modelo clasifica en 8 condiciones cardíacas: Normal, Infarto Agudo, Infarto Antiguo, Fibrilación Auricular, Bloqueo de Rama Izquierda, Bloqueo de Rama Derecha, Extrasístole Auricular y Extrasístole Ventricular.",
allow_flagging="auto"
)
# ===============================================
# 4. LANZAMIENTO
# ===============================================
if __name__ == "__main__":
# Usa demo.launch() para iniciar la aplicación web
# share=True proporciona un enlace público temporal
demo.launch(share=True)