import streamlit as st import torch from models_conv import ConvGenerator import numpy as np # Set page config st.set_page_config(page_title="MNIST Digit Generator", layout="centered") # Load the trained generator device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = ConvGenerator().to(device) # Load the latest checkpoint checkpoint = torch.load('checkpoints/wgan_checkpoint_epoch_190.pt', map_location=device) generator.load_state_dict(checkpoint['generator_state_dict']) generator.eval() # Title and description st.title("MNIST Digit Generator") st.write("Generate MNIST-like digits using a Wasserstein GAN") # Sidebar controls with st.sidebar: st.header("Generation Controls") noise_seed = st.slider("Noise Seed", 1, 1000, 42) num_images = st.slider("Number of Images", 1, 16, 4) generate_button = st.button("Generate New Images") # Function to generate images def generate_images(noise_seed, num_images): torch.manual_seed(noise_seed) z = torch.randn(num_images, 100).to(device) with torch.no_grad(): imgs = generator(z) # Convert to numpy array and adjust format imgs = imgs.cpu().numpy() # Normalize to [0,1] range imgs = (imgs + 1) / 2 return imgs # Main content if generate_button or 'generated_images' not in st.session_state: images = generate_images(noise_seed, num_images) st.session_state.generated_images = images else: images = st.session_state.generated_images # Display images in a grid cols = st.columns(min(4, num_images)) for idx, img in enumerate(images): with cols[idx % min(4, num_images)]: st.image(img.squeeze(), caption=f"Generated Image {idx+1}", use_column_width=True) # Add information about the model st.markdown("---") st.markdown("### About the Model") st.write(""" This is a Wasserstein GAN (WGAN) model trained on the MNIST dataset. The model generates 28x28 grayscale images that resemble handwritten digits. """)