Yash Nagraj commited on
Commit ·
f0ff580
1
Parent(s): 8d85e1f
Make changes
Browse files- dataset/celeba.py +6 -2
- dataset/dataset.py +142 -0
- models/blocks.py +252 -260
- models/vqvae.py +101 -98
- train_vqvae.py +4 -5
dataset/celeba.py
CHANGED
|
@@ -8,11 +8,12 @@ from PIL import Image
|
|
| 8 |
|
| 9 |
|
| 10 |
class ParquetImageDataset(Dataset):
|
| 11 |
-
def __init__(self, parquet_files, transform=None, im_size=256):
|
| 12 |
self.data = pd.concat([pd.read_parquet(file)
|
| 13 |
for file in parquet_files], ignore_index=True)
|
| 14 |
self.transform = transform
|
| 15 |
self.im_size = im_size
|
|
|
|
| 16 |
|
| 17 |
def __len__(self):
|
| 18 |
return len(self.data)
|
|
@@ -27,7 +28,10 @@ class ParquetImageDataset(Dataset):
|
|
| 27 |
])(image)
|
| 28 |
image.close()
|
| 29 |
im_tensor = (2 * im_tensor) - 1 # type: ignore
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def create_dataloader(parquet_dir, batch_size=32, shuffle=True, num_workers=4):
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class ParquetImageDataset(Dataset):
|
| 11 |
+
def __init__(self, parquet_files, transform=None, im_size=256,condition_config=None):
|
| 12 |
self.data = pd.concat([pd.read_parquet(file)
|
| 13 |
for file in parquet_files], ignore_index=True)
|
| 14 |
self.transform = transform
|
| 15 |
self.im_size = im_size
|
| 16 |
+
self.condition_types = [] if condition_config is None else condition_config['condition_types']
|
| 17 |
|
| 18 |
def __len__(self):
|
| 19 |
return len(self.data)
|
|
|
|
| 28 |
])(image)
|
| 29 |
image.close()
|
| 30 |
im_tensor = (2 * im_tensor) - 1 # type: ignore
|
| 31 |
+
if len(self.condition_types) == 0:
|
| 32 |
+
return im_tensor
|
| 33 |
+
else:
|
| 34 |
+
return im_tensor, caption
|
| 35 |
|
| 36 |
|
| 37 |
def create_dataloader(parquet_dir, batch_size=32, shuffle=True, num_workers=4):
|
dataset/dataset.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from torch.utils.data.dataset import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CelebDataset(Dataset):
|
| 13 |
+
r"""
|
| 14 |
+
Celeb dataset will by default centre crop and resize the images.
|
| 15 |
+
This can be replaced by any other dataset. As long as all the images
|
| 16 |
+
are under one directory.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
|
| 20 |
+
use_latents=False, latent_path=None, condition_config=None):
|
| 21 |
+
self.split = split
|
| 22 |
+
self.im_size = im_size
|
| 23 |
+
self.im_channels = im_channels
|
| 24 |
+
self.im_ext = im_ext
|
| 25 |
+
self.im_path = im_path
|
| 26 |
+
self.latent_maps = None
|
| 27 |
+
self.use_latents = False
|
| 28 |
+
|
| 29 |
+
self.condition_types = [] if condition_config is None else condition_config['condition_types']
|
| 30 |
+
|
| 31 |
+
self.idx_to_cls_map = {}
|
| 32 |
+
self.cls_to_idx_map = {}
|
| 33 |
+
|
| 34 |
+
if 'image' in self.condition_types:
|
| 35 |
+
self.mask_channels = condition_config['image_condition_config']['image_condition_input_channels']
|
| 36 |
+
self.mask_h = condition_config['image_condition_config']['image_condition_h']
|
| 37 |
+
self.mask_w = condition_config['image_condition_config']['image_condition_w']
|
| 38 |
+
|
| 39 |
+
self.images, self.texts, self.masks = self.load_images(im_path)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_images(self, im_path):
|
| 43 |
+
r"""
|
| 44 |
+
Gets all images from the path specified
|
| 45 |
+
and stacks them all up
|
| 46 |
+
"""
|
| 47 |
+
assert os.path.exists(
|
| 48 |
+
im_path), "images path {} does not exist".format(im_path)
|
| 49 |
+
ims = []
|
| 50 |
+
fnames = glob.glob(os.path.join(
|
| 51 |
+
im_path, 'CelebA-HQ-img/*.{}'.format('png')))
|
| 52 |
+
fnames += glob.glob(os.path.join(im_path,
|
| 53 |
+
'CelebA-HQ-img/*.{}'.format('jpg')))
|
| 54 |
+
fnames += glob.glob(os.path.join(im_path,
|
| 55 |
+
'CelebA-HQ-img/*.{}'.format('jpeg')))
|
| 56 |
+
texts = []
|
| 57 |
+
masks = []
|
| 58 |
+
|
| 59 |
+
if 'image' in self.condition_types:
|
| 60 |
+
label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth',
|
| 61 |
+
'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
|
| 62 |
+
self.idx_to_cls_map = {idx: label_list[idx]
|
| 63 |
+
for idx in range(len(label_list))}
|
| 64 |
+
self.cls_to_idx_map = {
|
| 65 |
+
label_list[idx]: idx for idx in range(len(label_list))}
|
| 66 |
+
|
| 67 |
+
for fname in tqdm(fnames):
|
| 68 |
+
ims.append(fname)
|
| 69 |
+
|
| 70 |
+
if 'text' in self.condition_types:
|
| 71 |
+
im_name = os.path.split(fname)[1].split('.')[0]
|
| 72 |
+
captions_im = []
|
| 73 |
+
with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f:
|
| 74 |
+
for line in f.readlines():
|
| 75 |
+
captions_im.append(line.strip())
|
| 76 |
+
texts.append(captions_im)
|
| 77 |
+
|
| 78 |
+
if 'image' in self.condition_types:
|
| 79 |
+
im_name = int(os.path.split(fname)[1].split('.')[0])
|
| 80 |
+
masks.append(os.path.join(
|
| 81 |
+
im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name)))
|
| 82 |
+
if 'text' in self.condition_types:
|
| 83 |
+
assert len(texts) == len(
|
| 84 |
+
ims), "Condition Type Text but could not find captions for all images"
|
| 85 |
+
if 'image' in self.condition_types:
|
| 86 |
+
assert len(masks) == len(
|
| 87 |
+
ims), "Condition Type Image but could not find masks for all images"
|
| 88 |
+
print('Found {} images'.format(len(ims)))
|
| 89 |
+
print('Found {} masks'.format(len(masks)))
|
| 90 |
+
print('Found {} captions'.format(len(texts)))
|
| 91 |
+
return ims, texts, masks
|
| 92 |
+
|
| 93 |
+
def get_mask(self, index):
|
| 94 |
+
r"""
|
| 95 |
+
Method to get the mask of WxH
|
| 96 |
+
for given index and convert it into
|
| 97 |
+
Classes x W x H mask image
|
| 98 |
+
:param index:
|
| 99 |
+
:return:
|
| 100 |
+
"""
|
| 101 |
+
mask_im = Image.open(self.masks[index])
|
| 102 |
+
mask_im = np.array(mask_im)
|
| 103 |
+
im_base = np.zeros((self.mask_h, self.mask_w, self.mask_channels))
|
| 104 |
+
for orig_idx in range(len(self.idx_to_cls_map)):
|
| 105 |
+
im_base[mask_im == (orig_idx+1), orig_idx] = 1
|
| 106 |
+
mask = torch.from_numpy(im_base).permute(2, 0, 1).float()
|
| 107 |
+
return mask
|
| 108 |
+
|
| 109 |
+
def __len__(self):
|
| 110 |
+
return len(self.images)
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, index):
|
| 113 |
+
######## Set Conditioning Info ########
|
| 114 |
+
cond_inputs = {}
|
| 115 |
+
if 'text' in self.condition_types:
|
| 116 |
+
cond_inputs['text'] = random.sample(self.texts[index], k=1)[0]
|
| 117 |
+
if 'image' in self.condition_types:
|
| 118 |
+
mask = self.get_mask(index)
|
| 119 |
+
cond_inputs['image'] = mask
|
| 120 |
+
#######################################
|
| 121 |
+
|
| 122 |
+
if self.use_latents:
|
| 123 |
+
latent = self.latent_maps[self.images[index]]
|
| 124 |
+
if len(self.condition_types) == 0:
|
| 125 |
+
return latent
|
| 126 |
+
else:
|
| 127 |
+
return latent, cond_inputs
|
| 128 |
+
else:
|
| 129 |
+
im = Image.open(self.images[index])
|
| 130 |
+
im_tensor = torchvision.transforms.Compose([
|
| 131 |
+
torchvision.transforms.Resize(self.im_size),
|
| 132 |
+
torchvision.transforms.CenterCrop(self.im_size),
|
| 133 |
+
torchvision.transforms.ToTensor(),
|
| 134 |
+
])(im)
|
| 135 |
+
im.close()
|
| 136 |
+
|
| 137 |
+
# Convert input to -1 to 1 range.
|
| 138 |
+
im_tensor = (2 * im_tensor) - 1
|
| 139 |
+
if len(self.condition_types) == 0:
|
| 140 |
+
return im_tensor
|
| 141 |
+
else:
|
| 142 |
+
return im_tensor, cond_inputs
|
models/blocks.py
CHANGED
|
@@ -1,92 +1,99 @@
|
|
| 1 |
-
from re import A
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
|
| 5 |
|
| 6 |
def get_time_embedding(time_steps, temb_dim):
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
factor = 10000 ** ((torch.arange(
|
| 10 |
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
|
| 11 |
)
|
| 12 |
-
|
| 13 |
# pos / factor
|
| 14 |
-
#
|
| 15 |
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
|
| 16 |
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
|
| 17 |
return t_emb
|
| 18 |
|
| 19 |
|
| 20 |
class DownBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
3) Down Sample
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
def __init__(self, in_channels, out_channels, t_emd_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False,
|
| 29 |
-
context_dim=None):
|
| 30 |
super().__init__()
|
|
|
|
| 31 |
self.down_sample = down_sample
|
| 32 |
-
self.
|
| 33 |
self.context_dim = context_dim
|
| 34 |
self.cross_attn = cross_attn
|
| 35 |
-
self.t_emb_dim =
|
| 36 |
-
self.
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
])
|
| 48 |
if self.t_emb_dim is not None:
|
| 49 |
-
self.
|
| 50 |
nn.Sequential(
|
| 51 |
nn.SiLU(),
|
| 52 |
nn.Linear(self.t_emb_dim, out_channels)
|
| 53 |
)
|
| 54 |
for _ in range(num_layers)
|
| 55 |
])
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
if self.attn:
|
| 68 |
self.attention_norms = nn.ModuleList(
|
| 69 |
[nn.GroupNorm(norm_channels, out_channels)
|
| 70 |
for _ in range(num_layers)]
|
| 71 |
)
|
| 72 |
-
|
| 73 |
-
self.
|
| 74 |
-
[nn.MultiheadAttention(
|
| 75 |
-
|
| 76 |
)
|
| 77 |
-
|
| 78 |
if self.cross_attn:
|
| 79 |
-
assert context_dim is not None, "Context Dimension must be passed
|
| 80 |
-
self.
|
| 81 |
[nn.GroupNorm(norm_channels, out_channels)
|
| 82 |
for _ in range(num_layers)]
|
| 83 |
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
|
| 88 |
)
|
| 89 |
-
|
| 90 |
self.context_proj = nn.ModuleList(
|
| 91 |
[nn.Linear(context_dim, out_channels)
|
| 92 |
for _ in range(num_layers)]
|
|
@@ -94,177 +101,173 @@ class DownBlock(nn.Module):
|
|
| 94 |
|
| 95 |
self.residual_input_conv = nn.ModuleList(
|
| 96 |
[
|
| 97 |
-
nn.Conv2d(in_channels
|
| 98 |
-
out_channels=out_channels, kernel_size=1)
|
| 99 |
for i in range(num_layers)
|
| 100 |
-
|
| 101 |
]
|
| 102 |
)
|
| 103 |
-
|
| 104 |
-
self.resnet_down_conv = nn.Conv2d(out_channels, out_channels,
|
| 105 |
4, 2, 1) if self.down_sample else nn.Identity()
|
|
|
|
| 106 |
def forward(self, x, t_emb=None, context=None):
|
| 107 |
out = x
|
| 108 |
for i in range(self.num_layers):
|
| 109 |
-
# Resnet
|
| 110 |
resnet_input = out
|
| 111 |
out = self.resnet_conv_first[i](out)
|
| 112 |
if self.t_emb_dim is not None:
|
| 113 |
-
out = out + self.
|
| 114 |
out = self.resnet_conv_second[i](out)
|
| 115 |
out = out + self.residual_input_conv[i](resnet_input)
|
| 116 |
-
|
| 117 |
-
# Self Attention
|
| 118 |
if self.attn:
|
|
|
|
| 119 |
batch_size, channels, h, w = out.shape
|
| 120 |
-
in_attn = out.reshape(batch_size, channels, h*w)
|
| 121 |
in_attn = self.attention_norms[i](in_attn)
|
| 122 |
in_attn = in_attn.transpose(1, 2)
|
| 123 |
-
out_attn, _ = self.
|
| 124 |
-
out_attn =
|
| 125 |
-
batch_size, channels, h, w)
|
| 126 |
out = out + out_attn
|
| 127 |
-
|
| 128 |
-
# Cross Attention
|
| 129 |
if self.cross_attn:
|
| 130 |
-
assert context is not None, "
|
| 131 |
batch_size, channels, h, w = out.shape
|
| 132 |
in_attn = out.reshape(batch_size, channels, h * w)
|
| 133 |
in_attn = self.cross_attention_norms[i](in_attn)
|
| 134 |
in_attn = in_attn.transpose(1, 2)
|
| 135 |
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
|
| 136 |
context_proj = self.context_proj[i](context)
|
| 137 |
-
out_attn, _ = self.cross_attentions[i](
|
| 138 |
-
|
| 139 |
-
out_attn = out_attn.transpose(1, 2).reshape(
|
| 140 |
-
batch_size, channels, h, w)
|
| 141 |
out = out + out_attn
|
| 142 |
-
|
| 143 |
-
|
|
|
|
| 144 |
return out
|
| 145 |
|
| 146 |
|
| 147 |
class MidBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
"""
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
2) Self Attention block
|
| 152 |
-
3) Resnet block with time embedding
|
| 153 |
-
"""
|
| 154 |
-
|
| 155 |
-
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_dim, cross_attn=None, context_dim=None):
|
| 156 |
super().__init__()
|
| 157 |
-
self.
|
| 158 |
-
self.out_channels = out_channels
|
| 159 |
self.t_emb_dim = t_emb_dim
|
| 160 |
-
self.cross_attn = cross_attn
|
| 161 |
self.context_dim = context_dim
|
| 162 |
-
self.
|
| 163 |
-
self.
|
| 164 |
-
|
| 165 |
-
nn.
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
| 174 |
if self.t_emb_dim is not None:
|
| 175 |
-
self.
|
| 176 |
nn.Sequential(
|
| 177 |
nn.SiLU(),
|
| 178 |
nn.Linear(t_emb_dim, out_channels)
|
| 179 |
)
|
| 180 |
for _ in range(num_layers + 1)
|
| 181 |
])
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
| 191 |
self.attention_norms = nn.ModuleList(
|
| 192 |
-
[nn.GroupNorm(
|
|
|
|
| 193 |
)
|
| 194 |
-
|
| 195 |
-
self.
|
| 196 |
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 197 |
for _ in range(num_layers)]
|
| 198 |
)
|
| 199 |
-
|
| 200 |
if self.cross_attn:
|
| 201 |
-
assert context_dim is not None, "Context must be
|
| 202 |
-
self.
|
| 203 |
-
[nn.GroupNorm(
|
| 204 |
for _ in range(num_layers)]
|
| 205 |
)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
|
| 210 |
)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
| 215 |
for i in range(num_layers + 1)
|
| 216 |
-
]
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 220 |
-
out_channels, kernel_size=1)
|
| 221 |
-
for i in range(num_layers + 1)
|
| 222 |
-
|
| 223 |
-
])
|
| 224 |
-
|
| 225 |
def forward(self, x, t_emb=None, context=None):
|
| 226 |
out = x
|
|
|
|
|
|
|
| 227 |
resnet_input = out
|
| 228 |
-
out = self.
|
| 229 |
if self.t_emb_dim is not None:
|
| 230 |
-
out = out + self.
|
| 231 |
-
out = self.
|
| 232 |
out = out + self.residual_input_conv[0](resnet_input)
|
| 233 |
-
|
| 234 |
for i in range(self.num_layers):
|
|
|
|
| 235 |
batch_size, channels, h, w = out.shape
|
| 236 |
-
in_attn = out.reshape(batch_size, channels, h*w)
|
| 237 |
in_attn = self.attention_norms[i](in_attn)
|
| 238 |
in_attn = in_attn.transpose(1, 2)
|
| 239 |
-
out_attn, _ = self.
|
| 240 |
-
out_attn = out_attn.reshape(batch_size, channels, h, w)
|
| 241 |
out = out + out_attn
|
| 242 |
-
|
| 243 |
if self.cross_attn:
|
| 244 |
-
assert context is not None, "
|
| 245 |
batch_size, channels, h, w = out.shape
|
| 246 |
-
in_attn = out.reshape(batch_size, channels, h*w)
|
| 247 |
-
in_attn = self.
|
| 248 |
in_attn = in_attn.transpose(1, 2)
|
| 249 |
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
|
| 250 |
context_proj = self.context_proj[i](context)
|
| 251 |
-
out_attn, _ = self.
|
| 252 |
-
|
| 253 |
-
out_attn = out_attn.transpose(1, 2).reshape(
|
| 254 |
-
batch_size, channels, h, w)
|
| 255 |
out = out + out_attn
|
| 256 |
-
|
|
|
|
|
|
|
| 257 |
resnet_input = out
|
| 258 |
-
out = self.
|
| 259 |
if self.t_emb_dim is not None:
|
| 260 |
-
out = out + self.
|
| 261 |
-
out =
|
| 262 |
-
out = out + self.residual_input_conv[i+1](resnet_input)
|
| 263 |
-
|
| 264 |
return out
|
| 265 |
|
| 266 |
|
| 267 |
-
class
|
| 268 |
r"""
|
| 269 |
Up conv block with attention.
|
| 270 |
Sequence of following blocks
|
|
@@ -273,20 +276,18 @@ class UpBlockUnet(nn.Module):
|
|
| 273 |
2. Resnet block with time embedding
|
| 274 |
3. Attention Block
|
| 275 |
"""
|
| 276 |
-
|
| 277 |
-
def __init__(self, in_channels, out_channels, t_emb_dim,
|
| 278 |
-
num_heads, num_layers,
|
| 279 |
super().__init__()
|
| 280 |
self.num_layers = num_layers
|
| 281 |
self.up_sample = up_sample
|
| 282 |
self.t_emb_dim = t_emb_dim
|
| 283 |
-
self.
|
| 284 |
-
self.context_dim = context_dim
|
| 285 |
self.resnet_conv_first = nn.ModuleList(
|
| 286 |
[
|
| 287 |
nn.Sequential(
|
| 288 |
-
nn.GroupNorm(norm_channels, in_channels if i ==
|
| 289 |
-
0 else out_channels),
|
| 290 |
nn.SiLU(),
|
| 291 |
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
|
| 292 |
padding=1),
|
|
@@ -294,7 +295,7 @@ class UpBlockUnet(nn.Module):
|
|
| 294 |
for i in range(num_layers)
|
| 295 |
]
|
| 296 |
)
|
| 297 |
-
|
| 298 |
if self.t_emb_dim is not None:
|
| 299 |
self.t_emb_layers = nn.ModuleList([
|
| 300 |
nn.Sequential(
|
|
@@ -303,104 +304,73 @@ class UpBlockUnet(nn.Module):
|
|
| 303 |
)
|
| 304 |
for _ in range(num_layers)
|
| 305 |
])
|
| 306 |
-
|
| 307 |
self.resnet_conv_second = nn.ModuleList(
|
| 308 |
[
|
| 309 |
nn.Sequential(
|
| 310 |
nn.GroupNorm(norm_channels, out_channels),
|
| 311 |
nn.SiLU(),
|
| 312 |
-
nn.Conv2d(out_channels, out_channels,
|
| 313 |
-
kernel_size=3, stride=1, padding=1),
|
| 314 |
)
|
| 315 |
for _ in range(num_layers)
|
| 316 |
]
|
| 317 |
)
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
self.attentions = nn.ModuleList(
|
| 327 |
-
[
|
| 328 |
-
nn.MultiheadAttention(
|
| 329 |
-
out_channels, num_heads, batch_first=True)
|
| 330 |
-
for _ in range(num_layers)
|
| 331 |
-
]
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
if self.cross_attn:
|
| 335 |
-
assert context_dim is not None, "Context Dimension must be passed for cross attention"
|
| 336 |
-
self.cross_attention_norms = nn.ModuleList(
|
| 337 |
-
[nn.GroupNorm(norm_channels, out_channels)
|
| 338 |
-
for _ in range(num_layers)]
|
| 339 |
-
)
|
| 340 |
-
self.cross_attentions = nn.ModuleList(
|
| 341 |
-
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 342 |
-
for _ in range(num_layers)]
|
| 343 |
)
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
| 347 |
)
|
|
|
|
| 348 |
self.residual_input_conv = nn.ModuleList(
|
| 349 |
[
|
| 350 |
-
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 351 |
-
out_channels, kernel_size=1)
|
| 352 |
for i in range(num_layers)
|
| 353 |
]
|
| 354 |
)
|
| 355 |
-
self.up_sample_conv = nn.ConvTranspose2d(in_channels
|
| 356 |
4, 2, 1) \
|
| 357 |
if self.up_sample else nn.Identity()
|
| 358 |
-
|
| 359 |
-
def forward(self, x, out_down=None, t_emb=None
|
|
|
|
| 360 |
x = self.up_sample_conv(x)
|
|
|
|
|
|
|
| 361 |
if out_down is not None:
|
| 362 |
x = torch.cat([x, out_down], dim=1)
|
| 363 |
-
|
| 364 |
out = x
|
| 365 |
for i in range(self.num_layers):
|
| 366 |
-
# Resnet
|
| 367 |
resnet_input = out
|
| 368 |
out = self.resnet_conv_first[i](out)
|
| 369 |
if self.t_emb_dim is not None:
|
| 370 |
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
|
| 371 |
out = self.resnet_conv_second[i](out)
|
| 372 |
out = out + self.residual_input_conv[i](resnet_input)
|
|
|
|
| 373 |
# Self Attention
|
| 374 |
-
|
| 375 |
-
in_attn = out.reshape(batch_size, channels, h * w)
|
| 376 |
-
in_attn = self.attention_norms[i](in_attn)
|
| 377 |
-
in_attn = in_attn.transpose(1, 2)
|
| 378 |
-
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
|
| 379 |
-
out_attn = out_attn.transpose(1, 2).reshape(
|
| 380 |
-
batch_size, channels, h, w)
|
| 381 |
-
out = out + out_attn
|
| 382 |
-
# Cross Attention
|
| 383 |
-
if self.cross_attn:
|
| 384 |
-
assert context is not None, "context cannot be None if cross attention layers are used"
|
| 385 |
batch_size, channels, h, w = out.shape
|
| 386 |
in_attn = out.reshape(batch_size, channels, h * w)
|
| 387 |
-
in_attn = self.
|
| 388 |
in_attn = in_attn.transpose(1, 2)
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
|
| 392 |
-
"Context shape does not match B,_,CONTEXT_DIM"
|
| 393 |
-
context_proj = self.context_proj[i](context)
|
| 394 |
-
out_attn, _ = self.cross_attentions[i](
|
| 395 |
-
in_attn, context_proj, context_proj)
|
| 396 |
-
out_attn = out_attn.transpose(1, 2).reshape(
|
| 397 |
-
batch_size, channels, h, w)
|
| 398 |
out = out + out_attn
|
| 399 |
-
|
| 400 |
return out
|
| 401 |
|
| 402 |
|
| 403 |
-
class
|
| 404 |
r"""
|
| 405 |
Up conv block with attention.
|
| 406 |
Sequence of following blocks
|
|
@@ -409,19 +379,19 @@ class UpBlock(nn.Module):
|
|
| 409 |
2. Resnet block with time embedding
|
| 410 |
3. Attention Block
|
| 411 |
"""
|
| 412 |
-
|
| 413 |
-
def __init__(self, in_channels, out_channels, t_emb_dim,
|
| 414 |
-
|
| 415 |
super().__init__()
|
| 416 |
self.num_layers = num_layers
|
| 417 |
self.up_sample = up_sample
|
| 418 |
self.t_emb_dim = t_emb_dim
|
| 419 |
-
self.
|
|
|
|
| 420 |
self.resnet_conv_first = nn.ModuleList(
|
| 421 |
[
|
| 422 |
nn.Sequential(
|
| 423 |
-
nn.GroupNorm(norm_channels, in_channels if i ==
|
| 424 |
-
0 else out_channels),
|
| 425 |
nn.SiLU(),
|
| 426 |
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
|
| 427 |
padding=1),
|
|
@@ -429,7 +399,7 @@ class UpBlock(nn.Module):
|
|
| 429 |
for i in range(num_layers)
|
| 430 |
]
|
| 431 |
)
|
| 432 |
-
|
| 433 |
if self.t_emb_dim is not None:
|
| 434 |
self.t_emb_layers = nn.ModuleList([
|
| 435 |
nn.Sequential(
|
|
@@ -438,71 +408,93 @@ class UpBlock(nn.Module):
|
|
| 438 |
)
|
| 439 |
for _ in range(num_layers)
|
| 440 |
])
|
| 441 |
-
|
| 442 |
self.resnet_conv_second = nn.ModuleList(
|
| 443 |
[
|
| 444 |
nn.Sequential(
|
| 445 |
nn.GroupNorm(norm_channels, out_channels),
|
| 446 |
nn.SiLU(),
|
| 447 |
-
nn.Conv2d(out_channels, out_channels,
|
| 448 |
-
kernel_size=3, stride=1, padding=1),
|
| 449 |
)
|
| 450 |
for _ in range(num_layers)
|
| 451 |
]
|
| 452 |
)
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
)
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
)
|
| 468 |
-
|
| 469 |
self.residual_input_conv = nn.ModuleList(
|
| 470 |
[
|
| 471 |
-
nn.Conv2d(in_channels if i == 0 else out_channels,
|
| 472 |
-
out_channels, kernel_size=1)
|
| 473 |
for i in range(num_layers)
|
| 474 |
]
|
| 475 |
)
|
| 476 |
-
self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
|
| 477 |
4, 2, 1) \
|
| 478 |
if self.up_sample else nn.Identity()
|
| 479 |
-
|
| 480 |
-
def forward(self, x, out_down=None, t_emb=None):
|
| 481 |
-
# Upsample
|
| 482 |
x = self.up_sample_conv(x)
|
| 483 |
-
|
| 484 |
-
# Concat with Downblock output
|
| 485 |
if out_down is not None:
|
| 486 |
x = torch.cat([x, out_down], dim=1)
|
| 487 |
-
|
| 488 |
out = x
|
| 489 |
for i in range(self.num_layers):
|
| 490 |
-
# Resnet
|
| 491 |
resnet_input = out
|
| 492 |
out = self.resnet_conv_first[i](out)
|
| 493 |
if self.t_emb_dim is not None:
|
| 494 |
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
|
| 495 |
out = self.resnet_conv_second[i](out)
|
| 496 |
out = out + self.residual_input_conv[i](resnet_input)
|
| 497 |
-
|
| 498 |
# Self Attention
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
batch_size, channels, h, w = out.shape
|
| 501 |
in_attn = out.reshape(batch_size, channels, h * w)
|
| 502 |
-
in_attn = self.
|
| 503 |
in_attn = in_attn.transpose(1, 2)
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
out = out + out_attn
|
|
|
|
| 508 |
return out
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
| 4 |
|
| 5 |
def get_time_embedding(time_steps, temb_dim):
|
| 6 |
+
r"""
|
| 7 |
+
Convert time steps tensor into an embedding using the
|
| 8 |
+
sinusoidal time embedding formula
|
| 9 |
+
:param time_steps: 1D tensor of length batch size
|
| 10 |
+
:param temb_dim: Dimension of the embedding
|
| 11 |
+
:return: BxD embedding representation of B time steps
|
| 12 |
+
"""
|
| 13 |
+
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
|
| 14 |
+
|
| 15 |
+
# factor = 10000^(2i/d_model)
|
| 16 |
factor = 10000 ** ((torch.arange(
|
| 17 |
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
|
| 18 |
)
|
| 19 |
+
|
| 20 |
# pos / factor
|
| 21 |
+
# timesteps B -> B, 1 -> B, temb_dim
|
| 22 |
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
|
| 23 |
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
|
| 24 |
return t_emb
|
| 25 |
|
| 26 |
|
| 27 |
class DownBlock(nn.Module):
|
| 28 |
+
r"""
|
| 29 |
+
Down conv block with attention.
|
| 30 |
+
Sequence of following block
|
| 31 |
+
1. Resnet block with time embedding
|
| 32 |
+
2. Attention block
|
| 33 |
+
3. Downsample
|
| 34 |
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, in_channels, out_channels, t_emb_dim,
|
| 37 |
+
down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
super().__init__()
|
| 39 |
+
self.num_layers = num_layers
|
| 40 |
self.down_sample = down_sample
|
| 41 |
+
self.attn = attn
|
| 42 |
self.context_dim = context_dim
|
| 43 |
self.cross_attn = cross_attn
|
| 44 |
+
self.t_emb_dim = t_emb_dim
|
| 45 |
+
self.resnet_conv_first = nn.ModuleList(
|
| 46 |
+
[
|
| 47 |
+
nn.Sequential(
|
| 48 |
+
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
|
| 49 |
+
nn.SiLU(),
|
| 50 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
|
| 51 |
+
kernel_size=3, stride=1, padding=1),
|
| 52 |
+
)
|
| 53 |
+
for i in range(num_layers)
|
| 54 |
+
]
|
| 55 |
+
)
|
|
|
|
| 56 |
if self.t_emb_dim is not None:
|
| 57 |
+
self.t_emb_layers = nn.ModuleList([
|
| 58 |
nn.Sequential(
|
| 59 |
nn.SiLU(),
|
| 60 |
nn.Linear(self.t_emb_dim, out_channels)
|
| 61 |
)
|
| 62 |
for _ in range(num_layers)
|
| 63 |
])
|
| 64 |
+
self.resnet_conv_second = nn.ModuleList(
|
| 65 |
+
[
|
| 66 |
+
nn.Sequential(
|
| 67 |
+
nn.GroupNorm(norm_channels, out_channels),
|
| 68 |
+
nn.SiLU(),
|
| 69 |
+
nn.Conv2d(out_channels, out_channels,
|
| 70 |
+
kernel_size=3, stride=1, padding=1),
|
| 71 |
+
)
|
| 72 |
+
for _ in range(num_layers)
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
if self.attn:
|
| 77 |
self.attention_norms = nn.ModuleList(
|
| 78 |
[nn.GroupNorm(norm_channels, out_channels)
|
| 79 |
for _ in range(num_layers)]
|
| 80 |
)
|
| 81 |
+
|
| 82 |
+
self.attentions = nn.ModuleList(
|
| 83 |
+
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 84 |
+
for _ in range(num_layers)]
|
| 85 |
)
|
| 86 |
+
|
| 87 |
if self.cross_attn:
|
| 88 |
+
assert context_dim is not None, "Context Dimension must be passed for cross attention"
|
| 89 |
+
self.cross_attention_norms = nn.ModuleList(
|
| 90 |
[nn.GroupNorm(norm_channels, out_channels)
|
| 91 |
for _ in range(num_layers)]
|
| 92 |
)
|
| 93 |
+
self.cross_attentions = nn.ModuleList(
|
| 94 |
+
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 95 |
+
for _ in range(num_layers)]
|
|
|
|
| 96 |
)
|
|
|
|
| 97 |
self.context_proj = nn.ModuleList(
|
| 98 |
[nn.Linear(context_dim, out_channels)
|
| 99 |
for _ in range(num_layers)]
|
|
|
|
| 101 |
|
| 102 |
self.residual_input_conv = nn.ModuleList(
|
| 103 |
[
|
| 104 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
|
|
|
|
| 105 |
for i in range(num_layers)
|
|
|
|
| 106 |
]
|
| 107 |
)
|
| 108 |
+
self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
|
|
|
|
| 109 |
4, 2, 1) if self.down_sample else nn.Identity()
|
| 110 |
+
|
| 111 |
def forward(self, x, t_emb=None, context=None):
|
| 112 |
out = x
|
| 113 |
for i in range(self.num_layers):
|
| 114 |
+
# Resnet block of Unet
|
| 115 |
resnet_input = out
|
| 116 |
out = self.resnet_conv_first[i](out)
|
| 117 |
if self.t_emb_dim is not None:
|
| 118 |
+
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
|
| 119 |
out = self.resnet_conv_second[i](out)
|
| 120 |
out = out + self.residual_input_conv[i](resnet_input)
|
| 121 |
+
|
|
|
|
| 122 |
if self.attn:
|
| 123 |
+
# Attention block of Unet
|
| 124 |
batch_size, channels, h, w = out.shape
|
| 125 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 126 |
in_attn = self.attention_norms[i](in_attn)
|
| 127 |
in_attn = in_attn.transpose(1, 2)
|
| 128 |
+
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
|
| 129 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
|
|
|
| 130 |
out = out + out_attn
|
| 131 |
+
|
|
|
|
| 132 |
if self.cross_attn:
|
| 133 |
+
assert context is not None, "context cannot be None if cross attention layers are used"
|
| 134 |
batch_size, channels, h, w = out.shape
|
| 135 |
in_attn = out.reshape(batch_size, channels, h * w)
|
| 136 |
in_attn = self.cross_attention_norms[i](in_attn)
|
| 137 |
in_attn = in_attn.transpose(1, 2)
|
| 138 |
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
|
| 139 |
context_proj = self.context_proj[i](context)
|
| 140 |
+
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
|
| 141 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
|
|
|
|
|
|
| 142 |
out = out + out_attn
|
| 143 |
+
|
| 144 |
+
# Downsample
|
| 145 |
+
out = self.down_sample_conv(out)
|
| 146 |
return out
|
| 147 |
|
| 148 |
|
| 149 |
class MidBlock(nn.Module):
|
| 150 |
+
r"""
|
| 151 |
+
Mid conv block with attention.
|
| 152 |
+
Sequence of following blocks
|
| 153 |
+
1. Resnet block with time embedding
|
| 154 |
+
2. Attention block
|
| 155 |
+
3. Resnet block with time embedding
|
| 156 |
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
super().__init__()
|
| 160 |
+
self.num_layers = num_layers
|
|
|
|
| 161 |
self.t_emb_dim = t_emb_dim
|
|
|
|
| 162 |
self.context_dim = context_dim
|
| 163 |
+
self.cross_attn = cross_attn
|
| 164 |
+
self.resnet_conv_first = nn.ModuleList(
|
| 165 |
+
[
|
| 166 |
+
nn.Sequential(
|
| 167 |
+
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
|
| 168 |
+
nn.SiLU(),
|
| 169 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
|
| 170 |
+
padding=1),
|
| 171 |
+
)
|
| 172 |
+
for i in range(num_layers + 1)
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
if self.t_emb_dim is not None:
|
| 177 |
+
self.t_emb_layers = nn.ModuleList([
|
| 178 |
nn.Sequential(
|
| 179 |
nn.SiLU(),
|
| 180 |
nn.Linear(t_emb_dim, out_channels)
|
| 181 |
)
|
| 182 |
for _ in range(num_layers + 1)
|
| 183 |
])
|
| 184 |
+
self.resnet_conv_second = nn.ModuleList(
|
| 185 |
+
[
|
| 186 |
+
nn.Sequential(
|
| 187 |
+
nn.GroupNorm(norm_channels, out_channels),
|
| 188 |
+
nn.SiLU(),
|
| 189 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 190 |
+
)
|
| 191 |
+
for _ in range(num_layers + 1)
|
| 192 |
+
]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
self.attention_norms = nn.ModuleList(
|
| 196 |
+
[nn.GroupNorm(norm_channels, out_channels)
|
| 197 |
+
for _ in range(num_layers)]
|
| 198 |
)
|
| 199 |
+
|
| 200 |
+
self.attentions = nn.ModuleList(
|
| 201 |
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 202 |
for _ in range(num_layers)]
|
| 203 |
)
|
|
|
|
| 204 |
if self.cross_attn:
|
| 205 |
+
assert context_dim is not None, "Context Dimension must be passed for cross attention"
|
| 206 |
+
self.cross_attention_norms = nn.ModuleList(
|
| 207 |
+
[nn.GroupNorm(norm_channels, out_channels)
|
| 208 |
for _ in range(num_layers)]
|
| 209 |
)
|
| 210 |
+
self.cross_attentions = nn.ModuleList(
|
| 211 |
+
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 212 |
+
for _ in range(num_layers)]
|
|
|
|
| 213 |
)
|
| 214 |
+
self.context_proj = nn.ModuleList(
|
| 215 |
+
[nn.Linear(context_dim, out_channels)
|
| 216 |
+
for _ in range(num_layers)]
|
| 217 |
+
)
|
| 218 |
+
self.residual_input_conv = nn.ModuleList(
|
| 219 |
+
[
|
| 220 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
|
| 221 |
for i in range(num_layers + 1)
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
def forward(self, x, t_emb=None, context=None):
|
| 226 |
out = x
|
| 227 |
+
|
| 228 |
+
# First resnet block
|
| 229 |
resnet_input = out
|
| 230 |
+
out = self.resnet_conv_first[0](out)
|
| 231 |
if self.t_emb_dim is not None:
|
| 232 |
+
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
|
| 233 |
+
out = self.resnet_conv_second[0](out)
|
| 234 |
out = out + self.residual_input_conv[0](resnet_input)
|
| 235 |
+
|
| 236 |
for i in range(self.num_layers):
|
| 237 |
+
# Attention Block
|
| 238 |
batch_size, channels, h, w = out.shape
|
| 239 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 240 |
in_attn = self.attention_norms[i](in_attn)
|
| 241 |
in_attn = in_attn.transpose(1, 2)
|
| 242 |
+
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
|
| 243 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
| 244 |
out = out + out_attn
|
| 245 |
+
|
| 246 |
if self.cross_attn:
|
| 247 |
+
assert context is not None, "context cannot be None if cross attention layers are used"
|
| 248 |
batch_size, channels, h, w = out.shape
|
| 249 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 250 |
+
in_attn = self.cross_attention_norms[i](in_attn)
|
| 251 |
in_attn = in_attn.transpose(1, 2)
|
| 252 |
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
|
| 253 |
context_proj = self.context_proj[i](context)
|
| 254 |
+
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
|
| 255 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
|
|
|
|
|
|
| 256 |
out = out + out_attn
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Resnet Block
|
| 260 |
resnet_input = out
|
| 261 |
+
out = self.resnet_conv_first[i + 1](out)
|
| 262 |
if self.t_emb_dim is not None:
|
| 263 |
+
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
|
| 264 |
+
out = self.resnet_conv_second[i + 1](out)
|
| 265 |
+
out = out + self.residual_input_conv[i + 1](resnet_input)
|
| 266 |
+
|
| 267 |
return out
|
| 268 |
|
| 269 |
|
| 270 |
+
class UpBlock(nn.Module):
|
| 271 |
r"""
|
| 272 |
Up conv block with attention.
|
| 273 |
Sequence of following blocks
|
|
|
|
| 276 |
2. Resnet block with time embedding
|
| 277 |
3. Attention Block
|
| 278 |
"""
|
| 279 |
+
|
| 280 |
+
def __init__(self, in_channels, out_channels, t_emb_dim,
|
| 281 |
+
up_sample, num_heads, num_layers, attn, norm_channels):
|
| 282 |
super().__init__()
|
| 283 |
self.num_layers = num_layers
|
| 284 |
self.up_sample = up_sample
|
| 285 |
self.t_emb_dim = t_emb_dim
|
| 286 |
+
self.attn = attn
|
|
|
|
| 287 |
self.resnet_conv_first = nn.ModuleList(
|
| 288 |
[
|
| 289 |
nn.Sequential(
|
| 290 |
+
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
|
|
|
|
| 291 |
nn.SiLU(),
|
| 292 |
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
|
| 293 |
padding=1),
|
|
|
|
| 295 |
for i in range(num_layers)
|
| 296 |
]
|
| 297 |
)
|
| 298 |
+
|
| 299 |
if self.t_emb_dim is not None:
|
| 300 |
self.t_emb_layers = nn.ModuleList([
|
| 301 |
nn.Sequential(
|
|
|
|
| 304 |
)
|
| 305 |
for _ in range(num_layers)
|
| 306 |
])
|
| 307 |
+
|
| 308 |
self.resnet_conv_second = nn.ModuleList(
|
| 309 |
[
|
| 310 |
nn.Sequential(
|
| 311 |
nn.GroupNorm(norm_channels, out_channels),
|
| 312 |
nn.SiLU(),
|
| 313 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
|
|
|
| 314 |
)
|
| 315 |
for _ in range(num_layers)
|
| 316 |
]
|
| 317 |
)
|
| 318 |
+
if self.attn:
|
| 319 |
+
self.attention_norms = nn.ModuleList(
|
| 320 |
+
[
|
| 321 |
+
nn.GroupNorm(norm_channels, out_channels)
|
| 322 |
+
for _ in range(num_layers)
|
| 323 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
)
|
| 325 |
+
|
| 326 |
+
self.attentions = nn.ModuleList(
|
| 327 |
+
[
|
| 328 |
+
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 329 |
+
for _ in range(num_layers)
|
| 330 |
+
]
|
| 331 |
)
|
| 332 |
+
|
| 333 |
self.residual_input_conv = nn.ModuleList(
|
| 334 |
[
|
| 335 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
|
|
|
|
| 336 |
for i in range(num_layers)
|
| 337 |
]
|
| 338 |
)
|
| 339 |
+
self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
|
| 340 |
4, 2, 1) \
|
| 341 |
if self.up_sample else nn.Identity()
|
| 342 |
+
|
| 343 |
+
def forward(self, x, out_down=None, t_emb=None):
|
| 344 |
+
# Upsample
|
| 345 |
x = self.up_sample_conv(x)
|
| 346 |
+
|
| 347 |
+
# Concat with Downblock output
|
| 348 |
if out_down is not None:
|
| 349 |
x = torch.cat([x, out_down], dim=1)
|
| 350 |
+
|
| 351 |
out = x
|
| 352 |
for i in range(self.num_layers):
|
| 353 |
+
# Resnet Block
|
| 354 |
resnet_input = out
|
| 355 |
out = self.resnet_conv_first[i](out)
|
| 356 |
if self.t_emb_dim is not None:
|
| 357 |
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
|
| 358 |
out = self.resnet_conv_second[i](out)
|
| 359 |
out = out + self.residual_input_conv[i](resnet_input)
|
| 360 |
+
|
| 361 |
# Self Attention
|
| 362 |
+
if self.attn:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
batch_size, channels, h, w = out.shape
|
| 364 |
in_attn = out.reshape(batch_size, channels, h * w)
|
| 365 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 366 |
in_attn = in_attn.transpose(1, 2)
|
| 367 |
+
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
|
| 368 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
out = out + out_attn
|
|
|
|
| 370 |
return out
|
| 371 |
|
| 372 |
|
| 373 |
+
class UpBlockUnet(nn.Module):
|
| 374 |
r"""
|
| 375 |
Up conv block with attention.
|
| 376 |
Sequence of following blocks
|
|
|
|
| 379 |
2. Resnet block with time embedding
|
| 380 |
3. Attention Block
|
| 381 |
"""
|
| 382 |
+
|
| 383 |
+
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
|
| 384 |
+
num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
|
| 385 |
super().__init__()
|
| 386 |
self.num_layers = num_layers
|
| 387 |
self.up_sample = up_sample
|
| 388 |
self.t_emb_dim = t_emb_dim
|
| 389 |
+
self.cross_attn = cross_attn
|
| 390 |
+
self.context_dim = context_dim
|
| 391 |
self.resnet_conv_first = nn.ModuleList(
|
| 392 |
[
|
| 393 |
nn.Sequential(
|
| 394 |
+
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
|
|
|
|
| 395 |
nn.SiLU(),
|
| 396 |
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
|
| 397 |
padding=1),
|
|
|
|
| 399 |
for i in range(num_layers)
|
| 400 |
]
|
| 401 |
)
|
| 402 |
+
|
| 403 |
if self.t_emb_dim is not None:
|
| 404 |
self.t_emb_layers = nn.ModuleList([
|
| 405 |
nn.Sequential(
|
|
|
|
| 408 |
)
|
| 409 |
for _ in range(num_layers)
|
| 410 |
])
|
| 411 |
+
|
| 412 |
self.resnet_conv_second = nn.ModuleList(
|
| 413 |
[
|
| 414 |
nn.Sequential(
|
| 415 |
nn.GroupNorm(norm_channels, out_channels),
|
| 416 |
nn.SiLU(),
|
| 417 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
|
|
|
| 418 |
)
|
| 419 |
for _ in range(num_layers)
|
| 420 |
]
|
| 421 |
)
|
| 422 |
+
|
| 423 |
+
self.attention_norms = nn.ModuleList(
|
| 424 |
+
[
|
| 425 |
+
nn.GroupNorm(norm_channels, out_channels)
|
| 426 |
+
for _ in range(num_layers)
|
| 427 |
+
]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.attentions = nn.ModuleList(
|
| 431 |
+
[
|
| 432 |
+
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 433 |
+
for _ in range(num_layers)
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if self.cross_attn:
|
| 438 |
+
assert context_dim is not None, "Context Dimension must be passed for cross attention"
|
| 439 |
+
self.cross_attention_norms = nn.ModuleList(
|
| 440 |
+
[nn.GroupNorm(norm_channels, out_channels)
|
| 441 |
+
for _ in range(num_layers)]
|
| 442 |
)
|
| 443 |
+
self.cross_attentions = nn.ModuleList(
|
| 444 |
+
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
|
| 445 |
+
for _ in range(num_layers)]
|
| 446 |
+
)
|
| 447 |
+
self.context_proj = nn.ModuleList(
|
| 448 |
+
[nn.Linear(context_dim, out_channels)
|
| 449 |
+
for _ in range(num_layers)]
|
| 450 |
)
|
|
|
|
| 451 |
self.residual_input_conv = nn.ModuleList(
|
| 452 |
[
|
| 453 |
+
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
|
|
|
|
| 454 |
for i in range(num_layers)
|
| 455 |
]
|
| 456 |
)
|
| 457 |
+
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
|
| 458 |
4, 2, 1) \
|
| 459 |
if self.up_sample else nn.Identity()
|
| 460 |
+
|
| 461 |
+
def forward(self, x, out_down=None, t_emb=None, context=None):
|
|
|
|
| 462 |
x = self.up_sample_conv(x)
|
|
|
|
|
|
|
| 463 |
if out_down is not None:
|
| 464 |
x = torch.cat([x, out_down], dim=1)
|
| 465 |
+
|
| 466 |
out = x
|
| 467 |
for i in range(self.num_layers):
|
| 468 |
+
# Resnet
|
| 469 |
resnet_input = out
|
| 470 |
out = self.resnet_conv_first[i](out)
|
| 471 |
if self.t_emb_dim is not None:
|
| 472 |
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
|
| 473 |
out = self.resnet_conv_second[i](out)
|
| 474 |
out = out + self.residual_input_conv[i](resnet_input)
|
|
|
|
| 475 |
# Self Attention
|
| 476 |
+
batch_size, channels, h, w = out.shape
|
| 477 |
+
in_attn = out.reshape(batch_size, channels, h * w)
|
| 478 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 479 |
+
in_attn = in_attn.transpose(1, 2)
|
| 480 |
+
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
|
| 481 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
| 482 |
+
out = out + out_attn
|
| 483 |
+
# Cross Attention
|
| 484 |
+
if self.cross_attn:
|
| 485 |
+
assert context is not None, "context cannot be None if cross attention layers are used"
|
| 486 |
batch_size, channels, h, w = out.shape
|
| 487 |
in_attn = out.reshape(batch_size, channels, h * w)
|
| 488 |
+
in_attn = self.cross_attention_norms[i](in_attn)
|
| 489 |
in_attn = in_attn.transpose(1, 2)
|
| 490 |
+
assert len(context.shape) == 3, \
|
| 491 |
+
"Context shape does not match B,_,CONTEXT_DIM"
|
| 492 |
+
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
|
| 493 |
+
"Context shape does not match B,_,CONTEXT_DIM"
|
| 494 |
+
context_proj = self.context_proj[i](context)
|
| 495 |
+
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
|
| 496 |
+
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
|
| 497 |
out = out + out_attn
|
| 498 |
+
|
| 499 |
return out
|
| 500 |
+
|
models/vqvae.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
from models.blocks import DownBlock,
|
| 4 |
|
| 5 |
|
| 6 |
class VQVAE(nn.Module):
|
|
@@ -10,122 +10,125 @@ class VQVAE(nn.Module):
|
|
| 10 |
self.mid_channels = model_config['mid_channels']
|
| 11 |
self.down_sample = model_config['down_sample']
|
| 12 |
self.num_down_layers = model_config['num_down_layers']
|
| 13 |
-
self.num_up_layers = model_config['num_up_layers']
|
| 14 |
self.num_mid_layers = model_config['num_mid_layers']
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
# Latent Dimension
|
| 20 |
-
self.z_channels = model_config[
|
| 21 |
-
self.codebook_size = model_config[
|
| 22 |
-
self.norm_channels = model_config[
|
| 23 |
-
self.num_heads = model_config[
|
| 24 |
-
|
|
|
|
| 25 |
assert self.mid_channels[0] == self.down_channels[-1]
|
| 26 |
assert self.mid_channels[-1] == self.down_channels[-1]
|
| 27 |
assert len(self.down_sample) == len(self.down_channels) - 1
|
| 28 |
assert len(self.attns) == len(self.down_channels) - 1
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
self.encoder_layers = nn.ModuleList([])
|
| 37 |
for i in range(len(self.down_channels) - 1):
|
| 38 |
-
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i+1],
|
| 39 |
-
|
| 40 |
-
num_heads=self.num_heads,
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
self.
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
#
|
|
|
|
|
|
|
|
|
|
| 57 |
self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
|
| 58 |
-
|
| 59 |
-
# Decoder
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
self.
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# Midblock +
|
| 66 |
-
self.
|
| 67 |
for i in reversed(range(1, len(self.mid_channels))):
|
| 68 |
-
self.
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
self.decoder_layers = nn.ModuleList([])
|
| 73 |
for i in reversed(range(1, len(self.down_channels))):
|
| 74 |
-
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i-1],
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
self.decoder_conv_out = nn.Conv2d(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
def quantize(self, x):
|
| 86 |
B, C, H, W = x.shape
|
| 87 |
-
|
| 88 |
-
# B,C,H,W -> B,H,W,C
|
| 89 |
x = x.permute(0, 2, 3, 1)
|
| 90 |
-
|
| 91 |
-
# B,H,W,C -> B, H*W, C
|
| 92 |
x = x.reshape(x.size(0), -1, x.size(-1))
|
| 93 |
-
|
| 94 |
-
# Find nearest
|
| 95 |
-
#
|
| 96 |
-
dist = torch.cdist(
|
| 97 |
-
|
| 98 |
-
|
| 99 |
min_encoding_indices = torch.argmin(dist, dim=-1)
|
| 100 |
-
|
| 101 |
-
# Replace encoder output with
|
| 102 |
-
quant_out
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
# x -> B*H*W,C
|
| 106 |
x = x.reshape((-1, x.size(-1)))
|
| 107 |
-
|
| 108 |
codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
}
|
| 113 |
-
|
| 114 |
# Straight through estimation
|
| 115 |
-
quant_out = x
|
| 116 |
-
|
| 117 |
-
# quant_out -> B,C,H,W
|
| 118 |
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
|
| 119 |
-
min_encoding_indices = min_encoding_indices.reshape(
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
return quant_out, quantize_loss, min_encoding_indices
|
| 123 |
|
| 124 |
def encode(self, x):
|
| 125 |
-
out = self.
|
| 126 |
-
for
|
| 127 |
out = down(out)
|
| 128 |
-
for mid in self.
|
| 129 |
out = mid(out)
|
| 130 |
out = self.encoder_norm_out(out)
|
| 131 |
out = nn.SiLU()(out)
|
|
@@ -133,21 +136,21 @@ class VQVAE(nn.Module):
|
|
| 133 |
out = self.pre_quant_conv(out)
|
| 134 |
out, quant_losses, _ = self.quantize(out)
|
| 135 |
return out, quant_losses
|
| 136 |
-
|
| 137 |
def decode(self, z):
|
| 138 |
out = z
|
| 139 |
out = self.post_quant_conv(out)
|
| 140 |
out = self.decoder_conv_in(out)
|
| 141 |
-
for mid in self.
|
| 142 |
out = mid(out)
|
| 143 |
-
for up in self.decoder_layers:
|
| 144 |
out = up(out)
|
| 145 |
-
|
| 146 |
out = self.decoder_norm_out(out)
|
| 147 |
-
out = nn.SiLU(out)
|
| 148 |
out = self.decoder_conv_out(out)
|
| 149 |
return out
|
| 150 |
-
|
| 151 |
def forward(self, x):
|
| 152 |
z, quant_losses = self.encode(x)
|
| 153 |
out = self.decode(z)
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
from models.blocks import DownBlock, MidBlock, UpBlock
|
| 4 |
|
| 5 |
|
| 6 |
class VQVAE(nn.Module):
|
|
|
|
| 10 |
self.mid_channels = model_config['mid_channels']
|
| 11 |
self.down_sample = model_config['down_sample']
|
| 12 |
self.num_down_layers = model_config['num_down_layers']
|
|
|
|
| 13 |
self.num_mid_layers = model_config['num_mid_layers']
|
| 14 |
+
self.num_up_layers = model_config['num_up_layers']
|
| 15 |
+
|
| 16 |
+
# To disable attention in Downblock of Encoder and Upblock of Decoder
|
| 17 |
+
self.attns = model_config['attn_down']
|
| 18 |
+
|
| 19 |
# Latent Dimension
|
| 20 |
+
self.z_channels = model_config['z_channels']
|
| 21 |
+
self.codebook_size = model_config['codebook_size']
|
| 22 |
+
self.norm_channels = model_config['norm_channels']
|
| 23 |
+
self.num_heads = model_config['num_heads']
|
| 24 |
+
|
| 25 |
+
# Assertion to validate the channel information
|
| 26 |
assert self.mid_channels[0] == self.down_channels[-1]
|
| 27 |
assert self.mid_channels[-1] == self.down_channels[-1]
|
| 28 |
assert len(self.down_sample) == len(self.down_channels) - 1
|
| 29 |
assert len(self.attns) == len(self.down_channels) - 1
|
| 30 |
+
|
| 31 |
+
# Wherever we use downsampling in encoder correspondingly use
|
| 32 |
+
# upsampling in decoder
|
| 33 |
+
self.up_sample = list(reversed(self.down_sample))
|
| 34 |
+
|
| 35 |
+
##################### Encoder ######################
|
| 36 |
+
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
|
| 37 |
+
|
| 38 |
+
# Downblock + Midblock
|
| 39 |
self.encoder_layers = nn.ModuleList([])
|
| 40 |
for i in range(len(self.down_channels) - 1):
|
| 41 |
+
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
|
| 42 |
+
t_emb_dim=None, down_sample=self.down_sample[i],
|
| 43 |
+
num_heads=self.num_heads,
|
| 44 |
+
num_layers=self.num_down_layers,
|
| 45 |
+
attn=self.attns[i],
|
| 46 |
+
norm_channels=self.norm_channels))
|
| 47 |
+
|
| 48 |
+
self.encoder_mids = nn.ModuleList([])
|
| 49 |
+
for i in range(len(self.mid_channels) - 1):
|
| 50 |
+
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
|
| 51 |
+
t_emb_dim=None,
|
| 52 |
+
num_heads=self.num_heads,
|
| 53 |
+
num_layers=self.num_mid_layers,
|
| 54 |
+
norm_channels=self.norm_channels))
|
| 55 |
+
|
| 56 |
+
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
|
| 57 |
+
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
|
| 58 |
+
|
| 59 |
+
# Pre Quantization Convolution
|
| 60 |
+
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
|
| 61 |
+
|
| 62 |
+
# Codebook
|
| 63 |
self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
|
| 64 |
+
|
| 65 |
+
##################### Decoder ######################
|
| 66 |
+
|
| 67 |
+
# Post Quantization Convolution
|
| 68 |
+
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
|
| 69 |
+
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
|
| 70 |
+
|
| 71 |
+
# Midblock + Upblock
|
| 72 |
+
self.decoder_mids = nn.ModuleList([])
|
| 73 |
for i in reversed(range(1, len(self.mid_channels))):
|
| 74 |
+
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
|
| 75 |
+
t_emb_dim=None,
|
| 76 |
+
num_heads=self.num_heads,
|
| 77 |
+
num_layers=self.num_mid_layers,
|
| 78 |
+
norm_channels=self.norm_channels))
|
| 79 |
+
|
| 80 |
self.decoder_layers = nn.ModuleList([])
|
| 81 |
for i in reversed(range(1, len(self.down_channels))):
|
| 82 |
+
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
|
| 83 |
+
t_emb_dim=None, up_sample=self.down_sample[i - 1],
|
| 84 |
+
num_heads=self.num_heads,
|
| 85 |
+
num_layers=self.num_up_layers,
|
| 86 |
+
attn=self.attns[i-1],
|
| 87 |
+
norm_channels=self.norm_channels))
|
| 88 |
+
|
| 89 |
+
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
|
| 90 |
+
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
|
| 91 |
+
|
|
|
|
| 92 |
def quantize(self, x):
|
| 93 |
B, C, H, W = x.shape
|
| 94 |
+
|
| 95 |
+
# B, C, H, W -> B, H, W, C
|
| 96 |
x = x.permute(0, 2, 3, 1)
|
| 97 |
+
|
| 98 |
+
# B, H, W, C -> B, H*W, C
|
| 99 |
x = x.reshape(x.size(0), -1, x.size(-1))
|
| 100 |
+
|
| 101 |
+
# Find nearest embedding/codebook vector
|
| 102 |
+
# dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
|
| 103 |
+
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
|
| 104 |
+
# (B, H*W)
|
|
|
|
| 105 |
min_encoding_indices = torch.argmin(dist, dim=-1)
|
| 106 |
+
|
| 107 |
+
# Replace encoder output with nearest codebook
|
| 108 |
+
# quant_out -> B*H*W, C
|
| 109 |
+
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
|
| 110 |
+
|
| 111 |
+
# x -> B*H*W, C
|
| 112 |
x = x.reshape((-1, x.size(-1)))
|
| 113 |
+
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
|
| 114 |
codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
|
| 115 |
+
quantize_losses = {
|
| 116 |
+
'codebook_loss': codebook_loss,
|
| 117 |
+
'commitment_loss': commmitment_loss
|
| 118 |
}
|
|
|
|
| 119 |
# Straight through estimation
|
| 120 |
+
quant_out = x + (quant_out - x).detach()
|
| 121 |
+
|
| 122 |
+
# quant_out -> B, C, H, W
|
| 123 |
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
|
| 124 |
+
min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
|
| 125 |
+
return quant_out, quantize_losses, min_encoding_indices
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def encode(self, x):
|
| 128 |
+
out = self.encoder_conv_in(x)
|
| 129 |
+
for idx, down in enumerate(self.encoder_layers):
|
| 130 |
out = down(out)
|
| 131 |
+
for mid in self.encoder_mids:
|
| 132 |
out = mid(out)
|
| 133 |
out = self.encoder_norm_out(out)
|
| 134 |
out = nn.SiLU()(out)
|
|
|
|
| 136 |
out = self.pre_quant_conv(out)
|
| 137 |
out, quant_losses, _ = self.quantize(out)
|
| 138 |
return out, quant_losses
|
| 139 |
+
|
| 140 |
def decode(self, z):
|
| 141 |
out = z
|
| 142 |
out = self.post_quant_conv(out)
|
| 143 |
out = self.decoder_conv_in(out)
|
| 144 |
+
for mid in self.decoder_mids:
|
| 145 |
out = mid(out)
|
| 146 |
+
for idx, up in enumerate(self.decoder_layers):
|
| 147 |
out = up(out)
|
| 148 |
+
|
| 149 |
out = self.decoder_norm_out(out)
|
| 150 |
+
out = nn.SiLU()(out)
|
| 151 |
out = self.decoder_conv_out(out)
|
| 152 |
return out
|
| 153 |
+
|
| 154 |
def forward(self, x):
|
| 155 |
z, quant_losses = self.encode(x)
|
| 156 |
out = self.decode(z)
|
train_vqvae.py
CHANGED
|
@@ -24,7 +24,6 @@ def train(args):
|
|
| 24 |
except yaml.YAMLError as e:
|
| 25 |
print(e)
|
| 26 |
|
| 27 |
-
|
| 28 |
autoencoder_config = config["autoencoder_params"]
|
| 29 |
train_config = config["train_config"]
|
| 30 |
dataset_config = config["dataset_config"]
|
|
@@ -84,11 +83,11 @@ def train(args):
|
|
| 84 |
|
| 85 |
# Image saving
|
| 86 |
if steps % img_save_steps == 0 or steps == 1:
|
| 87 |
-
sample_size = min(8,
|
| 88 |
save_output = torch.clamp(
|
| 89 |
output[:sample_size], -1., 1.).detach().cpu()
|
| 90 |
save_output = ((save_output + 1) / 2)
|
| 91 |
-
save_input = ((
|
| 92 |
|
| 93 |
grid = make_grid(
|
| 94 |
torch.cat([save_input, save_output], dim=0), nrow=sample_size)
|
|
@@ -97,8 +96,8 @@ def train(args):
|
|
| 97 |
os.mkdir(os.path.join(
|
| 98 |
train_config['task_name'], 'vqvae_autoencoder_samples'))
|
| 99 |
img.save(os.path.join(train_config['task_name'], 'vqvae_autoencoder_samples',
|
| 100 |
-
'current_autoencoder_sample_{}.png'.format(
|
| 101 |
-
|
| 102 |
img.close()
|
| 103 |
|
| 104 |
# Optimizing generator
|
|
|
|
| 24 |
except yaml.YAMLError as e:
|
| 25 |
print(e)
|
| 26 |
|
|
|
|
| 27 |
autoencoder_config = config["autoencoder_params"]
|
| 28 |
train_config = config["train_config"]
|
| 29 |
dataset_config = config["dataset_config"]
|
|
|
|
| 83 |
|
| 84 |
# Image saving
|
| 85 |
if steps % img_save_steps == 0 or steps == 1:
|
| 86 |
+
sample_size = min(8, im_tensor.shape[0])
|
| 87 |
save_output = torch.clamp(
|
| 88 |
output[:sample_size], -1., 1.).detach().cpu()
|
| 89 |
save_output = ((save_output + 1) / 2)
|
| 90 |
+
save_input = ((im_tensor[:sample_size] + 1) / 2).detach().cpu()
|
| 91 |
|
| 92 |
grid = make_grid(
|
| 93 |
torch.cat([save_input, save_output], dim=0), nrow=sample_size)
|
|
|
|
| 96 |
os.mkdir(os.path.join(
|
| 97 |
train_config['task_name'], 'vqvae_autoencoder_samples'))
|
| 98 |
img.save(os.path.join(train_config['task_name'], 'vqvae_autoencoder_samples',
|
| 99 |
+
'current_autoencoder_sample_{}.png'.format(img_saved)))
|
| 100 |
+
img_saved += 1
|
| 101 |
img.close()
|
| 102 |
|
| 103 |
# Optimizing generator
|