|
|
import streamlit as st |
|
|
import torch |
|
|
from models_conv import ConvGenerator |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="MNIST Digit Generator", layout="centered") |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
generator = ConvGenerator().to(device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load('checkpoints/wgan_checkpoint_epoch_190.pt', map_location=device) |
|
|
generator.load_state_dict(checkpoint['generator_state_dict']) |
|
|
generator.eval() |
|
|
|
|
|
|
|
|
st.title("MNIST Digit Generator") |
|
|
st.write("Generate MNIST-like digits using a Wasserstein GAN") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Generation Controls") |
|
|
noise_seed = st.slider("Noise Seed", 1, 1000, 42) |
|
|
num_images = st.slider("Number of Images", 1, 16, 4) |
|
|
generate_button = st.button("Generate New Images") |
|
|
|
|
|
|
|
|
def generate_images(noise_seed, num_images): |
|
|
torch.manual_seed(noise_seed) |
|
|
z = torch.randn(num_images, 100).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
imgs = generator(z) |
|
|
|
|
|
|
|
|
imgs = imgs.cpu().numpy() |
|
|
|
|
|
imgs = (imgs + 1) / 2 |
|
|
return imgs |
|
|
|
|
|
|
|
|
if generate_button or 'generated_images' not in st.session_state: |
|
|
images = generate_images(noise_seed, num_images) |
|
|
st.session_state.generated_images = images |
|
|
else: |
|
|
images = st.session_state.generated_images |
|
|
|
|
|
|
|
|
cols = st.columns(min(4, num_images)) |
|
|
for idx, img in enumerate(images): |
|
|
with cols[idx % min(4, num_images)]: |
|
|
st.image(img.squeeze(), caption=f"Generated Image {idx+1}", use_column_width=True) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### About the Model") |
|
|
st.write(""" |
|
|
This is a Wasserstein GAN (WGAN) model trained on the MNIST dataset. |
|
|
The model generates 28x28 grayscale images that resemble handwritten digits. |
|
|
""") |