dcavadia commited on
Commit
b2e7796
·
verified ·
1 Parent(s): febc5cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -97
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
  import json
3
  import numpy as np
4
  import gradio as gr
@@ -12,99 +11,98 @@ import os
12
  # ----------------------------
13
  # Model + metadata
14
  # ----------------------------
15
- ORT_PROVIDERS = ["CPUExecutionProvider"] # add "CUDAExecutionProvider" if available [1]
16
- ort_session = ort.InferenceSession("model_new_new_final.onnx", providers=ORT_PROVIDERS) # [1]
17
 
18
  with open("dat.json", "r", encoding="utf-8") as f:
19
- data = json.load(f) # [1]
20
- CLASSES = list(data) # ordered list of class names [1]
21
 
22
  def empty_df():
23
- # FIX: build zeros list via * len(CLASSES) instead of a starred expression in dict value [1]
24
- return pd.DataFrame({"item": CLASSES, "probability": [0] * len(CLASSES)}) # [3]
25
 
26
  # ----------------------------
27
  # Utils
28
  # ----------------------------
29
  def probabilities_to_ints(probabilities, total_sum=100):
30
- probabilities = np.array(probabilities) # [1]
31
- positive_values = np.maximum(probabilities, 0) # [1]
32
- total_positive = positive_values.sum() # [1]
33
  if total_positive == 0:
34
- return np.zeros_like(probabilities, dtype=int) # [1]
35
- scaled = positive_values / total_positive * total_sum # [1]
36
- rounded = np.round(scaled).astype(int) # [1]
37
- diff = total_sum - rounded.sum() # [1]
38
  if diff != 0:
39
- max_idx = int(np.argmax(positive_values)) # [1]
40
- rounded = rounded.flatten() # [1]
41
- rounded[max_idx] += diff # [1]
42
- rounded = rounded.reshape(scaled.shape) # [1]
43
- return rounded # [1]
44
-
45
- MEAN = [0.7611, 0.5869, 0.5923] # [1]
46
- STD = [0.1266, 0.1487, 0.1619] # [1]
47
  TFMS = transforms.Compose([
48
  transforms.Resize((100, 100)),
49
  transforms.ToTensor(),
50
  transforms.Normalize(mean=MEAN, std=STD),
51
- ]) # [1]
52
 
53
  def preprocess(pil_img: Image.Image):
54
- return TFMS(pil_img).unsqueeze(0).numpy() # [1]
55
 
56
  # ----------------------------
57
  # Inference function
58
  # ----------------------------
59
  def predict(image):
60
- # Handle clicks with no image gracefully [1]
61
  if image is None:
62
- return ("Cargue una imagen y presione Analizar.", "", "", "", "", "", empty_df(), "") # [1][3]
63
  if isinstance(image, Image.Image):
64
- pil = image.convert("RGB") # [1]
65
  else:
66
  try:
67
- pil = Image.fromarray(image).convert("RGB") # [1]
68
  except Exception:
69
- return ("Imagen inválida", "", "", "", "", "", empty_df(), "") # [1][3]
70
 
71
- t0 = time.time() # [1]
72
- input_tensor = preprocess(pil) # [1]
73
- input_name = ort_session.get_inputs().name # [1]
74
- output = ort_session.run(None, {input_name: input_tensor}) # [1]
75
 
76
- logits = output.squeeze() # [1]
77
- pred_idx = int(np.argmax(logits)) # [1]
78
- pred_name = CLASSES[pred_idx] # [1]
79
 
80
- # Softmax probabilities [1]
81
- exp = np.exp(logits - np.max(logits)) # [1]
82
- probs = exp / exp.sum() # [1]
83
- conf_text = f"{float(probs[pred_idx]) * 100:.1f}%" # [1]
84
 
85
- ints = probabilities_to_ints(probs * 100.0, total_sum=100) # [1]
86
  df = pd.DataFrame({"item": CLASSES, "probability": ints.astype(int)}).sort_values(
87
  "probability", ascending=True
88
- ) # [3]
89
 
90
- details = data[pred_name] # [1]
91
- descripcion = details.get("description", "") # [1]
92
- sintomas = details.get("symptoms", "") # [1]
93
- causas = details.get("causes", "") # [1]
94
- tratamiento = details.get("treatment-1", "") # [1]
95
 
96
- latency_ms = int((time.time() - t0) * 1000) # [1]
97
- return (pred_name, conf_text, descripcion, sintomas, causas, tratamiento, df, f"{latency_ms} ms") # [3][1]
98
 
99
  # ----------------------------
100
  # Theme (compatible across Gradio versions)
101
  # ----------------------------
102
  try:
103
- theme = gr.themes.Soft(primary_hue="rose", secondary_hue="slate") # [4]
104
  except Exception:
105
- theme = None # fallback to default theme [4]
106
 
107
- # CSS polish; tint bars via CSS for Gradio 4.27 [4]
108
  CUSTOM_CSS = """
109
  .header {display:flex; align-items:center; gap:12px;}
110
  .badge {font-size:12px; padding:4px 8px; border-radius:12px; background:#f1f5f9; color:#334155;}
@@ -113,7 +111,7 @@ CUSTOM_CSS = """
113
  button, .gradio-container .gr-box, .gradio-container .gr-panel { border-radius: 10px !important; }
114
  /* Uniform bar color in Vega-Lite (Gradio 4.27) */
115
  .vega-embed .mark-rect, .vega-embed .mark-bar, .vega-embed .role-mark rect { fill: #ef4444 !important; }
116
- """ # [4][3]
117
 
118
  # ----------------------------
119
  # UI
@@ -131,39 +129,39 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
131
  Sube una imagen dermatoscópica para ver la clase predicha, la confianza y la distribución de probabilidades.
132
  </p>
133
  """
134
- ) # [4]
135
  with gr.Column(scale=1, min_width=140):
136
  try:
137
- dark_toggle = gr.ThemeMode(label="Modo", value="system") # optional UI affordance [4]
138
  except Exception:
139
- gr.Markdown("") # [4]
140
 
141
  with gr.Row(equal_height=True):
142
  # Left column: input + actions
143
  with gr.Column(scale=5):
144
- image = gr.Image(type="numpy", label="Imagen de la lesión", height=420, sources=["upload", "clipboard"]) # [1]
145
  with gr.Row():
146
- analyze_btn = gr.Button("Analizar", variant="primary") # Always enabled [1]
147
- clear_btn = gr.Button("Limpiar") # [2]
148
  example_paths = [
149
  "examples/ak.jpg",
150
  "examples/bcc.jpg",
151
  "examples/df.jpg",
152
  "examples/melanoma.jpg",
153
  "examples/nevus.jpg",
154
- ] # [1]
155
- example_paths = [p for p in example_paths if os.path.exists(p)] # [1]
156
  if example_paths:
157
- gr.Examples(examples=example_paths, inputs=image, label="Ejemplos rápidos") # [1]
158
- latency = gr.Label(label="Latencia aproximada") # [1]
159
 
160
  # Right column: results
161
  with gr.Column(scale=5):
162
  with gr.Group():
163
  with gr.Row():
164
- nombre = gr.Label(label="Predicción principal", elem_classes=["pred-card"]) # [1]
165
- confianza = gr.Label(label="Confianza") # [1]
166
- # Default BarPlot; CSS applies color [3]
167
  prob_plot = gr.BarPlot(
168
  value=empty_df(),
169
  x="item",
@@ -175,54 +173,39 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
175
  tooltip=["item", "probability"],
176
  width=520,
177
  height=320,
178
- ) # [3]
179
  with gr.Tabs():
180
  with gr.TabItem("Detalles"):
181
  with gr.Accordion("Descripción", open=True):
182
- descripcion = gr.Textbox(lines=4, interactive=False) # [1]
183
  with gr.Accordion("Síntomas", open=False):
184
- sintomas = gr.Textbox(lines=4, interactive=False) # [1]
185
  with gr.Accordion("Causas", open=False):
186
- causas = gr.Textbox(lines=4, interactive=False) # [1]
187
  with gr.Accordion("Tratamiento", open=False):
188
- tratamiento = gr.Textbox(lines=4, interactive=False) # [1]
189
  with gr.TabItem("Acerca del modelo"):
190
  gr.Markdown(
191
  "- Arquitectura: CNN exportado a ONNX.<br>"
192
  "- Entrenamiento: dataset dermatoscópico (ver documentación).<br>"
193
  "- Nota: Esta herramienta es solo con fines educativos y no reemplaza una evaluación médica."
194
- ) # [4]
195
 
196
- gr.Markdown("<div class='footer'>Versión del modelo: 1.0 • Última actualización: 2025‑08 • Universidad Central de Venezuela</div>") # [4]
197
 
198
  # ----------------------------
199
- # Wiring
200
  # ----------------------------
201
- outputs = [nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency] # [1]
202
 
203
- # Analyze click runs prediction; predict() handles None safely [1]
204
- analyze_btn.click(fn=predict, inputs=[image], outputs=outputs, show_progress="full") # [1]
205
 
206
- # Clear resets input and outputs using explicit updates for compatibility [1][2]
207
  def clear_all():
208
- return (
209
- gr.Image.update(value=None), # image [2]
210
- gr.Label.update(value=""), # nombre [2]
211
- gr.Label.update(value=""), # confianza [2]
212
- gr.Textbox.update(value=""), # descripcion [2]
213
- gr.Textbox.update(value=""), # sintomas [2]
214
- gr.Textbox.update(value=""), # causas [2]
215
- gr.Textbox.update(value=""), # tratamiento [2]
216
- gr.BarPlot.update(value=empty_df()), # prob_plot expects a DataFrame [3]
217
- gr.Label.update(value=""), # latency [2]
218
- ) # [1]
219
-
220
- clear_btn.click(
221
- fn=clear_all,
222
- inputs=[],
223
- outputs=[image, nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency],
224
- show_progress="hidden"
225
- ) # [1][2]
226
 
227
  if __name__ == "__main__":
228
- demo.launch() # [1]
 
 
1
  import json
2
  import numpy as np
3
  import gradio as gr
 
11
  # ----------------------------
12
  # Model + metadata
13
  # ----------------------------
14
+ ORT_PROVIDERS = ["CPUExecutionProvider"] # add "CUDAExecutionProvider" if available
15
+ ort_session = ort.InferenceSession("model_new_new_final.onnx", providers=ORT_PROVIDERS)
16
 
17
  with open("dat.json", "r", encoding="utf-8") as f:
18
+ data = json.load(f)
19
+ CLASSES = list(data) # ordered list of class names
20
 
21
  def empty_df():
22
+ return pd.DataFrame({"item": CLASSES, "probability": [0] * len(CLASSES)})
 
23
 
24
  # ----------------------------
25
  # Utils
26
  # ----------------------------
27
  def probabilities_to_ints(probabilities, total_sum=100):
28
+ probabilities = np.array(probabilities)
29
+ positive_values = np.maximum(probabilities, 0)
30
+ total_positive = positive_values.sum()
31
  if total_positive == 0:
32
+ return np.zeros_like(probabilities, dtype=int)
33
+ scaled = positive_values / total_positive * total_sum
34
+ rounded = np.round(scaled).astype(int)
35
+ diff = total_sum - rounded.sum()
36
  if diff != 0:
37
+ max_idx = int(np.argmax(positive_values))
38
+ rounded = rounded.flatten()
39
+ rounded[max_idx] += diff
40
+ rounded = rounded.reshape(scaled.shape)
41
+ return rounded
42
+
43
+ MEAN = [0.7611, 0.5869, 0.5923]
44
+ STD = [0.1266, 0.1487, 0.1619]
45
  TFMS = transforms.Compose([
46
  transforms.Resize((100, 100)),
47
  transforms.ToTensor(),
48
  transforms.Normalize(mean=MEAN, std=STD),
49
+ ])
50
 
51
  def preprocess(pil_img: Image.Image):
52
+ return TFMS(pil_img).unsqueeze(0).numpy()
53
 
54
  # ----------------------------
55
  # Inference function
56
  # ----------------------------
57
  def predict(image):
58
+ # Handle clicks with no image gracefully
59
  if image is None:
60
+ return ("Cargue una imagen y presione Analizar.", "", "", "", "", "", empty_df(), "")
61
  if isinstance(image, Image.Image):
62
+ pil = image.convert("RGB")
63
  else:
64
  try:
65
+ pil = Image.fromarray(image).convert("RGB")
66
  except Exception:
67
+ return ("Imagen inválida", "", "", "", "", "", empty_df(), "")
68
 
69
+ t0 = time.time()
70
+ input_tensor = preprocess(pil)
71
+ input_name = ort_session.get_inputs()[0].name
72
+ output = ort_session.run(None, {input_name: input_tensor})
73
 
74
+ logits = output[0].squeeze()
75
+ pred_idx = int(np.argmax(logits))
76
+ pred_name = CLASSES[pred_idx]
77
 
78
+ # Softmax probabilities
79
+ exp = np.exp(logits - np.max(logits))
80
+ probs = exp / exp.sum()
81
+ conf_text = f"{float(probs[pred_idx]) * 100:.1f}%"
82
 
83
+ ints = probabilities_to_ints(probs * 100.0, total_sum=100)
84
  df = pd.DataFrame({"item": CLASSES, "probability": ints.astype(int)}).sort_values(
85
  "probability", ascending=True
86
+ )
87
 
88
+ details = data[pred_name]
89
+ descripcion = details.get("description", "")
90
+ sintomas = details.get("symptoms", "")
91
+ causas = details.get("causes", "")
92
+ tratamiento = details.get("treatment-1", "")
93
 
94
+ latency_ms = int((time.time() - t0) * 1000)
95
+ return (pred_name, conf_text, descripcion, sintomas, causas, tratamiento, df, f"{latency_ms} ms")
96
 
97
  # ----------------------------
98
  # Theme (compatible across Gradio versions)
99
  # ----------------------------
100
  try:
101
+ theme = gr.themes.Soft(primary_hue="rose", secondary_hue="slate")
102
  except Exception:
103
+ theme = None # fallback to default theme
104
 
105
+ # CSS polish; tint bars via CSS for Gradio 4.27
106
  CUSTOM_CSS = """
107
  .header {display:flex; align-items:center; gap:12px;}
108
  .badge {font-size:12px; padding:4px 8px; border-radius:12px; background:#f1f5f9; color:#334155;}
 
111
  button, .gradio-container .gr-box, .gradio-container .gr-panel { border-radius: 10px !important; }
112
  /* Uniform bar color in Vega-Lite (Gradio 4.27) */
113
  .vega-embed .mark-rect, .vega-embed .mark-bar, .vega-embed .role-mark rect { fill: #ef4444 !important; }
114
+ """
115
 
116
  # ----------------------------
117
  # UI
 
129
  Sube una imagen dermatoscópica para ver la clase predicha, la confianza y la distribución de probabilidades.
130
  </p>
131
  """
132
+ )
133
  with gr.Column(scale=1, min_width=140):
134
  try:
135
+ dark_toggle = gr.ThemeMode(label="Modo", value="system")
136
  except Exception:
137
+ gr.Markdown("")
138
 
139
  with gr.Row(equal_height=True):
140
  # Left column: input + actions
141
  with gr.Column(scale=5):
142
+ image = gr.Image(type="numpy", label="Imagen de la lesión", height=420, sources=["upload", "clipboard"])
143
  with gr.Row():
144
+ analyze_btn = gr.Button("Analizar", variant="primary") # Always enabled
145
+ clear_btn = gr.Button("Limpiar")
146
  example_paths = [
147
  "examples/ak.jpg",
148
  "examples/bcc.jpg",
149
  "examples/df.jpg",
150
  "examples/melanoma.jpg",
151
  "examples/nevus.jpg",
152
+ ]
153
+ example_paths = [p for p in example_paths if os.path.exists(p)]
154
  if example_paths:
155
+ gr.Examples(examples=example_paths, inputs=image, label="Ejemplos rápidos")
156
+ latency = gr.Label(label="Latencia aproximada")
157
 
158
  # Right column: results
159
  with gr.Column(scale=5):
160
  with gr.Group():
161
  with gr.Row():
162
+ nombre = gr.Label(label="Predicción principal", elem_classes=["pred-card"])
163
+ confianza = gr.Label(label="Confianza")
164
+ # Default BarPlot; CSS applies color
165
  prob_plot = gr.BarPlot(
166
  value=empty_df(),
167
  x="item",
 
173
  tooltip=["item", "probability"],
174
  width=520,
175
  height=320,
176
+ )
177
  with gr.Tabs():
178
  with gr.TabItem("Detalles"):
179
  with gr.Accordion("Descripción", open=True):
180
+ descripcion = gr.Textbox(lines=4, interactive=False)
181
  with gr.Accordion("Síntomas", open=False):
182
+ sintomas = gr.Textbox(lines=4, interactive=False)
183
  with gr.Accordion("Causas", open=False):
184
+ causas = gr.Textbox(lines=4, interactive=False)
185
  with gr.Accordion("Tratamiento", open=False):
186
+ tratamiento = gr.Textbox(lines=4, interactive=False)
187
  with gr.TabItem("Acerca del modelo"):
188
  gr.Markdown(
189
  "- Arquitectura: CNN exportado a ONNX.<br>"
190
  "- Entrenamiento: dataset dermatoscópico (ver documentación).<br>"
191
  "- Nota: Esta herramienta es solo con fines educativos y no reemplaza una evaluación médica."
192
+ )
193
 
194
+ gr.Markdown("<div class='footer'>Versión del modelo: 1.0 • Última actualización: 2025‑08 • Universidad Central de Venezuela</div>")
195
 
196
  # ----------------------------
197
+ # Wiring: original-like behavior
198
  # ----------------------------
199
+ outputs = [nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency]
200
 
201
+ # Analyze click runs prediction; predict() handles None safely
202
+ analyze_btn.click(fn=predict, inputs=[image], outputs=outputs, show_progress="full")
203
 
204
+ # Clear resets input and outputs
205
  def clear_all():
206
+ return (None, "", "", "", "", "", empty_df(), "")
207
+
208
+ clear_btn.click(fn=clear_all, inputs=None, outputs=[image] + outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  if __name__ == "__main__":
211
+ demo.launch()