MSTI-AI-CONTEST / app.py
selmandedeakay's picture
Upload 3 files
56c4bd2 verified
# app.py - Gradio app for MNIST Digit Generator
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import io
import matplotlib.pyplot as plt
# Define the same Generator architecture as in training
class Generator(nn.Module):
def __init__(self, latent_dim=100, num_classes=10):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.img_size = 28
# Label embedding
self.label_emb = nn.Embedding(num_classes, num_classes)
# Generator layers
self.main = nn.Sequential(
# Input: latent_dim + num_classes
nn.Linear(latent_dim + num_classes, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, self.img_size * self.img_size),
nn.Tanh()
)
def forward(self, noise, labels):
# Embed labels and concatenate with noise
label_emb = self.label_emb(labels)
gen_input = torch.cat([noise, label_emb], dim=1)
# Generate image
img = self.main(gen_input)
img = img.view(img.size(0), 1, self.img_size, self.img_size)
return img
# Load the model
device = torch.device('cpu')
generator = Generator(latent_dim=100, num_classes=10).to(device)
try:
generator.load_state_dict(torch.load('generator.pth', map_location=device))
generator.eval()
model_loaded = True
except FileNotFoundError:
model_loaded = False
print("Warning: generator.pth not found!")
# Generate function for single digit
def generate_single_digit(digit, seed):
if not model_loaded:
return None
# Set seed if provided
if seed != -1:
torch.manual_seed(seed)
np.random.seed(seed)
with torch.no_grad():
noise = torch.randn(1, 100).to(device)
label = torch.tensor([digit], dtype=torch.long).to(device)
generated_img = generator(noise, label)
# Denormalize from [-1, 1] to [0, 1]
generated_img = generated_img * 0.5 + 0.5
generated_img = torch.clamp(generated_img, 0, 1)
# Convert to numpy and then PIL
img_numpy = generated_img.cpu().numpy()[0, 0]
img_pil = Image.fromarray((img_numpy * 255).astype(np.uint8), mode='L')
# Resize for better display
img_pil = img_pil.resize((256, 256), Image.Resampling.NEAREST)
return img_pil
# Generate function for multiple digits
def generate_multiple_digits(digit, num_samples, seed):
if not model_loaded:
return None
if seed != -1:
torch.manual_seed(seed)
np.random.seed(seed)
with torch.no_grad():
noise = torch.randn(num_samples, 100).to(device)
labels = torch.full((num_samples,), digit, dtype=torch.long).to(device)
generated_imgs = generator(noise, labels)
# Denormalize
generated_imgs = generated_imgs * 0.5 + 0.5
generated_imgs = torch.clamp(generated_imgs, 0, 1)
# Create grid
cols = min(5, num_samples)
rows = (num_samples + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
if rows == 1 and cols == 1:
axes = [[axes]]
elif rows == 1:
axes = [axes]
elif cols == 1:
axes = [[ax] for ax in axes]
for idx in range(num_samples):
row = idx // cols
col = idx % cols
axes[row][col].imshow(generated_imgs[idx].cpu().numpy()[0], cmap='gray')
axes[row][col].axis('off')
axes[row][col].set_title(f'Sample {idx + 1}')
# Remove empty subplots
for idx in range(num_samples, rows * cols):
row = idx // cols
col = idx % cols
fig.delaxes(axes[row][col])
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
plt.close()
buf.seek(0)
img = Image.open(buf)
return img
# Generate all digits
def generate_all_digits(seed):
if not model_loaded:
return None
if seed != -1:
torch.manual_seed(seed)
np.random.seed(seed)
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
with torch.no_grad():
for digit in range(10):
noise = torch.randn(1, 100).to(device)
label = torch.tensor([digit], dtype=torch.long).to(device)
generated_img = generator(noise, label)
# Denormalize
generated_img = generated_img * 0.5 + 0.5
generated_img = torch.clamp(generated_img, 0, 1)
row = digit // 5
col = digit % 5
axes[row, col].imshow(generated_img.cpu().numpy()[0, 0], cmap='gray')
axes[row, col].set_title(f'Digit {digit}')
axes[row, col].axis('off')
plt.suptitle('Generated Digits 0-9', fontsize=16)
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
plt.close()
buf.seek(0)
img = Image.open(buf)
return img
# Create Gradio interface
with gr.Blocks(title="MNIST Digit Generator") as demo:
gr.Markdown("# 🔢 MNIST Digit Generator")
gr.Markdown("Generate handwritten digits using a trained GAN model")
with gr.Tab("Single Digit"):
with gr.Row():
with gr.Column():
single_digit = gr.Slider(
minimum=0,
maximum=9,
step=1,
value=0,
label="Select Digit"
)
single_seed = gr.Slider(
minimum=-1,
maximum=9999,
step=1,
value=-1,
label="Seed (-1 for random)"
)
single_generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
single_output = gr.Image(
label="Generated Digit",
type="pil"
)
single_generate_btn.click(
fn=generate_single_digit,
inputs=[single_digit, single_seed],
outputs=single_output
)
with gr.Tab("Multiple Samples"):
with gr.Row():
with gr.Column():
multi_digit = gr.Slider(
minimum=0,
maximum=9,
step=1,
value=0,
label="Select Digit"
)
num_samples = gr.Slider(
minimum=1,
maximum=25,
step=1,
value=5,
label="Number of Samples"
)
multi_seed = gr.Slider(
minimum=-1,
maximum=9999,
step=1,
value=-1,
label="Seed (-1 for random)"
)
multi_generate_btn = gr.Button("Generate Multiple", variant="primary")
with gr.Column():
multi_output = gr.Image(
label="Generated Samples",
type="pil"
)
multi_generate_btn.click(
fn=generate_multiple_digits,
inputs=[multi_digit, num_samples, multi_seed],
outputs=multi_output
)
with gr.Tab("All Digits"):
with gr.Row():
with gr.Column():
all_seed = gr.Slider(
minimum=-1,
maximum=9999,
step=1,
value=-1,
label="Seed (-1 for random)"
)
all_generate_btn = gr.Button("Generate All Digits (0-9)", variant="primary")
with gr.Column():
all_output = gr.Image(
label="All Generated Digits",
type="pil"
)
all_generate_btn.click(
fn=generate_all_digits,
inputs=[all_seed],
outputs=all_output
)
with gr.Tab("About"):
gr.Markdown("""
## About this Model
This app uses a **Generative Adversarial Network (GAN)** trained on the MNIST dataset to generate
handwritten digit images.
### How it works:
- The model takes random noise and a digit label as input
- It generates a 28x28 pixel grayscale image of the specified digit
- Each generation with the same seed will produce the same image
### Model Architecture:
- **Generator**: 4-layer neural network with LeakyReLU activation
- **Training**: 100 epochs on MNIST dataset
- **Approach**: Conditional GAN for digit-specific generation
### Features:
- Generate any digit from 0-9
- Control randomness with seed values
- Generate multiple samples at once
- View all digits in a single grid
""")
# Examples
gr.Examples(
examples=[
[7, 42],
[3, 123],
[9, 999],
[0, 2024],
],
inputs=[single_digit, single_seed],
outputs=single_output,
fn=generate_single_digit,
cache_examples=True,
)
# Launch the app
if __name__ == "__main__":
demo.launch()