import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model import matplotlib.pyplot as plt from tensorflow.keras.metrics import Metric import tensorflow as tf from tensorflow.keras.saving import register_keras_serializable # Cargar el modelo entrenado en formato .keras # Forzar el uso de CPU tf.config.set_visible_devices([], 'GPU') import tensorflow as tf import gradio as gr # Registrar la métrica personalizada @tf.keras.utils.register_keras_serializable() class F1Score(tf.keras.metrics.Metric): def __init__(self, name='f1_score', **kwargs): super(F1Score, self).__init__(name=name, **kwargs) self.precision = tf.keras.metrics.Precision() self.recall = tf.keras.metrics.Recall() def update_state(self, y_true, y_pred, sample_weight=None): self.precision.update_state(y_true, y_pred, sample_weight) self.recall.update_state(y_true, y_pred, sample_weight) def result(self): precision = self.precision.result() recall = self.recall.result() return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon())) def reset_states(self): self.precision.reset_states() self.recall.reset_states() # Cargar el modelo con la métrica personalizada model = tf.keras.models.load_model("my_model.keras", custom_objects={"F1Score": F1Score}) # Etiquetas en español para las clases de CIFAR-10 class_names_es = [ "avión", # airplane "automóvil", # automobile "pájaro", # bird "gato", # cat "ciervo", # deer ] # Función para preprocesar la imagen from PIL import Image def preprocess_image(image): if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convertir de numpy a PIL si es necesario image = image.resize((32, 32)) # Redimensionar a 32x32 image = np.array(image) / 255.0 # Normalizar image = np.expand_dims(image, axis=0) # Añadir dimensión del lote return image # Función para hacer la predicción def predict_image(image): try: processed_image = preprocess_image(image) predictions = model.predict(processed_image, verbose=0)[0] except Exception as e: return "Error en la predicción", None, str(e) # Crear un gráfico de barras plt.figure(figsize=(10, 5)) plt.bar(class_names_es, predictions, color='skyblue') plt.title("Predicciones por clase") plt.xlabel("Clases") plt.ylabel("Probabilidad") plt.ylim(0, 1) plt.xticks(rotation=45) plt.tight_layout() # Guardar el gráfico en un archivo temporal plot_path = "predicciones.png" plt.savefig(plot_path) plt.close() # Crear una lista de clases con sus probabilidades class_probabilities = [ f"{class_name}: {prob:.2%}" for class_name, prob in zip(class_names_es, predictions) ] class_probabilities_str = "\n".join(class_probabilities) # Devolver la clase predicha, el gráfico y la lista de probabilidades predicted_class = class_names_es[np.argmax(predictions)] return predicted_class, plot_path, class_probabilities_str # Interfaz de Gradio with gr.Blocks() as demo: gr.Markdown("# Predicción de imágenes con CIFAR-10") gr.Markdown("Sube una imagen, pégala desde el portapapeles o usa tu cámara para predecir la clase.") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Subir imagen", sources=["upload", "clipboard", "webcam"]) predict_button = gr.Button("Predecir") with gr.Column(): class_output = gr.Textbox(label="Clase predicha") plot_output = gr.Image(label="Gráfico de predicciones") probabilities_output = gr.Textbox(label="Probabilidades por clase", lines=10) # Conectar el botón a la función de predicción predict_button.click( fn=predict_image, inputs=image_input, outputs=[class_output, plot_output, probabilities_output] ) # Lanzar la interfaz demo.launch()