Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| import matplotlib.pyplot as plt | |
| from tensorflow.keras.metrics import Metric | |
| import tensorflow as tf | |
| from tensorflow.keras.saving import register_keras_serializable | |
| # Cargar el modelo entrenado en formato .keras | |
| # Forzar el uso de CPU | |
| tf.config.set_visible_devices([], 'GPU') | |
| import tensorflow as tf | |
| import gradio as gr | |
| # Registrar la m茅trica personalizada | |
| class F1Score(tf.keras.metrics.Metric): | |
| def __init__(self, name='f1_score', **kwargs): | |
| super(F1Score, self).__init__(name=name, **kwargs) | |
| self.precision = tf.keras.metrics.Precision() | |
| self.recall = tf.keras.metrics.Recall() | |
| def update_state(self, y_true, y_pred, sample_weight=None): | |
| self.precision.update_state(y_true, y_pred, sample_weight) | |
| self.recall.update_state(y_true, y_pred, sample_weight) | |
| def result(self): | |
| precision = self.precision.result() | |
| recall = self.recall.result() | |
| return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon())) | |
| def reset_states(self): | |
| self.precision.reset_states() | |
| self.recall.reset_states() | |
| # Cargar el modelo con la m茅trica personalizada | |
| model = tf.keras.models.load_model("my_model.keras", custom_objects={"F1Score": F1Score}) | |
| # Etiquetas en espa帽ol para las clases de CIFAR-10 | |
| class_names_es = [ | |
| "avi贸n", # airplane | |
| "autom贸vil", # automobile | |
| "p谩jaro", # bird | |
| "gato", # cat | |
| "ciervo", # deer | |
| ] | |
| # Funci贸n para preprocesar la imagen | |
| from PIL import Image | |
| def preprocess_image(image): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) # Convertir de numpy a PIL si es necesario | |
| image = image.resize((32, 32)) # Redimensionar a 32x32 | |
| image = np.array(image) / 255.0 # Normalizar | |
| image = np.expand_dims(image, axis=0) # A帽adir dimensi贸n del lote | |
| return image | |
| # Funci贸n para hacer la predicci贸n | |
| def predict_image(image): | |
| try: | |
| processed_image = preprocess_image(image) | |
| predictions = model.predict(processed_image, verbose=0)[0] | |
| except Exception as e: | |
| return "Error en la predicci贸n", None, str(e) | |
| # Crear un gr谩fico de barras | |
| plt.figure(figsize=(10, 5)) | |
| plt.bar(class_names_es, predictions, color='skyblue') | |
| plt.title("Predicciones por clase") | |
| plt.xlabel("Clases") | |
| plt.ylabel("Probabilidad") | |
| plt.ylim(0, 1) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| # Guardar el gr谩fico en un archivo temporal | |
| plot_path = "predicciones.png" | |
| plt.savefig(plot_path) | |
| plt.close() | |
| # Crear una lista de clases con sus probabilidades | |
| class_probabilities = [ | |
| f"{class_name}: {prob:.2%}" | |
| for class_name, prob in zip(class_names_es, predictions) | |
| ] | |
| class_probabilities_str = "\n".join(class_probabilities) | |
| # Devolver la clase predicha, el gr谩fico y la lista de probabilidades | |
| predicted_class = class_names_es[np.argmax(predictions)] | |
| return predicted_class, plot_path, class_probabilities_str | |
| # Interfaz de Gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Predicci贸n de im谩genes con CIFAR-10") | |
| gr.Markdown("Sube una imagen, p茅gala desde el portapapeles o usa tu c谩mara para predecir la clase.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Subir imagen", sources=["upload", "clipboard", "webcam"]) | |
| predict_button = gr.Button("Predecir") | |
| with gr.Column(): | |
| class_output = gr.Textbox(label="Clase predicha") | |
| plot_output = gr.Image(label="Gr谩fico de predicciones") | |
| probabilities_output = gr.Textbox(label="Probabilidades por clase", lines=10) | |
| # Conectar el bot贸n a la funci贸n de predicci贸n | |
| predict_button.click( | |
| fn=predict_image, | |
| inputs=image_input, | |
| outputs=[class_output, plot_output, probabilities_output] | |
| ) | |
| # Lanzar la interfaz | |
| demo.launch() |