maxxxi100 commited on
Commit
4885409
verified
1 Parent(s): ce3bd1e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+
7
+ # ===============================================
8
+ # 1. CONFIGURACI脫N Y CARGA DEL MODELO
9
+ # ===============================================
10
+
11
+ # Definici贸n de las 8 clases de salida
12
+ # NOTA: El modelo DENSENET201 DEBE haber sido entrenado con estas 8 clases.
13
+ CLASS_NAMES = [
14
+ "Normal",
15
+ "Infarto Agudo",
16
+ "Infarto Antiguo",
17
+ "Fibrilaci贸n Auricular",
18
+ "Bloqueo de Rama Izquierda",
19
+ "Bloqueo de Rama Derecha",
20
+ "Extras铆stole Auricular",
21
+ "Extras铆stole Ventricular"
22
+ ]
23
+
24
+ # Configuraci贸n del dispositivo (GPU si est谩 disponible, sino CPU)
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ def load_model(model_path: str = "densenet_model.pth"):
28
+ """Carga los pesos del modelo DenseNet201 entrenado."""
29
+
30
+ # 1. Definir la arquitectura base (DenseNet201)
31
+ # Se usa weights=None porque los pesos ser谩n cargados de nuestro .pth
32
+ model = models.densenet201(weights=None)
33
+
34
+ # 2. Reemplazar la capa final para que coincida con el n煤mero de clases (8)
35
+ num_ftrs = model.classifier.in_features
36
+ model.classifier = nn.Linear(num_ftrs, len(CLASS_NAMES))
37
+
38
+ # 3. Cargar los pesos entrenados
39
+ try:
40
+ model.load_state_dict(torch.load(model_path, map_location=device))
41
+ model.to(device)
42
+ model.eval() # Poner el modelo en modo de evaluaci贸n
43
+ print(f"Modelo cargado con 茅xito desde {model_path} en {device}.")
44
+ return model
45
+ except FileNotFoundError:
46
+ print(f"Error: El archivo del modelo '{model_path}' no se encontr贸. 隆La aplicaci贸n no funcionar谩 correctamente sin 茅l!")
47
+ return None
48
+
49
+ # Cargar el modelo globalmente (隆solo una vez!)
50
+ ecg_model = load_model()
51
+
52
+ # ===============================================
53
+ # 2. FUNCI脫N DE PREDICCI脫N
54
+ # ===============================================
55
+
56
+ # Definici贸n de las transformaciones necesarias para el pre-procesamiento
57
+ preprocess = transforms.Compose([
58
+ transforms.Resize((224, 224)),
59
+ transforms.ToTensor(),
60
+ # Valores de normalizaci贸n est谩ndar para modelos pre-entrenados
61
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
62
+ ])
63
+
64
+ def predict(image: Image.Image):
65
+ """Realiza la inferencia en la imagen del ECG."""
66
+ if ecg_model is None:
67
+ # Devuelve un mensaje de error legible en la UI de Gradio
68
+ return {"ERROR: Modelo no cargado (Falta .pth)": 1.0}
69
+
70
+ # 1. Preprocesar la imagen
71
+ img_tensor = preprocess(image).unsqueeze(0).to(device)
72
+
73
+ # 2. Inferir sin calcular gradientes (m谩s r谩pido)
74
+ with torch.no_grad():
75
+ output = ecg_model(img_tensor)
76
+
77
+ # 3. Post-procesar (Convertir a probabilidades)
78
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
79
+
80
+ # Crear el diccionario de resultados para Gradio
81
+ results = {
82
+ CLASS_NAMES[i]: float(probabilities[i])
83
+ for i in range(len(CLASS_NAMES))
84
+ }
85
+ return results
86
+
87
+ # ===============================================
88
+ # 3. INTERFAZ GRADIO
89
+ # ===============================================
90
+
91
+ # Asignamos la interfaz a la variable 'demo'
92
+ demo = gr.Interface(
93
+ fn=predict,
94
+ inputs=gr.Image(type="pil", label="Sube una imagen de ECG"),
95
+ # Muestra las 5 clases con mayor probabilidad
96
+ outputs=gr.Label(num_top_classes=5, label="Clasificaci贸n del Modelo"),
97
+ title="An谩lisis de ECG con IA: 8 Clases de Diagn贸stico",
98
+ description="Sube una imagen de tu electrocardiograma. El modelo clasifica en 8 condiciones card铆acas: Normal, Infarto Agudo, Infarto Antiguo, Fibrilaci贸n Auricular, Bloqueo de Rama Izquierda, Bloqueo de Rama Derecha, Extras铆stole Auricular y Extras铆stole Ventricular.",
99
+ allow_flagging="auto"
100
+ )
101
+
102
+ # ===============================================
103
+ # 4. LANZAMIENTO
104
+ # ===============================================
105
+
106
+ if __name__ == "__main__":
107
+ # Usa demo.launch() para iniciar la aplicaci贸n web
108
+ # share=True proporciona un enlace p煤blico temporal
109
+ demo.launch(share=True)