Ariel013 commited on
Commit
a6e51fc
·
verified ·
1 Parent(s): 4719d8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -62
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import gradio as gr
2
- import matplotlib.pyplot as plt
3
  from PIL import Image
4
  import numpy as np
5
  import tensorflow as tf
6
- import pandas as pd
7
 
8
  # Etiquetas en español
9
  cifar10_labels = [
@@ -11,75 +9,61 @@ cifar10_labels = [
11
  'perro', 'rana', 'caballo', 'barco', 'camión'
12
  ]
13
 
14
- # Cargar el modelo al iniciar la app
15
  model = tf.keras.models.load_model('my_model.keras')
16
 
17
  def preprocess_image(image):
18
- """Preprocesa la imagen para el modelo"""
19
- img = image.resize((32, 32)).convert('RGB') # Forzar formato RGB
20
- img = np.array(img).astype('float32') / 255 # Normalizar
21
- return img.reshape(1, 32, 32, 3)
22
 
23
  def predict(image):
24
- """Realiza la predicción y devuelve los resultados"""
25
  processed_img = preprocess_image(image)
26
- preds = model.predict(processed_img)[0]
27
-
28
- # Crear gráfico de barras profesional
29
- df = pd.DataFrame({
30
- 'Clase': cifar10_labels,
31
- 'Probabilidad': preds
32
- }).sort_values('Probabilidad', ascending=False)
33
-
34
- fig, ax = plt.subplots(figsize=(8, 5))
35
- bars = ax.barh(df['Clase'], df['Probabilidad'], color='skyblue')
36
- ax.set_xlim(0, 1)
37
- ax.set_title('Distribución de Probabilidades', pad=20)
38
- ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0%}'))
39
-
40
- # Añadir etiquetas de porcentaje
41
- for bar in bars:
42
- width = bar.get_width()
43
- ax.text(width + 0.03, bar.get_y() + bar.get_height()/2,
44
- f'{width:.1%}',
45
- ha='left', va='center')
46
-
47
- plt.tight_layout()
48
-
49
- # Devolver resultados en formato correcto
50
- return {cifar10_labels[i]: float(preds[i]) for i in range(10)}, fig
51
 
52
- # Configuración de la interfaz
53
- title = "Clasificador CIFAR-10 ✈️🚗"
54
- description = "Sube una imagen para clasificarla en una de las 10 categorías del dataset CIFAR-10"
55
  examples = [
56
- ['ejemplo_avion.jpg'], # avión
57
- ['ejemplo_auto.jpg'], # automóvil
58
- ['ejemplo_pajaro.jpg'], # pájaro
59
- ['ejemplo_gato.jpg'], # gato
60
- ['ejemplo_venado.jpg'], # venado
61
- ['ejemplo_perro.jpg'], # perro
62
- ['ejemplo_rana.jpg'], # rana
63
- ['ejemplo_caballo.jpg'], # caballo
64
- ['ejemplo_barco.jpg'], # barco
65
- ['ejemplo_camion.jpg'] # camión
66
  ]
67
 
68
- # Crear la interfaz Gradio
69
- interface = gr.Interface(
70
- fn=predict,
71
- inputs=gr.Image(type="pil", label="Imagen de entrada"),
72
- outputs=[
73
- gr.Label(num_top_classes=3, label="Top 3 Predicciones"),
74
- gr.Plot(label="Distribución Completa")
75
- ],
76
- title=title,
77
- description=description,
78
- examples=examples,
79
- theme=gr.themes.Soft(),
80
- allow_flagging="never"
81
- )
 
 
 
 
 
 
 
 
82
 
83
- # Lanzar la aplicación
84
  if __name__ == "__main__":
85
- interface.launch()
 
1
  import gradio as gr
 
2
  from PIL import Image
3
  import numpy as np
4
  import tensorflow as tf
 
5
 
6
  # Etiquetas en español
7
  cifar10_labels = [
 
9
  'perro', 'rana', 'caballo', 'barco', 'camión'
10
  ]
11
 
12
+ # Cargar el modelo
13
  model = tf.keras.models.load_model('my_model.keras')
14
 
15
  def preprocess_image(image):
16
+ """Preprocesado de imagen"""
17
+ img = image.resize((32, 32)).convert('RGB')
18
+ return np.array(img).astype('float32') / 255
 
19
 
20
  def predict(image):
21
+ """Realizar predicción"""
22
  processed_img = preprocess_image(image)
23
+ preds = model.predict(np.expand_dims(processed_img, axis=0))[0]
24
+ return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)}
25
+
26
+ # Configurar ejemplos con etiquetas
27
+ dataset_info = "**Este dataset incluye las siguientes 10 categorías:**\n" + "\n".join(
28
+ [f"- {label.capitalize()}" for label in cifar10_labels]
29
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
31
  examples = [
32
+ ["ejemplos/avion.jpg", "avión"],
33
+ ["ejemplos/automovil.jpg", "automóvil"],
34
+ ["ejemplos/pajaro.jpg", "pájaro"],
35
+ ["ejemplos/gato.jpg", "gato"],
36
+ ["ejemplos/venado.jpg", "venado"],
37
+ ["ejemplos/perro.jpg", "perro"],
38
+ ["ejemplos/rana.jpg", "rana"],
39
+ ["ejemplos/caballo.jpg", "caballo"],
40
+ ["ejemplos/barco.jpg", "barco"],
41
+ ["ejemplos/camion.jpg", "camión"]
42
  ]
43
 
44
+ # Construir interfaz
45
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
46
+ gr.Markdown("# Clasificador CIFAR-10 ✈️🚗")
47
+ gr.Markdown("Sube una imagen o prueba con nuestros ejemplos:")
48
+
49
+ with gr.Row():
50
+ with gr.Column():
51
+ input_image = gr.Image(type="pil", label="Imagen de entrada")
52
+ submit_btn = gr.Button("Clasificar")
53
+
54
+ with gr.Column():
55
+ output_label = gr.Label(label="Predicciones", num_top_classes=10)
56
+
57
+ gr.Markdown(dataset_info)
58
+
59
+ # Sección de ejemplos con etiquetas
60
+ gr.Examples(
61
+ examples=examples,
62
+ inputs=[input_image],
63
+ label="Ejemplos del Dataset",
64
+ examples_per_page=5
65
+ )
66
 
67
+ # Lanzar aplicación
68
  if __name__ == "__main__":
69
+ app.launch()