Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import json
|
| 2 |
-
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
import onnxruntime
|
|
@@ -7,120 +6,113 @@ from PIL import Image
|
|
| 7 |
from torchvision import transforms
|
| 8 |
import pandas as pd
|
| 9 |
|
| 10 |
-
#
|
| 11 |
ort_session = onnxruntime.InferenceSession("model_new_new_final.onnx")
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
with open('dat.json') as f:
|
| 15 |
data = json.load(f)
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
def Predict(image):
|
| 28 |
-
# Preprocesar la imagen
|
| 29 |
-
img = cv2.resize(image, (100, 100))
|
| 30 |
-
|
| 31 |
-
# Convertir el arreglo NumPy de vuelta a una imagen PIL
|
| 32 |
-
image = Image.fromarray(image)
|
| 33 |
-
|
| 34 |
-
# Preprocesar la imagen
|
| 35 |
-
img = image.resize((100, 100))
|
| 36 |
-
|
| 37 |
-
# Definir transformaciones
|
| 38 |
-
test_tfms = transforms.Compose([
|
| 39 |
-
transforms.Resize((100, 100)),
|
| 40 |
-
transforms.ToTensor(),
|
| 41 |
-
transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619])
|
| 42 |
-
])
|
| 43 |
-
|
| 44 |
-
# Aplicar transformaciones
|
| 45 |
-
input_image = test_tfms(img).unsqueeze(0).numpy() # Agregar dimension de lote y convertir a arreglo numpy
|
| 46 |
-
|
| 47 |
-
# Preparar tensor de entrada
|
| 48 |
-
input_name = ort_session.get_inputs()[0].name
|
| 49 |
-
input_dict = {input_name: input_image}
|
| 50 |
-
|
| 51 |
-
# Ejecutar inferencia
|
| 52 |
-
output = ort_session.run(None, input_dict)
|
| 53 |
-
|
| 54 |
-
# Obtener el indice de la clase predicha
|
| 55 |
-
prediction_idx = np.argmax(output)
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
disease_name = keys[prediction_idx]
|
| 59 |
-
description = data[disease_name]['description']
|
| 60 |
-
symptoms = data[disease_name]['symptoms']
|
| 61 |
-
causes = data[disease_name]['causes']
|
| 62 |
-
treatment = data[disease_name]['treatment-1']
|
| 63 |
|
| 64 |
-
#
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
)
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# Funcion para convertir probabilidades a enteros
|
| 86 |
-
def probabilities_to_ints(probabilities, total_sum=100):
|
| 87 |
-
# Filtrar los valores negativos
|
| 88 |
-
positive_values = np.maximum(probabilities, 0)
|
| 89 |
-
|
| 90 |
-
# Encontrar el peso positivo total
|
| 91 |
-
total_positive_weight = np.sum(positive_values)
|
| 92 |
-
|
| 93 |
-
# Calcular probabilidades escaladas para valores positivos
|
| 94 |
-
scaled_probabilities = np.zeros_like(probabilities)
|
| 95 |
-
if total_positive_weight > 0:
|
| 96 |
-
scaled_probabilities = positive_values / total_positive_weight * total_sum
|
| 97 |
-
|
| 98 |
-
# Redondear las probabilidades escaladas a enteros
|
| 99 |
-
rounded_probabilities = np.round(scaled_probabilities).astype(int)
|
| 100 |
-
|
| 101 |
-
# Ajustar por errores de redondeo
|
| 102 |
-
rounding_diff = total_sum - np.sum(rounded_probabilities)
|
| 103 |
-
if rounding_diff != 0 and np.sum(positive_values) > 0:
|
| 104 |
-
# Agregar la diferencia de redondeo a la clase con mayor peso positivo
|
| 105 |
-
max_positive_index = np.argmax(positive_values)
|
| 106 |
-
flattened_probabilities = rounded_probabilities.flatten()
|
| 107 |
-
flattened_probabilities[max_positive_index] += rounding_diff
|
| 108 |
-
rounded_probabilities = np.reshape(flattened_probabilities, rounded_probabilities.shape)
|
| 109 |
-
|
| 110 |
-
return rounded_probabilities
|
| 111 |
-
|
| 112 |
-
# Definir la interfaz Gradio
|
| 113 |
-
demo = gr.Interface(fn=Predict,
|
| 114 |
-
inputs="image",
|
| 115 |
-
outputs=[
|
| 116 |
-
gr.Textbox(label='Nombre de la Enfermedad', type="text"),
|
| 117 |
-
gr.Textbox(label='Descripcion', type="text"),
|
| 118 |
-
gr.Textbox(label='Sintomas', type="text"),
|
| 119 |
-
gr.Textbox(label='Causas', type="text"),
|
| 120 |
-
gr.Textbox(label='Tratamiento', type="text"),
|
| 121 |
-
"bar_plot"
|
| 122 |
-
],
|
| 123 |
-
title="Clasificacion de Enfermedades de la Piel",
|
| 124 |
-
description = 'Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n1)Queratosis Actinica \n\n2)Carcinoma Basocelular \n\n3)Dermatofibroma \n\n4)Melanoma \n\n5)Nevus \n\n6)Queratosis Pigmentada Benigna \n\n7)Queratosis Seborreica \n\n8)Carcinoma de Celulas Escamosas \n\n9)Lesion Vascular \n\n')
|
| 125 |
-
|
| 126 |
demo.launch(debug=True)
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import gradio as gr
|
| 4 |
import onnxruntime
|
|
|
|
| 6 |
from torchvision import transforms
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
+
# Load ONNX model
|
| 10 |
ort_session = onnxruntime.InferenceSession("model_new_new_final.onnx")
|
| 11 |
|
| 12 |
+
# Load metadata JSON
|
| 13 |
+
with open('dat.json', 'r', encoding='utf-8') as f:
|
| 14 |
data = json.load(f)
|
| 15 |
|
| 16 |
+
# Ensure deterministic class ordering
|
| 17 |
+
keys = list(data.keys())
|
| 18 |
|
| 19 |
+
# Define transforms (PIL -> tensor NCHW)
|
| 20 |
+
test_tfms = transforms.Compose([
|
| 21 |
+
transforms.Resize((100, 100)),
|
| 22 |
+
transforms.ToTensor(),
|
| 23 |
+
transforms.Normalize(mean=[0.7611, 0.5869, 0.5923], std=[0.1266, 0.1487, 0.1619])
|
| 24 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def probabilities_to_ints(probabilities, total_sum=100):
|
| 27 |
+
probs = np.asarray(probabilities, dtype=np.float64)
|
| 28 |
+
probs = np.maximum(probs, 0)
|
| 29 |
+
total = probs.sum()
|
| 30 |
+
scaled = np.zeros_like(probs)
|
| 31 |
+
if total > 0:
|
| 32 |
+
scaled = probs / total * total_sum
|
| 33 |
+
rounded = np.round(scaled).astype(int)
|
| 34 |
+
diff = total_sum - int(rounded.sum())
|
| 35 |
+
if diff != 0 and total > 0:
|
| 36 |
+
rounded[int(np.argmax(probs))] += diff
|
| 37 |
+
return rounded
|
| 38 |
+
|
| 39 |
+
def Predict(image: Image.Image):
|
| 40 |
+
# Input is PIL.Image when inputs=gr.Image(type="pil")
|
| 41 |
+
img = image.convert("RGB")
|
| 42 |
+
tensor = test_tfms(img).unsqueeze(0).numpy().astype(np.float32) # (1,C,H,W)
|
| 43 |
+
|
| 44 |
+
# ONNX inference
|
| 45 |
+
input_name = ort_session.get_inputs().name
|
| 46 |
+
outputs = ort_session.run(None, {input_name: tensor})
|
| 47 |
+
logits = outputs
|
| 48 |
+
|
| 49 |
+
# Flatten to 1D scores
|
| 50 |
+
if logits.ndim == 2:
|
| 51 |
+
scores = logits
|
| 52 |
+
elif logits.ndim == 1:
|
| 53 |
+
scores = logits
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Unexpected logits shape: {logits.shape}")
|
| 56 |
+
|
| 57 |
+
# Predicted class
|
| 58 |
+
prediction_idx = int(np.argmax(scores))
|
| 59 |
disease_name = keys[prediction_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
# Lookup metadata (safe gets)
|
| 62 |
+
info = data.get(disease_name, {})
|
| 63 |
+
description = info.get('description', '')
|
| 64 |
+
symptoms = info.get('symptoms', '')
|
| 65 |
+
causes = info.get('causes', '')
|
| 66 |
+
treatment = info.get('treatment-1', info.get('treatment', ''))
|
| 67 |
|
| 68 |
+
# Build probabilities DataFrame for the bar plot
|
| 69 |
+
probs_int = probabilities_to_ints(scores)
|
| 70 |
+
df = pd.DataFrame({
|
| 71 |
+
"item": keys,
|
| 72 |
+
"probability": probs_int.astype(int)
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
# Return values matching declared outputs
|
| 76 |
+
return disease_name, description, symptoms, causes, treatment, df
|
| 77 |
+
|
| 78 |
+
# Declare a BarPlot output component that will receive a DataFrame
|
| 79 |
+
bar_output = gr.BarPlot(
|
| 80 |
+
x="item",
|
| 81 |
+
y="probability",
|
| 82 |
+
y_title="Probabilidad",
|
| 83 |
+
x_title="Nombre de la Enfermedad",
|
| 84 |
+
title="Distribucion de Probabilidad",
|
| 85 |
+
tooltip=["item", "probability"],
|
| 86 |
+
vertical=False
|
| 87 |
+
)
|
| 88 |
|
| 89 |
+
demo = gr.Interface(
|
| 90 |
+
fn=Predict,
|
| 91 |
+
inputs=gr.Image(type="pil", label="Imagen"),
|
| 92 |
+
outputs=[
|
| 93 |
+
gr.Textbox(label='Nombre de la Enfermedad'),
|
| 94 |
+
gr.Textbox(label='Descripcion'),
|
| 95 |
+
gr.Textbox(label='Sintomas'),
|
| 96 |
+
gr.Textbox(label='Causas'),
|
| 97 |
+
gr.Textbox(label='Tratamiento'),
|
| 98 |
+
bar_output
|
| 99 |
+
],
|
| 100 |
+
title="Clasificacion de Enfermedades de la Piel",
|
| 101 |
+
description=(
|
| 102 |
+
'Este espacio se ha desarrollado como parte de una tesis para la Universidad Central de Venezuela '
|
| 103 |
+
'con el proposito de realizar diagnosticos precisos sobre una variedad de lesiones cutaneas. '
|
| 104 |
+
'Su objetivo es ayudar en la identificacion temprana y precisa de condiciones dermatologicas, incluyendo:\n\n'
|
| 105 |
+
'1) Queratosis Actinica \n\n'
|
| 106 |
+
'2) Carcinoma Basocelular \n\n'
|
| 107 |
+
'3) Dermatofibroma \n\n'
|
| 108 |
+
'4) Melanoma \n\n'
|
| 109 |
+
'5) Nevus \n\n'
|
| 110 |
+
'6) Queratosis Pigmentada Benigna \n\n'
|
| 111 |
+
'7) Queratosis Seborreica \n\n'
|
| 112 |
+
'8) Carcinoma de Celulas Escamosas \n\n'
|
| 113 |
+
'9) Lesion Vascular \n\n'
|
| 114 |
)
|
| 115 |
+
)
|
| 116 |
|
| 117 |
+
# In Spaces, do not set share=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
demo.launch(debug=True)
|