Ariel013 commited on
Commit
4719d8c
verified
1 Parent(s): 8a46cb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -14
app.py CHANGED
@@ -1,50 +1,85 @@
1
  import gradio as gr
 
2
  from PIL import Image
3
  import numpy as np
4
  import tensorflow as tf
 
5
 
6
  # Etiquetas en espa帽ol
7
- cifar10_labels = np.array([
8
  'avi贸n', 'autom贸vil', 'p谩jaro', 'gato', 'venado',
9
  'perro', 'rana', 'caballo', 'barco', 'cami贸n'
10
- ])
11
 
12
  # Cargar el modelo al iniciar la app
13
  model = tf.keras.models.load_model('my_model.keras')
14
 
15
  def preprocess_image(image):
16
  """Preprocesa la imagen para el modelo"""
17
- img = image.resize((32, 32)) # Redimensionar
18
- img = np.array(img) # Convertir a numpy array
19
- img = img.astype('float32') / 255 # Normalizar
20
- return img.reshape(1, 32, 32, 3) # Reformatear para el modelo
21
 
22
  def predict(image):
23
  """Realiza la predicci贸n y devuelve los resultados"""
24
  processed_img = preprocess_image(image)
25
  preds = model.predict(processed_img)[0]
26
- return {cifar10_labels[i]: float(preds[i]) for i in range(10)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Configuraci贸n de la interfaz
29
  title = "Clasificador CIFAR-10 鉁堬笍馃殫"
30
  description = "Sube una imagen para clasificarla en una de las 10 categor铆as del dataset CIFAR-10"
31
  examples = [
32
- 'ejemplo_avion.jpg',
33
- 'ejemplo_auto.jpg',
34
- 'ejemplo_pajaro.jpg',
35
- 'ejemplo_gato.jpg'
 
 
 
 
 
 
36
  ]
37
 
38
  # Crear la interfaz Gradio
39
  interface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil", label="Imagen de entrada"),
42
- outputs=gr.Label(num_top_classes=3, label="Predicciones"),
 
 
 
43
  title=title,
44
  description=description,
45
  examples=examples,
46
- theme=gr.themes.Soft()
 
47
  )
48
 
49
  # Lanzar la aplicaci贸n
50
- interface.launch()
 
 
1
  import gradio as gr
2
+ import matplotlib.pyplot as plt
3
  from PIL import Image
4
  import numpy as np
5
  import tensorflow as tf
6
+ import pandas as pd
7
 
8
  # Etiquetas en espa帽ol
9
+ cifar10_labels = [
10
  'avi贸n', 'autom贸vil', 'p谩jaro', 'gato', 'venado',
11
  'perro', 'rana', 'caballo', 'barco', 'cami贸n'
12
+ ]
13
 
14
  # Cargar el modelo al iniciar la app
15
  model = tf.keras.models.load_model('my_model.keras')
16
 
17
  def preprocess_image(image):
18
  """Preprocesa la imagen para el modelo"""
19
+ img = image.resize((32, 32)).convert('RGB') # Forzar formato RGB
20
+ img = np.array(img).astype('float32') / 255 # Normalizar
21
+ return img.reshape(1, 32, 32, 3)
 
22
 
23
  def predict(image):
24
  """Realiza la predicci贸n y devuelve los resultados"""
25
  processed_img = preprocess_image(image)
26
  preds = model.predict(processed_img)[0]
27
+
28
+ # Crear gr谩fico de barras profesional
29
+ df = pd.DataFrame({
30
+ 'Clase': cifar10_labels,
31
+ 'Probabilidad': preds
32
+ }).sort_values('Probabilidad', ascending=False)
33
+
34
+ fig, ax = plt.subplots(figsize=(8, 5))
35
+ bars = ax.barh(df['Clase'], df['Probabilidad'], color='skyblue')
36
+ ax.set_xlim(0, 1)
37
+ ax.set_title('Distribuci贸n de Probabilidades', pad=20)
38
+ ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0%}'))
39
+
40
+ # A帽adir etiquetas de porcentaje
41
+ for bar in bars:
42
+ width = bar.get_width()
43
+ ax.text(width + 0.03, bar.get_y() + bar.get_height()/2,
44
+ f'{width:.1%}',
45
+ ha='left', va='center')
46
+
47
+ plt.tight_layout()
48
+
49
+ # Devolver resultados en formato correcto
50
+ return {cifar10_labels[i]: float(preds[i]) for i in range(10)}, fig
51
 
52
  # Configuraci贸n de la interfaz
53
  title = "Clasificador CIFAR-10 鉁堬笍馃殫"
54
  description = "Sube una imagen para clasificarla en una de las 10 categor铆as del dataset CIFAR-10"
55
  examples = [
56
+ ['ejemplo_avion.jpg'], # avi贸n
57
+ ['ejemplo_auto.jpg'], # autom贸vil
58
+ ['ejemplo_pajaro.jpg'], # p谩jaro
59
+ ['ejemplo_gato.jpg'], # gato
60
+ ['ejemplo_venado.jpg'], # venado
61
+ ['ejemplo_perro.jpg'], # perro
62
+ ['ejemplo_rana.jpg'], # rana
63
+ ['ejemplo_caballo.jpg'], # caballo
64
+ ['ejemplo_barco.jpg'], # barco
65
+ ['ejemplo_camion.jpg'] # cami贸n
66
  ]
67
 
68
  # Crear la interfaz Gradio
69
  interface = gr.Interface(
70
  fn=predict,
71
  inputs=gr.Image(type="pil", label="Imagen de entrada"),
72
+ outputs=[
73
+ gr.Label(num_top_classes=3, label="Top 3 Predicciones"),
74
+ gr.Plot(label="Distribuci贸n Completa")
75
+ ],
76
  title=title,
77
  description=description,
78
  examples=examples,
79
+ theme=gr.themes.Soft(),
80
+ allow_flagging="never"
81
  )
82
 
83
  # Lanzar la aplicaci贸n
84
+ if __name__ == "__main__":
85
+ interface.launch()