srijaydeshpande commited on
Commit
5ab9057
·
verified ·
1 Parent(s): 16b73bf

Upload vq_vae.py

Browse files
Files changed (1) hide show
  1. vq_vae.py +440 -0
vq_vae.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from scipy.signal import savgol_filter
4
+ import os, cv2
5
+ import imageio, glob
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import DataLoader
10
+ import torch.optim as optim
11
+ import torchvision.datasets as datasets
12
+ import torchvision.transforms as transforms
13
+ from torchvision.utils import make_grid, save_image
14
+ from gan_losses import get_gan_losses
15
+ from PIL import Image
16
+ import torchvision.utils as vutils
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ """## Load Data"""
21
+
22
+ # data_variance = np.var(training_data.data / 255.0)
23
+ data_variance = 1
24
+
25
+ def mkdir(dir):
26
+ if not os.path.exists(dir):
27
+ os.makedirs(dir)
28
+
29
+ def read_image(img_path):
30
+ img = cv2.imread(img_path)
31
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
32
+ img = img / 255.0
33
+ return img
34
+
35
+ class VectorQuantizer(nn.Module):
36
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
37
+ super(VectorQuantizer, self).__init__()
38
+
39
+ self._embedding_dim = embedding_dim
40
+ self._num_embeddings = num_embeddings
41
+
42
+ #codebook
43
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
44
+ self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
45
+ self._commitment_cost = commitment_cost
46
+
47
+ def forward(self, inputs):
48
+ # convert inputs from BCHW -> BHWC
49
+ inputs = inputs.permute(0, 2, 3, 1).contiguous()
50
+ input_shape = inputs.shape
51
+
52
+ # Flatten input
53
+ flat_input = inputs.view(-1, self._embedding_dim)
54
+
55
+ # Calculate distances
56
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
57
+ + torch.sum(self._embedding.weight**2, dim=1)
58
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
59
+
60
+ # Encoding
61
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
62
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
63
+ encodings.scatter_(1, encoding_indices, 1)
64
+
65
+
66
+ # Quantize and unflatten
67
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
68
+
69
+ # Loss
70
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
71
+ q_latent_loss = F.mse_loss(quantized, inputs.detach())
72
+ loss = q_latent_loss + self._commitment_cost * e_latent_loss
73
+
74
+ quantized = inputs + (quantized - inputs).detach()
75
+ avg_probs = torch.mean(encodings, dim=0)
76
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
77
+
78
+ # convert quantized from BHWC -> BCHW
79
+ return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices
80
+
81
+ class VectorQuantizerEMA(nn.Module):
82
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
83
+ super(VectorQuantizerEMA, self).__init__()
84
+
85
+ self._embedding_dim = embedding_dim
86
+ self._num_embeddings = num_embeddings
87
+
88
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
89
+ self._embedding.weight.data.normal_()
90
+ self._commitment_cost = commitment_cost
91
+
92
+ self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
93
+ self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
94
+ self._ema_w.data.normal_()
95
+
96
+ self._decay = decay
97
+ self._epsilon = epsilon
98
+
99
+ def forward(self, inputs):
100
+
101
+ # convert inputs from BCHW -> BHWC
102
+ inputs = inputs.permute(0, 2, 3, 1).contiguous()
103
+ input_shape = inputs.shape
104
+
105
+ # Flatten input
106
+ flat_input = inputs.view(-1, self._embedding_dim)
107
+
108
+ # Calculate distances
109
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
110
+ + torch.sum(self._embedding.weight**2, dim=1)
111
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
112
+
113
+ # Encoding
114
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
115
+ # encoding_indices[encoding_indices == 3] = 4 # 1 means background, 2 means epithelial cells, 4 means connective, 3 means neutrophil, 5 means plasma, 6 lymphocytes
116
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
117
+ encodings.scatter_(1, encoding_indices, 1)
118
+
119
+ # Quantize and unflatten
120
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
121
+
122
+ # Use EMA to update the embedding vectors
123
+ if self.training:
124
+ self._ema_cluster_size = self._ema_cluster_size * self._decay + \
125
+ (1 - self._decay) * torch.sum(encodings, 0)
126
+
127
+ # Laplace smoothing of the cluster size
128
+ n = torch.sum(self._ema_cluster_size.data)
129
+ self._ema_cluster_size = (
130
+ (self._ema_cluster_size + self._epsilon)
131
+ / (n + self._num_embeddings * self._epsilon) * n)
132
+
133
+ dw = torch.matmul(encodings.t(), flat_input)
134
+ self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
135
+
136
+ self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
137
+
138
+ # Loss
139
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
140
+ loss = self._commitment_cost * e_latent_loss
141
+
142
+ # Straight Through Estimator
143
+ quantized = inputs + (quantized - inputs).detach()
144
+ avg_probs = torch.mean(encodings, dim=0)
145
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
146
+
147
+ # convert quantized from BHWC -> BCHW
148
+ return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices
149
+
150
+ class Residual(nn.Module):
151
+ def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
152
+ super(Residual, self).__init__()
153
+ self._block = nn.Sequential(
154
+ nn.ReLU(True),
155
+ nn.Conv2d(in_channels=in_channels,
156
+ out_channels=num_residual_hiddens,
157
+ kernel_size=3, stride=1, padding=1, bias=False),
158
+ nn.ReLU(True),
159
+ nn.Conv2d(in_channels=num_residual_hiddens,
160
+ out_channels=num_hiddens,
161
+ kernel_size=1, stride=1, bias=False)
162
+ )
163
+
164
+ def forward(self, x):
165
+ return x + self._block(x)
166
+
167
+ class ResidualStack(nn.Module):
168
+ def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
169
+ super(ResidualStack, self).__init__()
170
+ self._num_residual_layers = num_residual_layers
171
+ self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
172
+ for _ in range(self._num_residual_layers)])
173
+
174
+ def forward(self, x):
175
+ for i in range(self._num_residual_layers):
176
+ x = self._layers[i](x)
177
+ return F.relu(x)
178
+
179
+ class Encoder(nn.Module):
180
+
181
+ def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, embedding_dim):
182
+ super(Encoder, self).__init__()
183
+
184
+ self._conv_1 = nn.Conv2d(in_channels=in_channels,
185
+ out_channels=num_hiddens//2,
186
+ kernel_size=4,
187
+ stride=2, padding=1)
188
+ self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2,
189
+ out_channels=num_hiddens,
190
+ kernel_size=4,
191
+ stride=2, padding=1)
192
+ self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
193
+ out_channels=num_hiddens,
194
+ kernel_size=3,
195
+ stride=1, padding=1)
196
+ self._residual_stack = ResidualStack(in_channels=num_hiddens,
197
+ num_hiddens=num_hiddens,
198
+ num_residual_layers=num_residual_layers,
199
+ num_residual_hiddens=num_residual_hiddens)
200
+
201
+ self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
202
+ out_channels=embedding_dim,
203
+ kernel_size=1,
204
+ stride=1)
205
+
206
+ self.apply_tanh = nn.Tanh()
207
+
208
+ def forward(self, inputs):
209
+
210
+ x = self._conv_1(inputs)
211
+ x = F.relu(x)
212
+
213
+ x = self._conv_2(x)
214
+ x = F.relu(x)
215
+
216
+ x = self._conv_3(x)
217
+
218
+ x = self._residual_stack(x)
219
+
220
+ x = self._pre_vq_conv(x)
221
+
222
+ return x
223
+
224
+ class Decoder(nn.Module):
225
+ def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
226
+ super(Decoder, self).__init__()
227
+
228
+ self._conv_1 = nn.Conv2d(in_channels=in_channels,
229
+ out_channels=num_hiddens,
230
+ kernel_size=3,
231
+ stride=1, padding=1)
232
+
233
+ self._residual_stack = ResidualStack(in_channels=num_hiddens,
234
+ num_hiddens=num_hiddens,
235
+ num_residual_layers=num_residual_layers,
236
+ num_residual_hiddens=num_residual_hiddens)
237
+
238
+ self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
239
+ out_channels=num_hiddens//2,
240
+ kernel_size=4,
241
+ stride=2, padding=1)
242
+
243
+ self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
244
+ out_channels=3,
245
+ kernel_size=4,
246
+ stride=2, padding=1)
247
+
248
+ self.apply_tanh = nn.Tanh()
249
+
250
+ def forward(self, inputs):
251
+ x = self._conv_1(inputs)
252
+
253
+ x = self._residual_stack(x)
254
+
255
+ x = self._conv_trans_1(x)
256
+ x = F.relu(x)
257
+
258
+ x = self._conv_trans_2(x)
259
+
260
+ return self.apply_tanh(x)
261
+
262
+ class VQModel(nn.Module):
263
+
264
+ def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
265
+ num_embeddings, embedding_dim, commitment_cost, decay=0):
266
+ super(VQModel, self).__init__()
267
+
268
+ self._encoder = Encoder(3, num_hiddens,
269
+ num_residual_layers,
270
+ num_residual_hiddens,
271
+ embedding_dim)
272
+
273
+ if decay > 0.0:
274
+ self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
275
+ commitment_cost, decay)
276
+ else:
277
+ self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
278
+ commitment_cost)
279
+ self._decoder = Decoder(embedding_dim,
280
+ num_hiddens,
281
+ num_residual_layers,
282
+ num_residual_hiddens)
283
+
284
+ def forward(self, x):
285
+ z = self._encoder(x)
286
+ loss, quantized, perplexity, _ = self._vq_vae(z)
287
+ x_recon = self._decoder(quantized)
288
+
289
+ return loss, x_recon, perplexity
290
+
291
+ def save_generated_images(image_names, batch_images, ind, mode, type):
292
+ current_output_dir = os.path.join(output_dir, mode, type)
293
+ mkdir(current_output_dir)
294
+ num_images = batch_images.shape[0]
295
+ for i in range(0,num_images):
296
+ save_image(batch_images[i], os.path.join(current_output_dir,image_names[i]))
297
+
298
+ def generate_images_from_diffusion_latents(model, latents_path, output_dir):
299
+ latent_paths = glob.glob(os.path.join(latents_path, "*.pt"))
300
+ for latent_path in latent_paths:
301
+ latent = torch.load(latent_path).cuda()
302
+ latent = latent.detach()
303
+ _, quantized_latent, _, _ = model._vq_vae(latent)
304
+ image = model._decoder(quantized_latent)
305
+ image_name = os.path.basename(latent_path).split(".")[0]+".png"
306
+ save_image(image, os.path.join(output_dir, image_name))
307
+
308
+ class UNetDown(nn.Module):
309
+ def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
310
+ super(UNetDown, self).__init__()
311
+ layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
312
+ if normalize:
313
+ layers.append(nn.InstanceNorm2d(out_size))
314
+ layers.append(nn.LeakyReLU(0.2))
315
+ if dropout:
316
+ layers.append(nn.Dropout(dropout))
317
+ self.model = nn.Sequential(*layers)
318
+
319
+ def forward(self, x):
320
+ return self.model(x)
321
+
322
+ class UNetUp(nn.Module):
323
+ def __init__(self, in_size, out_size, dropout=0.0):
324
+ super(UNetUp, self).__init__()
325
+ layers = [
326
+ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
327
+ nn.InstanceNorm2d(out_size),
328
+ nn.ReLU(inplace=True),
329
+ ]
330
+ if dropout:
331
+ layers.append(nn.Dropout(dropout))
332
+
333
+ self.model = nn.Sequential(*layers)
334
+
335
+ def forward(self, x, skip_input):
336
+ x = self.model(x)
337
+ x = torch.cat((x, skip_input), 1)
338
+
339
+ return x
340
+
341
+ class Pix2PixGenerator(nn.Module):
342
+ def __init__(self, in_channels=3, out_channels=3):
343
+ super(Pix2PixGenerator, self).__init__()
344
+
345
+ self.down1 = UNetDown(in_channels, 64, normalize=False)
346
+ self.down2 = UNetDown(64, 128)
347
+ self.down3 = UNetDown(128, 256)
348
+ self.down4 = UNetDown(256, 512, dropout=0.5)
349
+ self.down5 = UNetDown(512, 512, dropout=0.5)
350
+ self.down6 = UNetDown(512, 512, dropout=0.5)
351
+ self.down7 = UNetDown(512, 512, dropout=0.5)
352
+ self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
353
+
354
+ self.up1 = UNetUp(512, 512, dropout=0.5)
355
+ self.up2 = UNetUp(1024, 512, dropout=0.5)
356
+ self.up3 = UNetUp(1024, 512, dropout=0.5)
357
+ self.up4 = UNetUp(1024, 512, dropout=0.5)
358
+ self.up5 = UNetUp(1024, 256)
359
+ self.up6 = UNetUp(512, 128)
360
+ self.up7 = UNetUp(256, 64)
361
+
362
+ self.final = nn.Sequential(
363
+ nn.Upsample(scale_factor=2),
364
+ nn.ZeroPad2d((1, 0, 1, 0)),
365
+ nn.Conv2d(128, out_channels, 4, padding=1),
366
+ nn.Tanh(),
367
+ )
368
+
369
+ # self.down1 = UNetDown(in_channels, 16, normalize=False)
370
+ # self.down2 = UNetDown(16, 32)
371
+ # self.down3 = UNetDown(32, 64)
372
+ # self.down4 = UNetDown(64, 128, dropout=0.5)
373
+ # self.down5 = UNetDown(128, 256, dropout=0.5)
374
+ # self.down6 = UNetDown(256, 512, dropout=0.5)
375
+ # self.down7 = UNetDown(512, 512, dropout=0.5)
376
+ # self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
377
+ #
378
+ # self.up1 = UNetUp(512, 512, dropout=0.5)
379
+ # self.up2 = UNetUp(1024, 512, dropout=0.5)
380
+ # self.up3 = UNetUp(1024, 256, dropout=0.5)
381
+ # self.up4 = UNetUp(512, 128, dropout=0.5)
382
+ # self.up5 = UNetUp(256, 64)
383
+ # self.up6 = UNetUp(128, 32)
384
+ # self.up7 = UNetUp(64, 16)
385
+ #
386
+ # self.final = nn.Sequential(
387
+ # nn.Upsample(scale_factor=2),
388
+ # nn.ZeroPad2d((1, 0, 1, 0)),
389
+ # nn.Conv2d(32, out_channels, 4, padding=1),
390
+ # nn.Tanh(),
391
+ # )
392
+
393
+
394
+ def forward(self, x):
395
+ # U-Net generator with skip connections from encoder to decoder
396
+ d1 = self.down1(x)
397
+ d2 = self.down2(d1)
398
+ d3 = self.down3(d2)
399
+ d4 = self.down4(d3)
400
+ d5 = self.down5(d4)
401
+ d6 = self.down6(d5)
402
+ d7 = self.down7(d6)
403
+ d8 = self.down8(d7)
404
+ u1 = self.up1(d8, d7)
405
+ u2 = self.up2(u1, d6)
406
+ u3 = self.up3(u2, d5)
407
+ u4 = self.up4(u3, d4)
408
+ u5 = self.up5(u4, d3)
409
+ u6 = self.up6(u5, d2)
410
+ u7 = self.up7(u6, d1)
411
+ return self.final(u7)
412
+
413
+ batch_size = 32 #Keep 16 for good results
414
+ num_training_updates = 30000
415
+
416
+ num_hiddens = 32 #Original: 128 , 32 used for masks
417
+ num_residual_hiddens = 32
418
+ num_residual_layers = 2 #Original was 2
419
+
420
+ embedding_dim = 3
421
+ num_embeddings = 2 #number of codebook vectors
422
+ commitment_cost = 0.25
423
+ decay = 0.99
424
+
425
+ model_name = "dp_bimask_2dim_1024size_tanhindecoder.pt"
426
+
427
+ def create_mask(model_dir, latents_path, final_output_dir):
428
+
429
+ model = VQModel(num_hiddens, num_residual_layers, num_residual_hiddens,
430
+ num_embeddings, embedding_dim,
431
+ commitment_cost, decay).to(device)
432
+
433
+ model.load_state_dict(torch.load(os.path.join(model_dir,model_name)))
434
+
435
+ model.eval()
436
+
437
+ mkdir(final_output_dir)
438
+ generate_images_from_diffusion_latents(model=model,
439
+ latents_path=latents_path,
440
+ output_dir=final_output_dir)