File size: 3,557 Bytes
7d43055
 
 
 
02ee94f
 
 
 
 
7d43055
61c0af3
 
 
 
 
7d43055
61c0af3
7d43055
 
40472a6
a6e51fc
 
7d43055
 
8713743
91c0b3e
 
 
7d43055
a6e51fc
8713743
 
61c0af3
7d43055
91c0b3e
 
 
 
61c0af3
c23f320
 
 
 
 
7d43055
 
8713743
a46ba42
 
 
 
 
c23f320
61c0af3
a6e51fc
 
40472a6
 
 
91c0b3e
 
40472a6
91c0b3e
 
 
a6e51fc
 
40472a6
 
91c0b3e
40472a6
 
61c0af3
8713743
91c0b3e
a46ba42
61c0af3
 
 
 
 
 
 
 
 
 
 
 
a46ba42
61c0af3
a46ba42
 
 
 
 
 
 
 
8713743
91c0b3e
 
 
 
 
 
 
 
 
 
 
 
 
7d43055
4719d8c
61c0af3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
import os

# Configurar variables de entorno para reducir advertencias
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Configuraci贸n inicial
cifar10_labels = [
    'avi贸n', 'autom贸vil', 'p谩jaro', 'gato', 'venado',
    'perro', 'rana', 'caballo', 'barco', 'cami贸n'
]

model = tf.keras.models.load_model('my_model.keras')

def preprocess_image(image):
    """Preprocesado de imagen para el modelo"""
    img = image.resize((32, 32)).convert('RGB')
    return np.array(img).astype('float32') / 255

def predict(image):
    """Realizar predicci贸n y formatear resultados"""
    if image is None:
        raise gr.Error("隆Por favor sube una imagen o toma una foto!")
    
    processed_img = preprocess_image(image)
    preds = model.predict(np.expand_dims(processed_img, axis=0))[0]
    return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)}

# Configurar ejemplos
examples = [
    ["ejemplos/avion.jpg"],
    ["ejemplos/automovil.jpg"],
    ["ejemplos/pajaro.jpg"],
    ["ejemplos/gato.jpg"],
    ["ejemplos/venado.jpg"],
    ["ejemplos/perro.jpg"],
    ["ejemplos/rana.jpg"],
    ["ejemplos/caballo.jpg"],
    ["ejemplos/barco.jpg"],
    ["ejemplos/camion.jpg"]
]

# Construir interfaz
with gr.Blocks(theme=gr.themes.Soft(), css="""
    .examples-grid {display: flex !important; flex-direction: column; gap: 1rem}
    .examples-row {display: flex !important; gap: 1rem; justify-content: center}
    """) as app:
    
    gr.Markdown("# 馃摲 Clasificador CIFAR-10")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                sources=["upload", "webcam", "clipboard"],
                type="pil",
                label="Entrada de imagen",
                height=250
            )
            with gr.Row():
                submit_btn = gr.Button("Predecir", variant="primary")
                clear_btn = gr.Button("Limpiar")
        
        with gr.Column():
            output_label = gr.Label(
                label="Resultados",
                num_top_classes=3,
                show_label=True
            )
    
    # Secci贸n de ejemplos con interacci贸n
    gr.Markdown("## Ejemplos de categor铆as")
    with gr.Column(elem_classes=["examples-grid"]):
        # Primera fila
        with gr.Row(elem_classes=["examples-row"]):
            for example, label in zip(examples[:5], cifar10_labels[:5]):
                gr.Examples(
                    examples=example,
                    inputs=[input_image],
                    label=label.capitalize(),
                    examples_per_page=1,
                    fn=predict,
                    outputs=[output_label],
                )
        # Segunda fila
        with gr.Row(elem_classes=["examples-row"]):
            for example, label in zip(examples[5:], cifar10_labels[5:]):
                gr.Examples(
                    examples=example,
                    inputs=[input_image],
                    label=label.capitalize(),
                    examples_per_page=1,
                    fn=predict,
                    outputs=[output_label],
                )

    # Conectar eventos
    submit_btn.click(
        fn=predict,
        inputs=input_image,
        outputs=output_label,
        api_name="predict"
    )
    
    clear_btn.click(
        fn=lambda: [None, None],
        inputs=None,
        outputs=[input_image, output_label]
    )

if __name__ == "__main__":
    app.launch()