import gradio as gr from PIL import Image import numpy as np import tensorflow as tf import os # Configurar variables de entorno para reducir advertencias os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Configuración inicial cifar10_labels = [ 'avión', 'automóvil', 'pájaro', 'gato', 'venado', 'perro', 'rana', 'caballo', 'barco', 'camión' ] model = tf.keras.models.load_model('my_model.keras') def preprocess_image(image): """Preprocesado de imagen para el modelo""" img = image.resize((32, 32)).convert('RGB') return np.array(img).astype('float32') / 255 def predict(image): """Realizar predicción y formatear resultados""" if image is None: raise gr.Error("¡Por favor sube una imagen o toma una foto!") processed_img = preprocess_image(image) preds = model.predict(np.expand_dims(processed_img, axis=0))[0] return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)} # Configurar ejemplos examples = [ ["ejemplos/avion.jpg"], ["ejemplos/automovil.jpg"], ["ejemplos/pajaro.jpg"], ["ejemplos/gato.jpg"], ["ejemplos/venado.jpg"], ["ejemplos/perro.jpg"], ["ejemplos/rana.jpg"], ["ejemplos/caballo.jpg"], ["ejemplos/barco.jpg"], ["ejemplos/camion.jpg"] ] # Construir interfaz with gr.Blocks(theme=gr.themes.Soft(), css=""" .examples-grid {display: flex !important; flex-direction: column; gap: 1rem} .examples-row {display: flex !important; gap: 1rem; justify-content: center} """) as app: gr.Markdown("# 📷 Clasificador CIFAR-10") with gr.Row(): with gr.Column(): input_image = gr.Image( sources=["upload", "webcam", "clipboard"], type="pil", label="Entrada de imagen", height=250 ) with gr.Row(): submit_btn = gr.Button("Predecir", variant="primary") clear_btn = gr.Button("Limpiar") with gr.Column(): output_label = gr.Label( label="Resultados", num_top_classes=3, show_label=True ) # Sección de ejemplos con interacción gr.Markdown("## Ejemplos de categorías") with gr.Column(elem_classes=["examples-grid"]): # Primera fila with gr.Row(elem_classes=["examples-row"]): for example, label in zip(examples[:5], cifar10_labels[:5]): gr.Examples( examples=example, inputs=[input_image], label=label.capitalize(), examples_per_page=1, fn=predict, outputs=[output_label], ) # Segunda fila with gr.Row(elem_classes=["examples-row"]): for example, label in zip(examples[5:], cifar10_labels[5:]): gr.Examples( examples=example, inputs=[input_image], label=label.capitalize(), examples_per_page=1, fn=predict, outputs=[output_label], ) # Conectar eventos submit_btn.click( fn=predict, inputs=input_image, outputs=output_label, api_name="predict" ) clear_btn.click( fn=lambda: [None, None], inputs=None, outputs=[input_image, output_label] ) if __name__ == "__main__": app.launch()