import streamlit as st import numpy as np from tensorflow.keras.models import load_model from PIL import Image # Cargar los modelos model_cnn = load_model('MNIST_conv.h5') # Modelo convolucional model_mlp = load_model('MNIST.h5') # Modelo no convolucional model_cnn_v2 = load_model('mnist_conv (1).h5') # Diccionario de etiquetas para Fashion-MNIST class_labels = { 0: "Cero", 1: "Uno", 2: "Dos", 3: "Tres", 4: "Cuatro", 5: "Cinco", 6: "Seis", 7: "Siete", 8: "Ocho", 9: "Nueve" } # Función para preparar la imagen def prepare_image(image): image = image.convert('L') # Convertir a escala de grises image = image.resize((28, 28)) # Redimensionar a 28x28 image_array = np.array(image) # Convertir a array numpy image_array = image_array / 255.0 # Normalizar a [0, 1] image_array = image_array.reshape(1, 28, 28, 1) # Cambiar la forma a (1, 28, 28, 1) return image_array # Título de la aplicación st.title("Predicción de Moda con Modelos") # Selección del modelo model_choice = st.selectbox("Selecciona el modelo:", ("Modelo Convolucional", "Modelo No Convolucional", "Modelo Convolucional (10 epocas mas (30))")) # Subida de imagen uploaded_file = st.file_uploader("Sube una imagen de 28x28 en escala de grises", type=["png", "jpg", "jpeg"]) # Botón para realizar la predicción if uploaded_file is not None: # Mostrar la imagen subida image = Image.open(uploaded_file) st.image(image, caption='Imagen subida', use_column_width=True) # Preparar la imagen para la predicción prepared_image = prepare_image(image) # Realizar la predicción según el modelo seleccionado if model_choice == "Modelo Convolucional": prediction = model_cnn.predict(prepared_image) elif model_choice == "Modelo No Convolucional": prediction = model_mlp.predict(prepared_image) else: prediction = model_cnn_v2.predict(prepared_image) predicted_class = np.argmax(prediction, axis=1)[0] # Obtener la clase con mayor probabilidad # Obtener la etiqueta correspondiente predicted_label = class_labels[predicted_class] # Mostrar el resultado st.success(f"La predicción es: {predicted_label}")