maxxxi100 commited on
Commit
b8ae86c
·
verified ·
1 Parent(s): e743899

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -46
app.py CHANGED
@@ -1,72 +1,104 @@
1
-
2
-
3
  import torch
4
  from monai.networks.nets import DenseNet121
5
-
6
  import gradio as gr
7
-
8
- #from PIL import Image
9
-
10
- model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6)
11
- model.load_state_dict(torch.load('weights/mednist_model.pth', map_location=torch.device('cpu')))
12
-
13
  from monai.transforms import (
14
  EnsureChannelFirst,
15
  Compose,
16
  LoadImage,
17
  ScaleIntensity,
18
- # NUEVA IMPORTACIÓN
19
  Resize,
20
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # 2. DEFINICIÓN DE TRANSFORMACIONES MEJORADAS
23
- # Redimensionamos explícitamente a 64x64, un tamaño común para MedNIST,
24
- # y nos aseguramos de que el cargador maneje la conversión a escala de grises ('L').
25
  test_transforms = Compose(
26
  [
27
- LoadImage(image_only=True, convert_to_tensor=True), # Asegura que se carga como tensor
28
- EnsureChannelFirst(), # Asegura la forma [C, H, W]
29
- Resize(spatial_size=(64, 64)), # REDIMENSIONAMIENTO CLAVE
30
- ScaleIntensity(), # Normaliza
 
 
 
31
  ]
32
  )
33
 
34
-
35
- class_names = [
36
- 'AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT'
37
- ]
38
-
39
- import os, glob
40
-
41
- #examples_dir = './samples'
42
- #example_files = glob.glob(os.path.join(examples_dir, '*.jpg'))
43
 
44
  def classify_image(image_filepath):
45
- input = test_transforms(image_filepath)
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  model.eval()
48
  with torch.no_grad():
49
- pred = model(input.unsqueeze(dim=0))
50
-
51
- prob = torch.nn.functional.softmax(pred[0], dim=0)
52
-
53
- confidences = {class_names[i]: float(prob[i]) for i in range(6)}
54
- print(confidences)
55
-
56
- return confidences
57
-
58
-
59
- with gr.Blocks(title="Medical Image Classification with MONAI - ClassCat",
60
- css=".gradio-container {background:mintcream;}"
61
- ) as demo:
 
 
62
  gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Medical Image Classification with MONAI</div>""")
63
-
64
- with gr.Row():
65
- input_image = gr.Image(type="filepath", image_mode="L")
 
66
  output_label=gr.Label(label="Probabilities", num_top_classes=3)
67
-
68
  send_btn = gr.Button("Infer")
69
  send_btn.click(fn=classify_image, inputs=input_image, outputs=output_label)
 
 
70
 
71
  with gr.Row():
72
  gr.Examples(['./samples/mednist_AbdomenCT00.png'], label='Sample images : AbdomenCT', inputs=input_image)
@@ -75,5 +107,5 @@ with gr.Blocks(title="Medical Image Classification with MONAI - ClassCat",
75
  gr.Examples(['./samples/mednist_Hand01.png'], label='Hand', inputs=input_image)
76
  gr.Examples(['./samples/mednist_HeadCT07.png'], label='HeadCT', inputs=input_image)
77
 
78
- #demo.queue(concurrency_count=3)
79
  demo.launch(debug=True)
 
 
 
 
1
  import torch
2
  from monai.networks.nets import DenseNet121
 
3
  import gradio as gr
 
 
 
 
 
 
4
  from monai.transforms import (
5
  EnsureChannelFirst,
6
  Compose,
7
  LoadImage,
8
  ScaleIntensity,
 
9
  Resize,
10
  )
11
+ import os, glob
12
+
13
+ # =================================================================
14
+ # 1. CARGA Y MODELO (Sin cambios necesarios)
15
+ # =================================================================
16
+ model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6)
17
+ # **Importante:** Asegúrate de que el archivo 'weights/mednist_model.pth' exista y sea accesible.
18
+ try:
19
+ model.load_state_dict(torch.load('weights/mednist_model.pth', map_location=torch.device('cpu')))
20
+ except FileNotFoundError:
21
+ print("ERROR: No se encontró el archivo de pesos 'mednist_model.pth'. La aplicación fallará.")
22
+ # Puedes añadir un placeholder o salir si el archivo no existe.
23
+ pass
24
+
25
+ class_names = [
26
+ 'AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
27
+
28
+ # =================================================================
29
+ # 2. TRANSFORMACIONES (Añadido 'image_only=True' para robustez)
30
+ # =================================================================
31
+ # La resolución de 64x64 es un supuesto común para MedNIST, si el entrenamiento
32
+ # usó otro tamaño, ajusta 'spatial_size'.
33
 
 
 
 
34
  test_transforms = Compose(
35
  [
36
+ # LoadImage ahora usa 'image_only=True' para devolver un tensor simple
37
+ # y no un diccionario, simplificando la tubería.
38
+ # Además, añadiremos 'convert_to_tensor=True' y la gestión de canales.
39
+ LoadImage(image_only=True),
40
+ EnsureChannelFirst(),
41
+ Resize(spatial_size=(64, 64)),
42
+ ScaleIntensity(), # Normaliza al rango [0, 1]
43
  ]
44
  )
45
 
46
+ # =================================================================
47
+ # 3. FUNCIÓN DE CLASIFICACIÓN (Manejo de errores mejorado)
48
+ # =================================================================
 
 
 
 
 
 
49
 
50
  def classify_image(image_filepath):
51
+ # Manejamos explícitamente el caso donde no hay imagen
52
+ if image_filepath is None:
53
+ return {'Error': 1.0}
54
+
55
+ try:
56
+ # Aplicar las transformaciones.
57
+ # MONAI cargará y preprocesará la imagen.
58
+ input_tensor = test_transforms(image_filepath)
59
+ except Exception as e:
60
+ print(f"Error durante el preprocesamiento de la imagen: {e}")
61
+ # Devolvemos un error explícito a Gradio
62
+ return {"Error en preprocesamiento": 1.0}
63
+
64
+ # Aseguramos que la forma del tensor de entrada sea [1, 1, 64, 64] para el modelo 2D
65
+ # Si la imagen tiene 3 canales y LoadImage falla la conversión, esto fallará aquí o antes.
66
+ if input_tensor.ndim == 3 and input_tensor.shape[0] == 1:
67
+ input_tensor = input_tensor.unsqueeze(dim=0) # Añadir dimensión de lote: [1, C, H, W]
68
+ else:
69
+ print(f"Forma inesperada del tensor después de EnsureChannelFirst: {input_tensor.shape}")
70
+ return {"Error de forma (Canales/Dimensiones)": 1.0}
71
+
72
+
73
+ # Inferencia
74
  model.eval()
75
  with torch.no_grad():
76
+ try:
77
+ pred = model(input_tensor)
78
+ prob = torch.nn.functional.softmax(pred[0], dim=0)
79
+ confidences = {class_names[i]: float(prob[i]) for i in range(6)}
80
+ print(confidences)
81
+ return confidences
82
+ except Exception as e:
83
+ print(f"Error durante la inferencia del modelo: {e}")
84
+ return {"Error de inferencia (El modelo falló)": 1.0}
85
+
86
+ # =================================================================
87
+ # 4. INTERFAZ DE USUARIO (Ajuste de Gradio)
88
+ # =================================================================
89
+
90
+ with gr.Blocks(title="Medical Image Classification with MONAI - ClassCat", css=".gradio-container {background:mintcream;}" ) as demo:
91
  gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">Medical Image Classification with MONAI</div>""")
92
+ with gr.Row():
93
+ # **AJUSTE CRUCIAL:** 'image_mode="L"' pide a Gradio que convierta
94
+ # la imagen a escala de grises al subirla, previniendo errores de 3 canales.
95
+ input_image = gr.Image(type="filepath", image_mode="L")
96
  output_label=gr.Label(label="Probabilities", num_top_classes=3)
97
+
98
  send_btn = gr.Button("Infer")
99
  send_btn.click(fn=classify_image, inputs=input_image, outputs=output_label)
100
+
101
+ # ... (Sección de ejemplos sin cambios) ...
102
 
103
  with gr.Row():
104
  gr.Examples(['./samples/mednist_AbdomenCT00.png'], label='Sample images : AbdomenCT', inputs=input_image)
 
107
  gr.Examples(['./samples/mednist_Hand01.png'], label='Hand', inputs=input_image)
108
  gr.Examples(['./samples/mednist_HeadCT07.png'], label='HeadCT', inputs=input_image)
109
 
 
110
  demo.launch(debug=True)
111
+