dcavadia commited on
Commit
170cfc0
·
verified ·
1 Parent(s): 7a6e646

restore faulty interface

Browse files
Files changed (1) hide show
  1. app.py +191 -91
app.py CHANGED
@@ -1,111 +1,211 @@
1
- import os
2
- import sys
3
  import json
4
- import logging
5
- import traceback
6
  import numpy as np
7
- import pandas as pd
8
  import gradio as gr
9
- import onnxruntime
10
  from PIL import Image
11
  from torchvision import transforms
 
 
 
12
 
13
- # Logging
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", stream=sys.stdout)
15
- log = logging.getLogger("space")
16
-
17
- def log_exc(prefix):
18
- etype, evalue, tb = sys.exc_info()
19
- stack = "".join(traceback.format_exception(etype, evalue, tb))
20
- log.error("%s: %s\n%s", prefix, evalue, stack)
21
- return f"{prefix}: {evalue}"
22
 
23
- # Load metadata
24
  with open("dat.json", "r", encoding="utf-8") as f:
25
  data = json.load(f)
26
- keys = list(data.keys())
27
- log.info("Loaded %d classes from dat.json", len(keys))
28
 
29
- # Load ONNX
30
- ort = onnxruntime.InferenceSession("model_new_new_final.onnx")
31
- log.info("ONNX inputs: %s", [(i.name, i.shape, i.type) for i in ort.get_inputs()])
32
- log.info("ONNX outputs: %s", [(o.name, o.shape, o.type) for o in ort.get_outputs()])
33
 
34
- # Preprocess
35
- tfms = transforms.Compose([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  transforms.Resize((100, 100)),
37
  transforms.ToTensor(),
38
- transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619]),
39
  ])
40
 
41
- def probabilities_to_ints(probabilities, total_sum=100):
42
- probs = np.asarray(probabilities, dtype=np.float64)
43
- probs = np.maximum(probs, 0)
44
- total = probs.sum()
45
- scaled = np.zeros_like(probs)
46
- if total > 0:
47
- scaled = probs / total * total_sum
48
- rounded = np.round(scaled).astype(int)
49
- diff = total_sum - int(rounded.sum())
50
- if diff != 0 and total > 0:
51
- rounded[int(np.argmax(probs))] += diff
52
- return rounded
53
 
54
- def predict(image: Image.Image):
55
- try:
56
- if image is None:
57
- return "Error", "", "", "", "", pd.DataFrame({"item": keys, "probability": *len(keys)}), "No image provided"
 
 
 
 
58
  pil = image.convert("RGB")
59
- x = tfms(pil).unsqueeze(0).numpy().astype(np.float32) # (1,C,H,W)
60
- input_name = ort.get_inputs().name
61
- outs = ort.run(None, {input_name: x})
62
- logits = outs
63
- if logits.ndim == 2:
64
- scores = logits
65
- elif logits.ndim == 1:
66
- scores = logits
67
- else:
68
- raise ValueError(f"Unexpected logits shape: {logits.shape}")
69
- if len(scores) != len(keys):
70
- raise ValueError(f"Logits length {len(scores)} != classes {len(keys)}")
71
- idx = int(np.argmax(scores))
72
- name = keys[idx]
73
- meta = data.get(name, {})
74
- desc = meta.get("description", "")
75
- symp = meta.get("symptoms", "")
76
- causes = meta.get("causes", "")
77
- treat = meta.get("treatment-1", meta.get("treatment", ""))
78
- df = pd.DataFrame({"item": keys, "probability": probabilities_to_ints(scores).astype(int)})
79
- return name, desc, symp, causes, treat, df, ""
80
- except Exception:
81
- err = log_exc("Inference failed")
82
- df = pd.DataFrame({"item": keys if keys else ["N/A"], "probability": *(len(keys) if keys else 1)})
83
- return "Error", "", "", "", "", df, err
84
-
85
- with gr.Blocks(title="Clasificacion de Enfermedades de la Piel") as demo:
86
- gr.Markdown("Suba una imagen y ejecute la prediccion.")
87
- with gr.Row():
88
- img = gr.Image(type="pil", label="Imagen")
89
- with gr.Column():
90
- out_name = gr.Textbox(label="Nombre de la Enfermedad")
91
- out_desc = gr.Textbox(label="Descripcion")
92
- out_symp = gr.Textbox(label="Sintomas")
93
- out_causes = gr.Textbox(label="Causas")
94
- out_treat = gr.Textbox(label="Tratamiento")
95
- bar = gr.BarPlot(
96
- x="item",
97
- y="probability",
98
- title="Distribucion de Probabilidad",
99
- x_title="Nombre de la Enfermedad",
100
- y_title="Probabilidad",
101
- tooltip=["item", "probability"],
102
- vertical=False,
103
- label="Probabilidades"
104
  )
105
- err = gr.Textbox(label="Errores", interactive=False)
106
- btn = gr.Button("Predecir")
107
- btn.click(fn=predict, inputs=[img], outputs=[out_name, out_desc, out_symp, out_causes, out_treat, bar, err])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
- # Spaces handles networking; no share=True
111
- demo.launch(debug=True)
 
 
 
1
  import json
 
 
2
  import numpy as np
 
3
  import gradio as gr
4
+ import onnxruntime as ort
5
  from PIL import Image
6
  from torchvision import transforms
7
+ import pandas as pd
8
+ import time
9
+ import os
10
 
11
+ # ----------------------------
12
+ # Model + metadata
13
+ # ----------------------------
14
+ ORT_PROVIDERS = ["CPUExecutionProvider"] # add "CUDAExecutionProvider" if available
15
+ ort_session = ort.InferenceSession("model_new_new_final.onnx", providers=ORT_PROVIDERS)
 
 
 
 
16
 
 
17
  with open("dat.json", "r", encoding="utf-8") as f:
18
  data = json.load(f)
19
+ CLASSES = list(data) # ordered list of class names
 
20
 
21
+ def empty_df():
22
+ return pd.DataFrame({"item": CLASSES, "probability": [0] * len(CLASSES)})
 
 
23
 
24
+ # ----------------------------
25
+ # Utils
26
+ # ----------------------------
27
+ def probabilities_to_ints(probabilities, total_sum=100):
28
+ probabilities = np.array(probabilities)
29
+ positive_values = np.maximum(probabilities, 0)
30
+ total_positive = positive_values.sum()
31
+ if total_positive == 0:
32
+ return np.zeros_like(probabilities, dtype=int)
33
+ scaled = positive_values / total_positive * total_sum
34
+ rounded = np.round(scaled).astype(int)
35
+ diff = total_sum - rounded.sum()
36
+ if diff != 0:
37
+ max_idx = int(np.argmax(positive_values))
38
+ rounded = rounded.flatten()
39
+ rounded[max_idx] += diff
40
+ rounded = rounded.reshape(scaled.shape)
41
+ return rounded
42
+
43
+ MEAN = [0.7611, 0.5869, 0.5923]
44
+ STD = [0.1266, 0.1487, 0.1619]
45
+ TFMS = transforms.Compose([
46
  transforms.Resize((100, 100)),
47
  transforms.ToTensor(),
48
+ transforms.Normalize(mean=MEAN, std=STD),
49
  ])
50
 
51
+ def preprocess(pil_img: Image.Image):
52
+ return TFMS(pil_img).unsqueeze(0).numpy()
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # ----------------------------
55
+ # Inference function
56
+ # ----------------------------
57
+ def predict(image):
58
+ # Handle clicks with no image gracefully
59
+ if image is None:
60
+ return ("Cargue una imagen y presione Analizar.", "", "", "", "", "", empty_df(), "")
61
+ if isinstance(image, Image.Image):
62
  pil = image.convert("RGB")
63
+ else:
64
+ try:
65
+ pil = Image.fromarray(image).convert("RGB")
66
+ except Exception:
67
+ return ("Imagen inválida", "", "", "", "", "", empty_df(), "")
68
+
69
+ t0 = time.time()
70
+ input_tensor = preprocess(pil)
71
+ input_name = ort_session.get_inputs()[0].name
72
+ output = ort_session.run(None, {input_name: input_tensor})
73
+
74
+ logits = output[0].squeeze()
75
+ pred_idx = int(np.argmax(logits))
76
+ pred_name = CLASSES[pred_idx]
77
+
78
+ # Softmax probabilities
79
+ exp = np.exp(logits - np.max(logits))
80
+ probs = exp / exp.sum()
81
+ conf_text = f"{float(probs[pred_idx]) * 100:.1f}%"
82
+
83
+ ints = probabilities_to_ints(probs * 100.0, total_sum=100)
84
+ df = pd.DataFrame({"item": CLASSES, "probability": ints.astype(int)}).sort_values(
85
+ "probability", ascending=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
+
88
+ details = data[pred_name]
89
+ descripcion = details.get("description", "")
90
+ sintomas = details.get("symptoms", "")
91
+ causas = details.get("causes", "")
92
+ tratamiento = details.get("treatment-1", "")
93
+
94
+ latency_ms = int((time.time() - t0) * 1000)
95
+ return (pred_name, conf_text, descripcion, sintomas, causas, tratamiento, df, f"{latency_ms} ms")
96
+
97
+ # ----------------------------
98
+ # Theme (compatible across Gradio versions)
99
+ # ----------------------------
100
+ try:
101
+ theme = gr.themes.Soft(primary_hue="rose", secondary_hue="slate")
102
+ except Exception:
103
+ theme = None # fallback to default theme
104
+
105
+ # CSS polish; tint bars via CSS for Gradio 4.27
106
+ CUSTOM_CSS = """
107
+ .header {display:flex; align-items:center; gap:12px;}
108
+ .badge {font-size:12px; padding:4px 8px; border-radius:12px; background:#f1f5f9; color:#334155;}
109
+ .pred-card {font-size:18px;}
110
+ .footer {font-size:12px; color:#64748b; text-align:center; padding:12px 0;}
111
+ button, .gradio-container .gr-box, .gradio-container .gr-panel { border-radius: 10px !important; }
112
+ /* Uniform bar color in Vega-Lite (Gradio 4.27) */
113
+ .vega-embed .mark-rect, .vega-embed .mark-bar, .vega-embed .role-mark rect { fill: #ef4444 !important; }
114
+ """
115
+
116
+ # ----------------------------
117
+ # UI
118
+ # ----------------------------
119
+ with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
120
+ with gr.Row():
121
+ with gr.Column(scale=6):
122
+ gr.Markdown(
123
+ """
124
+ <div class="header">
125
+ <h1 style="margin:0;">Clasificación de Enfermedades de la Piel</h1>
126
+ <span class="badge">Demo • No diagnóstico médico</span>
127
+ </div>
128
+ <p style="margin-top:6px;">
129
+ Sube una imagen dermatoscópica para ver la clase predicha, la confianza y la distribución de probabilidades.
130
+ </p>
131
+ """
132
+ )
133
+ with gr.Column(scale=1, min_width=140):
134
+ try:
135
+ dark_toggle = gr.ThemeMode(label="Modo", value="system")
136
+ except Exception:
137
+ gr.Markdown("")
138
+
139
+ with gr.Row(equal_height=True):
140
+ # Left column: input + actions
141
+ with gr.Column(scale=5):
142
+ image = gr.Image(type="numpy", label="Imagen de la lesión", height=420, sources=["upload", "clipboard"])
143
+ with gr.Row():
144
+ analyze_btn = gr.Button("Analizar", variant="primary") # Always enabled
145
+ clear_btn = gr.Button("Limpiar")
146
+ example_paths = [
147
+ "examples/ak.jpg",
148
+ "examples/bcc.jpg",
149
+ "examples/df.jpg",
150
+ "examples/melanoma.jpg",
151
+ "examples/nevus.jpg",
152
+ ]
153
+ example_paths = [p for p in example_paths if os.path.exists(p)]
154
+ if example_paths:
155
+ gr.Examples(examples=example_paths, inputs=image, label="Ejemplos rápidos")
156
+ latency = gr.Label(label="Latencia aproximada")
157
+
158
+ # Right column: results
159
+ with gr.Column(scale=5):
160
+ with gr.Group():
161
+ with gr.Row():
162
+ nombre = gr.Label(label="Predicción principal", elem_classes=["pred-card"])
163
+ confianza = gr.Label(label="Confianza")
164
+ # Default BarPlot; CSS applies color
165
+ prob_plot = gr.BarPlot(
166
+ value=empty_df(),
167
+ x="item",
168
+ y="probability",
169
+ title="Distribuci��n de probabilidad (Top‑k)",
170
+ x_title="Clase",
171
+ y_title="Probabilidad",
172
+ vertical=False,
173
+ tooltip=["item", "probability"],
174
+ width=520,
175
+ height=320,
176
+ )
177
+ with gr.Tabs():
178
+ with gr.TabItem("Detalles"):
179
+ with gr.Accordion("Descripción", open=True):
180
+ descripcion = gr.Textbox(lines=4, interactive=False)
181
+ with gr.Accordion("Síntomas", open=False):
182
+ sintomas = gr.Textbox(lines=4, interactive=False)
183
+ with gr.Accordion("Causas", open=False):
184
+ causas = gr.Textbox(lines=4, interactive=False)
185
+ with gr.Accordion("Tratamiento", open=False):
186
+ tratamiento = gr.Textbox(lines=4, interactive=False)
187
+ with gr.TabItem("Acerca del modelo"):
188
+ gr.Markdown(
189
+ "- Arquitectura: CNN exportado a ONNX.<br>"
190
+ "- Entrenamiento: dataset dermatoscópico (ver documentación).<br>"
191
+ "- Nota: Esta herramienta es solo con fines educativos y no reemplaza una evaluación médica."
192
+ )
193
+
194
+ gr.Markdown("<div class='footer'>Versión del modelo: 1.0 • Última actualización: 2025‑08 • Universidad Central de Venezuela</div>")
195
+
196
+ # ----------------------------
197
+ # Wiring: original-like behavior
198
+ # ----------------------------
199
+ outputs = [nombre, confianza, descripcion, sintomas, causas, tratamiento, prob_plot, latency]
200
+
201
+ # Analyze click runs prediction; predict() handles None safely
202
+ analyze_btn.click(fn=predict, inputs=[image], outputs=outputs, show_progress="full")
203
+
204
+ # Clear resets input and outputs
205
+ def clear_all():
206
+ return (None, "", "", "", "", "", empty_df(), "")
207
+
208
+ clear_btn.click(fn=clear_all, inputs=None, outputs=[image] + outputs)
209
 
210
  if __name__ == "__main__":
211
+ demo.launch()