erixest commited on
Commit
dbb5961
·
verified ·
1 Parent(s): c777814

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import os
8
+
9
+ # Check device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ class ConditionalVAE(nn.Module):
13
+ def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10):
14
+ super(ConditionalVAE, self).__init__()
15
+
16
+ # Encoder
17
+ self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim)
18
+ self.fc21 = nn.Linear(hidden_dim, latent_dim)
19
+ self.fc22 = nn.Linear(hidden_dim, latent_dim)
20
+
21
+ # Decoder
22
+ self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim)
23
+ self.fc4 = nn.Linear(hidden_dim, input_dim)
24
+
25
+ self.latent_dim = latent_dim
26
+ self.num_classes = num_classes
27
+
28
+ def encode(self, x, y):
29
+ inputs = torch.cat([x, y], 1)
30
+ h1 = F.relu(self.fc1(inputs))
31
+ return self.fc21(h1), self.fc22(h1)
32
+
33
+ def reparameterize(self, mu, logvar):
34
+ std = torch.exp(0.5 * logvar)
35
+ eps = torch.randn_like(std)
36
+ return mu + eps * std
37
+
38
+ def decode(self, z, y):
39
+ inputs = torch.cat([z, y], 1)
40
+ h3 = F.relu(self.fc3(inputs))
41
+ return torch.sigmoid(self.fc4(h3))
42
+
43
+ def forward(self, x, y):
44
+ mu, logvar = self.encode(x.view(-1, 784), y)
45
+ z = self.reparameterize(mu, logvar)
46
+ return self.decode(z, y), mu, logvar
47
+
48
+ # Load model
49
+ @st.cache_resource # This will be ignored by Gradio but won't cause errors
50
+ def load_model():
51
+ model = ConditionalVAE(input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10)
52
+ model.load_state_dict(torch.load('mnist_cvae_model.pth', map_location=device))
53
+ model = model.to(device)
54
+ model.eval()
55
+ return model
56
+
57
+ def generate_digits(model, digit, num_samples=5):
58
+ model.eval()
59
+ with torch.no_grad():
60
+ label = torch.zeros(num_samples, 10).to(device)
61
+ label[:, digit] = 1
62
+
63
+ z = torch.randn(num_samples, model.latent_dim).to(device)
64
+ generated = model.decode(z, label)
65
+ generated = generated.view(num_samples, 28, 28)
66
+ generated = generated.cpu().numpy()
67
+ generated = (generated * 255).astype(np.uint8)
68
+
69
+ return generated
70
+
71
+ def generate_digit_images(digit):
72
+ try:
73
+ model = load_model()
74
+ generated_images = generate_digits(model, int(digit), num_samples=5)
75
+
76
+ pil_images = []
77
+ for img in generated_images:
78
+ pil_img = Image.fromarray(img, mode='L')
79
+ pil_img = pil_img.resize((112, 112), Image.NEAREST)
80
+ pil_images.append(pil_img)
81
+
82
+ return pil_images
83
+ except Exception as e:
84
+ print(f"Error: {e}")
85
+ placeholder = Image.new('L', (112, 112), color=128)
86
+ return [placeholder] * 5
87
+
88
+ def generate_and_display(digit):
89
+ images = generate_digit_images(digit)
90
+ return images[0], images[1], images[2], images[3], images[4]
91
+
92
+ # Create Gradio interface
93
+ with gr.Blocks(title="MNIST Digit Generator", theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown("# 🔢 MNIST Handwritten Digit Generator")
95
+ gr.Markdown("Select a digit (0-9) and generate 5 unique handwritten samples using a trained Conditional VAE model.")
96
+
97
+ with gr.Row():
98
+ digit_input = gr.Slider(
99
+ minimum=0,
100
+ maximum=9,
101
+ step=1,
102
+ value=0,
103
+ label="Select Digit to Generate"
104
+ )
105
+
106
+ generate_btn = gr.Button("🎨 Generate 5 Digit Images", variant="primary", size="lg")
107
+
108
+ gr.Markdown("## Generated Images")
109
+ with gr.Row():
110
+ img1 = gr.Image(label="Sample 1", width=112, height=112)
111
+ img2 = gr.Image(label="Sample 2", width=112, height=112)
112
+ img3 = gr.Image(label="Sample 3", width=112, height=112)
113
+ img4 = gr.Image(label="Sample 4", width=112, height=112)
114
+ img5 = gr.Image(label="Sample 5", width=112, height=112)
115
+
116
+ generate_btn.click(
117
+ fn=generate_and_display,
118
+ inputs=[digit_input],
119
+ outputs=[img1, img2, img3, img4, img5]
120
+ )
121
+
122
+ with gr.Accordion("📋 Model Information", open=False):
123
+ gr.Markdown("""
124
+ ### Technical Details
125
+ - **Architecture**: Conditional Variational Autoencoder (CVAE)
126
+ - **Dataset**: MNIST (28×28 grayscale images)
127
+ - **Training**: From scratch on Google Colab T4 GPU
128
+ - **Latent Dimension**: 20
129
+ - **Training Epochs**: 15
130
+ - **Loss Function**: BCE + KL Divergence
131
+
132
+ The model generates diverse samples by sampling from the learned latent space conditioned on digit labels.
133
+ """)
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch()