|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class ConditionalVAE(nn.Module): |
|
|
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10): |
|
|
super(ConditionalVAE, self).__init__() |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim) |
|
|
self.fc21 = nn.Linear(hidden_dim, latent_dim) |
|
|
self.fc22 = nn.Linear(hidden_dim, latent_dim) |
|
|
|
|
|
|
|
|
self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim) |
|
|
self.fc4 = nn.Linear(hidden_dim, input_dim) |
|
|
|
|
|
self.latent_dim = latent_dim |
|
|
self.num_classes = num_classes |
|
|
|
|
|
def encode(self, x, y): |
|
|
inputs = torch.cat([x, y], 1) |
|
|
h1 = F.relu(self.fc1(inputs)) |
|
|
return self.fc21(h1), self.fc22(h1) |
|
|
|
|
|
def reparameterize(self, mu, logvar): |
|
|
std = torch.exp(0.5 * logvar) |
|
|
eps = torch.randn_like(std) |
|
|
return mu + eps * std |
|
|
|
|
|
def decode(self, z, y): |
|
|
inputs = torch.cat([z, y], 1) |
|
|
h3 = F.relu(self.fc3(inputs)) |
|
|
return torch.sigmoid(self.fc4(h3)) |
|
|
|
|
|
def forward(self, x, y): |
|
|
mu, logvar = self.encode(x.view(-1, 784), y) |
|
|
z = self.reparameterize(mu, logvar) |
|
|
return self.decode(z, y), mu, logvar |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
model = ConditionalVAE(input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10) |
|
|
model.load_state_dict(torch.load('mnist_cvae_model.pth', map_location=device)) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def generate_digits(model, digit, num_samples=5): |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
label = torch.zeros(num_samples, 10).to(device) |
|
|
label[:, digit] = 1 |
|
|
|
|
|
z = torch.randn(num_samples, model.latent_dim).to(device) |
|
|
generated = model.decode(z, label) |
|
|
generated = generated.view(num_samples, 28, 28) |
|
|
generated = generated.cpu().numpy() |
|
|
generated = (generated * 255).astype(np.uint8) |
|
|
|
|
|
return generated |
|
|
|
|
|
def generate_digit_images(digit): |
|
|
try: |
|
|
model = load_model() |
|
|
generated_images = generate_digits(model, int(digit), num_samples=5) |
|
|
|
|
|
pil_images = [] |
|
|
for img in generated_images: |
|
|
pil_img = Image.fromarray(img, mode='L') |
|
|
pil_img = pil_img.resize((112, 112), Image.NEAREST) |
|
|
pil_images.append(pil_img) |
|
|
|
|
|
return pil_images |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
placeholder = Image.new('L', (112, 112), color=128) |
|
|
return [placeholder] * 5 |
|
|
|
|
|
def generate_and_display(digit): |
|
|
images = generate_digit_images(digit) |
|
|
return images[0], images[1], images[2], images[3], images[4] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="MNIST Digit Generator", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π’ MNIST Handwritten Digit Generator") |
|
|
gr.Markdown("Select a digit (0-9) and generate 5 unique handwritten samples using a trained Conditional VAE model.") |
|
|
|
|
|
with gr.Row(): |
|
|
digit_input = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=9, |
|
|
step=1, |
|
|
value=0, |
|
|
label="Select Digit to Generate" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("π¨ Generate 5 Digit Images", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("## Generated Images") |
|
|
with gr.Row(): |
|
|
img1 = gr.Image(label="Sample 1", width=112, height=112) |
|
|
img2 = gr.Image(label="Sample 2", width=112, height=112) |
|
|
img3 = gr.Image(label="Sample 3", width=112, height=112) |
|
|
img4 = gr.Image(label="Sample 4", width=112, height=112) |
|
|
img5 = gr.Image(label="Sample 5", width=112, height=112) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_and_display, |
|
|
inputs=[digit_input], |
|
|
outputs=[img1, img2, img3, img4, img5] |
|
|
) |
|
|
|
|
|
with gr.Accordion("π Model Information", open=False): |
|
|
gr.Markdown(""" |
|
|
### Technical Details |
|
|
- **Architecture**: Conditional Variational Autoencoder (CVAE) |
|
|
- **Dataset**: MNIST (28Γ28 grayscale images) |
|
|
- **Training**: From scratch on Google Colab T4 GPU |
|
|
- **Latent Dimension**: 20 |
|
|
- **Training Epochs**: 15 |
|
|
- **Loss Function**: BCE + KL Divergence |
|
|
|
|
|
The model generates diverse samples by sampling from the learned latent space conditioned on digit labels. |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |