File size: 1,982 Bytes
f3cdb6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import streamlit as st
import torch
from models_conv import ConvGenerator
import numpy as np

# Set page config
st.set_page_config(page_title="MNIST Digit Generator", layout="centered")

# Load the trained generator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = ConvGenerator().to(device)

# Load the latest checkpoint
checkpoint = torch.load('checkpoints/wgan_checkpoint_epoch_190.pt', map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()

# Title and description
st.title("MNIST Digit Generator")
st.write("Generate MNIST-like digits using a Wasserstein GAN")

# Sidebar controls
with st.sidebar:
    st.header("Generation Controls")
    noise_seed = st.slider("Noise Seed", 1, 1000, 42)
    num_images = st.slider("Number of Images", 1, 16, 4)
    generate_button = st.button("Generate New Images")

# Function to generate images
def generate_images(noise_seed, num_images):
    torch.manual_seed(noise_seed)
    z = torch.randn(num_images, 100).to(device)
    
    with torch.no_grad():
        imgs = generator(z)
        
    # Convert to numpy array and adjust format
    imgs = imgs.cpu().numpy()
    # Normalize to [0,1] range
    imgs = (imgs + 1) / 2
    return imgs

# Main content
if generate_button or 'generated_images' not in st.session_state:
    images = generate_images(noise_seed, num_images)
    st.session_state.generated_images = images
else:
    images = st.session_state.generated_images

# Display images in a grid
cols = st.columns(min(4, num_images))
for idx, img in enumerate(images):
    with cols[idx % min(4, num_images)]:
        st.image(img.squeeze(), caption=f"Generated Image {idx+1}", use_column_width=True)

# Add information about the model
st.markdown("---")
st.markdown("### About the Model")
st.write("""
This is a Wasserstein GAN (WGAN) model trained on the MNIST dataset. 
The model generates 28x28 grayscale images that resemble handwritten digits.
""")