dcavadia commited on
Commit
b05265c
·
verified ·
1 Parent(s): 0619e0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -248
app.py CHANGED
@@ -1,10 +1,5 @@
1
- import os
2
- import sys
3
  import json
4
- import traceback
5
- import logging
6
- from typing import Tuple, Any, Dict, List
7
-
8
  import numpy as np
9
  import gradio as gr
10
  import onnxruntime
@@ -12,252 +7,120 @@ from PIL import Image
12
  from torchvision import transforms
13
  import pandas as pd
14
 
15
- # ------------------------------------------------------------------------------
16
- # Logging setup
17
- # ------------------------------------------------------------------------------
18
- LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG").upper()
19
- logging.basicConfig(
20
- level=getattr(logging, LOG_LEVEL, logging.DEBUG),
21
- format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
- stream=sys.stdout,
23
- )
24
- log = logging.getLogger("app")
25
-
26
- def log_exception(prefix: str) -> str:
27
- etype, evalue, tb = sys.exc_info()
28
- stack = "".join(traceback.format_exception(etype, evalue, tb))
29
- log.error("%s: %s\n%s", prefix, evalue, stack)
30
- return f"{prefix}: {evalue}\nSee server logs for stack trace."
31
-
32
- # ------------------------------------------------------------------------------
33
- # Environment and package info
34
- # ------------------------------------------------------------------------------
35
- def log_env_info():
36
- log.info("Python: %s", sys.version)
37
- log.info("Platform: %s", sys.platform)
38
- try:
39
- import gradio
40
- import fastapi
41
- import starlette
42
- import onnx
43
- import torch
44
- log.info("gradio=%s fastapi=%s starlette=%s onnx=%s torch=%s",
45
- getattr(gradio, '__version__', '?'),
46
- getattr(fastapi, '__version__', '?'),
47
- getattr(starlette, '__version__', '?'),
48
- getattr(onnx, '__version__', '?'),
49
- getattr(torch, '__version__', '?'))
50
- except Exception:
51
- log.warning("Could not log all package versions", exc_info=True)
52
-
53
- log_env_info()
54
-
55
- # ------------------------------------------------------------------------------
56
- # Model and metadata loading
57
- # ------------------------------------------------------------------------------
58
- MODEL_PATH = os.getenv("MODEL_PATH", "model_new_new_final.onnx")
59
- JSON_PATH = os.getenv("JSON_PATH", "dat.json")
60
-
61
- def load_metadata(json_path: str) -> Tuple[List[str], Dict[str, Any]]:
62
- with open(json_path, 'r', encoding='utf-8') as f:
63
- meta = json.load(f)
64
- # Ensure deterministic order: keep file order if Py3.7+ or optionally use a stored labels list
65
- keys = list(meta.keys())
66
- log.info("Loaded metadata: %d classes from %s", len(keys), json_path)
67
- return keys, meta
68
-
69
- def load_session(model_path: str) -> onnxruntime.InferenceSession:
70
- so = onnxruntime.SessionOptions()
71
- so.log_severity_level = 0 # Verbose ORT logs if needed
72
- providers = ["CPUExecutionProvider"]
73
- sess = onnxruntime.InferenceSession(model_path, sess_options=so, providers=providers)
74
- # Log IO details
75
- inputs = sess.get_inputs()
76
- outputs = sess.get_outputs()
77
- for i, inp in enumerate(inputs):
78
- log.info("ONNX Input[%d]: name=%s shape=%s type=%s", i, inp.name, inp.shape, inp.type)
79
- for i, out in enumerate(outputs):
80
- log.info("ONNX Output[%d]: name=%s shape=%s type=%s", i, out.name, out.shape, out.type)
81
- return sess
82
-
83
- # Load once at startup with protective logging
84
- try:
85
- KEYS, META = load_metadata(JSON_PATH)
86
- except Exception:
87
- msg = log_exception(f"Failed to load metadata JSON at {JSON_PATH}")
88
- # Fallbacks to keep app launching
89
- KEYS, META = [], {}
90
- META["_startup_error"] = msg
91
 
92
- try:
93
- ORT = load_session(MODEL_PATH)
94
- except Exception:
95
- msg = log_exception(f"Failed to load ONNX model at {MODEL_PATH}")
96
- ORT = None
97
 
98
- # ------------------------------------------------------------------------------
99
- # Preprocessing
100
- # ------------------------------------------------------------------------------
101
- MEAN = [0.7611, 0.5869, 0.5923]
102
- STD = [0.1266, 0.1487, 0.1619]
103
 
104
- TFMS = transforms.Compose([
105
- transforms.Resize((100, 100)),
106
- transforms.ToTensor(),
107
- transforms.Normalize(mean=MEAN, std=STD),
108
- ])
109
-
110
- def preprocess(img: Image.Image) -> np.ndarray:
111
- if img is None:
112
- raise ValueError("No image provided")
113
- log.debug("Input image: mode=%s size=%s", img.mode, img.size)
114
- img = img.convert("RGB")
115
- t = TFMS(img).unsqueeze(0).numpy().astype(np.float32) # (1,C,H,W)
116
- log.debug("Tensor shape=%s dtype=%s min=%.4f max=%.4f mean=%.4f std=%.4f",
117
- t.shape, t.dtype, float(t.min()), float(t.max()),
118
- float(t.mean()), float(t.std()))
119
- return t
120
-
121
- # ------------------------------------------------------------------------------
122
- # Postprocessing and helpers
123
- # ------------------------------------------------------------------------------
124
- def probabilities_to_ints(probabilities: np.ndarray, total_sum: int = 100) -> np.ndarray:
125
- probs = np.asarray(probabilities).astype(np.float64)
126
- if probs.ndim != 1:
127
- raise ValueError(f"Expected 1D probs, got shape {probs.shape}")
128
- if not np.isfinite(probs).all():
129
- raise ValueError("Non-finite values in probabilities")
130
- positives = np.maximum(probs, 0)
131
- total = positives.sum()
132
- scaled = np.zeros_like(positives)
133
- if total > 0:
134
- scaled = positives / total * total_sum
135
- rounded = np.round(scaled).astype(int)
136
- diff = total_sum - int(rounded.sum())
137
- if diff != 0 and total > 0:
138
- rounded[int(np.argmax(positives))] += diff
139
- return rounded
140
-
141
- def safe_lookup(disease_name: str, meta: Dict[str, Any]) -> Tuple[str, str, str, str]:
142
- info = meta.get(disease_name, {})
143
- desc = info.get('description', '')
144
- symp = info.get('symptoms', '')
145
- causes = info.get('causes', '')
146
- treat = info.get('treatment-1', '') or info.get('treatment', '')
147
- return desc, symp, causes, treat
148
-
149
- def infer(tensor: np.ndarray) -> np.ndarray:
150
- if ORT is None:
151
- raise RuntimeError("ONNX session not initialized. Check model load errors in logs.")
152
- # Validate input against model input signature
153
- model_input = ORT.get_inputs()[0]
154
- if model_input.type not in ("tensor(float)", "tensor(float16)"):
155
- log.warning("Model expects %s but provided float32; ORT may cast automatically.", model_input.type)
156
- feed_name = model_input.name
157
- log.debug("Feeding input: name=%s shape=%s dtype=%s", feed_name, tensor.shape, tensor.dtype)
158
- out_list = ORT.run(None, {feed_name: tensor})
159
- if not out_list:
160
- raise RuntimeError("ONNX returned no outputs")
161
- logits = out_list
162
- log.debug("ONNX outputs: %d tensors", len(out_list))
163
- for i, o in enumerate(out_list):
164
- # Defensive stats
165
- try:
166
- min_v = float(np.nanmin(o))
167
- max_v = float(np.nanmax(o))
168
- mean_v = float(np.nanmean(o))
169
- log.debug("Output[%d]: shape=%s dtype=%s min=%.4f max=%.4f mean=%.4f",
170
- i, getattr(o, 'shape', '?'), getattr(o, 'dtype', '?'), min_v, max_v, mean_v)
171
- except Exception:
172
- log.debug("Output[%d]: non-numpy type %s", i, type(o))
173
- if isinstance(logits, list):
174
- logits = np.asarray(logits)
175
- logits = np.array(logits)
176
- if logits.ndim == 2:
177
- logits = logits[0]
178
- elif logits.ndim != 1:
179
- raise ValueError(f"Unexpected logits shape: {logits.shape}")
180
- if not np.isfinite(logits).all():
181
- raise ValueError("Logits contain non-finite values")
182
- return logits.astype(np.float32)
183
-
184
- # ------------------------------------------------------------------------------
185
- # Gradio Predict wrapper with robust error reporting
186
- # ------------------------------------------------------------------------------
187
- def predict_ui(image: Image.Image):
188
- try:
189
- # Startup errors surfaced in UI
190
- if "_startup_error" in META:
191
- raise RuntimeError(META["_startup_error"])
192
- if not KEYS:
193
- raise RuntimeError("Class list is empty. Verify dat.json content.")
194
- x = preprocess(image)
195
- logits = infer(x)
196
- if len(logits) != len(KEYS):
197
- raise ValueError(f"Logits length ({len(logits)}) != number of classes ({len(KEYS)}). "
198
- "Ensure label order matches model output.")
199
- pred_idx = int(np.argmax(logits))
200
- pred_name = KEYS[pred_idx]
201
- log.info("Prediction: idx=%d name=%s score=%.4f", pred_idx, pred_name, float(logits[pred_idx]))
202
- desc, symp, causes, treat = safe_lookup(pred_name, META)
203
- probs_int = probabilities_to_ints(logits)
204
- df = pd.DataFrame({"item": KEYS, "probability": probs_int.astype(int)})
205
- # Return without error
206
- return pred_name, desc, symp, causes, treat, df, ""
207
- except Exception as e:
208
- # Log and return placeholders plus an error message textbox
209
- err = log_exception("Inference failed")
210
- # Provide minimal but valid outputs so the UI doesn't crash
211
- empty_df = pd.DataFrame({"item": KEYS if KEYS else ["N/A"], "probability": [0]*(len(KEYS) if KEYS else 1)})
212
- return "Error", "", "", "", "", empty_df, err
213
-
214
- # ------------------------------------------------------------------------------
215
- # Gradio UI
216
- # ------------------------------------------------------------------------------
217
- bar_output = gr.BarPlot(
218
- x="item",
219
- y="probability",
220
- y_title="Probabilidad",
221
- x_title="Nombre de la Enfermedad",
222
- title="Distribucion de Probabilidad",
223
- tooltip=["item", "probability"],
224
- vertical=False
225
  )
226
 
227
- with gr.Blocks(title="Clasificacion de Enfermedades de la Piel") as demo:
228
- gr.Markdown(
229
- "Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela "
230
- "con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. "
231
- "Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n"
232
- "1) Queratosis Actinica \n\n"
233
- "2) Carcinoma Basocelular \n\n"
234
- "3) Dermatofibroma \n\n"
235
- "4) Melanoma \n\n"
236
- "5) Nevus \n\n"
237
- "6) Queratosis Pigmentada Benigna \n\n"
238
- "7) Queratosis Seborreica \n\n"
239
- "8) Carcinoma de Celulas Escamosas \n\n"
240
- "9) Lesion Vascular \n\n"
241
- )
242
-
243
- with gr.Row():
244
- img_in = gr.Image(type="pil", label="Imagen")
245
- with gr.Column():
246
- out_name = gr.Textbox(label='Nombre de la Enfermedad')
247
- out_desc = gr.Textbox(label='Descripcion')
248
- out_symp = gr.Textbox(label='Sintomas')
249
- out_causes = gr.Textbox(label='Causas')
250
- out_treat = gr.Textbox(label='Tratamiento')
251
- bar = bar_output
252
- err_box = gr.Textbox(label="Errores", interactive=False)
253
-
254
- btn = gr.Button("Predecir")
255
- btn.click(
256
- fn=predict_ui,
257
- inputs=[img_in],
258
- outputs=[out_name, out_desc, out_symp, out_causes, out_treat, bar, err_box]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  )
260
 
261
- # Do not set share=True in Spaces
262
- if __name__ == "__main__":
263
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import cv2
 
 
 
3
  import numpy as np
4
  import gradio as gr
5
  import onnxruntime
 
7
  from torchvision import transforms
8
  import pandas as pd
9
 
10
+ # Cargar el modelo ONNX
11
+ ort_session = onnxruntime.InferenceSession("model_new_new_final.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Abrir archivo JSON
14
+ with open('dat.json') as f:
15
+ data = json.load(f)
 
 
16
 
17
+ keys = list(data)
 
 
 
 
18
 
19
+ # DataFrame de ejemplo
20
+ simple = pd.DataFrame(
21
+ {
22
+ "item": keys,
23
+ "probability": [0] * len(keys)
24
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
 
27
+ def Predict(image):
28
+ # Preprocesar la imagen
29
+ img = cv2.resize(image, (100, 100))
30
+
31
+ # Convertir el arreglo NumPy de vuelta a una imagen PIL
32
+ image = Image.fromarray(image)
33
+
34
+ # Preprocesar la imagen
35
+ img = image.resize((100, 100))
36
+
37
+ # Definir transformaciones
38
+ test_tfms = transforms.Compose([
39
+ transforms.Resize((100, 100)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619])
42
+ ])
43
+
44
+ # Aplicar transformaciones
45
+ input_image = test_tfms(img).unsqueeze(0).numpy() # Agregar dimension de lote y convertir a arreglo numpy
46
+
47
+ # Preparar tensor de entrada
48
+ input_name = ort_session.get_inputs()[0].name
49
+ input_dict = {input_name: input_image}
50
+
51
+ # Ejecutar inferencia
52
+ output = ort_session.run(None, input_dict)
53
+
54
+ # Obtener el indice de la clase predicha
55
+ prediction_idx = np.argmax(output)
56
+
57
+ # Recuperar informacion del JSON basada en la clase predicha
58
+ disease_name = keys[prediction_idx]
59
+ description = data[disease_name]['description']
60
+ symptoms = data[disease_name]['symptoms']
61
+ causes = data[disease_name]['causes']
62
+ treatment = data[disease_name]['treatment-1']
63
+
64
+ # Obtener probabilidades para cada clase y convertirlas a enteros
65
+ probabilities = output[0]
66
+ ints_probabilities = probabilities_to_ints(probabilities)
67
+
68
+ # Actualizar la probabilidad en el DataFrame
69
+ simple["probability"] = ints_probabilities[0]
70
+
71
+ # Crear grafico de barras para las probabilidades
72
+ bar_plot = gr.BarPlot(
73
+ value=simple,
74
+ x="item",
75
+ y="probability",
76
+ y_title="Probabilidad",
77
+ x_title="Nombre de la Enfermedad",
78
+ title="Distribucion de Probabilidad",
79
+ tooltip=["item", "probability"],
80
+ vertical = False
81
  )
82
 
83
+ return disease_name, description, symptoms, causes, treatment, bar_plot
84
+
85
+ # Funcion para convertir probabilidades a enteros
86
+ def probabilities_to_ints(probabilities, total_sum=100):
87
+ # Filtrar los valores negativos
88
+ positive_values = np.maximum(probabilities, 0)
89
+
90
+ # Encontrar el peso positivo total
91
+ total_positive_weight = np.sum(positive_values)
92
+
93
+ # Calcular probabilidades escaladas para valores positivos
94
+ scaled_probabilities = np.zeros_like(probabilities)
95
+ if total_positive_weight > 0:
96
+ scaled_probabilities = positive_values / total_positive_weight * total_sum
97
+
98
+ # Redondear las probabilidades escaladas a enteros
99
+ rounded_probabilities = np.round(scaled_probabilities).astype(int)
100
+
101
+ # Ajustar por errores de redondeo
102
+ rounding_diff = total_sum - np.sum(rounded_probabilities)
103
+ if rounding_diff != 0 and np.sum(positive_values) > 0:
104
+ # Agregar la diferencia de redondeo a la clase con mayor peso positivo
105
+ max_positive_index = np.argmax(positive_values)
106
+ flattened_probabilities = rounded_probabilities.flatten()
107
+ flattened_probabilities[max_positive_index] += rounding_diff
108
+ rounded_probabilities = np.reshape(flattened_probabilities, rounded_probabilities.shape)
109
+
110
+ return rounded_probabilities
111
+
112
+ # Definir la interfaz Gradio
113
+ demo = gr.Interface(fn=Predict,
114
+ inputs="image",
115
+ outputs=[
116
+ gr.Textbox(label='Nombre de la Enfermedad', type="text"),
117
+ gr.Textbox(label='Descripcion', type="text"),
118
+ gr.Textbox(label='Sintomas', type="text"),
119
+ gr.Textbox(label='Causas', type="text"),
120
+ gr.Textbox(label='Tratamiento', type="text"),
121
+ "bar_plot"
122
+ ],
123
+ title="Clasificacion de Enfermedades de la Piel",
124
+ description = 'Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n1)Queratosis Actinica \n\n2)Carcinoma Basocelular \n\n3)Dermatofibroma \n\n4)Melanoma \n\n5)Nevus \n\n6)Queratosis Pigmentada Benigna \n\n7)Queratosis Seborreica \n\n8)Carcinoma de Celulas Escamosas \n\n9)Lesion Vascular \n\n')
125
+
126
+ demo.launch(debug=True)