import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image # =============================================== # 1. CONFIGURACIÓN Y CARGA DEL MODELO # =============================================== # Definición de las 8 clases de salida # NOTA: El modelo DENSENET201 DEBE haber sido entrenado con estas 8 clases. CLASS_NAMES = [ "Normal", "Infarto Agudo", "Infarto Antiguo", "Fibrilación Auricular", "Bloqueo de Rama Izquierda", "Bloqueo de Rama Derecha", "Extrasístole Auricular", "Extrasístole Ventricular" ] # Configuración del dispositivo (GPU si está disponible, sino CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(model_path: str = "densenet_model.pth"): """Carga los pesos del modelo DenseNet201 entrenado.""" # 1. Definir la arquitectura base (DenseNet201) # Se usa weights=None porque los pesos serán cargados de nuestro .pth model = models.densenet201(weights=None) # 2. Reemplazar la capa final para que coincida con el número de clases (8) num_ftrs = model.classifier.in_features model.classifier = nn.Linear(num_ftrs, len(CLASS_NAMES)) # 3. Cargar los pesos entrenados try: model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Poner el modelo en modo de evaluación print(f"Modelo cargado con éxito desde {model_path} en {device}.") return model except FileNotFoundError: print(f"Error: El archivo del modelo '{model_path}' no se encontró. ¡La aplicación no funcionará correctamente sin él!") return None # Cargar el modelo globalmente (¡solo una vez!) ecg_model = load_model() # =============================================== # 2. FUNCIÓN DE PREDICCIÓN # =============================================== # Definición de las transformaciones necesarias para el pre-procesamiento preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), # Valores de normalización estándar para modelos pre-entrenados transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image: Image.Image): """Realiza la inferencia en la imagen del ECG.""" if ecg_model is None: # Devuelve un mensaje de error legible en la UI de Gradio return {"ERROR: Modelo no cargado (Falta .pth)": 1.0} # 1. Preprocesar la imagen img_tensor = preprocess(image).unsqueeze(0).to(device) # 2. Inferir sin calcular gradientes (más rápido) with torch.no_grad(): output = ecg_model(img_tensor) # 3. Post-procesar (Convertir a probabilidades) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Crear el diccionario de resultados para Gradio results = { CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES)) } return results # =============================================== # 3. INTERFAZ GRADIO # =============================================== # Asignamos la interfaz a la variable 'demo' demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Sube una imagen de ECG"), # Muestra las 5 clases con mayor probabilidad outputs=gr.Label(num_top_classes=5, label="Clasificación del Modelo"), title="Análisis de ECG con IA: 8 Clases de Diagnóstico", description="Sube una imagen de tu electrocardiograma. El modelo clasifica en 8 condiciones cardíacas: Normal, Infarto Agudo, Infarto Antiguo, Fibrilación Auricular, Bloqueo de Rama Izquierda, Bloqueo de Rama Derecha, Extrasístole Auricular y Extrasístole Ventricular.", allow_flagging="auto" ) # =============================================== # 4. LANZAMIENTO # =============================================== if __name__ == "__main__": # Usa demo.launch() para iniciar la aplicación web # share=True proporciona un enlace público temporal demo.launch(share=True)