task_3 / app.py
erixest's picture
Update app.py
79ee527 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
from PIL import Image
import os
# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ConditionalVAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10):
super(ConditionalVAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, latent_dim)
self.fc22 = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
self.latent_dim = latent_dim
self.num_classes = num_classes
def encode(self, x, y):
inputs = torch.cat([x, y], 1)
h1 = F.relu(self.fc1(inputs))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, y):
inputs = torch.cat([z, y], 1)
h3 = F.relu(self.fc3(inputs))
return torch.sigmoid(self.fc4(h3))
def forward(self, x, y):
mu, logvar = self.encode(x.view(-1, 784), y)
z = self.reparameterize(mu, logvar)
return self.decode(z, y), mu, logvar
# Load model
def load_model():
model = ConditionalVAE(input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10)
model.load_state_dict(torch.load('mnist_cvae_model.pth', map_location=device))
model = model.to(device)
model.eval()
return model
def generate_digits(model, digit, num_samples=5):
model.eval()
with torch.no_grad():
label = torch.zeros(num_samples, 10).to(device)
label[:, digit] = 1
z = torch.randn(num_samples, model.latent_dim).to(device)
generated = model.decode(z, label)
generated = generated.view(num_samples, 28, 28)
generated = generated.cpu().numpy()
generated = (generated * 255).astype(np.uint8)
return generated
def generate_digit_images(digit):
try:
model = load_model()
generated_images = generate_digits(model, int(digit), num_samples=5)
pil_images = []
for img in generated_images:
pil_img = Image.fromarray(img, mode='L')
pil_img = pil_img.resize((112, 112), Image.NEAREST)
pil_images.append(pil_img)
return pil_images
except Exception as e:
print(f"Error: {e}")
placeholder = Image.new('L', (112, 112), color=128)
return [placeholder] * 5
def generate_and_display(digit):
images = generate_digit_images(digit)
return images[0], images[1], images[2], images[3], images[4]
# Create Gradio interface
with gr.Blocks(title="MNIST Digit Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ”’ MNIST Handwritten Digit Generator")
gr.Markdown("Select a digit (0-9) and generate 5 unique handwritten samples using a trained Conditional VAE model.")
with gr.Row():
digit_input = gr.Slider(
minimum=0,
maximum=9,
step=1,
value=0,
label="Select Digit to Generate"
)
generate_btn = gr.Button("🎨 Generate 5 Digit Images", variant="primary", size="lg")
gr.Markdown("## Generated Images")
with gr.Row():
img1 = gr.Image(label="Sample 1", width=112, height=112)
img2 = gr.Image(label="Sample 2", width=112, height=112)
img3 = gr.Image(label="Sample 3", width=112, height=112)
img4 = gr.Image(label="Sample 4", width=112, height=112)
img5 = gr.Image(label="Sample 5", width=112, height=112)
generate_btn.click(
fn=generate_and_display,
inputs=[digit_input],
outputs=[img1, img2, img3, img4, img5]
)
with gr.Accordion("πŸ“‹ Model Information", open=False):
gr.Markdown("""
### Technical Details
- **Architecture**: Conditional Variational Autoencoder (CVAE)
- **Dataset**: MNIST (28Γ—28 grayscale images)
- **Training**: From scratch on Google Colab T4 GPU
- **Latent Dimension**: 20
- **Training Epochs**: 15
- **Loss Function**: BCE + KL Divergence
The model generates diverse samples by sampling from the learned latent space conditioned on digit labels.
""")
if __name__ == "__main__":
demo.launch()