Ariel013 commited on
Commit
8713743
·
verified ·
1 Parent(s): 00c4254

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -8,13 +8,11 @@ import os
8
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
10
 
11
- # Configuración inicial: solo 5 clases
12
- cifar10_labels = [
13
- 'avión', 'automóvil', 'pájaro', 'gato', 'venado'
14
- ]
15
 
16
- # Cargar el modelo (asegúrate de que esté entrenado para 5 clases)
17
- model = tf.keras.models.load_model('modelo_completo')
18
 
19
  def preprocess_image(image):
20
  """Preprocesado de imagen para el modelo"""
@@ -22,17 +20,17 @@ def preprocess_image(image):
22
  return np.array(img).astype('float32') / 255
23
 
24
  def predict(image):
25
- """Realizar predicción y formatear resultados para 5 clases"""
26
  if image is None:
27
  raise gr.Error("¡Por favor sube una imagen o toma una foto!")
28
 
29
  processed_img = preprocess_image(image)
30
  preds = model.predict(np.expand_dims(processed_img, axis=0))[0]
31
- # Si el modelo devuelve más de 5 predicciones, se limitan a las primeras 5
32
- preds = preds[:5]
33
- return {label: float(pred) for label, pred in zip(cifar10_labels, preds)}
34
 
35
- # Configurar ejemplos: solo para las 5 clases
 
 
 
36
  examples = [
37
  ["ejemplos/avion.jpg"],
38
  ["ejemplos/automovil.jpg"],
@@ -41,14 +39,14 @@ examples = [
41
  ["ejemplos/venado.jpg"]
42
  ]
43
 
44
- # Construir la interfaz con Gradio
45
  with gr.Blocks(theme=gr.themes.Soft(), css="""
46
  .examples-grid {display: flex !important; flex-direction: column; gap: 1rem}
47
  .examples-row {display: flex !important; gap: 1rem; justify-content: center}
48
  """) as app:
49
 
50
- gr.Markdown("# 📷 Clasificador CIFAR-10 by Aryy :3")
51
-
52
  with gr.Row():
53
  with gr.Column():
54
  input_image = gr.Image(
@@ -67,8 +65,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css="""
67
  num_top_classes=3,
68
  show_label=True
69
  )
70
-
71
- # Sección de ejemplos con interacción (solo 1 fila para 5 clases)
72
  gr.Markdown("## Ejemplos de categorías")
73
  with gr.Column(elem_classes=["examples-grid"]):
74
  with gr.Row(elem_classes=["examples-row"]):
@@ -81,7 +79,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css="""
81
  fn=predict,
82
  outputs=[output_label],
83
  )
84
-
85
  # Conectar eventos
86
  submit_btn.click(
87
  fn=predict,
 
8
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
10
 
11
+ # Etiquetas de las clases (Solo 5 categorías)
12
+ cifar10_labels = ['avión', 'automóvil', 'pájaro', 'gato', 'venado']
 
 
13
 
14
+ # Cargar el modelo correctamente desde un archivo .h5
15
+ model = tf.keras.models.load_model('modelo_completo.h5')
16
 
17
  def preprocess_image(image):
18
  """Preprocesado de imagen para el modelo"""
 
20
  return np.array(img).astype('float32') / 255
21
 
22
  def predict(image):
23
+ """Realizar predicción y formatear resultados"""
24
  if image is None:
25
  raise gr.Error("¡Por favor sube una imagen o toma una foto!")
26
 
27
  processed_img = preprocess_image(image)
28
  preds = model.predict(np.expand_dims(processed_img, axis=0))[0]
 
 
 
29
 
30
+ # Solo devolver predicciones para las 5 clases que hemos seleccionado
31
+ return {label: float(preds[i]) for i, label in enumerate(cifar10_labels)}
32
+
33
+ # Configurar ejemplos (Solo para 5 clases)
34
  examples = [
35
  ["ejemplos/avion.jpg"],
36
  ["ejemplos/automovil.jpg"],
 
39
  ["ejemplos/venado.jpg"]
40
  ]
41
 
42
+ # Construir interfaz
43
  with gr.Blocks(theme=gr.themes.Soft(), css="""
44
  .examples-grid {display: flex !important; flex-direction: column; gap: 1rem}
45
  .examples-row {display: flex !important; gap: 1rem; justify-content: center}
46
  """) as app:
47
 
48
+ gr.Markdown("# 📷 Clasificador CIFAR-10 (5 Clases)")
49
+
50
  with gr.Row():
51
  with gr.Column():
52
  input_image = gr.Image(
 
65
  num_top_classes=3,
66
  show_label=True
67
  )
68
+
69
+ # Sección de ejemplos con interacción
70
  gr.Markdown("## Ejemplos de categorías")
71
  with gr.Column(elem_classes=["examples-grid"]):
72
  with gr.Row(elem_classes=["examples-row"]):
 
79
  fn=predict,
80
  outputs=[output_label],
81
  )
82
+
83
  # Conectar eventos
84
  submit_btn.click(
85
  fn=predict,