GreyCC99127 commited on
Commit
2b316b4
·
verified ·
1 Parent(s): 8c0d676

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. DTGM_model_167500.pt +3 -0
  2. eval.py +178 -0
  3. gdtls.py +505 -0
  4. models.py +273 -0
DTGM_model_167500.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac52dc0c74c44bac9506bec31e9d94cadd40dd44921dea65401bb79d4b3af308
3
+ size 1250282226
eval.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.cuda
3
+ from util.models import *
4
+ from gdtls import DTLS, Trainer
5
+ import argparse
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--device', default="cuda:0", type=str)
9
+ parser.add_argument('--hr_size', default=256, type=int, help="size of HR image")
10
+ parser.add_argument('--lr_size', default=2, type=int, help="size of LR image")
11
+ parser.add_argument('--interval_mode', default="fibonacci", type=str, help="linear; exp; fibonacci")
12
+ parser.add_argument('--stride', default=2, type=int, help="size change between each step if linear mode is used")
13
+ parser.add_argument('--train_steps', default=200001, type=int)
14
+ parser.add_argument('--lr_rate', default=2e-5, help="learning rate")
15
+ parser.add_argument('--sample_every_iterations', default=5000, type=int, help="sample SR images for every number of iterations")
16
+ parser.add_argument('--save_folder', default="DTGM_segnoise_4b_50k", type=str, help="Folder to save your train or evaluation result")
17
+ parser.add_argument('--load_path', default="DTGM_segnoise_4b/DTGM_model_165000.pt", type=str, help="None or directory to pretrained model")
18
+ parser.add_argument('--data_path', default='/hdda/Datasets/Face_super_resolution/images1024x1024/', type=str, help="directory to your training dataset")
19
+
20
+ parser.add_argument('--batch_size', default=1, type=int)
21
+
22
+ args = parser.parse_args()
23
+ device = args.device if torch.cuda.is_available() else "cpu"
24
+ size_list = [256, 64, 32, 16, 8, 4, 3, 2]
25
+ timestep = len(size_list) - 1
26
+
27
+ print(f"Total steps for {args.lr_size} to {args.hr_size}: {timestep}")
28
+
29
+ model = UNet().to(device)
30
+ discriminator = Discriminator().to(device)
31
+
32
+ dtls = DTLS(
33
+ model,
34
+ image_size = args.hr_size,
35
+ stride = args.stride,
36
+ size_list=size_list,
37
+ timesteps = timestep, # number of steps
38
+ device=device,
39
+ ).to(device)
40
+
41
+
42
+ trainer = Trainer(
43
+ dtls,
44
+ discriminator,
45
+ args.data_path,
46
+ image_size = args.hr_size,
47
+ train_batch_size = args.batch_size,
48
+ train_num_steps = args.train_steps, # total training steps
49
+ ema_decay = 0.995, # exponential moving average decay
50
+ results_folder = args.save_folder,
51
+ load_path = args.load_path,
52
+ device = device,
53
+ eval_mode=True,
54
+ save_and_sample_every = args.sample_every_iterations
55
+ )
56
+
57
+ if __name__ == "__main__":
58
+ trainer.evaluation()
59
+ trainer.fid(created_dataset=args.save_folder)
60
+
61
+
62
+
63
+
64
+ # import copy
65
+ # from pathlib import Path
66
+ # import torch
67
+ # from util.models import *
68
+ # import argparse
69
+ # import random
70
+ # import torch.nn.functional as F
71
+ # from torchvision import utils
72
+ # import os
73
+ # import errno
74
+ #
75
+ # parser = argparse.ArgumentParser()
76
+ # parser.add_argument('--device', default="cuda:1", type=str)
77
+ # parser.add_argument('--hr_size', default=256, type=int, help="size of HR image")
78
+ # parser.add_argument('--lr_size', default=2, type=int, help="size of LR image")
79
+ # parser.add_argument('--num_sample', default=50000, type=str, help="Number of image to generate")
80
+ # parser.add_argument('--save_folder', default="DTGM_Xeii_lpips_FM_ii_80kpt_50k", type=str, help="Folder to save your train or evaluation result")
81
+ # parser.add_argument('--load_path', default="DTGM_Xeii_lpips_FM_ii/GDTLS_80000.pt", type=str, help="None or directory to pretrained model")
82
+ #
83
+ # def create_folder(path):
84
+ # try:
85
+ # os.mkdir(path)
86
+ # except OSError as exc:
87
+ # if exc.errno != errno.EEXIST:
88
+ # raise
89
+ # pass
90
+ #
91
+ # def transform_func_sample(img, target_size):
92
+ # n = target_size
93
+ # m = args.hr_size
94
+ #
95
+ # if m / n > 16:
96
+ # img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
97
+ # img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
98
+ # img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
99
+ # else:
100
+ # img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
101
+ # img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
102
+ #
103
+ # return img_1
104
+ #
105
+ # def transform_func_noise(img, device_, target_size, fixed_std=True):
106
+ # n = target_size
107
+ # m = args.hr_size
108
+ #
109
+ # random_mean = torch.rand(1).add(-0.5).item()
110
+ # if fixed_std:
111
+ # random_std = 0.5
112
+ # else:
113
+ # random_std = torch.rand(1).mul(0.5).item()
114
+ # decreasing_scale = 0.9 ** (n - 2)
115
+ #
116
+ # if m / n > 16:
117
+ # img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
118
+ # img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
119
+ # img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
120
+ # else:
121
+ # img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
122
+ #
123
+ # # noise = torch.normal(mean=random_mean, std=random_std, size=(img_1.shape[0], 3, 2, 2)).to(self.device)
124
+ # # noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True)
125
+ # noise = torch.normal(mean=random_mean, std=random_std, size=img_1.shape).to(device_)
126
+ #
127
+ # img_1 += noise * decreasing_scale
128
+ # img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
129
+ #
130
+ # noise_refinement = torch.normal(mean=0, std=0.05, size=img_1.shape).to(device_)
131
+ # return img_1 + noise_refinement
132
+ #
133
+ # def random_vector(batch_size):
134
+ # mean = random.uniform(-0.5, 0.5)
135
+ # std = random.uniform(0.1, 0.3)
136
+ # vector = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
137
+ # for colors in range(2):
138
+ # mean = random.uniform(-0.5, 0.5)
139
+ # std = random.uniform(0.1, 0.3)
140
+ # rgb = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
141
+ # vector = torch.cat((vector, rgb), dim=1)
142
+ # return vector
143
+ #
144
+ # def sample(size_list_, device_, model_, batch_size=1, img=None, t=None):
145
+ # blur_img = transform_func_sample(img.clone(), size_list_[t])
146
+ # img_t = blur_img.clone()
147
+ #
148
+ # ####### Domain Transfer
149
+ # while t:
150
+ # next_step = size_list_[t - 1]
151
+ # step = torch.full((batch_size,), t, dtype=torch.long).to(device_)
152
+ # R_x = model_(img_t, step)
153
+ # if t == 1:
154
+ # return R_x
155
+ # else:
156
+ # img_t = transform_func_noise(R_x, device_, next_step, fixed_std=True)
157
+ # t -= 1
158
+ # return img_t
159
+ #
160
+ # if __name__ == "__main__":
161
+ # args = parser.parse_args()
162
+ # results_folder = Path(args.save_folder)
163
+ # results_folder.mkdir(exist_ok=True)
164
+ # device = args.device if torch.cuda.is_available() else "cpu"
165
+ # size_list = [256, 64, 32, 16, 8, 6, 4, 3, 2]
166
+ # timestep = len(size_list) - 1
167
+ # print(f"Total steps for {args.lr_size} to {args.hr_size}: {timestep}")
168
+ # model = UNet().to(device)
169
+ #
170
+ # if args.load_path is not None:
171
+ # data = torch.load(args.load_path, map_location=device)
172
+ # model.load_state_dict(data['ema'], strict=False)
173
+ #
174
+ # for i in range(args.num_sample):
175
+ # input_vector = random_vector(1).to(device)
176
+ # sample_hr = sample(size_list, device, model, batch_size=1, img=input_vector, t=timestep)
177
+ # utils.save_image(sample_hr.add(1).mul(0.5), f"{results_folder}/result_{i}.png", nrow=1)
178
+ # print("saving ", i)
gdtls.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import os
5
+ import errno
6
+ import torch
7
+ import shutil
8
+ import wandb
9
+ import random
10
+ import time
11
+ import lpips
12
+
13
+ from torch import nn
14
+ from torch.utils import data
15
+ from pathlib import Path
16
+ from torch.optim import Adam, AdamW
17
+ from torchvision import transforms, utils
18
+ from torchvision.transforms import InterpolationMode
19
+ from torchvision.transforms.v2 import RandomResize
20
+ from PIL import Image
21
+ from util.fid_score import calculate_fid_given_paths
22
+
23
+ try:
24
+ from apex import amp
25
+ APEX_AVAILABLE = True
26
+ except:
27
+ APEX_AVAILABLE = False
28
+
29
+ ####### helpers functions
30
+
31
+ def create_folder(path):
32
+ try:
33
+ os.mkdir(path)
34
+ except OSError as exc:
35
+ if exc.errno != errno.EEXIST:
36
+ raise
37
+ pass
38
+
39
+ def del_folder(path):
40
+ try:
41
+ shutil.rmtree(path)
42
+ except OSError as exc:
43
+ pass
44
+
45
+ def cycle(dl):
46
+ while True:
47
+ for data in dl:
48
+ yield data
49
+
50
+ def num_to_groups(num, divisor):
51
+ groups = num // divisor
52
+ remainder = num % divisor
53
+ arr = [divisor] * groups
54
+ if remainder > 0:
55
+ arr.append(remainder)
56
+ return arr
57
+
58
+ def loss_backwards(fp16, loss, optimizer, **kwargs):
59
+ if fp16:
60
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
61
+ scaled_loss.backward(**kwargs)
62
+ else:
63
+ loss.backward(**kwargs)
64
+
65
+ # small helper modules
66
+ def rand_bbox(size, lam):
67
+ W = size[2]
68
+ H = size[3]
69
+ cut_rat = np.sqrt(1. - lam)
70
+ cut_w = np.int(W * cut_rat)
71
+ cut_h = np.int(H * cut_rat)
72
+
73
+ # uniform
74
+ cx = np.random.randint(W)
75
+ cy = np.random.randint(H)
76
+
77
+ bbx1 = np.clip(cx - cut_w // 2, 0, W)
78
+ bby1 = np.clip(cy - cut_h // 2, 0, H)
79
+ bbx2 = np.clip(cx + cut_w // 2, 0, W)
80
+ bby2 = np.clip(cy + cut_h // 2, 0, H)
81
+
82
+ return bbx1, bby1, bbx2, bby2
83
+
84
+
85
+ def Huber(input, target, delta=0.1, reduce=True):
86
+ abs_error = torch.abs(input - target)
87
+ quadratic = torch.clamp(abs_error, max=delta)
88
+
89
+ # The following expression is the same in value as
90
+ # tf.maximum(abs_error - delta, 0), but importantly the gradient for the
91
+ # expression when abs_error == delta is 0 (for tf.maximum it would be 1).
92
+ # This is necessary to avoid doubling the gradient, since there is already a
93
+ # nonzero contribution to the gradient from the quadratic term.
94
+ linear = (abs_error - quadratic)
95
+ losses = 0.5 * torch.pow(quadratic, 2) + delta * linear
96
+
97
+ if reduce:
98
+ return torch.mean(losses)
99
+ else:
100
+ return losses
101
+
102
+ class EMA():
103
+ def __init__(self, beta):
104
+ super().__init__()
105
+ self.beta = beta
106
+
107
+ def update_model_average(self, ma_model, current_model):
108
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
109
+ old_weight, up_weight = ma_params.data, current_params.data
110
+ ma_params.data = self.update_average(old_weight, up_weight)
111
+
112
+ def update_average(self, old, new):
113
+ if old is None:
114
+ return new
115
+ return old * self.beta + (1 - self.beta) * new
116
+
117
+
118
+ class DTLS(nn.Module):
119
+ def __init__(
120
+ self,
121
+ model,
122
+ *,
123
+ image_size,
124
+ size_list,
125
+ stride,
126
+ timesteps,
127
+ device,
128
+ stochastic=False,
129
+ ):
130
+ super().__init__()
131
+ self.image_size = image_size
132
+ self.UNet = model
133
+
134
+ self.num_timesteps = int(timesteps)
135
+ self.size_list = size_list
136
+ self.stride = stride
137
+ self.device = device
138
+ self.MSE_loss = nn.MSELoss()
139
+ # self.vgg_loss = Vgg19()
140
+ self.lpips_loss = lpips.LPIPS(net='alex')
141
+
142
+ def transform_func_loss(self, img, target_size):
143
+ n = target_size
144
+ m = self.image_size
145
+
146
+ if m/n > 16:
147
+ img_1 = F.interpolate(img, size=m//4, mode='bicubic', antialias=True)
148
+ img_1 = F.interpolate(img_1, size=m//8, mode='bicubic', antialias=True)
149
+ img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
150
+ else:
151
+ img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
152
+
153
+ return img_1
154
+
155
+ def transform_func_sample(self, img, target_size):
156
+ n = target_size
157
+ m = self.image_size
158
+
159
+ if m/n > 16:
160
+ img_1 = F.interpolate(img, size=m//4, mode='bicubic', antialias=True)
161
+ img_1 = F.interpolate(img_1, size=m//8, mode='bicubic', antialias=True)
162
+ img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
163
+ else:
164
+ img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
165
+ img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
166
+
167
+ return img_1
168
+
169
+ def transform_func_noise(self, img, target_size, std_eval = False):
170
+ n = target_size
171
+ m = self.image_size
172
+
173
+ random_mean = torch.rand(1).add(-.5).item()
174
+ # random_std = torch.rand(1).mul(0.5).item()
175
+ decreasing_scale = 0.9 ** (n - 2)
176
+
177
+ if m / n > 16:
178
+ img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
179
+ img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
180
+ img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
181
+ else:
182
+ img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
183
+
184
+ noise = torch.normal(mean=random_mean, std=0.5, size=(img_1.shape[0], 3, 2, 2)).to(self.device)
185
+ noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True)
186
+ img_1 += noise * decreasing_scale
187
+ img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
188
+
189
+ if n >= 16:
190
+ noise_refinement = torch.normal(mean=0, std=1, size=img_1.shape).to(self.device)
191
+ img_1 = img_1 + noise_refinement * decreasing_scale
192
+ return img_1
193
+
194
+
195
+ @torch.no_grad()
196
+ def sample(self, batch_size=16, img=None, t=None, save_folder=None):
197
+ if t == None:
198
+ t = self.num_timesteps
199
+ blur_img = self.transform_func_sample(img.clone(), self.size_list[t])
200
+ img_t = blur_img.clone()
201
+ ####### Domain Transfer
202
+ while (t):
203
+ next_step = self.size_list[t-1]
204
+ step = torch.full((batch_size,), t, dtype=torch.long).to(self.device)
205
+ R_x = self.UNet(img_t, step)
206
+ if t == 1:
207
+ return blur_img, R_x
208
+ else:
209
+ img_t = self.transform_func_noise(R_x, next_step)
210
+ t -= 1
211
+ return blur_img, img_t
212
+
213
+
214
+ def p_losses(self, x_start, t):
215
+ x_blur = x_start.clone()
216
+
217
+ for i in range(t.shape[0]):
218
+ current_step = self.size_list[t[i]]
219
+ x_blur[i] = self.transform_func_noise(x_blur[i].unsqueeze(0), current_step)
220
+
221
+ x_recon = self.UNet(x_blur, t)
222
+
223
+ ### Pattern Domain Similarity Loss
224
+ x_clone = x_recon.clone()
225
+ for i in range(t.shape[0]):
226
+ current_step = self.size_list[t[i]]
227
+ x_clone[i] = self.transform_func_sample(x_recon[i].unsqueeze(0), current_step)
228
+
229
+ ### Lowest Pattern Domain Similarity Loss
230
+ # x_clone = x_recon.clone()
231
+ # x_clone = self.transform_func_loss(x_clone, self.size_list[-1])
232
+ # x_blur = self.transform_func_loss(x_blur, self.size_list[-1])
233
+
234
+ loss = self.MSE_loss(x_clone, x_blur)
235
+ # lpips_loss = self.lpips_loss(x_recon, x_start).mean()
236
+ return loss, x_recon
237
+
238
+ def forward(self, x, *args, **kwargs):
239
+ b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
240
+ assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
241
+ t = torch.randint(1, self.num_timesteps + 1, (b,), device=device).long()
242
+ return self.p_losses(x, t, *args, **kwargs)
243
+
244
+ # dataset classes
245
+
246
+ class Dataset(data.Dataset):
247
+ def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png']):
248
+ super().__init__()
249
+ self.folder = folder
250
+ self.image_size = image_size
251
+ self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
252
+
253
+ self.transform = transforms.Compose([
254
+ RandomResize(int(image_size), int(image_size*1.2), interpolation=InterpolationMode.BICUBIC, antialias=True),
255
+ transforms.RandomCrop(image_size),
256
+ transforms.RandomHorizontalFlip(),
257
+ transforms.ToTensor(),
258
+ transforms.Lambda(lambda t: (t * 2) - 1)
259
+ ])
260
+
261
+ def __len__(self):
262
+ return len(self.paths)
263
+
264
+ def __getitem__(self, index):
265
+ path = self.paths[index]
266
+ img = Image.open(path)
267
+ return self.transform(img)
268
+
269
+ # trainer class
270
+
271
+ class Trainer(object):
272
+ def __init__(
273
+ self,
274
+ diffusion_model,
275
+ discriminator,
276
+ folder,
277
+ *,
278
+ ema_decay = 0.9925,
279
+ image_size = 128,
280
+ train_batch_size = 32,
281
+ train_num_steps = 200000,
282
+ step_start_ema = 500,
283
+ update_ema_every = 10,
284
+ save_and_sample_every = 1000,
285
+ results_folder,
286
+ load_path = None,
287
+ shuffle=True,
288
+ eval_mode=False,
289
+ device,
290
+ ):
291
+ super().__init__()
292
+
293
+ ########## Wandb ##########
294
+ if not eval_mode:
295
+ wandb.init(project="DTGM", notes=str(results_folder), name=results_folder)
296
+ self.results_folder = Path(results_folder)
297
+ self.results_folder.mkdir(exist_ok = True)
298
+
299
+ self.device = device
300
+
301
+ self.model = diffusion_model
302
+ self.discriminator = discriminator
303
+ self.model_size()
304
+
305
+ self.ema = EMA(ema_decay)
306
+ self.ema_model = copy.deepcopy(self.model)
307
+
308
+ self.update_ema_every = update_ema_every
309
+ self.step_start_ema = step_start_ema
310
+
311
+ self.save_and_sample_every = save_and_sample_every
312
+
313
+ self.image_size = diffusion_model.image_size
314
+ self.batch_size = train_batch_size
315
+ self.train_num_steps = train_num_steps
316
+ self.nrow = train_batch_size // 2
317
+
318
+ self.folder_path = folder
319
+ self.ds = Dataset(folder, image_size)
320
+
321
+ self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=shuffle, pin_memory=True, num_workers=2))
322
+
323
+ self.opt = AdamW(diffusion_model.parameters(), lr=2e-5, betas=(0.0, 0.9), eps=1e-8)
324
+ self.opt_d = AdamW(self.discriminator.parameters(), lr=5e-5, betas=(0.0, 0.9), eps=1e-8)
325
+
326
+ self.BCE_loss = torch.nn.BCEWithLogitsLoss()
327
+
328
+ self.step = 0
329
+ self.reset_parameters()
330
+ self.best_quality = 0
331
+
332
+ self.loss_dis_false_temp = 0
333
+ self.loss_dis_true_temp = 0
334
+
335
+ self.load_path = load_path
336
+ self.n_mix = 0
337
+ self.fid_list = []
338
+ with open(f'{self.results_folder}/fid.txt', 'w') as f:
339
+ for a in self.fid_list:
340
+ f.write(f'{self.step} {a}\n')
341
+
342
+ def reset_parameters(self):
343
+ self.ema_model.load_state_dict(self.model.state_dict())
344
+
345
+ def step_ema(self):
346
+ if self.step < self.step_start_ema:
347
+ self.reset_parameters()
348
+ return
349
+ self.ema.update_model_average(self.ema_model, self.model)
350
+
351
+ def model_size(self):
352
+ param_size = 0
353
+ for param in self.model.parameters():
354
+ param_size += param.nelement() * param.element_size()
355
+ buffer_size = 0
356
+ for buffer in self.model.buffers():
357
+ buffer_size += buffer.nelement() * buffer.element_size()
358
+
359
+ size_all_mb = (param_size + buffer_size) / 1024**2
360
+ line = ('model size: {:.3f}MB'.format(size_all_mb))
361
+ print(line)
362
+
363
+ def save_ckpt(self):
364
+ data = {
365
+ 'step': self.step,
366
+ 'model': self.model.state_dict(),
367
+ 'ema': self.ema_model.state_dict(),
368
+ 'dis': self.discriminator.state_dict(),
369
+ }
370
+ torch.save(data, str(self.results_folder / f'DTGM_ckpt_{self.step}.pt'))
371
+
372
+ def save_model(self):
373
+ data = {
374
+ 'ema': self.ema_model.state_dict(),
375
+ }
376
+ torch.save(data, str(self.results_folder / f'DTGM_model_{self.step}.pt'))
377
+
378
+
379
+ def load_all(self, load_path):
380
+ print("Loading : ", load_path)
381
+ data = torch.load(load_path, map_location=self.device)
382
+
383
+ self.step = data['step']
384
+ self.model.load_state_dict(data['model'], strict=False)
385
+ self.ema_model.load_state_dict(data['ema'], strict=False)
386
+ self.discriminator.load_state_dict(data['dis'], strict=False)
387
+
388
+ def load_for_eval(self, load_path):
389
+ data = torch.load(load_path, map_location=self.device)
390
+ self.ema_model.load_state_dict(data['ema'], strict=False)
391
+
392
+ def train(self):
393
+ if self.load_path is not None:
394
+ self.load_all(self.load_path)
395
+
396
+ while self.step < self.train_num_steps:
397
+ start_time = time.time()
398
+ data = next(self.dl)
399
+ data = data.to(self.device)
400
+
401
+ loss_domain_sim, x_recon = self.model(data)
402
+
403
+ self.opt_d.zero_grad()
404
+ score_true = self.discriminator(data)
405
+ GAN_true = torch.ones_like(score_true)
406
+ loss_dis_true = self.BCE_loss(score_true, GAN_true)
407
+ loss_dis_true.backward()
408
+
409
+ score_false = self.discriminator(x_recon.detach())
410
+ GAN_false = torch.zeros_like(score_false)
411
+ loss_dis_false = self.BCE_loss(score_false, GAN_false)
412
+ loss_dis_false.backward()
413
+ self.opt_d.step()
414
+
415
+ self.loss_dis_false_temp = loss_dis_false.item()
416
+ self.loss_dis_true_temp = loss_dis_true.item()
417
+
418
+ self.opt.zero_grad()
419
+ score_fake = self.discriminator(x_recon)
420
+ GAN_fake = torch.ones_like(score_fake)
421
+ loss_gen = self.BCE_loss(score_fake, GAN_fake) * 1e-2
422
+ (loss_gen + loss_domain_sim).backward()
423
+ self.opt.step()
424
+
425
+ if self.step % 10 == 0:
426
+ print(f'{self.step} DTLS: Total loss: {loss_domain_sim.item() + loss_gen.item()} | Domain sim: {loss_domain_sim.item()} '
427
+ f'| Generate: {loss_gen.item()} '
428
+ f'| Dis real: {self.loss_dis_true_temp} | Dis false: {self.loss_dis_false_temp}')
429
+ # f'| Features Matching loss: {loss_FM.item()}')
430
+
431
+ wandb.log({"Total loss": loss_domain_sim.item() + loss_gen.item(), "Domain Similarity Loss": loss_domain_sim.item(),
432
+ "Generation loss": loss_gen.item(),
433
+ "Discriminator loss (real)": self.loss_dis_true_temp,
434
+ "Discriminator loss (fake)": self.loss_dis_false_temp,}, step=self.step)
435
+ # "Feature Matching loss": loss_FM.item() "LPIPS loss": lpips_loss.item()
436
+
437
+ if self.step % self.update_ema_every == 0:
438
+ self.step_ema()
439
+
440
+ if self.step == 0 or self.step % self.save_and_sample_every == 0:
441
+ lr_real, sr_real = self.ema_model.sample(batch_size=self.batch_size, img=data)
442
+ _, sr_real_ii = self.model.sample(batch_size=self.batch_size, img=data)
443
+
444
+ save_img = torch.cat((sr_real_ii, lr_real, data, sr_real),dim=0)
445
+ utils.save_image((save_img+1)/2, str(self.results_folder / f'{self.step}_GDTLS.png'), nrow=self.nrow)
446
+
447
+ wandb.log({"Checkpoint result": wandb.Image(str(self.results_folder / f'{self.step}_GDTLS.png'))})
448
+ if self.step >= 100000 and self.step % 2500 == 0:
449
+ self.save_model()
450
+ self.validate()
451
+ if self.step != 0 and self.step % 10000 == 0:
452
+ self.save_ckpt()
453
+
454
+ self.step += 1
455
+ print('training completed')
456
+ wandb.finish()
457
+
458
+ def validate(self):
459
+ folder_name = f"temp_samples_{self.device.split(':')[-1]}"
460
+ create_folder(folder_name)
461
+ for i in range(2000):
462
+ random_vector = self.random_vector(1)
463
+ _, sample_hr = self.ema_model.sample(batch_size=1, img=random_vector, save_folder=self.results_folder)
464
+ utils.save_image((sample_hr + 1) /2, f"{folder_name}/result_{i}.png", nrow=1)
465
+ fid = calculate_fid_given_paths([folder_name, "/hdda/Datasets/ffhq256_mini"], 200, self.device, dims=2048, num_workers=4)
466
+ self.fid_list.append(f"{self.step} {fid.item()}")
467
+ with open(f'{self.results_folder}/fid.txt', 'w') as f:
468
+ for a in self.fid_list:
469
+ f.write(f'{a}\n')
470
+ wandb.log({"FID score": fid.item()})
471
+ del_folder(folder_name)
472
+
473
+ def random_vector(self, batch_size):
474
+ mean = random.uniform(-0.75, 0.75)
475
+ std = random.uniform(0.01, 0.5)
476
+ vector = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
477
+ for i in range(2):
478
+ mean = random.uniform(-0.75, 0.75)
479
+ std = random.uniform(0.01, 0.5)
480
+ rgb = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
481
+ vector = torch.cat((vector, rgb), dim=1)
482
+ return vector.to(self.device)
483
+
484
+
485
+ def evaluation(self, num_sample=50000, batch_size=16):
486
+ if self.load_path != None:
487
+ self.load_for_eval(self.load_path)
488
+ img_count = 1
489
+ while img_count <= num_sample:
490
+ # for i in range(num_sample):
491
+ random_vector = self.random_vector(batch_size)
492
+ _, sample_hr = self.ema_model.sample(batch_size=batch_size, img=random_vector, save_folder=self.results_folder)
493
+ for img in sample_hr:
494
+ if img_count <= num_sample:
495
+ utils.save_image((img + 1) /2, str(self.results_folder / f'result_{img_count}.png'), nrow=1)
496
+ print("saving ", img_count)
497
+ img_count += 1
498
+
499
+ # utils.save_image((blur_img_set + 1) /2, str(self.results_folder / f'random_vector_{i}.png'), nrow=4)
500
+ # utils.save_image((sample_hr + 1) /2, str(self.results_folder / f'result_{i}.png'), nrow=1)
501
+ # print("saving ", i)
502
+
503
+ def fid(self, created_dataset, realistic_dataset="/hdda/Datasets/ffhq256"):
504
+ fid = calculate_fid_given_paths([created_dataset, realistic_dataset], 500, self.device, dims=2048, num_workers=4)
505
+ print("FID Score 50k: ", fid)
models.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange
5
+ from inspect import isfunction
6
+
7
+ def weights_init(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find('Conv') != -1:
10
+ try:
11
+ m.weight.data.normal_(0.0, 0.02)
12
+ except:
13
+ pass
14
+ elif classname.find('BatchNorm') != -1:
15
+ m.weight.data.normal_(1.0, 0.02)
16
+ m.bias.data.fill_(0)
17
+
18
+
19
+ def exists(x):
20
+ return x is not None
21
+
22
+ def default(val, d):
23
+ if exists(val):
24
+ return val
25
+ return d() if isfunction(d) else d
26
+
27
+
28
+ class Residual(nn.Module):
29
+ def __init__(self, fn):
30
+ super().__init__()
31
+ self.fn = fn
32
+
33
+ def forward(self, x, *args, **kwargs):
34
+ return self.fn(x, *args, **kwargs) + x
35
+
36
+ class SinusoidalPosEmb(nn.Module):
37
+ def __init__(self, dim):
38
+ super().__init__()
39
+ self.dim = dim
40
+
41
+ def forward(self, x):
42
+ device = x.device
43
+ half_dim = self.dim // 2
44
+ emb = math.log(10000) / (half_dim - 1)
45
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
46
+ emb = x[:, None] * emb[None, :]
47
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
48
+ return emb
49
+
50
+
51
+ class LayerNorm(nn.Module):
52
+ def __init__(self, dim, eps = 1e-5):
53
+ super().__init__()
54
+ self.eps = eps
55
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
56
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
57
+
58
+ def forward(self, x):
59
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
60
+ mean = torch.mean(x, dim = 1, keepdim = True)
61
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
62
+
63
+ class PreNorm(nn.Module):
64
+ def __init__(self, dim, fn):
65
+ super().__init__()
66
+ self.fn = fn
67
+ self.norm = LayerNorm(dim)
68
+ # self.norm = nn.BatchNorm2d(dim)
69
+ # self.norm = nn.GroupNorm(dim // 32, dim)
70
+
71
+ def forward(self, x):
72
+ x = self.norm(x)
73
+ return self.fn(x)
74
+
75
+ # building block modules
76
+
77
+
78
+ class ConvNextBlock(nn.Module):
79
+ """ https://arxiv.org/abs/2201.03545 """
80
+
81
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
82
+ super().__init__()
83
+ self.mlp = nn.Sequential(
84
+ nn.GELU(),
85
+ nn.Linear(time_emb_dim, dim*2)
86
+ ) if exists(time_emb_dim) else None
87
+
88
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
89
+
90
+ self.net = nn.Sequential(
91
+ LayerNorm(dim) if norm else nn.Identity(),
92
+ nn.Conv2d(dim, dim_out * mult, 3, 1, 1),
93
+ nn.GELU(),
94
+ nn.Conv2d(dim_out * mult, dim_out, 3, 1, 1),
95
+ )
96
+
97
+ # self.noise_adding = NoiseInjection(dim_out)
98
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
99
+
100
+ def forward(self, x, time_emb = None):
101
+ h = self.ds_conv(x)
102
+
103
+ if exists(self.mlp):
104
+ assert exists(time_emb), 'time emb must be passed in'
105
+ condition = self.mlp(time_emb)
106
+ condition = rearrange(condition, 'b c -> b c 1 1')
107
+ weight, bias = torch.split(condition, x.shape[1],dim=1)
108
+ h = h * (1 + weight) + bias
109
+
110
+ h = self.net(h)
111
+ # h = self.noise_adding(h)
112
+ return h + self.res_conv(x)
113
+
114
+
115
+ class ConvNextBlock_dis(nn.Module):
116
+ """ https://arxiv.org/abs/2201.03545 """
117
+
118
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
119
+ super().__init__()
120
+ self.mlp = nn.Sequential(
121
+ nn.GELU(),
122
+ nn.Linear(time_emb_dim, dim*2)
123
+ ) if exists(time_emb_dim) else None
124
+
125
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
126
+
127
+ self.net = nn.Sequential(
128
+ nn.BatchNorm2d(dim) if norm else nn.Identity(),
129
+ # LayerNorm(dim) if norm else nn.Identity(),
130
+ nn.Conv2d(dim, dim_out * mult, 3, 1, 1),
131
+ nn.GELU(),
132
+ nn.Conv2d(dim_out * mult, dim_out, 3, 1, 1),
133
+ )
134
+
135
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
136
+
137
+ def forward(self, x):
138
+ h = self.ds_conv(x)
139
+ h = self.net(h)
140
+ return h + self.res_conv(x)
141
+
142
+
143
+
144
+ class LinearAttention(nn.Module):
145
+ def __init__(self, dim, heads = 4, dim_head = 32):
146
+ super().__init__()
147
+ self.scale = dim_head ** -0.5
148
+ self.heads = heads
149
+ hidden_dim = dim_head * heads
150
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
151
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
152
+
153
+ def forward(self, x):
154
+ b, c, h, w = x.shape
155
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
156
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
157
+ q = q * self.scale
158
+
159
+ k = k.softmax(dim = -1)
160
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
161
+
162
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
163
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
164
+ return self.to_out(out)
165
+
166
+
167
+ # model
168
+ class UNet(nn.Module):
169
+ def __init__(
170
+ self,
171
+ dim = 32,
172
+ dim_mults=(1, 2, 4, 8, 16, 32, 32),
173
+ channels = 3,
174
+ ):
175
+ super().__init__()
176
+ self.channels = dim
177
+
178
+ dims = [dim, *map(lambda m: dim * m, dim_mults)]
179
+ in_out = list(zip(dims[:-1], dims[1:]))
180
+ self.model_depth = len(dim_mults)
181
+
182
+ time_dim = dim
183
+ self.time_mlp = nn.Sequential(
184
+ SinusoidalPosEmb(dim),
185
+ nn.Linear(dim, dim * 2),
186
+ nn.GELU(),
187
+ nn.Linear(dim * 2, dim)
188
+ )
189
+
190
+ self.downs = nn.ModuleList([])
191
+ self.ups = nn.ModuleList([])
192
+
193
+ num_resolutions = len(in_out)
194
+
195
+ self.initial = nn.Conv2d(channels, dim, 7,1,3, bias=False)
196
+
197
+ for ind, (dim_in, dim_out) in enumerate(in_out):
198
+ self.downs.append(nn.ModuleList([
199
+ ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
200
+ nn.AvgPool2d(2),
201
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))) if ind >= (num_resolutions - 3) else nn.Identity(),
202
+ ConvNextBlock(dim_out, dim_out, time_emb_dim=time_dim),
203
+ ]))
204
+
205
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
206
+ self.ups.append(nn.ModuleList([
207
+ ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
208
+ nn.Upsample(scale_factor=2, mode='nearest'),
209
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))) if ind < 3 else nn.Identity(),
210
+ ConvNextBlock(dim_in, dim_in, time_emb_dim=time_dim),
211
+ ]))
212
+
213
+ self.final_conv = nn.Conv2d(dim, 3, 1, bias=False)
214
+
215
+ def forward(self, x, time):
216
+ x = self.initial(x)
217
+ t = self.time_mlp(time) if exists(self.time_mlp) else None
218
+ h = []
219
+ for convnext, downsample, attn, convnext2 in self.downs:
220
+ x = convnext(x, t)
221
+ x = downsample(x)
222
+ h.append(x)
223
+ x = attn(x)
224
+ x = convnext2(x, t)
225
+
226
+ for convnext, upsample, attn, convnext2 in self.ups:
227
+ x = torch.cat((x, h.pop()), dim=1)
228
+ x = convnext(x, t)
229
+ x = upsample(x)
230
+ x = attn(x)
231
+ x = convnext2(x, t)
232
+
233
+ return self.final_conv(x)
234
+
235
+
236
+ class Discriminator(nn.Module):
237
+ def __init__(
238
+ self,
239
+ dim=32,
240
+ dim_mults=(1, 2, 4, 8, 16, 32, 32),
241
+ channels=3,
242
+ with_time_emb=True,
243
+ ):
244
+ super().__init__()
245
+ self.channels = dim
246
+
247
+ dims = [dim, *map(lambda m: dim * m, dim_mults)]
248
+ in_out = list(zip(dims[:-1], dims[1:]))
249
+ self.model_depth = len(dim_mults)
250
+
251
+ self.downs = nn.ModuleList([])
252
+ num_resolutions = len(in_out)
253
+
254
+ self.initial = nn.Conv2d(channels, dim, 7,1,3, bias=False)
255
+
256
+ for ind, (dim_in, dim_out) in enumerate(in_out):
257
+ is_last = ind >= (num_resolutions - 1)
258
+ self.downs.append(nn.ModuleList([
259
+ ConvNextBlock_dis(dim_in, dim_out, norm=ind != 0),
260
+ nn.AvgPool2d(2),
261
+ ConvNextBlock_dis(dim_out, dim_out),
262
+ ]))
263
+ dim_out = dim_mults[-1] * dim
264
+
265
+ self.out = nn.Conv2d(dim_out, 1, 1, bias=False)
266
+
267
+ def forward(self, x):
268
+ x = self.initial(x)
269
+ for convnext, downsample, convnext2 in self.downs:
270
+ x = convnext(x)
271
+ x = downsample(x)
272
+ x = convnext2(x)
273
+ return self.out(x).view(x.shape[0], -1)