|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import os |
|
|
|
|
|
|
|
|
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
|
|
|
|
|
cifar10_labels = [ |
|
|
'avi贸n', 'autom贸vil', 'p谩jaro', 'gato', 'venado', |
|
|
'perro', 'rana', 'caballo', 'barco', 'cami贸n' |
|
|
] |
|
|
|
|
|
model = tf.keras.models.load_model('my_model.keras') |
|
|
|
|
|
def preprocess_image(image): |
|
|
"""Preprocesado de imagen para el modelo""" |
|
|
img = image.resize((32, 32)).convert('RGB') |
|
|
return np.array(img).astype('float32') / 255 |
|
|
|
|
|
def predict(image): |
|
|
"""Realizar predicci贸n y formatear resultados""" |
|
|
if image is None: |
|
|
raise gr.Error("隆Por favor sube una imagen o toma una foto!") |
|
|
|
|
|
processed_img = preprocess_image(image) |
|
|
preds = model.predict(np.expand_dims(processed_img, axis=0))[0] |
|
|
return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)} |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["ejemplos/avion.jpg"], |
|
|
["ejemplos/automovil.jpg"], |
|
|
["ejemplos/pajaro.jpg"], |
|
|
["ejemplos/gato.jpg"], |
|
|
["ejemplos/venado.jpg"], |
|
|
["ejemplos/perro.jpg"], |
|
|
["ejemplos/rana.jpg"], |
|
|
["ejemplos/caballo.jpg"], |
|
|
["ejemplos/barco.jpg"], |
|
|
["ejemplos/camion.jpg"] |
|
|
] |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=""" |
|
|
.examples-grid {display: flex !important; flex-direction: column; gap: 1rem} |
|
|
.examples-row {display: flex !important; gap: 1rem; justify-content: center} |
|
|
""") as app: |
|
|
|
|
|
gr.Markdown("# 馃摲 Clasificador CIFAR-10") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
sources=["upload", "webcam", "clipboard"], |
|
|
type="pil", |
|
|
label="Entrada de imagen", |
|
|
height=250 |
|
|
) |
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Predecir", variant="primary") |
|
|
clear_btn = gr.Button("Limpiar") |
|
|
|
|
|
with gr.Column(): |
|
|
output_label = gr.Label( |
|
|
label="Resultados", |
|
|
num_top_classes=3, |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("## Ejemplos de categor铆as") |
|
|
with gr.Column(elem_classes=["examples-grid"]): |
|
|
|
|
|
with gr.Row(elem_classes=["examples-row"]): |
|
|
for example, label in zip(examples[:5], cifar10_labels[:5]): |
|
|
gr.Examples( |
|
|
examples=example, |
|
|
inputs=[input_image], |
|
|
label=label.capitalize(), |
|
|
examples_per_page=1, |
|
|
fn=predict, |
|
|
outputs=[output_label], |
|
|
) |
|
|
|
|
|
with gr.Row(elem_classes=["examples-row"]): |
|
|
for example, label in zip(examples[5:], cifar10_labels[5:]): |
|
|
gr.Examples( |
|
|
examples=example, |
|
|
inputs=[input_image], |
|
|
label=label.capitalize(), |
|
|
examples_per_page=1, |
|
|
fn=predict, |
|
|
outputs=[output_label], |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=predict, |
|
|
inputs=input_image, |
|
|
outputs=output_label, |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: [None, None], |
|
|
inputs=None, |
|
|
outputs=[input_image, output_label] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |