ankushthakurr09 commited on
Commit
837c9f7
·
verified ·
1 Parent(s): de6c084

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import os
8
+ IMAGE_SIZE = 32
9
+ CHANNELS = 3
10
+ BATCH_SIZE = 128
11
+ NUM_FEATURES = 128
12
+ Z_DIM = 200
13
+ LEARNING_RATE = 0.0005
14
+ EPOCHS = 30
15
+ BETA = 2000
16
+ LOAD_MODEL = False
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ def reparameterize(mu, log_var):
19
+ std = torch.exp(0.5 * log_var)
20
+ epsilon = torch.randn_like(std).to(mu.device)
21
+ return mu + std * epsilon
22
+
23
+ class Encoder(nn.Module):
24
+ def __init__(self, image_size, channels, num_features, z_dim):
25
+ super(Encoder, self).__init__()
26
+ self.output_size = image_size // (2**4)
27
+
28
+ self.main = nn.Sequential(
29
+ nn.Conv2d(channels, num_features, kernel_size=3, stride=2, padding=1, bias=False),
30
+ nn.BatchNorm2d(num_features),
31
+ nn.LeakyReLU(0.2, inplace=True),
32
+ nn.Conv2d(num_features, num_features, kernel_size=3, stride=2, padding=1, bias=False),
33
+ nn.BatchNorm2d(num_features),
34
+ nn.LeakyReLU(0.2, inplace=True),
35
+ nn.Conv2d(num_features, num_features, kernel_size=3, stride=2, padding=1, bias=False),
36
+ nn.BatchNorm2d(num_features),
37
+ nn.LeakyReLU(0.2, inplace=True),
38
+ nn.Conv2d(num_features, num_features, kernel_size=3, stride=2, padding=1, bias=False),
39
+ nn.BatchNorm2d(num_features),
40
+ nn.LeakyReLU(0.2, inplace=True),
41
+ )
42
+ self.flatten_size = num_features * self.output_size * self.output_size
43
+ self.fc_mu = nn.Linear(self.flatten_size, z_dim)
44
+ self.fc_log_var = nn.Linear(self.flatten_size, z_dim)
45
+
46
+ def forward(self, x):
47
+ x = self.main(x)
48
+ x = torch.flatten(x, 1)
49
+ mu = self.fc_mu(x)
50
+ log_var = self.fc_log_var(x)
51
+ z = reparameterize(mu, log_var)
52
+
53
+ return mu, log_var, z
54
+ class Decoder(nn.Module):
55
+ def __init__(self, image_size, channels, num_features, z_dim):
56
+ super(Decoder, self).__init__()
57
+
58
+ self.input_size = image_size // (2**4)
59
+ self.num_features = num_features
60
+ self.fc = nn.Linear(z_dim, num_features * self.input_size * self.input_size)
61
+
62
+ self.main = nn.Sequential(
63
+ nn.BatchNorm2d(num_features),
64
+ nn.LeakyReLU(0.2, inplace=True),
65
+ nn.ConvTranspose2d(num_features, num_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
66
+ nn.BatchNorm2d(num_features),
67
+ nn.LeakyReLU(0.2, inplace=True),
68
+ nn.ConvTranspose2d(num_features, num_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
69
+ nn.BatchNorm2d(num_features),
70
+ nn.LeakyReLU(0.2, inplace=True),
71
+ nn.ConvTranspose2d(num_features, num_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
72
+ nn.BatchNorm2d(num_features),
73
+ nn.LeakyReLU(0.2, inplace=True),
74
+ nn.ConvTranspose2d(num_features, num_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
75
+ nn.BatchNorm2d(num_features),
76
+ nn.LeakyReLU(0.2, inplace=True),
77
+ nn.ConvTranspose2d(num_features, channels, kernel_size=3, stride=1, padding=1, bias=False),
78
+ nn.Sigmoid()
79
+ )
80
+
81
+ def forward(self, z):
82
+ x = self.fc(z)
83
+ x = x.view(-1, self.num_features, self.input_size, self.input_size)
84
+ x = self.main(x)
85
+ return x
86
+ class VAE(nn.Module):
87
+ def __init__(self, encoder, decoder, beta):
88
+ super(VAE, self).__init__()
89
+ self.encoder = encoder
90
+ self.decoder = decoder
91
+ self.beta = beta
92
+
93
+ def forward(self, x):
94
+ mu, log_var, z = self.encoder(x)
95
+ reconstruction = self.decoder(z)
96
+ return reconstruction, mu, log_var
97
+
98
+ def vae_loss(self, recon_x, x, mu, log_var):
99
+ recon_loss = self.beta * F.mse_loss(recon_x, x, reduction='sum') / x.size(0)
100
+ kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / x.size(0)
101
+ total_loss = recon_loss + kl_loss
102
+ return total_loss, recon_loss, kl_loss
103
+ encoder = Encoder(IMAGE_SIZE, CHANNELS, NUM_FEATURES, Z_DIM).to(device)
104
+ decoder = Decoder(IMAGE_SIZE, CHANNELS, NUM_FEATURES, Z_DIM).to(device)
105
+ model = VAE(encoder, decoder, BETA).to(device)
106
+ model_weights_path = 'vae_final.pth'
107
+ if os.path.exists(model_weights_path):
108
+ try:
109
+ model.load_state_dict(torch.load(model_weights_path, map_location=device))
110
+ model.eval() # Set to evaluation mode
111
+ print("Model weights loaded successfully.")
112
+ except Exception as e:
113
+ print(f"Error loading model weights: {e}")
114
+ model = None
115
+ else:
116
+ print(f"Error: Model weights file not found at: {model_weights_path}")
117
+ model = None
118
+ def generate_image():
119
+ if model is None:
120
+ return "Error: Model not loaded. Please ensure 'vae_final.pth' is available."
121
+
122
+ with torch.no_grad():
123
+ z = torch.randn(1, Z_DIM).to(device)
124
+ generated_image = model.decoder(z).squeeze(0)
125
+ generated_image = generated_image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
126
+ pil_image = Image.fromarray(generated_image, 'RGB')
127
+ return pil_image
128
+ if model is not None:
129
+ # Create Gradio interface
130
+ iface = gr.Interface(fn=generate_image, inputs=None, outputs="image")
131
+ iface.launch(debug=True)
132
+ else:
133
+ print("Cannot launch Gradio interface because the model was not loaded.")