dcavadia commited on
Commit
98b29fc
·
verified ·
1 Parent(s): 7a0bb09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -75
app.py CHANGED
@@ -12,99 +12,99 @@ import os
12
  # ----------------------------
13
  # Model + metadata
14
  # ----------------------------
15
- ORT_PROVIDERS = ["CPUExecutionProvider"] # add "CUDAExecutionProvider" if available [4]
16
- ort_session = ort.InferenceSession("model_new_new_final.onnx", providers=ORT_PROVIDERS) # [4]
17
 
18
  with open("dat.json", "r", encoding="utf-8") as f:
19
- data = json.load(f) # [4]
20
- CLASSES = list(data) # ordered list of class names [4]
21
 
22
  def empty_df():
23
- # FIX: build zeros list via * len(CLASSES) instead of a starred expression in dict value [1]
24
- return pd.DataFrame({"item": CLASSES, "probability": [0] * len(CLASSES)}) # [3]
25
 
26
  # ----------------------------
27
  # Utils
28
  # ----------------------------
29
  def probabilities_to_ints(probabilities, total_sum=100):
30
- probabilities = np.array(probabilities) # [4]
31
- positive_values = np.maximum(probabilities, 0) # [4]
32
- total_positive = positive_values.sum() # [4]
33
  if total_positive == 0:
34
- return np.zeros_like(probabilities, dtype=int) # [4]
35
- scaled = positive_values / total_positive * total_sum # [4]
36
- rounded = np.round(scaled).astype(int) # [4]
37
- diff = total_sum - rounded.sum() # [4]
38
  if diff != 0:
39
- max_idx = int(np.argmax(positive_values)) # [4]
40
- rounded = rounded.flatten() # [4]
41
- rounded[max_idx] += diff # [4]
42
- rounded = rounded.reshape(scaled.shape) # [4]
43
- return rounded # [4]
44
-
45
- MEAN = [0.7611, 0.5869, 0.5923] # [4]
46
- STD = [0.1266, 0.1487, 0.1619] # [4]
47
  TFMS = transforms.Compose([
48
  transforms.Resize((100, 100)),
49
  transforms.ToTensor(),
50
  transforms.Normalize(mean=MEAN, std=STD),
51
- ]) # [4]
52
 
53
  def preprocess(pil_img: Image.Image):
54
- return TFMS(pil_img).unsqueeze(0).numpy() # [4]
55
 
56
  # ----------------------------
57
  # Inference function
58
  # ----------------------------
59
  def predict(image):
60
- # Handle clicks with no image gracefully [4]
61
  if image is None:
62
- return ("Cargue una imagen y presione Analizar.", "", "", "", "", "", empty_df(), "") # [4][3]
63
  if isinstance(image, Image.Image):
64
- pil = image.convert("RGB") # [4]
65
  else:
66
  try:
67
- pil = Image.fromarray(image).convert("RGB") # [4]
68
  except Exception:
69
- return ("Imagen inválida", "", "", "", "", "", empty_df(), "") # [4][3]
70
 
71
- t0 = time.time() # [4]
72
- input_tensor = preprocess(pil) # [4]
73
- input_name = ort_session.get_inputs().name # [4]
74
- output = ort_session.run(None, {input_name: input_tensor}) # [4]
75
 
76
- logits = output.squeeze() # [4]
77
- pred_idx = int(np.argmax(logits)) # [4]
78
- pred_name = CLASSES[pred_idx] # [4]
79
 
80
- # Softmax probabilities [4]
81
- exp = np.exp(logits - np.max(logits)) # [4]
82
- probs = exp / exp.sum() # [4]
83
- conf_text = f"{float(probs[pred_idx]) * 100:.1f}%" # [4]
84
 
85
- ints = probabilities_to_ints(probs * 100.0, total_sum=100) # [4]
86
  df = pd.DataFrame({"item": CLASSES, "probability": ints.astype(int)}).sort_values(
87
  "probability", ascending=True
88
  ) # [3]
89
 
90
- details = data[pred_name] # [4]
91
- descripcion = details.get("description", "") # [4]
92
- sintomas = details.get("symptoms", "") # [4]
93
- causas = details.get("causes", "") # [4]
94
- tratamiento = details.get("treatment-1", "") # [4]
95
 
96
- latency_ms = int((time.time() - t0) * 1000) # [4]
97
- return (pred_name, conf_text, descripcion, sintomas, causas, tratamiento, df, f"{latency_ms} ms") # [3][4]
98
 
99
  # ----------------------------
100
  # Theme (compatible across Gradio versions)
101
  # ----------------------------
102
  try:
103
- theme = gr.themes.Soft(primary_hue="rose", secondary_hue="slate") # [5]
104
  except Exception:
105
- theme = None # fallback to default theme [5]
106
 
107
- # CSS polish; tint bars via CSS for Gradio 4.27 [5]
108
  CUSTOM_CSS = """
109
  .header {display:flex; align-items:center; gap:12px;}
110
  .badge {font-size:12px; padding:4px 8px; border-radius:12px; background:#f1f5f9; color:#334155;}
@@ -113,7 +113,7 @@ CUSTOM_CSS = """
113
  button, .gradio-container .gr-box, .gradio-container .gr-panel { border-radius: 10px !important; }
114
  /* Uniform bar color in Vega-Lite (Gradio 4.27) */
115
  .vega-embed .mark-rect, .vega-embed .mark-bar, .vega-embed .role-mark rect { fill: #ef4444 !important; }
116
- """ # [5][3]
117
 
118
  # ----------------------------
119
  # UI
@@ -131,19 +131,19 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
131
  Sube una imagen dermatoscópica para ver la clase predicha, la confianza y la distribución de probabilidades.
132
  </p>
133
  """
134
- ) # [5]
135
  with gr.Column(scale=1, min_width=140):
136
  try:
137
- dark_toggle = gr.ThemeMode(label="Modo", value="system") # optional UI affordance [5]
138
  except Exception:
139
- gr.Markdown("") # [5]
140
 
141
  with gr.Row(equal_height=True):
142
  # Left column: input + actions
143
  with gr.Column(scale=5):
144
- image = gr.Image(type="numpy", label="Imagen de la lesión", height=420, sources=["upload", "clipboard"]) # [4]
145
  with gr.Row():
146
- analyze_btn = gr.Button("Analizar", variant="primary") # Always enabled [4]
147
  clear_btn = gr.Button("Limpiar") # [2]
148
  example_paths = [
149
  "examples/ak.jpg",
@@ -151,18 +151,18 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
151
  "examples/df.jpg",
152
  "examples/melanoma.jpg",
153
  "examples/nevus.jpg",
154
- ] # [4]
155
- example_paths = [p for p in example_paths if os.path.exists(p)] # [4]
156
  if example_paths:
157
- gr.Examples(examples=example_paths, inputs=image, label="Ejemplos rápidos") # [4]
158
- latency = gr.Label(label="Latencia aproximada") # [4]
159
 
160
  # Right column: results
161
  with gr.Column(scale=5):
162
  with gr.Group():
163
  with gr.Row():
164
- nombre = gr.Label(label="Predicción principal", elem_classes=["pred-card"]) # [4]
165
- confianza = gr.Label(label="Confianza") # [4]
166
  # Default BarPlot; CSS applies color [3]
167
  prob_plot = gr.BarPlot(
168
  value=empty_df(),
@@ -179,31 +179,31 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
179
  with gr.Tabs():
180
  with gr.TabItem("Detalles"):
181
  with gr.Accordion("Descripción", open=True):
182
- descripcion = gr.Textbox(lines=4, interactive=False) # [4]
183
  with gr.Accordion("Síntomas", open=False):
184
- sintomas = gr.Textbox(lines=4, interactive=False) # [4]
185
  with gr.Accordion("Causas", open=False):
186
- causas = gr.Textbox(lines=4, interactive=False) # [4]
187
  with gr.Accordion("Tratamiento", open=False):
188
- tratamiento = gr.Textbox(lines=4, interactive=False) # [4]
189
  with gr.TabItem("Acerca del modelo"):
190
  gr.Markdown(
191
  "- Arquitectura: CNN exportado a ONNX.<br>"
192
  "- Entrenamiento: dataset dermatoscópico (ver documentación).<br>"
193
  "- Nota: Esta herramienta es solo con fines educativos y no reemplaza una evaluación médica."
194
- ) # [5]
195
 
196
- gr.Markdown("<div class='footer'>Versión del modelo: 1.0 • Última actualización: 2025‑08 • Universidad Central de Venezuela</div>") # [5]
197
 
198
  # ----------------------------
199
  # Wiring
200
  # ----------------------------
201
- outputs = [nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency] # [4]
202
 
203
- # Analyze click runs prediction; predict() handles None safely [4]
204
- analyze_btn.click(fn=predict, inputs=[image], outputs=outputs, show_progress="full") # [4]
205
 
206
- # Clear resets input and outputs using explicit updates for compatibility [4][2]
207
  def clear_all():
208
  return (
209
  gr.Image.update(value=None), # image [2]
@@ -215,14 +215,14 @@ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
215
  gr.Textbox.update(value=""), # tratamiento [2]
216
  gr.BarPlot.update(value=empty_df()), # prob_plot expects a DataFrame [3]
217
  gr.Label.update(value=""), # latency [2]
218
- ) # [4]
219
 
220
  clear_btn.click(
221
  fn=clear_all,
222
  inputs=[],
223
  outputs=[image, nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency],
224
  show_progress="hidden"
225
- ) # [4][2]
226
 
227
  if __name__ == "__main__":
228
- demo.launch() # [4]
 
12
  # ----------------------------
13
  # Model + metadata
14
  # ----------------------------
15
+ ORT_PROVIDERS = ["CPUExecutionProvider"] # add "CUDAExecutionProvider" if available [1]
16
+ ort_session = ort.InferenceSession("model_new_new_final.onnx", providers=ORT_PROVIDERS) # [1]
17
 
18
  with open("dat.json", "r", encoding="utf-8") as f:
19
+ data = json.load(f) # [1]
20
+ CLASSES = list(data) # ordered list of class names [1]
21
 
22
  def empty_df():
23
+ # Provide a fresh DataFrame for BarPlot resets (component expects a DataFrame) [3]
24
+ return pd.DataFrame({"item": CLASSES, "probability": * len(CLASSES)}) # [3]
25
 
26
  # ----------------------------
27
  # Utils
28
  # ----------------------------
29
  def probabilities_to_ints(probabilities, total_sum=100):
30
+ probabilities = np.array(probabilities) # [1]
31
+ positive_values = np.maximum(probabilities, 0) # [1]
32
+ total_positive = positive_values.sum() # [1]
33
  if total_positive == 0:
34
+ return np.zeros_like(probabilities, dtype=int) # [1]
35
+ scaled = positive_values / total_positive * total_sum # [1]
36
+ rounded = np.round(scaled).astype(int) # [1]
37
+ diff = total_sum - rounded.sum() # [1]
38
  if diff != 0:
39
+ max_idx = int(np.argmax(positive_values)) # [1]
40
+ rounded = rounded.flatten() # [1]
41
+ rounded[max_idx] += diff # [1]
42
+ rounded = rounded.reshape(scaled.shape) # [1]
43
+ return rounded # [1]
44
+
45
+ MEAN = [0.7611, 0.5869, 0.5923] # [1]
46
+ STD = [0.1266, 0.1487, 0.1619] # [1]
47
  TFMS = transforms.Compose([
48
  transforms.Resize((100, 100)),
49
  transforms.ToTensor(),
50
  transforms.Normalize(mean=MEAN, std=STD),
51
+ ]) # [1]
52
 
53
  def preprocess(pil_img: Image.Image):
54
+ return TFMS(pil_img).unsqueeze(0).numpy() # [1]
55
 
56
  # ----------------------------
57
  # Inference function
58
  # ----------------------------
59
  def predict(image):
60
+ # Handle clicks with no image gracefully [1]
61
  if image is None:
62
+ return ("Cargue una imagen y presione Analizar.", "", "", "", "", "", empty_df(), "") # [1][3]
63
  if isinstance(image, Image.Image):
64
+ pil = image.convert("RGB") # [1]
65
  else:
66
  try:
67
+ pil = Image.fromarray(image).convert("RGB") # [1]
68
  except Exception:
69
+ return ("Imagen inválida", "", "", "", "", "", empty_df(), "") # [1][3]
70
 
71
+ t0 = time.time() # [1]
72
+ input_tensor = preprocess(pil) # [1]
73
+ input_name = ort_session.get_inputs().name # [1]
74
+ output = ort_session.run(None, {input_name: input_tensor}) # [1]
75
 
76
+ logits = output.squeeze() # [1]
77
+ pred_idx = int(np.argmax(logits)) # [1]
78
+ pred_name = CLASSES[pred_idx] # [1]
79
 
80
+ # Softmax probabilities [1]
81
+ exp = np.exp(logits - np.max(logits)) # [1]
82
+ probs = exp / exp.sum() # [1]
83
+ conf_text = f"{float(probs[pred_idx]) * 100:.1f}%" # [1]
84
 
85
+ ints = probabilities_to_ints(probs * 100.0, total_sum=100) # [1]
86
  df = pd.DataFrame({"item": CLASSES, "probability": ints.astype(int)}).sort_values(
87
  "probability", ascending=True
88
  ) # [3]
89
 
90
+ details = data[pred_name] # [1]
91
+ descripcion = details.get("description", "") # [1]
92
+ sintomas = details.get("symptoms", "") # [1]
93
+ causas = details.get("causes", "") # [1]
94
+ tratamiento = details.get("treatment-1", "") # [1]
95
 
96
+ latency_ms = int((time.time() - t0) * 1000) # [1]
97
+ return (pred_name, conf_text, descripcion, sintomas, causas, tratamiento, df, f"{latency_ms} ms") # [3][1]
98
 
99
  # ----------------------------
100
  # Theme (compatible across Gradio versions)
101
  # ----------------------------
102
  try:
103
+ theme = gr.themes.Soft(primary_hue="rose", secondary_hue="slate") # [4]
104
  except Exception:
105
+ theme = None # fallback to default theme [4]
106
 
107
+ # CSS polish; tint bars via CSS for Gradio 4.27 [4]
108
  CUSTOM_CSS = """
109
  .header {display:flex; align-items:center; gap:12px;}
110
  .badge {font-size:12px; padding:4px 8px; border-radius:12px; background:#f1f5f9; color:#334155;}
 
113
  button, .gradio-container .gr-box, .gradio-container .gr-panel { border-radius: 10px !important; }
114
  /* Uniform bar color in Vega-Lite (Gradio 4.27) */
115
  .vega-embed .mark-rect, .vega-embed .mark-bar, .vega-embed .role-mark rect { fill: #ef4444 !important; }
116
+ """ # [4][3]
117
 
118
  # ----------------------------
119
  # UI
 
131
  Sube una imagen dermatoscópica para ver la clase predicha, la confianza y la distribución de probabilidades.
132
  </p>
133
  """
134
+ ) # [4]
135
  with gr.Column(scale=1, min_width=140):
136
  try:
137
+ dark_toggle = gr.ThemeMode(label="Modo", value="system") # optional UI affordance [4]
138
  except Exception:
139
+ gr.Markdown("") # [4]
140
 
141
  with gr.Row(equal_height=True):
142
  # Left column: input + actions
143
  with gr.Column(scale=5):
144
+ image = gr.Image(type="numpy", label="Imagen de la lesión", height=420, sources=["upload", "clipboard"]) # [1]
145
  with gr.Row():
146
+ analyze_btn = gr.Button("Analizar", variant="primary") # Always enabled [1]
147
  clear_btn = gr.Button("Limpiar") # [2]
148
  example_paths = [
149
  "examples/ak.jpg",
 
151
  "examples/df.jpg",
152
  "examples/melanoma.jpg",
153
  "examples/nevus.jpg",
154
+ ] # [1]
155
+ example_paths = [p for p in example_paths if os.path.exists(p)] # [1]
156
  if example_paths:
157
+ gr.Examples(examples=example_paths, inputs=image, label="Ejemplos rápidos") # [1]
158
+ latency = gr.Label(label="Latencia aproximada") # [1]
159
 
160
  # Right column: results
161
  with gr.Column(scale=5):
162
  with gr.Group():
163
  with gr.Row():
164
+ nombre = gr.Label(label="Predicción principal", elem_classes=["pred-card"]) # [1]
165
+ confianza = gr.Label(label="Confianza") # [1]
166
  # Default BarPlot; CSS applies color [3]
167
  prob_plot = gr.BarPlot(
168
  value=empty_df(),
 
179
  with gr.Tabs():
180
  with gr.TabItem("Detalles"):
181
  with gr.Accordion("Descripción", open=True):
182
+ descripcion = gr.Textbox(lines=4, interactive=False) # [1]
183
  with gr.Accordion("Síntomas", open=False):
184
+ sintomas = gr.Textbox(lines=4, interactive=False) # [1]
185
  with gr.Accordion("Causas", open=False):
186
+ causas = gr.Textbox(lines=4, interactive=False) # [1]
187
  with gr.Accordion("Tratamiento", open=False):
188
+ tratamiento = gr.Textbox(lines=4, interactive=False) # [1]
189
  with gr.TabItem("Acerca del modelo"):
190
  gr.Markdown(
191
  "- Arquitectura: CNN exportado a ONNX.<br>"
192
  "- Entrenamiento: dataset dermatoscópico (ver documentación).<br>"
193
  "- Nota: Esta herramienta es solo con fines educativos y no reemplaza una evaluación médica."
194
+ ) # [4]
195
 
196
+ gr.Markdown("<div class='footer'>Versión del modelo: 1.0 • Última actualización: 2025‑08 • Universidad Central de Venezuela</div>") # [4]
197
 
198
  # ----------------------------
199
  # Wiring
200
  # ----------------------------
201
+ outputs = [nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency] # [1]
202
 
203
+ # Analyze click runs prediction; predict() handles None safely [1]
204
+ analyze_btn.click(fn=predict, inputs=[image], outputs=outputs, show_progress="full") # [1]
205
 
206
+ # Clear resets input and outputs using explicit updates for compatibility [1][2]
207
  def clear_all():
208
  return (
209
  gr.Image.update(value=None), # image [2]
 
215
  gr.Textbox.update(value=""), # tratamiento [2]
216
  gr.BarPlot.update(value=empty_df()), # prob_plot expects a DataFrame [3]
217
  gr.Label.update(value=""), # latency [2]
218
+ ) # [1]
219
 
220
  clear_btn.click(
221
  fn=clear_all,
222
  inputs=[],
223
  outputs=[image, nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency],
224
  show_progress="hidden"
225
+ ) # [1][2]
226
 
227
  if __name__ == "__main__":
228
+ demo.launch() # [1]