AETroyaB's picture
Update app.py
b510143 verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from tensorflow.keras.metrics import Metric
import tensorflow as tf
from tensorflow.keras.saving import register_keras_serializable
# Cargar el modelo entrenado en formato .keras
# Forzar el uso de CPU
tf.config.set_visible_devices([], 'GPU')
import tensorflow as tf
import gradio as gr
# Registrar la m茅trica personalizada
@tf.keras.utils.register_keras_serializable()
class F1Score(tf.keras.metrics.Metric):
def __init__(self, name='f1_score', **kwargs):
super(F1Score, self).__init__(name=name, **kwargs)
self.precision = tf.keras.metrics.Precision()
self.recall = tf.keras.metrics.Recall()
def update_state(self, y_true, y_pred, sample_weight=None):
self.precision.update_state(y_true, y_pred, sample_weight)
self.recall.update_state(y_true, y_pred, sample_weight)
def result(self):
precision = self.precision.result()
recall = self.recall.result()
return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))
def reset_states(self):
self.precision.reset_states()
self.recall.reset_states()
# Cargar el modelo con la m茅trica personalizada
model = tf.keras.models.load_model("my_model.keras", custom_objects={"F1Score": F1Score})
# Etiquetas en espa帽ol para las clases de CIFAR-10
class_names_es = [
"avi贸n", # airplane
"autom贸vil", # automobile
"p谩jaro", # bird
"gato", # cat
"ciervo", # deer
]
# Funci贸n para preprocesar la imagen
from PIL import Image
def preprocess_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image) # Convertir de numpy a PIL si es necesario
image = image.resize((32, 32)) # Redimensionar a 32x32
image = np.array(image) / 255.0 # Normalizar
image = np.expand_dims(image, axis=0) # A帽adir dimensi贸n del lote
return image
# Funci贸n para hacer la predicci贸n
def predict_image(image):
try:
processed_image = preprocess_image(image)
predictions = model.predict(processed_image, verbose=0)[0]
except Exception as e:
return "Error en la predicci贸n", None, str(e)
# Crear un gr谩fico de barras
plt.figure(figsize=(10, 5))
plt.bar(class_names_es, predictions, color='skyblue')
plt.title("Predicciones por clase")
plt.xlabel("Clases")
plt.ylabel("Probabilidad")
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.tight_layout()
# Guardar el gr谩fico en un archivo temporal
plot_path = "predicciones.png"
plt.savefig(plot_path)
plt.close()
# Crear una lista de clases con sus probabilidades
class_probabilities = [
f"{class_name}: {prob:.2%}"
for class_name, prob in zip(class_names_es, predictions)
]
class_probabilities_str = "\n".join(class_probabilities)
# Devolver la clase predicha, el gr谩fico y la lista de probabilidades
predicted_class = class_names_es[np.argmax(predictions)]
return predicted_class, plot_path, class_probabilities_str
# Interfaz de Gradio
with gr.Blocks() as demo:
gr.Markdown("# Predicci贸n de im谩genes con CIFAR-10")
gr.Markdown("Sube una imagen, p茅gala desde el portapapeles o usa tu c谩mara para predecir la clase.")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Subir imagen", sources=["upload", "clipboard", "webcam"])
predict_button = gr.Button("Predecir")
with gr.Column():
class_output = gr.Textbox(label="Clase predicha")
plot_output = gr.Image(label="Gr谩fico de predicciones")
probabilities_output = gr.Textbox(label="Probabilidades por clase", lines=10)
# Conectar el bot贸n a la funci贸n de predicci贸n
predict_button.click(
fn=predict_image,
inputs=image_input,
outputs=[class_output, plot_output, probabilities_output]
)
# Lanzar la interfaz
demo.launch()