File size: 1,944 Bytes
4ad8a34
 
 
388fa58
4ad8a34
388fa58
 
 
4ad8a34
388fa58
 
 
4ad8a34
388fa58
4ad8a34
 
388fa58
 
 
4ad8a34
388fa58
4ad8a34
 
 
 
 
 
 
 
 
 
388fa58
 
 
 
4ad8a34
388fa58
4fb7754
4ad8a34
 
388fa58
 
4ad8a34
 
 
388fa58
 
 
 
4ad8a34
 
 
388fa58
4ad8a34
388fa58
 
78eb968
388fa58
 
4ad8a34
 
388fa58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

# Conditional VAE definition (same as training)
class CVAE(torch.nn.Module):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.latent_dim = latent_dim
        self.label_embed = torch.nn.Embedding(10, 10)

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28*28 + 10, 400),
            torch.nn.ReLU(),
        )
        self.fc_mu = torch.nn.Linear(400, latent_dim)
        self.fc_logvar = torch.nn.Linear(400, latent_dim)

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim + 10, 400),
            torch.nn.ReLU(),
            torch.nn.Linear(400, 28*28),
            torch.nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y):
        y_embed = self.label_embed(y)
        inputs = torch.cat([z, y_embed], dim=1)
        return self.decoder(inputs)

model = CVAE()
model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu'))
model.eval()

# Image generation function
def generate_digit_images(digit):
    images = []
    for _ in range(5):
        z = torch.randn(1, 20)
        y = torch.tensor([int(digit)])
        with torch.no_grad():
            out = model.decode(z, y)
        img = out.view(28, 28).numpy()
        images.append((img * 255).astype(np.uint8))
    return images

# Launch Gradio app
iface = gr.Interface(
    fn=generate_digit_images,
    inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (0–9)"),
    outputs=[gr.Image(image_mode='L') for _ in range(5)],
    title="Conditional VAE Handwritten Digit Generator",
    description="Generates 5 images of the digit you select (0–9) using a Conditional Variational Autoencoder trained on MNIST."
)

iface.launch()