hellooooo / app.py
lolzysiu's picture
Rename streamlit_app.py to app.py
266d836 verified
import streamlit as st
import torch
from models_conv import ConvGenerator
import numpy as np
# Set page config
st.set_page_config(page_title="MNIST Digit Generator", layout="centered")
# Load the trained generator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = ConvGenerator().to(device)
# Load the latest checkpoint
checkpoint = torch.load('checkpoints/wgan_checkpoint_epoch_190.pt', map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()
# Title and description
st.title("MNIST Digit Generator")
st.write("Generate MNIST-like digits using a Wasserstein GAN")
# Sidebar controls
with st.sidebar:
st.header("Generation Controls")
noise_seed = st.slider("Noise Seed", 1, 1000, 42)
num_images = st.slider("Number of Images", 1, 16, 4)
generate_button = st.button("Generate New Images")
# Function to generate images
def generate_images(noise_seed, num_images):
torch.manual_seed(noise_seed)
z = torch.randn(num_images, 100).to(device)
with torch.no_grad():
imgs = generator(z)
# Convert to numpy array and adjust format
imgs = imgs.cpu().numpy()
# Normalize to [0,1] range
imgs = (imgs + 1) / 2
return imgs
# Main content
if generate_button or 'generated_images' not in st.session_state:
images = generate_images(noise_seed, num_images)
st.session_state.generated_images = images
else:
images = st.session_state.generated_images
# Display images in a grid
cols = st.columns(min(4, num_images))
for idx, img in enumerate(images):
with cols[idx % min(4, num_images)]:
st.image(img.squeeze(), caption=f"Generated Image {idx+1}", use_column_width=True)
# Add information about the model
st.markdown("---")
st.markdown("### About the Model")
st.write("""
This is a Wasserstein GAN (WGAN) model trained on the MNIST dataset.
The model generates 28x28 grayscale images that resemble handwritten digits.
""")