Upload folder using huggingface_hub
Browse files- DTGM_model_167500.pt +3 -0
- eval.py +178 -0
- gdtls.py +505 -0
- models.py +273 -0
DTGM_model_167500.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ac52dc0c74c44bac9506bec31e9d94cadd40dd44921dea65401bb79d4b3af308
|
| 3 |
+
size 1250282226
|
eval.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch.cuda
|
| 3 |
+
from util.models import *
|
| 4 |
+
from gdtls import DTLS, Trainer
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument('--device', default="cuda:0", type=str)
|
| 9 |
+
parser.add_argument('--hr_size', default=256, type=int, help="size of HR image")
|
| 10 |
+
parser.add_argument('--lr_size', default=2, type=int, help="size of LR image")
|
| 11 |
+
parser.add_argument('--interval_mode', default="fibonacci", type=str, help="linear; exp; fibonacci")
|
| 12 |
+
parser.add_argument('--stride', default=2, type=int, help="size change between each step if linear mode is used")
|
| 13 |
+
parser.add_argument('--train_steps', default=200001, type=int)
|
| 14 |
+
parser.add_argument('--lr_rate', default=2e-5, help="learning rate")
|
| 15 |
+
parser.add_argument('--sample_every_iterations', default=5000, type=int, help="sample SR images for every number of iterations")
|
| 16 |
+
parser.add_argument('--save_folder', default="DTGM_segnoise_4b_50k", type=str, help="Folder to save your train or evaluation result")
|
| 17 |
+
parser.add_argument('--load_path', default="DTGM_segnoise_4b/DTGM_model_165000.pt", type=str, help="None or directory to pretrained model")
|
| 18 |
+
parser.add_argument('--data_path', default='/hdda/Datasets/Face_super_resolution/images1024x1024/', type=str, help="directory to your training dataset")
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--batch_size', default=1, type=int)
|
| 21 |
+
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
device = args.device if torch.cuda.is_available() else "cpu"
|
| 24 |
+
size_list = [256, 64, 32, 16, 8, 4, 3, 2]
|
| 25 |
+
timestep = len(size_list) - 1
|
| 26 |
+
|
| 27 |
+
print(f"Total steps for {args.lr_size} to {args.hr_size}: {timestep}")
|
| 28 |
+
|
| 29 |
+
model = UNet().to(device)
|
| 30 |
+
discriminator = Discriminator().to(device)
|
| 31 |
+
|
| 32 |
+
dtls = DTLS(
|
| 33 |
+
model,
|
| 34 |
+
image_size = args.hr_size,
|
| 35 |
+
stride = args.stride,
|
| 36 |
+
size_list=size_list,
|
| 37 |
+
timesteps = timestep, # number of steps
|
| 38 |
+
device=device,
|
| 39 |
+
).to(device)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
trainer = Trainer(
|
| 43 |
+
dtls,
|
| 44 |
+
discriminator,
|
| 45 |
+
args.data_path,
|
| 46 |
+
image_size = args.hr_size,
|
| 47 |
+
train_batch_size = args.batch_size,
|
| 48 |
+
train_num_steps = args.train_steps, # total training steps
|
| 49 |
+
ema_decay = 0.995, # exponential moving average decay
|
| 50 |
+
results_folder = args.save_folder,
|
| 51 |
+
load_path = args.load_path,
|
| 52 |
+
device = device,
|
| 53 |
+
eval_mode=True,
|
| 54 |
+
save_and_sample_every = args.sample_every_iterations
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
trainer.evaluation()
|
| 59 |
+
trainer.fid(created_dataset=args.save_folder)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# import copy
|
| 65 |
+
# from pathlib import Path
|
| 66 |
+
# import torch
|
| 67 |
+
# from util.models import *
|
| 68 |
+
# import argparse
|
| 69 |
+
# import random
|
| 70 |
+
# import torch.nn.functional as F
|
| 71 |
+
# from torchvision import utils
|
| 72 |
+
# import os
|
| 73 |
+
# import errno
|
| 74 |
+
#
|
| 75 |
+
# parser = argparse.ArgumentParser()
|
| 76 |
+
# parser.add_argument('--device', default="cuda:1", type=str)
|
| 77 |
+
# parser.add_argument('--hr_size', default=256, type=int, help="size of HR image")
|
| 78 |
+
# parser.add_argument('--lr_size', default=2, type=int, help="size of LR image")
|
| 79 |
+
# parser.add_argument('--num_sample', default=50000, type=str, help="Number of image to generate")
|
| 80 |
+
# parser.add_argument('--save_folder', default="DTGM_Xeii_lpips_FM_ii_80kpt_50k", type=str, help="Folder to save your train or evaluation result")
|
| 81 |
+
# parser.add_argument('--load_path', default="DTGM_Xeii_lpips_FM_ii/GDTLS_80000.pt", type=str, help="None or directory to pretrained model")
|
| 82 |
+
#
|
| 83 |
+
# def create_folder(path):
|
| 84 |
+
# try:
|
| 85 |
+
# os.mkdir(path)
|
| 86 |
+
# except OSError as exc:
|
| 87 |
+
# if exc.errno != errno.EEXIST:
|
| 88 |
+
# raise
|
| 89 |
+
# pass
|
| 90 |
+
#
|
| 91 |
+
# def transform_func_sample(img, target_size):
|
| 92 |
+
# n = target_size
|
| 93 |
+
# m = args.hr_size
|
| 94 |
+
#
|
| 95 |
+
# if m / n > 16:
|
| 96 |
+
# img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
|
| 97 |
+
# img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
|
| 98 |
+
# img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
|
| 99 |
+
# else:
|
| 100 |
+
# img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
|
| 101 |
+
# img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
|
| 102 |
+
#
|
| 103 |
+
# return img_1
|
| 104 |
+
#
|
| 105 |
+
# def transform_func_noise(img, device_, target_size, fixed_std=True):
|
| 106 |
+
# n = target_size
|
| 107 |
+
# m = args.hr_size
|
| 108 |
+
#
|
| 109 |
+
# random_mean = torch.rand(1).add(-0.5).item()
|
| 110 |
+
# if fixed_std:
|
| 111 |
+
# random_std = 0.5
|
| 112 |
+
# else:
|
| 113 |
+
# random_std = torch.rand(1).mul(0.5).item()
|
| 114 |
+
# decreasing_scale = 0.9 ** (n - 2)
|
| 115 |
+
#
|
| 116 |
+
# if m / n > 16:
|
| 117 |
+
# img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
|
| 118 |
+
# img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
|
| 119 |
+
# img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
|
| 120 |
+
# else:
|
| 121 |
+
# img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
|
| 122 |
+
#
|
| 123 |
+
# # noise = torch.normal(mean=random_mean, std=random_std, size=(img_1.shape[0], 3, 2, 2)).to(self.device)
|
| 124 |
+
# # noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True)
|
| 125 |
+
# noise = torch.normal(mean=random_mean, std=random_std, size=img_1.shape).to(device_)
|
| 126 |
+
#
|
| 127 |
+
# img_1 += noise * decreasing_scale
|
| 128 |
+
# img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
|
| 129 |
+
#
|
| 130 |
+
# noise_refinement = torch.normal(mean=0, std=0.05, size=img_1.shape).to(device_)
|
| 131 |
+
# return img_1 + noise_refinement
|
| 132 |
+
#
|
| 133 |
+
# def random_vector(batch_size):
|
| 134 |
+
# mean = random.uniform(-0.5, 0.5)
|
| 135 |
+
# std = random.uniform(0.1, 0.3)
|
| 136 |
+
# vector = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
|
| 137 |
+
# for colors in range(2):
|
| 138 |
+
# mean = random.uniform(-0.5, 0.5)
|
| 139 |
+
# std = random.uniform(0.1, 0.3)
|
| 140 |
+
# rgb = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
|
| 141 |
+
# vector = torch.cat((vector, rgb), dim=1)
|
| 142 |
+
# return vector
|
| 143 |
+
#
|
| 144 |
+
# def sample(size_list_, device_, model_, batch_size=1, img=None, t=None):
|
| 145 |
+
# blur_img = transform_func_sample(img.clone(), size_list_[t])
|
| 146 |
+
# img_t = blur_img.clone()
|
| 147 |
+
#
|
| 148 |
+
# ####### Domain Transfer
|
| 149 |
+
# while t:
|
| 150 |
+
# next_step = size_list_[t - 1]
|
| 151 |
+
# step = torch.full((batch_size,), t, dtype=torch.long).to(device_)
|
| 152 |
+
# R_x = model_(img_t, step)
|
| 153 |
+
# if t == 1:
|
| 154 |
+
# return R_x
|
| 155 |
+
# else:
|
| 156 |
+
# img_t = transform_func_noise(R_x, device_, next_step, fixed_std=True)
|
| 157 |
+
# t -= 1
|
| 158 |
+
# return img_t
|
| 159 |
+
#
|
| 160 |
+
# if __name__ == "__main__":
|
| 161 |
+
# args = parser.parse_args()
|
| 162 |
+
# results_folder = Path(args.save_folder)
|
| 163 |
+
# results_folder.mkdir(exist_ok=True)
|
| 164 |
+
# device = args.device if torch.cuda.is_available() else "cpu"
|
| 165 |
+
# size_list = [256, 64, 32, 16, 8, 6, 4, 3, 2]
|
| 166 |
+
# timestep = len(size_list) - 1
|
| 167 |
+
# print(f"Total steps for {args.lr_size} to {args.hr_size}: {timestep}")
|
| 168 |
+
# model = UNet().to(device)
|
| 169 |
+
#
|
| 170 |
+
# if args.load_path is not None:
|
| 171 |
+
# data = torch.load(args.load_path, map_location=device)
|
| 172 |
+
# model.load_state_dict(data['ema'], strict=False)
|
| 173 |
+
#
|
| 174 |
+
# for i in range(args.num_sample):
|
| 175 |
+
# input_vector = random_vector(1).to(device)
|
| 176 |
+
# sample_hr = sample(size_list, device, model, batch_size=1, img=input_vector, t=timestep)
|
| 177 |
+
# utils.save_image(sample_hr.add(1).mul(0.5), f"{results_folder}/result_{i}.png", nrow=1)
|
| 178 |
+
# print("saving ", i)
|
gdtls.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import errno
|
| 6 |
+
import torch
|
| 7 |
+
import shutil
|
| 8 |
+
import wandb
|
| 9 |
+
import random
|
| 10 |
+
import time
|
| 11 |
+
import lpips
|
| 12 |
+
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.utils import data
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from torch.optim import Adam, AdamW
|
| 17 |
+
from torchvision import transforms, utils
|
| 18 |
+
from torchvision.transforms import InterpolationMode
|
| 19 |
+
from torchvision.transforms.v2 import RandomResize
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from util.fid_score import calculate_fid_given_paths
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from apex import amp
|
| 25 |
+
APEX_AVAILABLE = True
|
| 26 |
+
except:
|
| 27 |
+
APEX_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
####### helpers functions
|
| 30 |
+
|
| 31 |
+
def create_folder(path):
|
| 32 |
+
try:
|
| 33 |
+
os.mkdir(path)
|
| 34 |
+
except OSError as exc:
|
| 35 |
+
if exc.errno != errno.EEXIST:
|
| 36 |
+
raise
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def del_folder(path):
|
| 40 |
+
try:
|
| 41 |
+
shutil.rmtree(path)
|
| 42 |
+
except OSError as exc:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def cycle(dl):
|
| 46 |
+
while True:
|
| 47 |
+
for data in dl:
|
| 48 |
+
yield data
|
| 49 |
+
|
| 50 |
+
def num_to_groups(num, divisor):
|
| 51 |
+
groups = num // divisor
|
| 52 |
+
remainder = num % divisor
|
| 53 |
+
arr = [divisor] * groups
|
| 54 |
+
if remainder > 0:
|
| 55 |
+
arr.append(remainder)
|
| 56 |
+
return arr
|
| 57 |
+
|
| 58 |
+
def loss_backwards(fp16, loss, optimizer, **kwargs):
|
| 59 |
+
if fp16:
|
| 60 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| 61 |
+
scaled_loss.backward(**kwargs)
|
| 62 |
+
else:
|
| 63 |
+
loss.backward(**kwargs)
|
| 64 |
+
|
| 65 |
+
# small helper modules
|
| 66 |
+
def rand_bbox(size, lam):
|
| 67 |
+
W = size[2]
|
| 68 |
+
H = size[3]
|
| 69 |
+
cut_rat = np.sqrt(1. - lam)
|
| 70 |
+
cut_w = np.int(W * cut_rat)
|
| 71 |
+
cut_h = np.int(H * cut_rat)
|
| 72 |
+
|
| 73 |
+
# uniform
|
| 74 |
+
cx = np.random.randint(W)
|
| 75 |
+
cy = np.random.randint(H)
|
| 76 |
+
|
| 77 |
+
bbx1 = np.clip(cx - cut_w // 2, 0, W)
|
| 78 |
+
bby1 = np.clip(cy - cut_h // 2, 0, H)
|
| 79 |
+
bbx2 = np.clip(cx + cut_w // 2, 0, W)
|
| 80 |
+
bby2 = np.clip(cy + cut_h // 2, 0, H)
|
| 81 |
+
|
| 82 |
+
return bbx1, bby1, bbx2, bby2
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def Huber(input, target, delta=0.1, reduce=True):
|
| 86 |
+
abs_error = torch.abs(input - target)
|
| 87 |
+
quadratic = torch.clamp(abs_error, max=delta)
|
| 88 |
+
|
| 89 |
+
# The following expression is the same in value as
|
| 90 |
+
# tf.maximum(abs_error - delta, 0), but importantly the gradient for the
|
| 91 |
+
# expression when abs_error == delta is 0 (for tf.maximum it would be 1).
|
| 92 |
+
# This is necessary to avoid doubling the gradient, since there is already a
|
| 93 |
+
# nonzero contribution to the gradient from the quadratic term.
|
| 94 |
+
linear = (abs_error - quadratic)
|
| 95 |
+
losses = 0.5 * torch.pow(quadratic, 2) + delta * linear
|
| 96 |
+
|
| 97 |
+
if reduce:
|
| 98 |
+
return torch.mean(losses)
|
| 99 |
+
else:
|
| 100 |
+
return losses
|
| 101 |
+
|
| 102 |
+
class EMA():
|
| 103 |
+
def __init__(self, beta):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.beta = beta
|
| 106 |
+
|
| 107 |
+
def update_model_average(self, ma_model, current_model):
|
| 108 |
+
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
| 109 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
| 110 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
| 111 |
+
|
| 112 |
+
def update_average(self, old, new):
|
| 113 |
+
if old is None:
|
| 114 |
+
return new
|
| 115 |
+
return old * self.beta + (1 - self.beta) * new
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class DTLS(nn.Module):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
model,
|
| 122 |
+
*,
|
| 123 |
+
image_size,
|
| 124 |
+
size_list,
|
| 125 |
+
stride,
|
| 126 |
+
timesteps,
|
| 127 |
+
device,
|
| 128 |
+
stochastic=False,
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.image_size = image_size
|
| 132 |
+
self.UNet = model
|
| 133 |
+
|
| 134 |
+
self.num_timesteps = int(timesteps)
|
| 135 |
+
self.size_list = size_list
|
| 136 |
+
self.stride = stride
|
| 137 |
+
self.device = device
|
| 138 |
+
self.MSE_loss = nn.MSELoss()
|
| 139 |
+
# self.vgg_loss = Vgg19()
|
| 140 |
+
self.lpips_loss = lpips.LPIPS(net='alex')
|
| 141 |
+
|
| 142 |
+
def transform_func_loss(self, img, target_size):
|
| 143 |
+
n = target_size
|
| 144 |
+
m = self.image_size
|
| 145 |
+
|
| 146 |
+
if m/n > 16:
|
| 147 |
+
img_1 = F.interpolate(img, size=m//4, mode='bicubic', antialias=True)
|
| 148 |
+
img_1 = F.interpolate(img_1, size=m//8, mode='bicubic', antialias=True)
|
| 149 |
+
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
|
| 150 |
+
else:
|
| 151 |
+
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
|
| 152 |
+
|
| 153 |
+
return img_1
|
| 154 |
+
|
| 155 |
+
def transform_func_sample(self, img, target_size):
|
| 156 |
+
n = target_size
|
| 157 |
+
m = self.image_size
|
| 158 |
+
|
| 159 |
+
if m/n > 16:
|
| 160 |
+
img_1 = F.interpolate(img, size=m//4, mode='bicubic', antialias=True)
|
| 161 |
+
img_1 = F.interpolate(img_1, size=m//8, mode='bicubic', antialias=True)
|
| 162 |
+
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
|
| 163 |
+
else:
|
| 164 |
+
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
|
| 165 |
+
img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
|
| 166 |
+
|
| 167 |
+
return img_1
|
| 168 |
+
|
| 169 |
+
def transform_func_noise(self, img, target_size, std_eval = False):
|
| 170 |
+
n = target_size
|
| 171 |
+
m = self.image_size
|
| 172 |
+
|
| 173 |
+
random_mean = torch.rand(1).add(-.5).item()
|
| 174 |
+
# random_std = torch.rand(1).mul(0.5).item()
|
| 175 |
+
decreasing_scale = 0.9 ** (n - 2)
|
| 176 |
+
|
| 177 |
+
if m / n > 16:
|
| 178 |
+
img_1 = F.interpolate(img, size=m // 4, mode='bicubic', antialias=True)
|
| 179 |
+
img_1 = F.interpolate(img_1, size=m // 8, mode='bicubic', antialias=True)
|
| 180 |
+
img_1 = F.interpolate(img_1, size=n, mode='bicubic', antialias=True)
|
| 181 |
+
else:
|
| 182 |
+
img_1 = F.interpolate(img, size=n, mode='bicubic', antialias=True)
|
| 183 |
+
|
| 184 |
+
noise = torch.normal(mean=random_mean, std=0.5, size=(img_1.shape[0], 3, 2, 2)).to(self.device)
|
| 185 |
+
noise = F.interpolate(noise, size=n, mode='bicubic', antialias=True)
|
| 186 |
+
img_1 += noise * decreasing_scale
|
| 187 |
+
img_1 = F.interpolate(img_1, size=m, mode='bicubic', antialias=True)
|
| 188 |
+
|
| 189 |
+
if n >= 16:
|
| 190 |
+
noise_refinement = torch.normal(mean=0, std=1, size=img_1.shape).to(self.device)
|
| 191 |
+
img_1 = img_1 + noise_refinement * decreasing_scale
|
| 192 |
+
return img_1
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def sample(self, batch_size=16, img=None, t=None, save_folder=None):
|
| 197 |
+
if t == None:
|
| 198 |
+
t = self.num_timesteps
|
| 199 |
+
blur_img = self.transform_func_sample(img.clone(), self.size_list[t])
|
| 200 |
+
img_t = blur_img.clone()
|
| 201 |
+
####### Domain Transfer
|
| 202 |
+
while (t):
|
| 203 |
+
next_step = self.size_list[t-1]
|
| 204 |
+
step = torch.full((batch_size,), t, dtype=torch.long).to(self.device)
|
| 205 |
+
R_x = self.UNet(img_t, step)
|
| 206 |
+
if t == 1:
|
| 207 |
+
return blur_img, R_x
|
| 208 |
+
else:
|
| 209 |
+
img_t = self.transform_func_noise(R_x, next_step)
|
| 210 |
+
t -= 1
|
| 211 |
+
return blur_img, img_t
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def p_losses(self, x_start, t):
|
| 215 |
+
x_blur = x_start.clone()
|
| 216 |
+
|
| 217 |
+
for i in range(t.shape[0]):
|
| 218 |
+
current_step = self.size_list[t[i]]
|
| 219 |
+
x_blur[i] = self.transform_func_noise(x_blur[i].unsqueeze(0), current_step)
|
| 220 |
+
|
| 221 |
+
x_recon = self.UNet(x_blur, t)
|
| 222 |
+
|
| 223 |
+
### Pattern Domain Similarity Loss
|
| 224 |
+
x_clone = x_recon.clone()
|
| 225 |
+
for i in range(t.shape[0]):
|
| 226 |
+
current_step = self.size_list[t[i]]
|
| 227 |
+
x_clone[i] = self.transform_func_sample(x_recon[i].unsqueeze(0), current_step)
|
| 228 |
+
|
| 229 |
+
### Lowest Pattern Domain Similarity Loss
|
| 230 |
+
# x_clone = x_recon.clone()
|
| 231 |
+
# x_clone = self.transform_func_loss(x_clone, self.size_list[-1])
|
| 232 |
+
# x_blur = self.transform_func_loss(x_blur, self.size_list[-1])
|
| 233 |
+
|
| 234 |
+
loss = self.MSE_loss(x_clone, x_blur)
|
| 235 |
+
# lpips_loss = self.lpips_loss(x_recon, x_start).mean()
|
| 236 |
+
return loss, x_recon
|
| 237 |
+
|
| 238 |
+
def forward(self, x, *args, **kwargs):
|
| 239 |
+
b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
| 240 |
+
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
| 241 |
+
t = torch.randint(1, self.num_timesteps + 1, (b,), device=device).long()
|
| 242 |
+
return self.p_losses(x, t, *args, **kwargs)
|
| 243 |
+
|
| 244 |
+
# dataset classes
|
| 245 |
+
|
| 246 |
+
class Dataset(data.Dataset):
|
| 247 |
+
def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png']):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.folder = folder
|
| 250 |
+
self.image_size = image_size
|
| 251 |
+
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
| 252 |
+
|
| 253 |
+
self.transform = transforms.Compose([
|
| 254 |
+
RandomResize(int(image_size), int(image_size*1.2), interpolation=InterpolationMode.BICUBIC, antialias=True),
|
| 255 |
+
transforms.RandomCrop(image_size),
|
| 256 |
+
transforms.RandomHorizontalFlip(),
|
| 257 |
+
transforms.ToTensor(),
|
| 258 |
+
transforms.Lambda(lambda t: (t * 2) - 1)
|
| 259 |
+
])
|
| 260 |
+
|
| 261 |
+
def __len__(self):
|
| 262 |
+
return len(self.paths)
|
| 263 |
+
|
| 264 |
+
def __getitem__(self, index):
|
| 265 |
+
path = self.paths[index]
|
| 266 |
+
img = Image.open(path)
|
| 267 |
+
return self.transform(img)
|
| 268 |
+
|
| 269 |
+
# trainer class
|
| 270 |
+
|
| 271 |
+
class Trainer(object):
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
diffusion_model,
|
| 275 |
+
discriminator,
|
| 276 |
+
folder,
|
| 277 |
+
*,
|
| 278 |
+
ema_decay = 0.9925,
|
| 279 |
+
image_size = 128,
|
| 280 |
+
train_batch_size = 32,
|
| 281 |
+
train_num_steps = 200000,
|
| 282 |
+
step_start_ema = 500,
|
| 283 |
+
update_ema_every = 10,
|
| 284 |
+
save_and_sample_every = 1000,
|
| 285 |
+
results_folder,
|
| 286 |
+
load_path = None,
|
| 287 |
+
shuffle=True,
|
| 288 |
+
eval_mode=False,
|
| 289 |
+
device,
|
| 290 |
+
):
|
| 291 |
+
super().__init__()
|
| 292 |
+
|
| 293 |
+
########## Wandb ##########
|
| 294 |
+
if not eval_mode:
|
| 295 |
+
wandb.init(project="DTGM", notes=str(results_folder), name=results_folder)
|
| 296 |
+
self.results_folder = Path(results_folder)
|
| 297 |
+
self.results_folder.mkdir(exist_ok = True)
|
| 298 |
+
|
| 299 |
+
self.device = device
|
| 300 |
+
|
| 301 |
+
self.model = diffusion_model
|
| 302 |
+
self.discriminator = discriminator
|
| 303 |
+
self.model_size()
|
| 304 |
+
|
| 305 |
+
self.ema = EMA(ema_decay)
|
| 306 |
+
self.ema_model = copy.deepcopy(self.model)
|
| 307 |
+
|
| 308 |
+
self.update_ema_every = update_ema_every
|
| 309 |
+
self.step_start_ema = step_start_ema
|
| 310 |
+
|
| 311 |
+
self.save_and_sample_every = save_and_sample_every
|
| 312 |
+
|
| 313 |
+
self.image_size = diffusion_model.image_size
|
| 314 |
+
self.batch_size = train_batch_size
|
| 315 |
+
self.train_num_steps = train_num_steps
|
| 316 |
+
self.nrow = train_batch_size // 2
|
| 317 |
+
|
| 318 |
+
self.folder_path = folder
|
| 319 |
+
self.ds = Dataset(folder, image_size)
|
| 320 |
+
|
| 321 |
+
self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=shuffle, pin_memory=True, num_workers=2))
|
| 322 |
+
|
| 323 |
+
self.opt = AdamW(diffusion_model.parameters(), lr=2e-5, betas=(0.0, 0.9), eps=1e-8)
|
| 324 |
+
self.opt_d = AdamW(self.discriminator.parameters(), lr=5e-5, betas=(0.0, 0.9), eps=1e-8)
|
| 325 |
+
|
| 326 |
+
self.BCE_loss = torch.nn.BCEWithLogitsLoss()
|
| 327 |
+
|
| 328 |
+
self.step = 0
|
| 329 |
+
self.reset_parameters()
|
| 330 |
+
self.best_quality = 0
|
| 331 |
+
|
| 332 |
+
self.loss_dis_false_temp = 0
|
| 333 |
+
self.loss_dis_true_temp = 0
|
| 334 |
+
|
| 335 |
+
self.load_path = load_path
|
| 336 |
+
self.n_mix = 0
|
| 337 |
+
self.fid_list = []
|
| 338 |
+
with open(f'{self.results_folder}/fid.txt', 'w') as f:
|
| 339 |
+
for a in self.fid_list:
|
| 340 |
+
f.write(f'{self.step} {a}\n')
|
| 341 |
+
|
| 342 |
+
def reset_parameters(self):
|
| 343 |
+
self.ema_model.load_state_dict(self.model.state_dict())
|
| 344 |
+
|
| 345 |
+
def step_ema(self):
|
| 346 |
+
if self.step < self.step_start_ema:
|
| 347 |
+
self.reset_parameters()
|
| 348 |
+
return
|
| 349 |
+
self.ema.update_model_average(self.ema_model, self.model)
|
| 350 |
+
|
| 351 |
+
def model_size(self):
|
| 352 |
+
param_size = 0
|
| 353 |
+
for param in self.model.parameters():
|
| 354 |
+
param_size += param.nelement() * param.element_size()
|
| 355 |
+
buffer_size = 0
|
| 356 |
+
for buffer in self.model.buffers():
|
| 357 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 358 |
+
|
| 359 |
+
size_all_mb = (param_size + buffer_size) / 1024**2
|
| 360 |
+
line = ('model size: {:.3f}MB'.format(size_all_mb))
|
| 361 |
+
print(line)
|
| 362 |
+
|
| 363 |
+
def save_ckpt(self):
|
| 364 |
+
data = {
|
| 365 |
+
'step': self.step,
|
| 366 |
+
'model': self.model.state_dict(),
|
| 367 |
+
'ema': self.ema_model.state_dict(),
|
| 368 |
+
'dis': self.discriminator.state_dict(),
|
| 369 |
+
}
|
| 370 |
+
torch.save(data, str(self.results_folder / f'DTGM_ckpt_{self.step}.pt'))
|
| 371 |
+
|
| 372 |
+
def save_model(self):
|
| 373 |
+
data = {
|
| 374 |
+
'ema': self.ema_model.state_dict(),
|
| 375 |
+
}
|
| 376 |
+
torch.save(data, str(self.results_folder / f'DTGM_model_{self.step}.pt'))
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def load_all(self, load_path):
|
| 380 |
+
print("Loading : ", load_path)
|
| 381 |
+
data = torch.load(load_path, map_location=self.device)
|
| 382 |
+
|
| 383 |
+
self.step = data['step']
|
| 384 |
+
self.model.load_state_dict(data['model'], strict=False)
|
| 385 |
+
self.ema_model.load_state_dict(data['ema'], strict=False)
|
| 386 |
+
self.discriminator.load_state_dict(data['dis'], strict=False)
|
| 387 |
+
|
| 388 |
+
def load_for_eval(self, load_path):
|
| 389 |
+
data = torch.load(load_path, map_location=self.device)
|
| 390 |
+
self.ema_model.load_state_dict(data['ema'], strict=False)
|
| 391 |
+
|
| 392 |
+
def train(self):
|
| 393 |
+
if self.load_path is not None:
|
| 394 |
+
self.load_all(self.load_path)
|
| 395 |
+
|
| 396 |
+
while self.step < self.train_num_steps:
|
| 397 |
+
start_time = time.time()
|
| 398 |
+
data = next(self.dl)
|
| 399 |
+
data = data.to(self.device)
|
| 400 |
+
|
| 401 |
+
loss_domain_sim, x_recon = self.model(data)
|
| 402 |
+
|
| 403 |
+
self.opt_d.zero_grad()
|
| 404 |
+
score_true = self.discriminator(data)
|
| 405 |
+
GAN_true = torch.ones_like(score_true)
|
| 406 |
+
loss_dis_true = self.BCE_loss(score_true, GAN_true)
|
| 407 |
+
loss_dis_true.backward()
|
| 408 |
+
|
| 409 |
+
score_false = self.discriminator(x_recon.detach())
|
| 410 |
+
GAN_false = torch.zeros_like(score_false)
|
| 411 |
+
loss_dis_false = self.BCE_loss(score_false, GAN_false)
|
| 412 |
+
loss_dis_false.backward()
|
| 413 |
+
self.opt_d.step()
|
| 414 |
+
|
| 415 |
+
self.loss_dis_false_temp = loss_dis_false.item()
|
| 416 |
+
self.loss_dis_true_temp = loss_dis_true.item()
|
| 417 |
+
|
| 418 |
+
self.opt.zero_grad()
|
| 419 |
+
score_fake = self.discriminator(x_recon)
|
| 420 |
+
GAN_fake = torch.ones_like(score_fake)
|
| 421 |
+
loss_gen = self.BCE_loss(score_fake, GAN_fake) * 1e-2
|
| 422 |
+
(loss_gen + loss_domain_sim).backward()
|
| 423 |
+
self.opt.step()
|
| 424 |
+
|
| 425 |
+
if self.step % 10 == 0:
|
| 426 |
+
print(f'{self.step} DTLS: Total loss: {loss_domain_sim.item() + loss_gen.item()} | Domain sim: {loss_domain_sim.item()} '
|
| 427 |
+
f'| Generate: {loss_gen.item()} '
|
| 428 |
+
f'| Dis real: {self.loss_dis_true_temp} | Dis false: {self.loss_dis_false_temp}')
|
| 429 |
+
# f'| Features Matching loss: {loss_FM.item()}')
|
| 430 |
+
|
| 431 |
+
wandb.log({"Total loss": loss_domain_sim.item() + loss_gen.item(), "Domain Similarity Loss": loss_domain_sim.item(),
|
| 432 |
+
"Generation loss": loss_gen.item(),
|
| 433 |
+
"Discriminator loss (real)": self.loss_dis_true_temp,
|
| 434 |
+
"Discriminator loss (fake)": self.loss_dis_false_temp,}, step=self.step)
|
| 435 |
+
# "Feature Matching loss": loss_FM.item() "LPIPS loss": lpips_loss.item()
|
| 436 |
+
|
| 437 |
+
if self.step % self.update_ema_every == 0:
|
| 438 |
+
self.step_ema()
|
| 439 |
+
|
| 440 |
+
if self.step == 0 or self.step % self.save_and_sample_every == 0:
|
| 441 |
+
lr_real, sr_real = self.ema_model.sample(batch_size=self.batch_size, img=data)
|
| 442 |
+
_, sr_real_ii = self.model.sample(batch_size=self.batch_size, img=data)
|
| 443 |
+
|
| 444 |
+
save_img = torch.cat((sr_real_ii, lr_real, data, sr_real),dim=0)
|
| 445 |
+
utils.save_image((save_img+1)/2, str(self.results_folder / f'{self.step}_GDTLS.png'), nrow=self.nrow)
|
| 446 |
+
|
| 447 |
+
wandb.log({"Checkpoint result": wandb.Image(str(self.results_folder / f'{self.step}_GDTLS.png'))})
|
| 448 |
+
if self.step >= 100000 and self.step % 2500 == 0:
|
| 449 |
+
self.save_model()
|
| 450 |
+
self.validate()
|
| 451 |
+
if self.step != 0 and self.step % 10000 == 0:
|
| 452 |
+
self.save_ckpt()
|
| 453 |
+
|
| 454 |
+
self.step += 1
|
| 455 |
+
print('training completed')
|
| 456 |
+
wandb.finish()
|
| 457 |
+
|
| 458 |
+
def validate(self):
|
| 459 |
+
folder_name = f"temp_samples_{self.device.split(':')[-1]}"
|
| 460 |
+
create_folder(folder_name)
|
| 461 |
+
for i in range(2000):
|
| 462 |
+
random_vector = self.random_vector(1)
|
| 463 |
+
_, sample_hr = self.ema_model.sample(batch_size=1, img=random_vector, save_folder=self.results_folder)
|
| 464 |
+
utils.save_image((sample_hr + 1) /2, f"{folder_name}/result_{i}.png", nrow=1)
|
| 465 |
+
fid = calculate_fid_given_paths([folder_name, "/hdda/Datasets/ffhq256_mini"], 200, self.device, dims=2048, num_workers=4)
|
| 466 |
+
self.fid_list.append(f"{self.step} {fid.item()}")
|
| 467 |
+
with open(f'{self.results_folder}/fid.txt', 'w') as f:
|
| 468 |
+
for a in self.fid_list:
|
| 469 |
+
f.write(f'{a}\n')
|
| 470 |
+
wandb.log({"FID score": fid.item()})
|
| 471 |
+
del_folder(folder_name)
|
| 472 |
+
|
| 473 |
+
def random_vector(self, batch_size):
|
| 474 |
+
mean = random.uniform(-0.75, 0.75)
|
| 475 |
+
std = random.uniform(0.01, 0.5)
|
| 476 |
+
vector = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
|
| 477 |
+
for i in range(2):
|
| 478 |
+
mean = random.uniform(-0.75, 0.75)
|
| 479 |
+
std = random.uniform(0.01, 0.5)
|
| 480 |
+
rgb = torch.normal(mean=mean, std=std, size=(batch_size, 1, 2, 2))
|
| 481 |
+
vector = torch.cat((vector, rgb), dim=1)
|
| 482 |
+
return vector.to(self.device)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def evaluation(self, num_sample=50000, batch_size=16):
|
| 486 |
+
if self.load_path != None:
|
| 487 |
+
self.load_for_eval(self.load_path)
|
| 488 |
+
img_count = 1
|
| 489 |
+
while img_count <= num_sample:
|
| 490 |
+
# for i in range(num_sample):
|
| 491 |
+
random_vector = self.random_vector(batch_size)
|
| 492 |
+
_, sample_hr = self.ema_model.sample(batch_size=batch_size, img=random_vector, save_folder=self.results_folder)
|
| 493 |
+
for img in sample_hr:
|
| 494 |
+
if img_count <= num_sample:
|
| 495 |
+
utils.save_image((img + 1) /2, str(self.results_folder / f'result_{img_count}.png'), nrow=1)
|
| 496 |
+
print("saving ", img_count)
|
| 497 |
+
img_count += 1
|
| 498 |
+
|
| 499 |
+
# utils.save_image((blur_img_set + 1) /2, str(self.results_folder / f'random_vector_{i}.png'), nrow=4)
|
| 500 |
+
# utils.save_image((sample_hr + 1) /2, str(self.results_folder / f'result_{i}.png'), nrow=1)
|
| 501 |
+
# print("saving ", i)
|
| 502 |
+
|
| 503 |
+
def fid(self, created_dataset, realistic_dataset="/hdda/Datasets/ffhq256"):
|
| 504 |
+
fid = calculate_fid_given_paths([created_dataset, realistic_dataset], 500, self.device, dims=2048, num_workers=4)
|
| 505 |
+
print("FID Score 50k: ", fid)
|
models.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from inspect import isfunction
|
| 6 |
+
|
| 7 |
+
def weights_init(m):
|
| 8 |
+
classname = m.__class__.__name__
|
| 9 |
+
if classname.find('Conv') != -1:
|
| 10 |
+
try:
|
| 11 |
+
m.weight.data.normal_(0.0, 0.02)
|
| 12 |
+
except:
|
| 13 |
+
pass
|
| 14 |
+
elif classname.find('BatchNorm') != -1:
|
| 15 |
+
m.weight.data.normal_(1.0, 0.02)
|
| 16 |
+
m.bias.data.fill_(0)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def exists(x):
|
| 20 |
+
return x is not None
|
| 21 |
+
|
| 22 |
+
def default(val, d):
|
| 23 |
+
if exists(val):
|
| 24 |
+
return val
|
| 25 |
+
return d() if isfunction(d) else d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Residual(nn.Module):
|
| 29 |
+
def __init__(self, fn):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.fn = fn
|
| 32 |
+
|
| 33 |
+
def forward(self, x, *args, **kwargs):
|
| 34 |
+
return self.fn(x, *args, **kwargs) + x
|
| 35 |
+
|
| 36 |
+
class SinusoidalPosEmb(nn.Module):
|
| 37 |
+
def __init__(self, dim):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.dim = dim
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
device = x.device
|
| 43 |
+
half_dim = self.dim // 2
|
| 44 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 45 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 46 |
+
emb = x[:, None] * emb[None, :]
|
| 47 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LayerNorm(nn.Module):
|
| 52 |
+
def __init__(self, dim, eps = 1e-5):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.eps = eps
|
| 55 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
| 56 |
+
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
| 60 |
+
mean = torch.mean(x, dim = 1, keepdim = True)
|
| 61 |
+
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
| 62 |
+
|
| 63 |
+
class PreNorm(nn.Module):
|
| 64 |
+
def __init__(self, dim, fn):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.fn = fn
|
| 67 |
+
self.norm = LayerNorm(dim)
|
| 68 |
+
# self.norm = nn.BatchNorm2d(dim)
|
| 69 |
+
# self.norm = nn.GroupNorm(dim // 32, dim)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.norm(x)
|
| 73 |
+
return self.fn(x)
|
| 74 |
+
|
| 75 |
+
# building block modules
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ConvNextBlock(nn.Module):
|
| 79 |
+
""" https://arxiv.org/abs/2201.03545 """
|
| 80 |
+
|
| 81 |
+
def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.mlp = nn.Sequential(
|
| 84 |
+
nn.GELU(),
|
| 85 |
+
nn.Linear(time_emb_dim, dim*2)
|
| 86 |
+
) if exists(time_emb_dim) else None
|
| 87 |
+
|
| 88 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
| 89 |
+
|
| 90 |
+
self.net = nn.Sequential(
|
| 91 |
+
LayerNorm(dim) if norm else nn.Identity(),
|
| 92 |
+
nn.Conv2d(dim, dim_out * mult, 3, 1, 1),
|
| 93 |
+
nn.GELU(),
|
| 94 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, 1, 1),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# self.noise_adding = NoiseInjection(dim_out)
|
| 98 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 99 |
+
|
| 100 |
+
def forward(self, x, time_emb = None):
|
| 101 |
+
h = self.ds_conv(x)
|
| 102 |
+
|
| 103 |
+
if exists(self.mlp):
|
| 104 |
+
assert exists(time_emb), 'time emb must be passed in'
|
| 105 |
+
condition = self.mlp(time_emb)
|
| 106 |
+
condition = rearrange(condition, 'b c -> b c 1 1')
|
| 107 |
+
weight, bias = torch.split(condition, x.shape[1],dim=1)
|
| 108 |
+
h = h * (1 + weight) + bias
|
| 109 |
+
|
| 110 |
+
h = self.net(h)
|
| 111 |
+
# h = self.noise_adding(h)
|
| 112 |
+
return h + self.res_conv(x)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ConvNextBlock_dis(nn.Module):
|
| 116 |
+
""" https://arxiv.org/abs/2201.03545 """
|
| 117 |
+
|
| 118 |
+
def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.mlp = nn.Sequential(
|
| 121 |
+
nn.GELU(),
|
| 122 |
+
nn.Linear(time_emb_dim, dim*2)
|
| 123 |
+
) if exists(time_emb_dim) else None
|
| 124 |
+
|
| 125 |
+
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
| 126 |
+
|
| 127 |
+
self.net = nn.Sequential(
|
| 128 |
+
nn.BatchNorm2d(dim) if norm else nn.Identity(),
|
| 129 |
+
# LayerNorm(dim) if norm else nn.Identity(),
|
| 130 |
+
nn.Conv2d(dim, dim_out * mult, 3, 1, 1),
|
| 131 |
+
nn.GELU(),
|
| 132 |
+
nn.Conv2d(dim_out * mult, dim_out, 3, 1, 1),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
h = self.ds_conv(x)
|
| 139 |
+
h = self.net(h)
|
| 140 |
+
return h + self.res_conv(x)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class LinearAttention(nn.Module):
|
| 145 |
+
def __init__(self, dim, heads = 4, dim_head = 32):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.scale = dim_head ** -0.5
|
| 148 |
+
self.heads = heads
|
| 149 |
+
hidden_dim = dim_head * heads
|
| 150 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
| 151 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
| 152 |
+
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
b, c, h, w = x.shape
|
| 155 |
+
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
| 156 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
|
| 157 |
+
q = q * self.scale
|
| 158 |
+
|
| 159 |
+
k = k.softmax(dim = -1)
|
| 160 |
+
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
|
| 161 |
+
|
| 162 |
+
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
|
| 163 |
+
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
|
| 164 |
+
return self.to_out(out)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# model
|
| 168 |
+
class UNet(nn.Module):
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
dim = 32,
|
| 172 |
+
dim_mults=(1, 2, 4, 8, 16, 32, 32),
|
| 173 |
+
channels = 3,
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.channels = dim
|
| 177 |
+
|
| 178 |
+
dims = [dim, *map(lambda m: dim * m, dim_mults)]
|
| 179 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
| 180 |
+
self.model_depth = len(dim_mults)
|
| 181 |
+
|
| 182 |
+
time_dim = dim
|
| 183 |
+
self.time_mlp = nn.Sequential(
|
| 184 |
+
SinusoidalPosEmb(dim),
|
| 185 |
+
nn.Linear(dim, dim * 2),
|
| 186 |
+
nn.GELU(),
|
| 187 |
+
nn.Linear(dim * 2, dim)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self.downs = nn.ModuleList([])
|
| 191 |
+
self.ups = nn.ModuleList([])
|
| 192 |
+
|
| 193 |
+
num_resolutions = len(in_out)
|
| 194 |
+
|
| 195 |
+
self.initial = nn.Conv2d(channels, dim, 7,1,3, bias=False)
|
| 196 |
+
|
| 197 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
| 198 |
+
self.downs.append(nn.ModuleList([
|
| 199 |
+
ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
|
| 200 |
+
nn.AvgPool2d(2),
|
| 201 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if ind >= (num_resolutions - 3) else nn.Identity(),
|
| 202 |
+
ConvNextBlock(dim_out, dim_out, time_emb_dim=time_dim),
|
| 203 |
+
]))
|
| 204 |
+
|
| 205 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
|
| 206 |
+
self.ups.append(nn.ModuleList([
|
| 207 |
+
ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
|
| 208 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 209 |
+
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if ind < 3 else nn.Identity(),
|
| 210 |
+
ConvNextBlock(dim_in, dim_in, time_emb_dim=time_dim),
|
| 211 |
+
]))
|
| 212 |
+
|
| 213 |
+
self.final_conv = nn.Conv2d(dim, 3, 1, bias=False)
|
| 214 |
+
|
| 215 |
+
def forward(self, x, time):
|
| 216 |
+
x = self.initial(x)
|
| 217 |
+
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
| 218 |
+
h = []
|
| 219 |
+
for convnext, downsample, attn, convnext2 in self.downs:
|
| 220 |
+
x = convnext(x, t)
|
| 221 |
+
x = downsample(x)
|
| 222 |
+
h.append(x)
|
| 223 |
+
x = attn(x)
|
| 224 |
+
x = convnext2(x, t)
|
| 225 |
+
|
| 226 |
+
for convnext, upsample, attn, convnext2 in self.ups:
|
| 227 |
+
x = torch.cat((x, h.pop()), dim=1)
|
| 228 |
+
x = convnext(x, t)
|
| 229 |
+
x = upsample(x)
|
| 230 |
+
x = attn(x)
|
| 231 |
+
x = convnext2(x, t)
|
| 232 |
+
|
| 233 |
+
return self.final_conv(x)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Discriminator(nn.Module):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
dim=32,
|
| 240 |
+
dim_mults=(1, 2, 4, 8, 16, 32, 32),
|
| 241 |
+
channels=3,
|
| 242 |
+
with_time_emb=True,
|
| 243 |
+
):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.channels = dim
|
| 246 |
+
|
| 247 |
+
dims = [dim, *map(lambda m: dim * m, dim_mults)]
|
| 248 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
| 249 |
+
self.model_depth = len(dim_mults)
|
| 250 |
+
|
| 251 |
+
self.downs = nn.ModuleList([])
|
| 252 |
+
num_resolutions = len(in_out)
|
| 253 |
+
|
| 254 |
+
self.initial = nn.Conv2d(channels, dim, 7,1,3, bias=False)
|
| 255 |
+
|
| 256 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
| 257 |
+
is_last = ind >= (num_resolutions - 1)
|
| 258 |
+
self.downs.append(nn.ModuleList([
|
| 259 |
+
ConvNextBlock_dis(dim_in, dim_out, norm=ind != 0),
|
| 260 |
+
nn.AvgPool2d(2),
|
| 261 |
+
ConvNextBlock_dis(dim_out, dim_out),
|
| 262 |
+
]))
|
| 263 |
+
dim_out = dim_mults[-1] * dim
|
| 264 |
+
|
| 265 |
+
self.out = nn.Conv2d(dim_out, 1, 1, bias=False)
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
x = self.initial(x)
|
| 269 |
+
for convnext, downsample, convnext2 in self.downs:
|
| 270 |
+
x = convnext(x)
|
| 271 |
+
x = downsample(x)
|
| 272 |
+
x = convnext2(x)
|
| 273 |
+
return self.out(x).view(x.shape[0], -1)
|