File size: 1,944 Bytes
4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4fb7754 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 4ad8a34 388fa58 78eb968 388fa58 4ad8a34 388fa58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
# Conditional VAE definition (same as training)
class CVAE(torch.nn.Module):
def __init__(self, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
self.label_embed = torch.nn.Embedding(10, 10)
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28*28 + 10, 400),
torch.nn.ReLU(),
)
self.fc_mu = torch.nn.Linear(400, latent_dim)
self.fc_logvar = torch.nn.Linear(400, latent_dim)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(latent_dim + 10, 400),
torch.nn.ReLU(),
torch.nn.Linear(400, 28*28),
torch.nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, y):
y_embed = self.label_embed(y)
inputs = torch.cat([z, y_embed], dim=1)
return self.decoder(inputs)
model = CVAE()
model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu'))
model.eval()
# Image generation function
def generate_digit_images(digit):
images = []
for _ in range(5):
z = torch.randn(1, 20)
y = torch.tensor([int(digit)])
with torch.no_grad():
out = model.decode(z, y)
img = out.view(28, 28).numpy()
images.append((img * 255).astype(np.uint8))
return images
# Launch Gradio app
iface = gr.Interface(
fn=generate_digit_images,
inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (0–9)"),
outputs=[gr.Image(image_mode='L') for _ in range(5)],
title="Conditional VAE Handwritten Digit Generator",
description="Generates 5 images of the digit you select (0–9) using a Conditional Variational Autoencoder trained on MNIST."
)
iface.launch() |