dcavadia commited on
Commit
ee28ec8
·
verified ·
1 Parent(s): de49f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -105
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import cv2
3
  import numpy as np
4
  import gradio as gr
5
  import onnxruntime
@@ -7,120 +6,113 @@ from PIL import Image
7
  from torchvision import transforms
8
  import pandas as pd
9
 
10
- # Cargar el modelo ONNX
11
  ort_session = onnxruntime.InferenceSession("model_new_new_final.onnx")
12
 
13
- # Abrir archivo JSON
14
- with open('dat.json') as f:
15
  data = json.load(f)
16
 
17
- keys = list(data)
 
18
 
19
- # DataFrame de ejemplo
20
- simple = pd.DataFrame(
21
- {
22
- "item": keys,
23
- "probability": [0] * len(keys)
24
- }
25
- )
26
-
27
- def Predict(image):
28
- # Preprocesar la imagen
29
- img = cv2.resize(image, (100, 100))
30
-
31
- # Convertir el arreglo NumPy de vuelta a una imagen PIL
32
- image = Image.fromarray(image)
33
-
34
- # Preprocesar la imagen
35
- img = image.resize((100, 100))
36
-
37
- # Definir transformaciones
38
- test_tfms = transforms.Compose([
39
- transforms.Resize((100, 100)),
40
- transforms.ToTensor(),
41
- transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619])
42
- ])
43
-
44
- # Aplicar transformaciones
45
- input_image = test_tfms(img).unsqueeze(0).numpy() # Agregar dimension de lote y convertir a arreglo numpy
46
-
47
- # Preparar tensor de entrada
48
- input_name = ort_session.get_inputs()[0].name
49
- input_dict = {input_name: input_image}
50
-
51
- # Ejecutar inferencia
52
- output = ort_session.run(None, input_dict)
53
-
54
- # Obtener el indice de la clase predicha
55
- prediction_idx = np.argmax(output)
56
 
57
- # Recuperar informacion del JSON basada en la clase predicha
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  disease_name = keys[prediction_idx]
59
- description = data[disease_name]['description']
60
- symptoms = data[disease_name]['symptoms']
61
- causes = data[disease_name]['causes']
62
- treatment = data[disease_name]['treatment-1']
63
 
64
- # Obtener probabilidades para cada clase y convertirlas a enteros
65
- probabilities = output[0]
66
- ints_probabilities = probabilities_to_ints(probabilities)
 
 
 
67
 
68
- # Actualizar la probabilidad en el DataFrame
69
- simple["probability"] = ints_probabilities[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # Crear grafico de barras para las probabilidades
72
- bar_plot = gr.BarPlot(
73
- value=simple,
74
- x="item",
75
- y="probability",
76
- y_title="Probabilidad",
77
- x_title="Nombre de la Enfermedad",
78
- title="Distribucion de Probabilidad",
79
- tooltip=["item", "probability"],
80
- vertical = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
 
82
 
83
- return disease_name, description, symptoms, causes, treatment, bar_plot
84
-
85
- # Funcion para convertir probabilidades a enteros
86
- def probabilities_to_ints(probabilities, total_sum=100):
87
- # Filtrar los valores negativos
88
- positive_values = np.maximum(probabilities, 0)
89
-
90
- # Encontrar el peso positivo total
91
- total_positive_weight = np.sum(positive_values)
92
-
93
- # Calcular probabilidades escaladas para valores positivos
94
- scaled_probabilities = np.zeros_like(probabilities)
95
- if total_positive_weight > 0:
96
- scaled_probabilities = positive_values / total_positive_weight * total_sum
97
-
98
- # Redondear las probabilidades escaladas a enteros
99
- rounded_probabilities = np.round(scaled_probabilities).astype(int)
100
-
101
- # Ajustar por errores de redondeo
102
- rounding_diff = total_sum - np.sum(rounded_probabilities)
103
- if rounding_diff != 0 and np.sum(positive_values) > 0:
104
- # Agregar la diferencia de redondeo a la clase con mayor peso positivo
105
- max_positive_index = np.argmax(positive_values)
106
- flattened_probabilities = rounded_probabilities.flatten()
107
- flattened_probabilities[max_positive_index] += rounding_diff
108
- rounded_probabilities = np.reshape(flattened_probabilities, rounded_probabilities.shape)
109
-
110
- return rounded_probabilities
111
-
112
- # Definir la interfaz Gradio
113
- demo = gr.Interface(fn=Predict,
114
- inputs="image",
115
- outputs=[
116
- gr.Textbox(label='Nombre de la Enfermedad', type="text"),
117
- gr.Textbox(label='Descripcion', type="text"),
118
- gr.Textbox(label='Sintomas', type="text"),
119
- gr.Textbox(label='Causas', type="text"),
120
- gr.Textbox(label='Tratamiento', type="text"),
121
- "bar_plot"
122
- ],
123
- title="Clasificacion de Enfermedades de la Piel",
124
- description = 'Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n1)Queratosis Actinica \n\n2)Carcinoma Basocelular \n\n3)Dermatofibroma \n\n4)Melanoma \n\n5)Nevus \n\n6)Queratosis Pigmentada Benigna \n\n7)Queratosis Seborreica \n\n8)Carcinoma de Celulas Escamosas \n\n9)Lesion Vascular \n\n')
125
-
126
  demo.launch(debug=True)
 
1
  import json
 
2
  import numpy as np
3
  import gradio as gr
4
  import onnxruntime
 
6
  from torchvision import transforms
7
  import pandas as pd
8
 
9
+ # Load ONNX model
10
  ort_session = onnxruntime.InferenceSession("model_new_new_final.onnx")
11
 
12
+ # Load metadata JSON
13
+ with open('dat.json', 'r', encoding='utf-8') as f:
14
  data = json.load(f)
15
 
16
+ # Ensure deterministic class ordering
17
+ keys = list(data.keys())
18
 
19
+ # Define transforms (PIL -> tensor NCHW)
20
+ test_tfms = transforms.Compose([
21
+ transforms.Resize((100, 100)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619])
24
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def probabilities_to_ints(probabilities, total_sum=100):
27
+ probs = np.asarray(probabilities, dtype=np.float64)
28
+ probs = np.maximum(probs, 0)
29
+ total = probs.sum()
30
+ scaled = np.zeros_like(probs)
31
+ if total > 0:
32
+ scaled = probs / total * total_sum
33
+ rounded = np.round(scaled).astype(int)
34
+ diff = total_sum - int(rounded.sum())
35
+ if diff != 0 and total > 0:
36
+ rounded[int(np.argmax(probs))] += diff
37
+ return rounded
38
+
39
+ def Predict(image: Image.Image):
40
+ # Input is PIL.Image when inputs=gr.Image(type="pil")
41
+ img = image.convert("RGB")
42
+ tensor = test_tfms(img).unsqueeze(0).numpy().astype(np.float32) # (1,C,H,W)
43
+
44
+ # ONNX inference
45
+ input_name = ort_session.get_inputs().name
46
+ outputs = ort_session.run(None, {input_name: tensor})
47
+ logits = outputs
48
+
49
+ # Flatten to 1D scores
50
+ if logits.ndim == 2:
51
+ scores = logits
52
+ elif logits.ndim == 1:
53
+ scores = logits
54
+ else:
55
+ raise ValueError(f"Unexpected logits shape: {logits.shape}")
56
+
57
+ # Predicted class
58
+ prediction_idx = int(np.argmax(scores))
59
  disease_name = keys[prediction_idx]
 
 
 
 
60
 
61
+ # Lookup metadata (safe gets)
62
+ info = data.get(disease_name, {})
63
+ description = info.get('description', '')
64
+ symptoms = info.get('symptoms', '')
65
+ causes = info.get('causes', '')
66
+ treatment = info.get('treatment-1', info.get('treatment', ''))
67
 
68
+ # Build probabilities DataFrame for the bar plot
69
+ probs_int = probabilities_to_ints(scores)
70
+ df = pd.DataFrame({
71
+ "item": keys,
72
+ "probability": probs_int.astype(int)
73
+ })
74
+
75
+ # Return values matching declared outputs
76
+ return disease_name, description, symptoms, causes, treatment, df
77
+
78
+ # Declare a BarPlot output component that will receive a DataFrame
79
+ bar_output = gr.BarPlot(
80
+ x="item",
81
+ y="probability",
82
+ y_title="Probabilidad",
83
+ x_title="Nombre de la Enfermedad",
84
+ title="Distribucion de Probabilidad",
85
+ tooltip=["item", "probability"],
86
+ vertical=False
87
+ )
88
 
89
+ demo = gr.Interface(
90
+ fn=Predict,
91
+ inputs=gr.Image(type="pil", label="Imagen"),
92
+ outputs=[
93
+ gr.Textbox(label='Nombre de la Enfermedad'),
94
+ gr.Textbox(label='Descripcion'),
95
+ gr.Textbox(label='Sintomas'),
96
+ gr.Textbox(label='Causas'),
97
+ gr.Textbox(label='Tratamiento'),
98
+ bar_output
99
+ ],
100
+ title="Clasificacion de Enfermedades de la Piel",
101
+ description=(
102
+ 'Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela '
103
+ 'con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. '
104
+ 'Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n'
105
+ '1) Queratosis Actinica \n\n'
106
+ '2) Carcinoma Basocelular \n\n'
107
+ '3) Dermatofibroma \n\n'
108
+ '4) Melanoma \n\n'
109
+ '5) Nevus \n\n'
110
+ '6) Queratosis Pigmentada Benigna \n\n'
111
+ '7) Queratosis Seborreica \n\n'
112
+ '8) Carcinoma de Celulas Escamosas \n\n'
113
+ '9) Lesion Vascular \n\n'
114
  )
115
+ )
116
 
117
+ # In Spaces, do not set share=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  demo.launch(debug=True)