import gradio as gr from PIL import Image import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # Definir los nombres de las clases de CIFAR-10 class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer'] # Cargar el modelo preentrenado (asegúrate de que esté disponible en el Space) model = tf.keras.models.load_model('final_model.keras') # Definir las medias y desviaciones estándar para la normalización mean = np.array([0.4914, 0.4822, 0.4465]) std = np.array([0.2470, 0.2435, 0.2616]) # Función para cargar y preprocesar la imagen def load_and_preprocess_image(image): img = Image.open(image) # Abrir la imagen img = img.resize((32, 32)) # Redimensionar la imagen a 32x32 píxeles img = np.array(img) # Convertir la imagen en un array de NumPy # Asegurarse de que la imagen tenga 3 canales (RGB) if img.ndim == 2: # Si es una imagen en escala de grises img = np.stack([img] * 3, axis=-1) img = img.astype('float32') # Asegurarse de que la imagen sea de tipo float32 img = (img / 255.0 - mean) / std # Normalizar la imagen # Añadir una dimensión adicional para representar el batch (1 imagen) img = np.expand_dims(img, axis=0) return img # Función para predecir una imagen cargada def predict_image(image): # Cargar y preprocesar la imagen img = load_and_preprocess_image(image) # Hacer la predicción prediction = model.predict(img) # Obtener las probabilidades de todas las clases probabilities = tf.nn.softmax(prediction[0]).numpy() # Crear un diccionario con las probabilidades de cada clase results = {class_names[i]: float(probabilities[i]) * 100 for i in range(len(class_names))} # Formatear el resultado como una cadena de texto result_str = "\n".join([f"{class_name}: {prob:.2f}%" for class_name, prob in results.items()]) return result_str # Crear la interfaz de Gradio interface = gr.Interface( fn=predict_image, # Función que realiza la predicción inputs=gr.Image(type="filepath", label="Sube una imagen"), # Entrada: imagen outputs=gr.Textbox(label="Probabilidades"), # Salida: texto con las probabilidades title="Clasificador de imágenes CIFAR-10", description="Sube una imagen de un avion, vehiculo, pajaro, gato, cirvo para ver las probabilidades de cada clase." ) # Lanzar la aplicación interface.launch()