import gradio as gr
import tensorflow as tf
import numpy as np
import time # Para simular un retraso en la carga del modelo
# --- Configuración ---
IMG_SIZE = (224, 224)
MODEL_PATH = "dental_classifier_model.keras" # Asegúrate de que esta ruta sea correcta
CLASS_NAMES = ['no_valido', 'valido']
# --- Cargar Modelo con mensaje de carga ---
model = None
model_load_message = "Cargando modelo... por favor espera."
try:
# Simular una carga lenta para ver el mensaje (puedes quitar esto en producción)
time.sleep(2)
model = tf.keras.models.load_model(MODEL_PATH)
model_load_message = "Modelo cargado exitosamente."
print("Modelo cargado exitosamente.")
except Exception as e:
model_load_message = f"Error cargando el modelo: {e}. Asegúrate que 'dental_classifier_model.keras' existe."
print(model_load_message)
# --- Funciones de Procesamiento ---
def preprocess_image(img):
"""Preprocesa la imagen de entrada al formato que espera el modelo."""
if img is None:
return None
# Asegurarse de que la imagen tiene 3 canales si es en escala de grises
if len(img.shape) == 2:
img = np.stack((img,)*3, axis=-1)
elif img.shape[2] == 4: #RGBA a RGB
img = img[:, :, :3]
img = tf.image.resize(img, IMG_SIZE)
img_array = tf.expand_dims(img, 0)
img_array = img_array / 255.0 # Normalizar
return img_array
def predecir(rx_image):
"""Realiza la predicción y formatea la salida HTML."""
if model is None:
return "
Error: El modelo no se ha cargado.
"
if rx_image is None:
return "Por favor, sube una imagen RX para analizar.
"
img_array = preprocess_image(rx_image)
if img_array is None: # Si preprocess_image devuelve None por algún motivo
return "Error: No se pudo procesar la imagen.
"
try:
preds = model.predict(img_array)
except Exception as e:
print(f"Error durante la predicción: {e}")
return f"Error al realizar la predicción: {e}
"
score = tf.nn.softmax(preds[0])
predicted_index = np.argmax(score)
confidence = np.max(score) * 100
predicted_class = CLASS_NAMES[predicted_index]
other_index = 1 - predicted_index
other_class = CLASS_NAMES[other_index]
other_confidence = score[other_index] * 100
# Colores para el borde. El color del texto se heredará del CSS.
color_borde = "#4CAF50" if predicted_class == "valido" else "#FF6B6B" # Verde para válido, Rojo para no válido
# HTML para el resultado con estilos mejorados
# Eliminamos los 'style' de color de texto para que se hereden de la clase 'result-box'
resultado_texto = f"""
Resultado: {predicted_class.upper()}
Confianza: {confidence:.2f}%
(Probabilidad {other_class}: {other_confidence:.2f}%)
"""
return resultado_texto
# --- Interfaz de Gradio ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue=gr.themes.colors.emerald, secondary_hue=gr.themes.colors.slate)) as demo:
# --- Estilos CSS personalizados (Adaptados a Dark Mode) ---
gr.HTML(f"""
""")
gr.Markdown("## Clasificador RX LAB 🦷 V1(529NV-348V) TFG Marta B.")
gr.Markdown("Sube una imagen de una radiografía dental para clasificarla como válida o no válida.
")
# Mensaje de carga del modelo
gr.Textbox(value=model_load_message, interactive=False, container=False,
show_label=False, elem_id="model_status_message",
label="Estado del Modelo",
render=True, # Asegura que se renderiza inicialmente
info="El modelo está cargando..." if "Cargando" in model_load_message else None,
visible=True if "Cargando" in model_load_message or "Error" in model_load_message else False,
)
# Texto de estado inicial para la caja de resultados
initial_result_html = f""
with gr.Row(variant="panel", scale=1):
with gr.Column(scale=1, min_width=400):
gr.Markdown("### Sube tu Radiografía")
rx_input = gr.Image(type="numpy", label="Imagen de Radiografía Dental", show_label=True, height=450)
with gr.Row():
boton_limpiar = gr.Button("Limpiar", variant="secondary", size="lg", )
boton_analizar = gr.Button("Analizar RX", variant="primary", size="lg",)
with gr.Column(scale=1, min_width=400):
gr.Markdown("### Resultado del Análisis")
resultado = gr.HTML(label="Análisis de Radiografía", show_label=True, value=initial_result_html)
# --- Conexiones de Eventos ---
boton_analizar.click(fn=predecir, inputs=rx_input, outputs=resultado)
boton_limpiar.click(lambda: (None, initial_result_html), inputs=[], outputs=[rx_input, resultado])
# --- Lanzar la App ---
if __name__ == "__main__":
demo.launch(share=False)