Hem345's picture
Update app.py
8692d12 verified
raw
history blame
2.64 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist
# Function to display image, latent representation, and reconstructed image
def display_reconstruction(index, autoencoder, encoder, x_test):
original = x_test[index]
latent_repr = encoder.predict(np.expand_dims(original, 0))[0]
reconstructed = autoencoder.predict(np.expand_dims(original, 0))[0]
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
# Display original image
axs[0].imshow(np.reshape(original, (28, 28)), cmap='gray')
axs[0].set_title('Original Image')
# Display latent representation as a bar chart
axs[1].bar(range(len(latent_repr)), latent_repr)
axs[1].set_title('Latent Representation')
# Display reconstructed image
axs[2].imshow(np.reshape(reconstructed, (28, 28)), cmap='gray')
axs[2].set_title('Reconstructed Image')
for ax in axs:
ax.axis('off')
st.pyplot(fig)
# Main Streamlit app
st.title("Autoencoder Training and Visualization")
# Button to trigger training
if st.button("Train Autoencoder"):
# Load and preprocess data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 784)) # Flatten to (None, 784)
x_test = np.reshape(x_test, (-1, 784))
# Define autoencoder architecture
input_img = keras.Input(shape=(784,))
encoded = layers.Dense(128, activation='relu')(input_img)
encoded = layers.Dense(64, activation='relu')(encoded)
latent_vector = layers.Dense(32, activation='relu')(encoded)
decoded = layers.Dense(64, activation='relu')(latent_vector)
decoded = layers.Dense(128, activation='relu')(decoded)
decoded = layers.Dense(784, activation='sigmoid')(decoded)
autoencoder = keras.Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# Train the autoencoder and display progress
with st.spinner("Training in progress..."):
autoencoder.fit(x_train, x_train, epochs=5, batch_size=128, validation_data=(x_test, x_test))
# Create encoder model
encoder = keras.Model(input_img, latent_vector)
# Input for image index to display
test_index = st.number_input("Enter an index (0-9999) to view an image from the test set:", min_value=0, max_value=9999)
# Button to display the reconstruction
if st.button("Display Reconstruction"):
display_reconstruction(test_index, autoencoder, encoder, x_test)