Yash Nagraj commited on
Commit
f0ff580
·
1 Parent(s): 8d85e1f

Make changes

Browse files
Files changed (5) hide show
  1. dataset/celeba.py +6 -2
  2. dataset/dataset.py +142 -0
  3. models/blocks.py +252 -260
  4. models/vqvae.py +101 -98
  5. train_vqvae.py +4 -5
dataset/celeba.py CHANGED
@@ -8,11 +8,12 @@ from PIL import Image
8
 
9
 
10
  class ParquetImageDataset(Dataset):
11
- def __init__(self, parquet_files, transform=None, im_size=256):
12
  self.data = pd.concat([pd.read_parquet(file)
13
  for file in parquet_files], ignore_index=True)
14
  self.transform = transform
15
  self.im_size = im_size
 
16
 
17
  def __len__(self):
18
  return len(self.data)
@@ -27,7 +28,10 @@ class ParquetImageDataset(Dataset):
27
  ])(image)
28
  image.close()
29
  im_tensor = (2 * im_tensor) - 1 # type: ignore
30
- return im_tensor, caption
 
 
 
31
 
32
 
33
  def create_dataloader(parquet_dir, batch_size=32, shuffle=True, num_workers=4):
 
8
 
9
 
10
  class ParquetImageDataset(Dataset):
11
+ def __init__(self, parquet_files, transform=None, im_size=256,condition_config=None):
12
  self.data = pd.concat([pd.read_parquet(file)
13
  for file in parquet_files], ignore_index=True)
14
  self.transform = transform
15
  self.im_size = im_size
16
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
17
 
18
  def __len__(self):
19
  return len(self.data)
 
28
  ])(image)
29
  image.close()
30
  im_tensor = (2 * im_tensor) - 1 # type: ignore
31
+ if len(self.condition_types) == 0:
32
+ return im_tensor
33
+ else:
34
+ return im_tensor, caption
35
 
36
 
37
  def create_dataloader(parquet_dir, batch_size=32, shuffle=True, num_workers=4):
dataset/dataset.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import random
4
+ import torch
5
+ import torchvision
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CelebDataset(Dataset):
13
+ r"""
14
+ Celeb dataset will by default centre crop and resize the images.
15
+ This can be replaced by any other dataset. As long as all the images
16
+ are under one directory.
17
+ """
18
+
19
+ def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
20
+ use_latents=False, latent_path=None, condition_config=None):
21
+ self.split = split
22
+ self.im_size = im_size
23
+ self.im_channels = im_channels
24
+ self.im_ext = im_ext
25
+ self.im_path = im_path
26
+ self.latent_maps = None
27
+ self.use_latents = False
28
+
29
+ self.condition_types = [] if condition_config is None else condition_config['condition_types']
30
+
31
+ self.idx_to_cls_map = {}
32
+ self.cls_to_idx_map = {}
33
+
34
+ if 'image' in self.condition_types:
35
+ self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels']
36
+ self.mask_h = condition_config['image_condition_config']['image_condition_h']
37
+ self.mask_w = condition_config['image_condition_config']['image_condition_w']
38
+
39
+ self.images, self.texts, self.masks = self.load_images(im_path)
40
+
41
+
42
+ def load_images(self, im_path):
43
+ r"""
44
+ Gets all images from the path specified
45
+ and stacks them all up
46
+ """
47
+ assert os.path.exists(
48
+ im_path), "images path {} does not exist".format(im_path)
49
+ ims = []
50
+ fnames = glob.glob(os.path.join(
51
+ im_path, 'CelebA-HQ-img/*.{}'.format('png')))
52
+ fnames += glob.glob(os.path.join(im_path,
53
+ 'CelebA-HQ-img/*.{}'.format('jpg')))
54
+ fnames += glob.glob(os.path.join(im_path,
55
+ 'CelebA-HQ-img/*.{}'.format('jpeg')))
56
+ texts = []
57
+ masks = []
58
+
59
+ if 'image' in self.condition_types:
60
+ label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth',
61
+ 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
62
+ self.idx_to_cls_map = {idx: label_list[idx]
63
+ for idx in range(len(label_list))}
64
+ self.cls_to_idx_map = {
65
+ label_list[idx]: idx for idx in range(len(label_list))}
66
+
67
+ for fname in tqdm(fnames):
68
+ ims.append(fname)
69
+
70
+ if 'text' in self.condition_types:
71
+ im_name = os.path.split(fname)[1].split('.')[0]
72
+ captions_im = []
73
+ with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f:
74
+ for line in f.readlines():
75
+ captions_im.append(line.strip())
76
+ texts.append(captions_im)
77
+
78
+ if 'image' in self.condition_types:
79
+ im_name = int(os.path.split(fname)[1].split('.')[0])
80
+ masks.append(os.path.join(
81
+ im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name)))
82
+ if 'text' in self.condition_types:
83
+ assert len(texts) == len(
84
+ ims), "Condition Type Text but could not find captions for all images"
85
+ if 'image' in self.condition_types:
86
+ assert len(masks) == len(
87
+ ims), "Condition Type Image but could not find masks for all images"
88
+ print('Found {} images'.format(len(ims)))
89
+ print('Found {} masks'.format(len(masks)))
90
+ print('Found {} captions'.format(len(texts)))
91
+ return ims, texts, masks
92
+
93
+ def get_mask(self, index):
94
+ r"""
95
+ Method to get the mask of WxH
96
+ for given index and convert it into
97
+ Classes x W x H mask image
98
+ :param index:
99
+ :return:
100
+ """
101
+ mask_im = Image.open(self.masks[index])
102
+ mask_im = np.array(mask_im)
103
+ im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels))
104
+ for orig_idx in range(len(self.idx_to_cls_map)):
105
+ im_base[mask_im == (orig_idx+1), orig_idx] = 1
106
+ mask = torch.from_numpy(im_base).permute(2, 0, 1).float()
107
+ return mask
108
+
109
+ def __len__(self):
110
+ return len(self.images)
111
+
112
+ def __getitem__(self, index):
113
+ ######## Set Conditioning Info ########
114
+ cond_inputs = {}
115
+ if 'text' in self.condition_types:
116
+ cond_inputs['text'] = random.sample(self.texts[index], k=1)[0]
117
+ if 'image' in self.condition_types:
118
+ mask = self.get_mask(index)
119
+ cond_inputs['image'] = mask
120
+ #######################################
121
+
122
+ if self.use_latents:
123
+ latent = self.latent_maps[self.images[index]]
124
+ if len(self.condition_types) == 0:
125
+ return latent
126
+ else:
127
+ return latent, cond_inputs
128
+ else:
129
+ im = Image.open(self.images[index])
130
+ im_tensor = torchvision.transforms.Compose([
131
+ torchvision.transforms.Resize(self.im_size),
132
+ torchvision.transforms.CenterCrop(self.im_size),
133
+ torchvision.transforms.ToTensor(),
134
+ ])(im)
135
+ im.close()
136
+
137
+ # Convert input to -1 to 1 range.
138
+ im_tensor = (2 * im_tensor) - 1
139
+ if len(self.condition_types) == 0:
140
+ return im_tensor
141
+ else:
142
+ return im_tensor, cond_inputs
models/blocks.py CHANGED
@@ -1,92 +1,99 @@
1
- from re import A
2
  import torch
3
  import torch.nn as nn
4
 
5
 
6
  def get_time_embedding(time_steps, temb_dim):
7
- assert time_steps % 2 == 0, "time embedding dimension must be divisible by 2"
8
-
 
 
 
 
 
 
 
 
9
  factor = 10000 ** ((torch.arange(
10
  start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
11
  )
12
-
13
  # pos / factor
14
- # time_steps B -> B, 1 -> B, temb_dim
15
  t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
16
  t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
17
  return t_emb
18
 
19
 
20
  class DownBlock(nn.Module):
 
 
 
 
 
 
21
  """
22
- Down Block that down samples the image, flows like this:
23
- 1) Resnet block with time embedding
24
- 2) Self Attention block
25
- 3) Down Sample
26
- """
27
-
28
- def __init__(self, in_channels, out_channels, t_emd_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False,
29
- context_dim=None):
30
  super().__init__()
 
31
  self.down_sample = down_sample
32
- self.cross_attn = cross_attn
33
  self.context_dim = context_dim
34
  self.cross_attn = cross_attn
35
- self.t_emb_dim = t_emd_dim
36
- self.num_layers = num_layers
37
- self.attn = attn
38
- self.resnet_conv_first = nn.ModuleList([
39
- nn.Sequential(
40
- nn.GroupNorm(norm_channels, in_channels if i ==
41
- 0 else out_channels),
42
- nn.SiLU(),
43
- nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
44
- out_channels=out_channels, kernel_size=3, stride=1, padding=1)
45
-
46
- ) for i in range(num_layers)
47
- ])
48
  if self.t_emb_dim is not None:
49
- self.time_embd_layers = nn.ModuleList([
50
  nn.Sequential(
51
  nn.SiLU(),
52
  nn.Linear(self.t_emb_dim, out_channels)
53
  )
54
  for _ in range(num_layers)
55
  ])
56
-
57
- self.resnet_conv_second = nn.ModuleList([
58
- nn.Sequential(
59
- nn.GroupNorm(norm_channels, out_channels),
60
- nn.SiLU(),
61
- nn.Conv2d(in_channels, out_channels,
62
- kernel_size=3, stride=1, padding=1),
63
- )
64
- for _ in range(num_layers)
65
- ])
66
-
 
67
  if self.attn:
68
  self.attention_norms = nn.ModuleList(
69
  [nn.GroupNorm(norm_channels, out_channels)
70
  for _ in range(num_layers)]
71
  )
72
-
73
- self.attention = nn.ModuleList(
74
- [nn.MultiheadAttention(
75
- out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
76
  )
77
-
78
  if self.cross_attn:
79
- assert context_dim is not None, "Context Dimension must be passed to cross attention"
80
- self.cross_attn_norms = nn.ModuleList(
81
  [nn.GroupNorm(norm_channels, out_channels)
82
  for _ in range(num_layers)]
83
  )
84
-
85
- self.cross_attention = nn.ModuleList(
86
- [nn.MultiheadAttention(
87
- out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
88
  )
89
-
90
  self.context_proj = nn.ModuleList(
91
  [nn.Linear(context_dim, out_channels)
92
  for _ in range(num_layers)]
@@ -94,177 +101,173 @@ class DownBlock(nn.Module):
94
 
95
  self.residual_input_conv = nn.ModuleList(
96
  [
97
- nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
98
- out_channels=out_channels, kernel_size=1)
99
  for i in range(num_layers)
100
-
101
  ]
102
  )
103
-
104
- self.resnet_down_conv = nn.Conv2d(out_channels, out_channels,
105
  4, 2, 1) if self.down_sample else nn.Identity()
 
106
  def forward(self, x, t_emb=None, context=None):
107
  out = x
108
  for i in range(self.num_layers):
109
- # Resnet Block
110
  resnet_input = out
111
  out = self.resnet_conv_first[i](out)
112
  if self.t_emb_dim is not None:
113
- out = out + self.time_embd_layers[i](t_emb)[:, :, None, None]
114
  out = self.resnet_conv_second[i](out)
115
  out = out + self.residual_input_conv[i](resnet_input)
116
-
117
- # Self Attention
118
  if self.attn:
 
119
  batch_size, channels, h, w = out.shape
120
- in_attn = out.reshape(batch_size, channels, h*w)
121
  in_attn = self.attention_norms[i](in_attn)
122
  in_attn = in_attn.transpose(1, 2)
123
- out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
124
- out_attn = out.transpose(1, 2).reshape(
125
- batch_size, channels, h, w)
126
  out = out + out_attn
127
-
128
- # Cross Attention
129
  if self.cross_attn:
130
- assert context is not None, "Context must be given for cross_attn"
131
  batch_size, channels, h, w = out.shape
132
  in_attn = out.reshape(batch_size, channels, h * w)
133
  in_attn = self.cross_attention_norms[i](in_attn)
134
  in_attn = in_attn.transpose(1, 2)
135
  assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
136
  context_proj = self.context_proj[i](context)
137
- out_attn, _ = self.cross_attentions[i](
138
- in_attn, context_proj, context_proj)
139
- out_attn = out_attn.transpose(1, 2).reshape(
140
- batch_size, channels, h, w)
141
  out = out + out_attn
142
-
143
- out = self.resnet_down_conv(out)
 
144
  return out
145
 
146
 
147
  class MidBlock(nn.Module):
 
 
 
 
 
 
148
  """
149
- Mid Block that works with same dimensions, flows like this:
150
- 1) Resnet block with time embedding
151
- 2) Self Attention block
152
- 3) Resnet block with time embedding
153
- """
154
-
155
- def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
156
  super().__init__()
157
- self.in_channels = in_channels
158
- self.out_channels = out_channels
159
  self.t_emb_dim = t_emb_dim
160
- self.cross_attn = cross_attn
161
  self.context_dim = context_dim
162
- self.num_layers = num_layers
163
- self.resnet_conv_one = nn.ModuleList([
164
- nn.Sequential(
165
- nn.GroupNorm(norm_dim, in_channels if i ==
166
- 0 else out_channels),
167
- nn.SiLU(),
168
- nn.Conv2d(in_channels if i == 0 else out_channels,
169
- out_channels, 3, 1, 1)
170
- )
171
- for i in range(num_layers + 1)
172
- ])
173
-
 
174
  if self.t_emb_dim is not None:
175
- self.time_emb_layers = nn.ModuleList([
176
  nn.Sequential(
177
  nn.SiLU(),
178
  nn.Linear(t_emb_dim, out_channels)
179
  )
180
  for _ in range(num_layers + 1)
181
  ])
182
-
183
- self.resnet_conv_two = nn.ModuleList([
184
- nn.Sequential(
185
- nn.GroupNorm(norm_dim, out_channels),
186
- nn.SiLU(),
187
- nn.Conv2d(out_channels, out_channels, 3, 1, 1)
188
- ) for _ in range(num_layers + 1)
189
- ])
190
-
 
 
191
  self.attention_norms = nn.ModuleList(
192
- [nn.GroupNorm(norm_dim, out_channels) for _ in range(num_layers)]
 
193
  )
194
-
195
- self.attention_heads = nn.ModuleList(
196
  [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
197
  for _ in range(num_layers)]
198
  )
199
-
200
  if self.cross_attn:
201
- assert context_dim is not None, "Context must be given for cross attn"
202
- self.cross_attn_norms = nn.ModuleList(
203
- [nn.GroupNorm(norm_dim, out_channels)
204
  for _ in range(num_layers)]
205
  )
206
-
207
- self.cross_attn = nn.ModuleList(
208
- [nn.MultiheadAttention(
209
- out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
210
  )
211
-
212
- self.context_proj = nn.ModuleList([
213
- nn.Conv2d(in_channels if i == 0 else out_channels,
214
- out_channels, kernel_size=1)
 
 
 
215
  for i in range(num_layers + 1)
216
- ])
217
-
218
- self.residual_input_conv = nn.ModuleList([
219
- nn.Conv2d(in_channels if i == 0 else out_channels,
220
- out_channels, kernel_size=1)
221
- for i in range(num_layers + 1)
222
-
223
- ])
224
-
225
  def forward(self, x, t_emb=None, context=None):
226
  out = x
 
 
227
  resnet_input = out
228
- out = self.resnet_conv_one[0](out)
229
  if self.t_emb_dim is not None:
230
- out = out + self.time_emb_layers[0](t_emb)[:, :, None, None]
231
- out = self.resnet_conv_two[0](out)
232
  out = out + self.residual_input_conv[0](resnet_input)
233
-
234
  for i in range(self.num_layers):
 
235
  batch_size, channels, h, w = out.shape
236
- in_attn = out.reshape(batch_size, channels, h*w)
237
  in_attn = self.attention_norms[i](in_attn)
238
  in_attn = in_attn.transpose(1, 2)
239
- out_attn, _ = self.attention_heads[i](in_attn, in_attn, in_attn)
240
- out_attn = out_attn.reshape(batch_size, channels, h, w)
241
  out = out + out_attn
242
-
243
  if self.cross_attn:
244
- assert context is not None, "Context needed when using cross attn"
245
  batch_size, channels, h, w = out.shape
246
- in_attn = out.reshape(batch_size, channels, h*w)
247
- in_attn = self.cross_attn_norms[i](in_attn)
248
  in_attn = in_attn.transpose(1, 2)
249
  assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
250
  context_proj = self.context_proj[i](context)
251
- out_attn, _ = self.cross_attn[i](
252
- in_attn, context_proj, context_proj)
253
- out_attn = out_attn.transpose(1, 2).reshape(
254
- batch_size, channels, h, w)
255
  out = out + out_attn
256
-
 
 
257
  resnet_input = out
258
- out = self.resnet_conv_one[i+1](out)
259
  if self.t_emb_dim is not None:
260
- out = out + self.time_emb_layers[i+1](t_emb)[:, :, None, None]
261
- out = out + self.resnet_conv_two[i+1](out)
262
- out = out + self.residual_input_conv[i+1](resnet_input)
263
-
264
  return out
265
 
266
 
267
- class UpBlockUnet(nn.Module):
268
  r"""
269
  Up conv block with attention.
270
  Sequence of following blocks
@@ -273,20 +276,18 @@ class UpBlockUnet(nn.Module):
273
  2. Resnet block with time embedding
274
  3. Attention Block
275
  """
276
-
277
- def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
278
- num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
279
  super().__init__()
280
  self.num_layers = num_layers
281
  self.up_sample = up_sample
282
  self.t_emb_dim = t_emb_dim
283
- self.cross_attn = cross_attn
284
- self.context_dim = context_dim
285
  self.resnet_conv_first = nn.ModuleList(
286
  [
287
  nn.Sequential(
288
- nn.GroupNorm(norm_channels, in_channels if i ==
289
- 0 else out_channels),
290
  nn.SiLU(),
291
  nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
292
  padding=1),
@@ -294,7 +295,7 @@ class UpBlockUnet(nn.Module):
294
  for i in range(num_layers)
295
  ]
296
  )
297
-
298
  if self.t_emb_dim is not None:
299
  self.t_emb_layers = nn.ModuleList([
300
  nn.Sequential(
@@ -303,104 +304,73 @@ class UpBlockUnet(nn.Module):
303
  )
304
  for _ in range(num_layers)
305
  ])
306
-
307
  self.resnet_conv_second = nn.ModuleList(
308
  [
309
  nn.Sequential(
310
  nn.GroupNorm(norm_channels, out_channels),
311
  nn.SiLU(),
312
- nn.Conv2d(out_channels, out_channels,
313
- kernel_size=3, stride=1, padding=1),
314
  )
315
  for _ in range(num_layers)
316
  ]
317
  )
318
-
319
- self.attention_norms = nn.ModuleList(
320
- [
321
- nn.GroupNorm(norm_channels, out_channels)
322
- for _ in range(num_layers)
323
- ]
324
- )
325
-
326
- self.attentions = nn.ModuleList(
327
- [
328
- nn.MultiheadAttention(
329
- out_channels, num_heads, batch_first=True)
330
- for _ in range(num_layers)
331
- ]
332
- )
333
-
334
- if self.cross_attn:
335
- assert context_dim is not None, "Context Dimension must be passed for cross attention"
336
- self.cross_attention_norms = nn.ModuleList(
337
- [nn.GroupNorm(norm_channels, out_channels)
338
- for _ in range(num_layers)]
339
- )
340
- self.cross_attentions = nn.ModuleList(
341
- [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
342
- for _ in range(num_layers)]
343
  )
344
- self.context_proj = nn.ModuleList(
345
- [nn.Linear(context_dim, out_channels)
346
- for _ in range(num_layers)]
 
 
 
347
  )
 
348
  self.residual_input_conv = nn.ModuleList(
349
  [
350
- nn.Conv2d(in_channels if i == 0 else out_channels,
351
- out_channels, kernel_size=1)
352
  for i in range(num_layers)
353
  ]
354
  )
355
- self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
356
  4, 2, 1) \
357
  if self.up_sample else nn.Identity()
358
-
359
- def forward(self, x, out_down=None, t_emb=None, context=None):
 
360
  x = self.up_sample_conv(x)
 
 
361
  if out_down is not None:
362
  x = torch.cat([x, out_down], dim=1)
363
-
364
  out = x
365
  for i in range(self.num_layers):
366
- # Resnet
367
  resnet_input = out
368
  out = self.resnet_conv_first[i](out)
369
  if self.t_emb_dim is not None:
370
  out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
371
  out = self.resnet_conv_second[i](out)
372
  out = out + self.residual_input_conv[i](resnet_input)
 
373
  # Self Attention
374
- batch_size, channels, h, w = out.shape
375
- in_attn = out.reshape(batch_size, channels, h * w)
376
- in_attn = self.attention_norms[i](in_attn)
377
- in_attn = in_attn.transpose(1, 2)
378
- out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
379
- out_attn = out_attn.transpose(1, 2).reshape(
380
- batch_size, channels, h, w)
381
- out = out + out_attn
382
- # Cross Attention
383
- if self.cross_attn:
384
- assert context is not None, "context cannot be None if cross attention layers are used"
385
  batch_size, channels, h, w = out.shape
386
  in_attn = out.reshape(batch_size, channels, h * w)
387
- in_attn = self.cross_attention_norms[i](in_attn)
388
  in_attn = in_attn.transpose(1, 2)
389
- assert len(context.shape) == 3, \
390
- "Context shape does not match B,_,CONTEXT_DIM"
391
- assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
392
- "Context shape does not match B,_,CONTEXT_DIM"
393
- context_proj = self.context_proj[i](context)
394
- out_attn, _ = self.cross_attentions[i](
395
- in_attn, context_proj, context_proj)
396
- out_attn = out_attn.transpose(1, 2).reshape(
397
- batch_size, channels, h, w)
398
  out = out + out_attn
399
-
400
  return out
401
 
402
 
403
- class UpBlock(nn.Module):
404
  r"""
405
  Up conv block with attention.
406
  Sequence of following blocks
@@ -409,19 +379,19 @@ class UpBlock(nn.Module):
409
  2. Resnet block with time embedding
410
  3. Attention Block
411
  """
412
-
413
- def __init__(self, in_channels, out_channels, t_emb_dim,
414
- up_sample, num_heads, num_layers, attn, norm_channels):
415
  super().__init__()
416
  self.num_layers = num_layers
417
  self.up_sample = up_sample
418
  self.t_emb_dim = t_emb_dim
419
- self.attn = attn
 
420
  self.resnet_conv_first = nn.ModuleList(
421
  [
422
  nn.Sequential(
423
- nn.GroupNorm(norm_channels, in_channels if i ==
424
- 0 else out_channels),
425
  nn.SiLU(),
426
  nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
427
  padding=1),
@@ -429,7 +399,7 @@ class UpBlock(nn.Module):
429
  for i in range(num_layers)
430
  ]
431
  )
432
-
433
  if self.t_emb_dim is not None:
434
  self.t_emb_layers = nn.ModuleList([
435
  nn.Sequential(
@@ -438,71 +408,93 @@ class UpBlock(nn.Module):
438
  )
439
  for _ in range(num_layers)
440
  ])
441
-
442
  self.resnet_conv_second = nn.ModuleList(
443
  [
444
  nn.Sequential(
445
  nn.GroupNorm(norm_channels, out_channels),
446
  nn.SiLU(),
447
- nn.Conv2d(out_channels, out_channels,
448
- kernel_size=3, stride=1, padding=1),
449
  )
450
  for _ in range(num_layers)
451
  ]
452
  )
453
- if self.attn:
454
- self.attention_norms = nn.ModuleList(
455
- [
456
- nn.GroupNorm(norm_channels, out_channels)
457
- for _ in range(num_layers)
458
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  )
460
-
461
- self.attentions = nn.ModuleList(
462
- [
463
- nn.MultiheadAttention(
464
- out_channels, num_heads, batch_first=True)
465
- for _ in range(num_layers)
466
- ]
467
  )
468
-
469
  self.residual_input_conv = nn.ModuleList(
470
  [
471
- nn.Conv2d(in_channels if i == 0 else out_channels,
472
- out_channels, kernel_size=1)
473
  for i in range(num_layers)
474
  ]
475
  )
476
- self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
477
  4, 2, 1) \
478
  if self.up_sample else nn.Identity()
479
-
480
- def forward(self, x, out_down=None, t_emb=None):
481
- # Upsample
482
  x = self.up_sample_conv(x)
483
-
484
- # Concat with Downblock output
485
  if out_down is not None:
486
  x = torch.cat([x, out_down], dim=1)
487
-
488
  out = x
489
  for i in range(self.num_layers):
490
- # Resnet Block
491
  resnet_input = out
492
  out = self.resnet_conv_first[i](out)
493
  if self.t_emb_dim is not None:
494
  out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
495
  out = self.resnet_conv_second[i](out)
496
  out = out + self.residual_input_conv[i](resnet_input)
497
-
498
  # Self Attention
499
- if self.attn:
 
 
 
 
 
 
 
 
 
500
  batch_size, channels, h, w = out.shape
501
  in_attn = out.reshape(batch_size, channels, h * w)
502
- in_attn = self.attention_norms[i](in_attn)
503
  in_attn = in_attn.transpose(1, 2)
504
- out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
505
- out_attn = out_attn.transpose(1, 2).reshape(
506
- batch_size, channels, h, w)
 
 
 
 
507
  out = out + out_attn
 
508
  return out
 
 
 
1
  import torch
2
  import torch.nn as nn
3
 
4
 
5
  def get_time_embedding(time_steps, temb_dim):
6
+ r"""
7
+ Convert time steps tensor into an embedding using the
8
+ sinusoidal time embedding formula
9
+ :param time_steps: 1D tensor of length batch size
10
+ :param temb_dim: Dimension of the embedding
11
+ :return: BxD embedding representation of B time steps
12
+ """
13
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14
+
15
+ # factor = 10000^(2i/d_model)
16
  factor = 10000 ** ((torch.arange(
17
  start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
18
  )
19
+
20
  # pos / factor
21
+ # timesteps B -> B, 1 -> B, temb_dim
22
  t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
23
  t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
24
  return t_emb
25
 
26
 
27
  class DownBlock(nn.Module):
28
+ r"""
29
+ Down conv block with attention.
30
+ Sequence of following block
31
+ 1. Resnet block with time embedding
32
+ 2. Attention block
33
+ 3. Downsample
34
  """
35
+
36
+ def __init__(self, in_channels, out_channels, t_emb_dim,
37
+ down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
 
 
 
 
 
38
  super().__init__()
39
+ self.num_layers = num_layers
40
  self.down_sample = down_sample
41
+ self.attn = attn
42
  self.context_dim = context_dim
43
  self.cross_attn = cross_attn
44
+ self.t_emb_dim = t_emb_dim
45
+ self.resnet_conv_first = nn.ModuleList(
46
+ [
47
+ nn.Sequential(
48
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
49
+ nn.SiLU(),
50
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
51
+ kernel_size=3, stride=1, padding=1),
52
+ )
53
+ for i in range(num_layers)
54
+ ]
55
+ )
 
56
  if self.t_emb_dim is not None:
57
+ self.t_emb_layers = nn.ModuleList([
58
  nn.Sequential(
59
  nn.SiLU(),
60
  nn.Linear(self.t_emb_dim, out_channels)
61
  )
62
  for _ in range(num_layers)
63
  ])
64
+ self.resnet_conv_second = nn.ModuleList(
65
+ [
66
+ nn.Sequential(
67
+ nn.GroupNorm(norm_channels, out_channels),
68
+ nn.SiLU(),
69
+ nn.Conv2d(out_channels, out_channels,
70
+ kernel_size=3, stride=1, padding=1),
71
+ )
72
+ for _ in range(num_layers)
73
+ ]
74
+ )
75
+
76
  if self.attn:
77
  self.attention_norms = nn.ModuleList(
78
  [nn.GroupNorm(norm_channels, out_channels)
79
  for _ in range(num_layers)]
80
  )
81
+
82
+ self.attentions = nn.ModuleList(
83
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
84
+ for _ in range(num_layers)]
85
  )
86
+
87
  if self.cross_attn:
88
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
89
+ self.cross_attention_norms = nn.ModuleList(
90
  [nn.GroupNorm(norm_channels, out_channels)
91
  for _ in range(num_layers)]
92
  )
93
+ self.cross_attentions = nn.ModuleList(
94
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
95
+ for _ in range(num_layers)]
 
96
  )
 
97
  self.context_proj = nn.ModuleList(
98
  [nn.Linear(context_dim, out_channels)
99
  for _ in range(num_layers)]
 
101
 
102
  self.residual_input_conv = nn.ModuleList(
103
  [
104
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
 
105
  for i in range(num_layers)
 
106
  ]
107
  )
108
+ self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
 
109
  4, 2, 1) if self.down_sample else nn.Identity()
110
+
111
  def forward(self, x, t_emb=None, context=None):
112
  out = x
113
  for i in range(self.num_layers):
114
+ # Resnet block of Unet
115
  resnet_input = out
116
  out = self.resnet_conv_first[i](out)
117
  if self.t_emb_dim is not None:
118
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
119
  out = self.resnet_conv_second[i](out)
120
  out = out + self.residual_input_conv[i](resnet_input)
121
+
 
122
  if self.attn:
123
+ # Attention block of Unet
124
  batch_size, channels, h, w = out.shape
125
+ in_attn = out.reshape(batch_size, channels, h * w)
126
  in_attn = self.attention_norms[i](in_attn)
127
  in_attn = in_attn.transpose(1, 2)
128
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
129
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
 
130
  out = out + out_attn
131
+
 
132
  if self.cross_attn:
133
+ assert context is not None, "context cannot be None if cross attention layers are used"
134
  batch_size, channels, h, w = out.shape
135
  in_attn = out.reshape(batch_size, channels, h * w)
136
  in_attn = self.cross_attention_norms[i](in_attn)
137
  in_attn = in_attn.transpose(1, 2)
138
  assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
139
  context_proj = self.context_proj[i](context)
140
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
141
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
 
 
142
  out = out + out_attn
143
+
144
+ # Downsample
145
+ out = self.down_sample_conv(out)
146
  return out
147
 
148
 
149
  class MidBlock(nn.Module):
150
+ r"""
151
+ Mid conv block with attention.
152
+ Sequence of following blocks
153
+ 1. Resnet block with time embedding
154
+ 2. Attention block
155
+ 3. Resnet block with time embedding
156
  """
157
+
158
+ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
 
 
 
 
 
159
  super().__init__()
160
+ self.num_layers = num_layers
 
161
  self.t_emb_dim = t_emb_dim
 
162
  self.context_dim = context_dim
163
+ self.cross_attn = cross_attn
164
+ self.resnet_conv_first = nn.ModuleList(
165
+ [
166
+ nn.Sequential(
167
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
168
+ nn.SiLU(),
169
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
170
+ padding=1),
171
+ )
172
+ for i in range(num_layers + 1)
173
+ ]
174
+ )
175
+
176
  if self.t_emb_dim is not None:
177
+ self.t_emb_layers = nn.ModuleList([
178
  nn.Sequential(
179
  nn.SiLU(),
180
  nn.Linear(t_emb_dim, out_channels)
181
  )
182
  for _ in range(num_layers + 1)
183
  ])
184
+ self.resnet_conv_second = nn.ModuleList(
185
+ [
186
+ nn.Sequential(
187
+ nn.GroupNorm(norm_channels, out_channels),
188
+ nn.SiLU(),
189
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
190
+ )
191
+ for _ in range(num_layers + 1)
192
+ ]
193
+ )
194
+
195
  self.attention_norms = nn.ModuleList(
196
+ [nn.GroupNorm(norm_channels, out_channels)
197
+ for _ in range(num_layers)]
198
  )
199
+
200
+ self.attentions = nn.ModuleList(
201
  [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
202
  for _ in range(num_layers)]
203
  )
 
204
  if self.cross_attn:
205
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
206
+ self.cross_attention_norms = nn.ModuleList(
207
+ [nn.GroupNorm(norm_channels, out_channels)
208
  for _ in range(num_layers)]
209
  )
210
+ self.cross_attentions = nn.ModuleList(
211
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
212
+ for _ in range(num_layers)]
 
213
  )
214
+ self.context_proj = nn.ModuleList(
215
+ [nn.Linear(context_dim, out_channels)
216
+ for _ in range(num_layers)]
217
+ )
218
+ self.residual_input_conv = nn.ModuleList(
219
+ [
220
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
221
  for i in range(num_layers + 1)
222
+ ]
223
+ )
224
+
 
 
 
 
 
 
225
  def forward(self, x, t_emb=None, context=None):
226
  out = x
227
+
228
+ # First resnet block
229
  resnet_input = out
230
+ out = self.resnet_conv_first[0](out)
231
  if self.t_emb_dim is not None:
232
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
233
+ out = self.resnet_conv_second[0](out)
234
  out = out + self.residual_input_conv[0](resnet_input)
235
+
236
  for i in range(self.num_layers):
237
+ # Attention Block
238
  batch_size, channels, h, w = out.shape
239
+ in_attn = out.reshape(batch_size, channels, h * w)
240
  in_attn = self.attention_norms[i](in_attn)
241
  in_attn = in_attn.transpose(1, 2)
242
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
243
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
244
  out = out + out_attn
245
+
246
  if self.cross_attn:
247
+ assert context is not None, "context cannot be None if cross attention layers are used"
248
  batch_size, channels, h, w = out.shape
249
+ in_attn = out.reshape(batch_size, channels, h * w)
250
+ in_attn = self.cross_attention_norms[i](in_attn)
251
  in_attn = in_attn.transpose(1, 2)
252
  assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
253
  context_proj = self.context_proj[i](context)
254
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
255
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
 
 
256
  out = out + out_attn
257
+
258
+
259
+ # Resnet Block
260
  resnet_input = out
261
+ out = self.resnet_conv_first[i + 1](out)
262
  if self.t_emb_dim is not None:
263
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
264
+ out = self.resnet_conv_second[i + 1](out)
265
+ out = out + self.residual_input_conv[i + 1](resnet_input)
266
+
267
  return out
268
 
269
 
270
+ class UpBlock(nn.Module):
271
  r"""
272
  Up conv block with attention.
273
  Sequence of following blocks
 
276
  2. Resnet block with time embedding
277
  3. Attention Block
278
  """
279
+
280
+ def __init__(self, in_channels, out_channels, t_emb_dim,
281
+ up_sample, num_heads, num_layers, attn, norm_channels):
282
  super().__init__()
283
  self.num_layers = num_layers
284
  self.up_sample = up_sample
285
  self.t_emb_dim = t_emb_dim
286
+ self.attn = attn
 
287
  self.resnet_conv_first = nn.ModuleList(
288
  [
289
  nn.Sequential(
290
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
 
291
  nn.SiLU(),
292
  nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
293
  padding=1),
 
295
  for i in range(num_layers)
296
  ]
297
  )
298
+
299
  if self.t_emb_dim is not None:
300
  self.t_emb_layers = nn.ModuleList([
301
  nn.Sequential(
 
304
  )
305
  for _ in range(num_layers)
306
  ])
307
+
308
  self.resnet_conv_second = nn.ModuleList(
309
  [
310
  nn.Sequential(
311
  nn.GroupNorm(norm_channels, out_channels),
312
  nn.SiLU(),
313
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
 
314
  )
315
  for _ in range(num_layers)
316
  ]
317
  )
318
+ if self.attn:
319
+ self.attention_norms = nn.ModuleList(
320
+ [
321
+ nn.GroupNorm(norm_channels, out_channels)
322
+ for _ in range(num_layers)
323
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  )
325
+
326
+ self.attentions = nn.ModuleList(
327
+ [
328
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
329
+ for _ in range(num_layers)
330
+ ]
331
  )
332
+
333
  self.residual_input_conv = nn.ModuleList(
334
  [
335
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
 
336
  for i in range(num_layers)
337
  ]
338
  )
339
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
340
  4, 2, 1) \
341
  if self.up_sample else nn.Identity()
342
+
343
+ def forward(self, x, out_down=None, t_emb=None):
344
+ # Upsample
345
  x = self.up_sample_conv(x)
346
+
347
+ # Concat with Downblock output
348
  if out_down is not None:
349
  x = torch.cat([x, out_down], dim=1)
350
+
351
  out = x
352
  for i in range(self.num_layers):
353
+ # Resnet Block
354
  resnet_input = out
355
  out = self.resnet_conv_first[i](out)
356
  if self.t_emb_dim is not None:
357
  out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
358
  out = self.resnet_conv_second[i](out)
359
  out = out + self.residual_input_conv[i](resnet_input)
360
+
361
  # Self Attention
362
+ if self.attn:
 
 
 
 
 
 
 
 
 
 
363
  batch_size, channels, h, w = out.shape
364
  in_attn = out.reshape(batch_size, channels, h * w)
365
+ in_attn = self.attention_norms[i](in_attn)
366
  in_attn = in_attn.transpose(1, 2)
367
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
368
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
 
 
 
 
 
 
 
369
  out = out + out_attn
 
370
  return out
371
 
372
 
373
+ class UpBlockUnet(nn.Module):
374
  r"""
375
  Up conv block with attention.
376
  Sequence of following blocks
 
379
  2. Resnet block with time embedding
380
  3. Attention Block
381
  """
382
+
383
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
384
+ num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
385
  super().__init__()
386
  self.num_layers = num_layers
387
  self.up_sample = up_sample
388
  self.t_emb_dim = t_emb_dim
389
+ self.cross_attn = cross_attn
390
+ self.context_dim = context_dim
391
  self.resnet_conv_first = nn.ModuleList(
392
  [
393
  nn.Sequential(
394
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
 
395
  nn.SiLU(),
396
  nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
397
  padding=1),
 
399
  for i in range(num_layers)
400
  ]
401
  )
402
+
403
  if self.t_emb_dim is not None:
404
  self.t_emb_layers = nn.ModuleList([
405
  nn.Sequential(
 
408
  )
409
  for _ in range(num_layers)
410
  ])
411
+
412
  self.resnet_conv_second = nn.ModuleList(
413
  [
414
  nn.Sequential(
415
  nn.GroupNorm(norm_channels, out_channels),
416
  nn.SiLU(),
417
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
 
418
  )
419
  for _ in range(num_layers)
420
  ]
421
  )
422
+
423
+ self.attention_norms = nn.ModuleList(
424
+ [
425
+ nn.GroupNorm(norm_channels, out_channels)
426
+ for _ in range(num_layers)
427
+ ]
428
+ )
429
+
430
+ self.attentions = nn.ModuleList(
431
+ [
432
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
433
+ for _ in range(num_layers)
434
+ ]
435
+ )
436
+
437
+ if self.cross_attn:
438
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
439
+ self.cross_attention_norms = nn.ModuleList(
440
+ [nn.GroupNorm(norm_channels, out_channels)
441
+ for _ in range(num_layers)]
442
  )
443
+ self.cross_attentions = nn.ModuleList(
444
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
445
+ for _ in range(num_layers)]
446
+ )
447
+ self.context_proj = nn.ModuleList(
448
+ [nn.Linear(context_dim, out_channels)
449
+ for _ in range(num_layers)]
450
  )
 
451
  self.residual_input_conv = nn.ModuleList(
452
  [
453
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
 
454
  for i in range(num_layers)
455
  ]
456
  )
457
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
458
  4, 2, 1) \
459
  if self.up_sample else nn.Identity()
460
+
461
+ def forward(self, x, out_down=None, t_emb=None, context=None):
 
462
  x = self.up_sample_conv(x)
 
 
463
  if out_down is not None:
464
  x = torch.cat([x, out_down], dim=1)
465
+
466
  out = x
467
  for i in range(self.num_layers):
468
+ # Resnet
469
  resnet_input = out
470
  out = self.resnet_conv_first[i](out)
471
  if self.t_emb_dim is not None:
472
  out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
473
  out = self.resnet_conv_second[i](out)
474
  out = out + self.residual_input_conv[i](resnet_input)
 
475
  # Self Attention
476
+ batch_size, channels, h, w = out.shape
477
+ in_attn = out.reshape(batch_size, channels, h * w)
478
+ in_attn = self.attention_norms[i](in_attn)
479
+ in_attn = in_attn.transpose(1, 2)
480
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
481
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
482
+ out = out + out_attn
483
+ # Cross Attention
484
+ if self.cross_attn:
485
+ assert context is not None, "context cannot be None if cross attention layers are used"
486
  batch_size, channels, h, w = out.shape
487
  in_attn = out.reshape(batch_size, channels, h * w)
488
+ in_attn = self.cross_attention_norms[i](in_attn)
489
  in_attn = in_attn.transpose(1, 2)
490
+ assert len(context.shape) == 3, \
491
+ "Context shape does not match B,_,CONTEXT_DIM"
492
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
493
+ "Context shape does not match B,_,CONTEXT_DIM"
494
+ context_proj = self.context_proj[i](context)
495
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
496
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
497
  out = out + out_attn
498
+
499
  return out
500
+
models/vqvae.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
- from models.blocks import DownBlock, UpBlock, MidBlock
4
 
5
 
6
  class VQVAE(nn.Module):
@@ -10,122 +10,125 @@ class VQVAE(nn.Module):
10
  self.mid_channels = model_config['mid_channels']
11
  self.down_sample = model_config['down_sample']
12
  self.num_down_layers = model_config['num_down_layers']
13
- self.num_up_layers = model_config['num_up_layers']
14
  self.num_mid_layers = model_config['num_mid_layers']
15
-
16
- # To disable attn in encoder and decoder blocks
17
- self.attns = model_config['attn']
18
-
 
19
  # Latent Dimension
20
- self.z_channels = model_config["z_channels"]
21
- self.codebook_size = model_config["codebook_size"]
22
- self.norm_channels = model_config["norm_channels"]
23
- self.num_heads = model_config["num_heads"]
24
-
 
25
  assert self.mid_channels[0] == self.down_channels[-1]
26
  assert self.mid_channels[-1] == self.down_channels[-1]
27
  assert len(self.down_sample) == len(self.down_channels) - 1
28
  assert len(self.attns) == len(self.down_channels) - 1
29
-
30
- self.upsample = list(reversed(self.down_sample))
31
-
32
- # Encoder
33
- self.encoder_conv_one = nn.Conv2d(
34
- im_channels, self.down_channels[0], kernel_size=3, padding=1, stride=1)
35
-
 
 
36
  self.encoder_layers = nn.ModuleList([])
37
  for i in range(len(self.down_channels) - 1):
38
- self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i+1],
39
- t_emd_dim=None, down_sample=self.down_sample[i],
40
- num_heads=self.num_heads, num_layers=self.num_down_layers,
41
- attn=self.attns[i], norm_channels=self.norm_channels))
42
- self.encode_mid_blocks = nn.ModuleList([])
43
- for i in range(len(self.down_channels)-1):
44
- self.encode_mid_blocks.append(MidBlock(self.down_channels[i], self.down_channels[i+1],
45
- t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers,
46
- norm_dim=self.norm_channels))
47
- self.encoder_norm_out = nn.GroupNorm(
48
- self.norm_channels, self.down_channels[-1])
49
- self.encoder_conv_out = nn.Conv2d(
50
- self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
51
-
52
- # Pre-Quantization Convolution (Before comparing to code blocks to get embedding matrix)
53
- self.pre_quant_conv = nn.Conv2d(
54
- self.z_channels, self.z_channels, kernel_size=1)
55
-
56
- # Code book
 
 
 
57
  self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
58
-
59
- # Decoder
60
- self.post_quant_conv = nn.Conv2d(
61
- self.z_channels, self.z_channels, kernel_size=1)
62
- self.decoder_conv_out = nn.Conv2d(
63
- self.z_channels, self.mid_channels[-1], kernel_size=3, padding=1)
64
-
65
- # Midblock + UpBlock
66
- self.decode_mids = nn.ModuleList([])
67
  for i in reversed(range(1, len(self.mid_channels))):
68
- self.decode_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i-1],
69
- t_emb_dim=None, num_heads=self.num_heads,
70
- num_layers=self.num_mid_layers,
71
- norm_dim=self.norm_channels))
 
 
72
  self.decoder_layers = nn.ModuleList([])
73
  for i in reversed(range(1, len(self.down_channels))):
74
- self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i-1],
75
- t_emb_dim=None, up_sample=self.down_sample[i-1], num_heads=self.num_heads,
76
- num_layers=self.num_up_layers,
77
- attn=self.attns[i-1],
78
- norm_channels=self.norm_channels))
79
-
80
- self.decoder_norm_out = nn.GroupNorm(
81
- self.norm_channels, self.down_channels[0])
82
- self.decoder_conv_out = nn.Conv2d(
83
- self.down_channels[0], im_channels, kernel_size=3, padding=1)
84
-
85
  def quantize(self, x):
86
  B, C, H, W = x.shape
87
-
88
- # B,C,H,W -> B,H,W,C
89
  x = x.permute(0, 2, 3, 1)
90
-
91
- # B,H,W,C -> B, H*W, C
92
  x = x.reshape(x.size(0), -1, x.size(-1))
93
-
94
- # Find nearest neighbours/codebook vectors
95
- # Distance between B,H*W,C and B,K,C
96
- dist = torch.cdist(
97
- x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
98
-
99
  min_encoding_indices = torch.argmin(dist, dim=-1)
100
-
101
- # Replace encoder output with codebook vector
102
- quant_out = torch.index_select(
103
- self.embedding.weight, 0, min_encoding_indices.view(-1))
104
-
105
- # x -> B*H*W,C
106
  x = x.reshape((-1, x.size(-1)))
107
- commitment_loss = torch.mean((quant_out.detach() - x) ** 2)
108
  codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
109
- quantize_loss = {
110
- "codebook_loss": codebook_loss,
111
- "commitment_loss": commitment_loss
112
  }
113
-
114
  # Straight through estimation
115
- quant_out = x - (quant_out - x).detach()
116
-
117
- # quant_out -> B,C,H,W
118
  quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
119
- min_encoding_indices = min_encoding_indices.reshape(
120
- (-1, quant_out.size(-2), quant_out.size(-1)))
121
-
122
- return quant_out, quantize_loss, min_encoding_indices
123
 
124
  def encode(self, x):
125
- out = self.encoder_conv_one(x)
126
- for _, down in enumerate(self.encoder_layers):
127
  out = down(out)
128
- for mid in self.encode_mid_blocks:
129
  out = mid(out)
130
  out = self.encoder_norm_out(out)
131
  out = nn.SiLU()(out)
@@ -133,21 +136,21 @@ class VQVAE(nn.Module):
133
  out = self.pre_quant_conv(out)
134
  out, quant_losses, _ = self.quantize(out)
135
  return out, quant_losses
136
-
137
  def decode(self, z):
138
  out = z
139
  out = self.post_quant_conv(out)
140
  out = self.decoder_conv_in(out)
141
- for mid in self.decode_mids:
142
  out = mid(out)
143
- for up in self.decoder_layers:
144
  out = up(out)
145
-
146
  out = self.decoder_norm_out(out)
147
- out = nn.SiLU(out)
148
  out = self.decoder_conv_out(out)
149
  return out
150
-
151
  def forward(self, x):
152
  z, quant_losses = self.encode(x)
153
  out = self.decode(z)
 
1
  import torch
2
  import torch.nn as nn
3
+ from models.blocks import DownBlock, MidBlock, UpBlock
4
 
5
 
6
  class VQVAE(nn.Module):
 
10
  self.mid_channels = model_config['mid_channels']
11
  self.down_sample = model_config['down_sample']
12
  self.num_down_layers = model_config['num_down_layers']
 
13
  self.num_mid_layers = model_config['num_mid_layers']
14
+ self.num_up_layers = model_config['num_up_layers']
15
+
16
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
17
+ self.attns = model_config['attn_down']
18
+
19
  # Latent Dimension
20
+ self.z_channels = model_config['z_channels']
21
+ self.codebook_size = model_config['codebook_size']
22
+ self.norm_channels = model_config['norm_channels']
23
+ self.num_heads = model_config['num_heads']
24
+
25
+ # Assertion to validate the channel information
26
  assert self.mid_channels[0] == self.down_channels[-1]
27
  assert self.mid_channels[-1] == self.down_channels[-1]
28
  assert len(self.down_sample) == len(self.down_channels) - 1
29
  assert len(self.attns) == len(self.down_channels) - 1
30
+
31
+ # Wherever we use downsampling in encoder correspondingly use
32
+ # upsampling in decoder
33
+ self.up_sample = list(reversed(self.down_sample))
34
+
35
+ ##################### Encoder ######################
36
+ self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
37
+
38
+ # Downblock + Midblock
39
  self.encoder_layers = nn.ModuleList([])
40
  for i in range(len(self.down_channels) - 1):
41
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
42
+ t_emb_dim=None, down_sample=self.down_sample[i],
43
+ num_heads=self.num_heads,
44
+ num_layers=self.num_down_layers,
45
+ attn=self.attns[i],
46
+ norm_channels=self.norm_channels))
47
+
48
+ self.encoder_mids = nn.ModuleList([])
49
+ for i in range(len(self.mid_channels) - 1):
50
+ self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
51
+ t_emb_dim=None,
52
+ num_heads=self.num_heads,
53
+ num_layers=self.num_mid_layers,
54
+ norm_channels=self.norm_channels))
55
+
56
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
57
+ self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
58
+
59
+ # Pre Quantization Convolution
60
+ self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
61
+
62
+ # Codebook
63
  self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
64
+
65
+ ##################### Decoder ######################
66
+
67
+ # Post Quantization Convolution
68
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
69
+ self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
70
+
71
+ # Midblock + Upblock
72
+ self.decoder_mids = nn.ModuleList([])
73
  for i in reversed(range(1, len(self.mid_channels))):
74
+ self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
75
+ t_emb_dim=None,
76
+ num_heads=self.num_heads,
77
+ num_layers=self.num_mid_layers,
78
+ norm_channels=self.norm_channels))
79
+
80
  self.decoder_layers = nn.ModuleList([])
81
  for i in reversed(range(1, len(self.down_channels))):
82
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
83
+ t_emb_dim=None, up_sample=self.down_sample[i - 1],
84
+ num_heads=self.num_heads,
85
+ num_layers=self.num_up_layers,
86
+ attn=self.attns[i-1],
87
+ norm_channels=self.norm_channels))
88
+
89
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
90
+ self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
91
+
 
92
  def quantize(self, x):
93
  B, C, H, W = x.shape
94
+
95
+ # B, C, H, W -> B, H, W, C
96
  x = x.permute(0, 2, 3, 1)
97
+
98
+ # B, H, W, C -> B, H*W, C
99
  x = x.reshape(x.size(0), -1, x.size(-1))
100
+
101
+ # Find nearest embedding/codebook vector
102
+ # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
103
+ dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
104
+ # (B, H*W)
 
105
  min_encoding_indices = torch.argmin(dist, dim=-1)
106
+
107
+ # Replace encoder output with nearest codebook
108
+ # quant_out -> B*H*W, C
109
+ quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
110
+
111
+ # x -> B*H*W, C
112
  x = x.reshape((-1, x.size(-1)))
113
+ commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
114
  codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
115
+ quantize_losses = {
116
+ 'codebook_loss': codebook_loss,
117
+ 'commitment_loss': commmitment_loss
118
  }
 
119
  # Straight through estimation
120
+ quant_out = x + (quant_out - x).detach()
121
+
122
+ # quant_out -> B, C, H, W
123
  quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
124
+ min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
125
+ return quant_out, quantize_losses, min_encoding_indices
 
 
126
 
127
  def encode(self, x):
128
+ out = self.encoder_conv_in(x)
129
+ for idx, down in enumerate(self.encoder_layers):
130
  out = down(out)
131
+ for mid in self.encoder_mids:
132
  out = mid(out)
133
  out = self.encoder_norm_out(out)
134
  out = nn.SiLU()(out)
 
136
  out = self.pre_quant_conv(out)
137
  out, quant_losses, _ = self.quantize(out)
138
  return out, quant_losses
139
+
140
  def decode(self, z):
141
  out = z
142
  out = self.post_quant_conv(out)
143
  out = self.decoder_conv_in(out)
144
+ for mid in self.decoder_mids:
145
  out = mid(out)
146
+ for idx, up in enumerate(self.decoder_layers):
147
  out = up(out)
148
+
149
  out = self.decoder_norm_out(out)
150
+ out = nn.SiLU()(out)
151
  out = self.decoder_conv_out(out)
152
  return out
153
+
154
  def forward(self, x):
155
  z, quant_losses = self.encode(x)
156
  out = self.decode(z)
train_vqvae.py CHANGED
@@ -24,7 +24,6 @@ def train(args):
24
  except yaml.YAMLError as e:
25
  print(e)
26
 
27
-
28
  autoencoder_config = config["autoencoder_params"]
29
  train_config = config["train_config"]
30
  dataset_config = config["dataset_config"]
@@ -84,11 +83,11 @@ def train(args):
84
 
85
  # Image saving
86
  if steps % img_save_steps == 0 or steps == 1:
87
- sample_size = min(8, im.shape[0])
88
  save_output = torch.clamp(
89
  output[:sample_size], -1., 1.).detach().cpu()
90
  save_output = ((save_output + 1) / 2)
91
- save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
92
 
93
  grid = make_grid(
94
  torch.cat([save_input, save_output], dim=0), nrow=sample_size)
@@ -97,8 +96,8 @@ def train(args):
97
  os.mkdir(os.path.join(
98
  train_config['task_name'], 'vqvae_autoencoder_samples'))
99
  img.save(os.path.join(train_config['task_name'], 'vqvae_autoencoder_samples',
100
- 'current_autoencoder_sample_{}.png'.format(img_save_count)))
101
- img_save_count += 1
102
  img.close()
103
 
104
  # Optimizing generator
 
24
  except yaml.YAMLError as e:
25
  print(e)
26
 
 
27
  autoencoder_config = config["autoencoder_params"]
28
  train_config = config["train_config"]
29
  dataset_config = config["dataset_config"]
 
83
 
84
  # Image saving
85
  if steps % img_save_steps == 0 or steps == 1:
86
+ sample_size = min(8, im_tensor.shape[0])
87
  save_output = torch.clamp(
88
  output[:sample_size], -1., 1.).detach().cpu()
89
  save_output = ((save_output + 1) / 2)
90
+ save_input = ((im_tensor[:sample_size] + 1) / 2).detach().cpu()
91
 
92
  grid = make_grid(
93
  torch.cat([save_input, save_output], dim=0), nrow=sample_size)
 
96
  os.mkdir(os.path.join(
97
  train_config['task_name'], 'vqvae_autoencoder_samples'))
98
  img.save(os.path.join(train_config['task_name'], 'vqvae_autoencoder_samples',
99
+ 'current_autoencoder_sample_{}.png'.format(img_saved)))
100
+ img_saved += 1
101
  img.close()
102
 
103
  # Optimizing generator