CIFAR7 / app.py
Kellyss's picture
Update app.py
65cb2b9 verified
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
import os
# Configurar variables de entorno para reducir advertencias
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Configuraci贸n inicial
cifar10_labels = [
'avi贸n', 'autom贸vil', 'p谩jaro', 'gato', 'venado',
'perro', 'rana'
]
model = tf.keras.models.load_model('my_model.keras')
def preprocess_image(image):
"""Preprocesado de imagen para el modelo"""
img = image.resize((32, 32)).convert('RGB')
return np.array(img).astype('float32') / 255
def predict(image):
"""Realizar predicci贸n y formatear resultados"""
if image is None:
raise gr.Error("隆Por favor sube una imagen o toma una foto!")
processed_img = preprocess_image(image)
preds = model.predict(np.expand_dims(processed_img, axis=0))[0]
return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)}
# Configurar ejemplos
examples = [
["ejemplos/avion.jpg"],
["ejemplos/automovil.jpg"],
["ejemplos/pajaro.jpg"],
["ejemplos/gato.jpg"],
["ejemplos/venado.jpg"],
["ejemplos/perro.jpg"], # A帽adido perro
["ejemplos/rana.jpg"], # A帽adido rana
]
# Construir interfaz
with gr.Blocks(theme=gr.themes.Soft(), css="""
.examples-grid {display: flex !important; flex-direction: column; gap: 1rem}
.examples-row {display: flex !important; gap: 1rem; justify-content: center}
""") as app:
gr.Markdown("# CIFAR-10 con 7 clases")
with gr.Row():
with gr.Column():
input_image = gr.Image(
sources=["upload", "webcam", "clipboard"],
type="pil",
label="Entrada de imagen",
height=250
)
with gr.Row():
submit_btn = gr.Button("Predecir", variant="primary")
clear_btn = gr.Button("Limpiar")
with gr.Column():
output_label = gr.Label(
label="Resultados",
num_top_classes=3,
show_label=True
)
# Secci贸n de ejemplos con interacci贸n
gr.Markdown("## Ejemplos de categor铆as")
with gr.Column(elem_classes=["examples-grid"]):
# Primera fila
with gr.Row(elem_classes=["examples-row"]):
for example, label in zip(examples[:5], cifar10_labels[:5]):
gr.Examples(
examples=example,
inputs=[input_image],
label=label.capitalize(),
examples_per_page=1,
fn=predict,
outputs=[output_label],
)
# Segunda fila
with gr.Row(elem_classes=["examples-row"]):
for example, label in zip(examples[5:], cifar10_labels[5:]):
gr.Examples(
examples=example,
inputs=[input_image],
label=label.capitalize(),
examples_per_page=1,
fn=predict,
outputs=[output_label],
)
# Conectar eventos
submit_btn.click(
fn=predict,
inputs=input_image,
outputs=output_label,
api_name="predict"
)
clear_btn.click(
fn=lambda: [None, None],
inputs=None,
outputs=[input_image, output_label]
)
if __name__ == "__main__":
app.launch()