File size: 3,472 Bytes
98d33da
 
 
 
 
 
 
 
 
 
 
 
 
ae739be
98d33da
 
ae739be
98d33da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa1600
 
98d33da
 
 
 
 
 
 
 
 
65cb2b9
98d33da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from PIL import Image
import numpy as np
import tensorflow as tf
import os

# Configurar variables de entorno para reducir advertencias
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Configuraci贸n inicial
cifar10_labels = [
    'avi贸n', 'autom贸vil', 'p谩jaro', 'gato', 'venado',
    'perro', 'rana'
]

model = tf.keras.models.load_model('my_model.keras')

def preprocess_image(image):
    """Preprocesado de imagen para el modelo"""
    img = image.resize((32, 32)).convert('RGB')
    return np.array(img).astype('float32') / 255

def predict(image):
    """Realizar predicci贸n y formatear resultados"""
    if image is None:
        raise gr.Error("隆Por favor sube una imagen o toma una foto!")
    
    processed_img = preprocess_image(image)
    preds = model.predict(np.expand_dims(processed_img, axis=0))[0]
    return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)}

# Configurar ejemplos
examples = [
    ["ejemplos/avion.jpg"],
    ["ejemplos/automovil.jpg"],
    ["ejemplos/pajaro.jpg"],
    ["ejemplos/gato.jpg"],
    ["ejemplos/venado.jpg"],
    ["ejemplos/perro.jpg"],  # A帽adido perro
    ["ejemplos/rana.jpg"],   # A帽adido rana

]

# Construir interfaz
with gr.Blocks(theme=gr.themes.Soft(), css="""
    .examples-grid {display: flex !important; flex-direction: column; gap: 1rem}
    .examples-row {display: flex !important; gap: 1rem; justify-content: center}
    """) as app:
    
    gr.Markdown("# CIFAR-10 con 7 clases")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                sources=["upload", "webcam", "clipboard"],
                type="pil",
                label="Entrada de imagen",
                height=250
            )
            with gr.Row():
                submit_btn = gr.Button("Predecir", variant="primary")
                clear_btn = gr.Button("Limpiar")
        
        with gr.Column():
            output_label = gr.Label(
                label="Resultados",
                num_top_classes=3,
                show_label=True
            )
    
    # Secci贸n de ejemplos con interacci贸n
    gr.Markdown("## Ejemplos de categor铆as")
    with gr.Column(elem_classes=["examples-grid"]):
        # Primera fila
        with gr.Row(elem_classes=["examples-row"]):
            for example, label in zip(examples[:5], cifar10_labels[:5]):
                gr.Examples(
                    examples=example,
                    inputs=[input_image],
                    label=label.capitalize(),
                    examples_per_page=1,
                    fn=predict,
                    outputs=[output_label],
                )
        # Segunda fila
        with gr.Row(elem_classes=["examples-row"]):
            for example, label in zip(examples[5:], cifar10_labels[5:]):
                gr.Examples(
                    examples=example,
                    inputs=[input_image],
                    label=label.capitalize(),
                    examples_per_page=1,
                    fn=predict,
                    outputs=[output_label],
                )

    # Conectar eventos
    submit_btn.click(
        fn=predict,
        inputs=input_image,
        outputs=output_label,
        api_name="predict"
    )
    
    clear_btn.click(
        fn=lambda: [None, None],
        inputs=None,
        outputs=[input_image, output_label]
    )

if __name__ == "__main__":
    app.launch()