dcavadia commited on
Commit
7a6e646
·
verified ·
1 Parent(s): ee28ec8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -87
app.py CHANGED
@@ -1,26 +1,41 @@
 
 
1
  import json
 
 
2
  import numpy as np
 
3
  import gradio as gr
4
  import onnxruntime
5
  from PIL import Image
6
  from torchvision import transforms
7
- import pandas as pd
8
 
9
- # Load ONNX model
10
- ort_session = onnxruntime.InferenceSession("model_new_new_final.onnx")
 
11
 
12
- # Load metadata JSON
13
- with open('dat.json', 'r', encoding='utf-8') as f:
14
- data = json.load(f)
 
 
15
 
16
- # Ensure deterministic class ordering
 
 
17
  keys = list(data.keys())
 
 
 
 
 
 
18
 
19
- # Define transforms (PIL -> tensor NCHW)
20
- test_tfms = transforms.Compose([
21
  transforms.Resize((100, 100)),
22
  transforms.ToTensor(),
23
- transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619])
24
  ])
25
 
26
  def probabilities_to_ints(probabilities, total_sum=100):
@@ -36,83 +51,61 @@ def probabilities_to_ints(probabilities, total_sum=100):
36
  rounded[int(np.argmax(probs))] += diff
37
  return rounded
38
 
39
- def Predict(image: Image.Image):
40
- # Input is PIL.Image when inputs=gr.Image(type="pil")
41
- img = image.convert("RGB")
42
- tensor = test_tfms(img).unsqueeze(0).numpy().astype(np.float32) # (1,C,H,W)
43
-
44
- # ONNX inference
45
- input_name = ort_session.get_inputs().name
46
- outputs = ort_session.run(None, {input_name: tensor})
47
- logits = outputs
48
-
49
- # Flatten to 1D scores
50
- if logits.ndim == 2:
51
- scores = logits
52
- elif logits.ndim == 1:
53
- scores = logits
54
- else:
55
- raise ValueError(f"Unexpected logits shape: {logits.shape}")
56
-
57
- # Predicted class
58
- prediction_idx = int(np.argmax(scores))
59
- disease_name = keys[prediction_idx]
60
-
61
- # Lookup metadata (safe gets)
62
- info = data.get(disease_name, {})
63
- description = info.get('description', '')
64
- symptoms = info.get('symptoms', '')
65
- causes = info.get('causes', '')
66
- treatment = info.get('treatment-1', info.get('treatment', ''))
67
-
68
- # Build probabilities DataFrame for the bar plot
69
- probs_int = probabilities_to_ints(scores)
70
- df = pd.DataFrame({
71
- "item": keys,
72
- "probability": probs_int.astype(int)
73
- })
74
-
75
- # Return values matching declared outputs
76
- return disease_name, description, symptoms, causes, treatment, df
77
-
78
- # Declare a BarPlot output component that will receive a DataFrame
79
- bar_output = gr.BarPlot(
80
- x="item",
81
- y="probability",
82
- y_title="Probabilidad",
83
- x_title="Nombre de la Enfermedad",
84
- title="Distribucion de Probabilidad",
85
- tooltip=["item", "probability"],
86
- vertical=False
87
- )
88
 
89
- demo = gr.Interface(
90
- fn=Predict,
91
- inputs=gr.Image(type="pil", label="Imagen"),
92
- outputs=[
93
- gr.Textbox(label='Nombre de la Enfermedad'),
94
- gr.Textbox(label='Descripcion'),
95
- gr.Textbox(label='Sintomas'),
96
- gr.Textbox(label='Causas'),
97
- gr.Textbox(label='Tratamiento'),
98
- bar_output
99
- ],
100
- title="Clasificacion de Enfermedades de la Piel",
101
- description=(
102
- 'Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela '
103
- 'con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. '
104
- 'Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n'
105
- '1) Queratosis Actinica \n\n'
106
- '2) Carcinoma Basocelular \n\n'
107
- '3) Dermatofibroma \n\n'
108
- '4) Melanoma \n\n'
109
- '5) Nevus \n\n'
110
- '6) Queratosis Pigmentada Benigna \n\n'
111
- '7) Queratosis Seborreica \n\n'
112
- '8) Carcinoma de Celulas Escamosas \n\n'
113
- '9) Lesion Vascular \n\n'
114
  )
115
- )
 
 
116
 
117
- # In Spaces, do not set share=True
118
- demo.launch(debug=True)
 
 
1
+ import os
2
+ import sys
3
  import json
4
+ import logging
5
+ import traceback
6
  import numpy as np
7
+ import pandas as pd
8
  import gradio as gr
9
  import onnxruntime
10
  from PIL import Image
11
  from torchvision import transforms
 
12
 
13
+ # Logging
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", stream=sys.stdout)
15
+ log = logging.getLogger("space")
16
 
17
+ def log_exc(prefix):
18
+ etype, evalue, tb = sys.exc_info()
19
+ stack = "".join(traceback.format_exception(etype, evalue, tb))
20
+ log.error("%s: %s\n%s", prefix, evalue, stack)
21
+ return f"{prefix}: {evalue}"
22
 
23
+ # Load metadata
24
+ with open("dat.json", "r", encoding="utf-8") as f:
25
+ data = json.load(f)
26
  keys = list(data.keys())
27
+ log.info("Loaded %d classes from dat.json", len(keys))
28
+
29
+ # Load ONNX
30
+ ort = onnxruntime.InferenceSession("model_new_new_final.onnx")
31
+ log.info("ONNX inputs: %s", [(i.name, i.shape, i.type) for i in ort.get_inputs()])
32
+ log.info("ONNX outputs: %s", [(o.name, o.shape, o.type) for o in ort.get_outputs()])
33
 
34
+ # Preprocess
35
+ tfms = transforms.Compose([
36
  transforms.Resize((100, 100)),
37
  transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619]),
39
  ])
40
 
41
  def probabilities_to_ints(probabilities, total_sum=100):
 
51
  rounded[int(np.argmax(probs))] += diff
52
  return rounded
53
 
54
+ def predict(image: Image.Image):
55
+ try:
56
+ if image is None:
57
+ return "Error", "", "", "", "", pd.DataFrame({"item": keys, "probability": *len(keys)}), "No image provided"
58
+ pil = image.convert("RGB")
59
+ x = tfms(pil).unsqueeze(0).numpy().astype(np.float32) # (1,C,H,W)
60
+ input_name = ort.get_inputs().name
61
+ outs = ort.run(None, {input_name: x})
62
+ logits = outs
63
+ if logits.ndim == 2:
64
+ scores = logits
65
+ elif logits.ndim == 1:
66
+ scores = logits
67
+ else:
68
+ raise ValueError(f"Unexpected logits shape: {logits.shape}")
69
+ if len(scores) != len(keys):
70
+ raise ValueError(f"Logits length {len(scores)} != classes {len(keys)}")
71
+ idx = int(np.argmax(scores))
72
+ name = keys[idx]
73
+ meta = data.get(name, {})
74
+ desc = meta.get("description", "")
75
+ symp = meta.get("symptoms", "")
76
+ causes = meta.get("causes", "")
77
+ treat = meta.get("treatment-1", meta.get("treatment", ""))
78
+ df = pd.DataFrame({"item": keys, "probability": probabilities_to_ints(scores).astype(int)})
79
+ return name, desc, symp, causes, treat, df, ""
80
+ except Exception:
81
+ err = log_exc("Inference failed")
82
+ df = pd.DataFrame({"item": keys if keys else ["N/A"], "probability": *(len(keys) if keys else 1)})
83
+ return "Error", "", "", "", "", df, err
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ with gr.Blocks(title="Clasificacion de Enfermedades de la Piel") as demo:
86
+ gr.Markdown("Suba una imagen y ejecute la prediccion.")
87
+ with gr.Row():
88
+ img = gr.Image(type="pil", label="Imagen")
89
+ with gr.Column():
90
+ out_name = gr.Textbox(label="Nombre de la Enfermedad")
91
+ out_desc = gr.Textbox(label="Descripcion")
92
+ out_symp = gr.Textbox(label="Sintomas")
93
+ out_causes = gr.Textbox(label="Causas")
94
+ out_treat = gr.Textbox(label="Tratamiento")
95
+ bar = gr.BarPlot(
96
+ x="item",
97
+ y="probability",
98
+ title="Distribucion de Probabilidad",
99
+ x_title="Nombre de la Enfermedad",
100
+ y_title="Probabilidad",
101
+ tooltip=["item", "probability"],
102
+ vertical=False,
103
+ label="Probabilidades"
 
 
 
 
 
 
104
  )
105
+ err = gr.Textbox(label="Errores", interactive=False)
106
+ btn = gr.Button("Predecir")
107
+ btn.click(fn=predict, inputs=[img], outputs=[out_name, out_desc, out_symp, out_causes, out_treat, bar, err])
108
 
109
+ if __name__ == "__main__":
110
+ # Spaces handles networking; no share=True
111
+ demo.launch(debug=True)