WZT006 commited on
Commit
a7dedcf
·
1 Parent(s): 95ec8d7

added models

Browse files
dataset.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+
5
+ import os
6
+ import os.path
7
+ from io import BytesIO
8
+
9
+ import lmdb
10
+ from torch.utils.data import Dataset
11
+
12
+ class MultiResolutionDataset(Dataset):
13
+ def __init__(self, path, transform, resolution=256):
14
+ self.env = lmdb.open(
15
+ path,
16
+ max_readers=32,
17
+ readonly=True,
18
+ lock=False,
19
+ readahead=False,
20
+ meminit=False,
21
+ )
22
+
23
+ if not self.env:
24
+ raise IOError('Cannot open lmdb dataset', path)
25
+
26
+ with self.env.begin(write=False) as txn:
27
+ self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
28
+
29
+ self.resolution = resolution
30
+ self.transform = transform
31
+
32
+ def __len__(self):
33
+ return self.length
34
+
35
+ def __getitem__(self, index):
36
+ with self.env.begin(write=False) as txn:
37
+ key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
38
+ img_bytes = txn.get(key)
39
+
40
+ buffer = BytesIO(img_bytes)
41
+ img = Image.open(buffer)
42
+ img = self.transform(img)
43
+
44
+ return img
45
+
46
+
47
+ def has_file_allowed_extension(filename, extensions):
48
+ """Checks if a file is an allowed extension.
49
+
50
+ Args:
51
+ filename (string): path to a file
52
+
53
+ Returns:
54
+ bool: True if the filename ends with a known image extension
55
+ """
56
+ filename_lower = filename.lower()
57
+ return any(filename_lower.endswith(ext) for ext in extensions)
58
+
59
+
60
+ def find_classes(dir):
61
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
62
+ classes.sort()
63
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
64
+ return classes, class_to_idx
65
+
66
+
67
+ def make_dataset(dir, extensions):
68
+ images = []
69
+ for root, _, fnames in sorted(os.walk(dir)):
70
+ for fname in sorted(fnames):
71
+ if has_file_allowed_extension(fname, extensions):
72
+ path = os.path.join(root, fname)
73
+ item = (path, 0)
74
+ images.append(item)
75
+
76
+ return images
77
+
78
+
79
+ class DatasetFolder(data.Dataset):
80
+ def __init__(self, root, loader, extensions, transform=None, target_transform=None):
81
+ # classes, class_to_idx = find_classes(root)
82
+ samples = make_dataset(root, extensions)
83
+ if len(samples) == 0:
84
+ raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
85
+ "Supported extensions are: " + ",".join(extensions)))
86
+
87
+ self.root = root
88
+ self.loader = loader
89
+ self.extensions = extensions
90
+ self.samples = samples
91
+
92
+ self.transform = transform
93
+ self.target_transform = target_transform
94
+
95
+ def __getitem__(self, index):
96
+ """
97
+ Args:
98
+ index (int): Index
99
+
100
+ Returns:
101
+ tuple: (sample, target) where target is class_index of the target class.
102
+ """
103
+ path, target = self.samples[index]
104
+ sample = self.loader(path)
105
+ if self.transform is not None:
106
+ sample = self.transform(sample)
107
+ if self.target_transform is not None:
108
+ target = self.target_transform(target)
109
+
110
+ return sample
111
+
112
+ def __len__(self):
113
+ return len(self.samples)
114
+
115
+ def __repr__(self):
116
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
117
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
118
+ fmt_str += ' Root Location: {}\n'.format(self.root)
119
+ tmp = ' Transforms (if any): '
120
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
121
+ tmp = ' Target Transforms (if any): '
122
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
123
+ return fmt_str
124
+
125
+
126
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
127
+
128
+
129
+ def pil_loader(path):
130
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
131
+ with open(path, 'rb') as f:
132
+ img = Image.open(f)
133
+ return img.convert('RGB')
134
+
135
+
136
+ def default_loader(path):
137
+ return pil_loader(path)
138
+
139
+
140
+ class ImageFolder(DatasetFolder):
141
+ def __init__(self, root, transform1=None, transform2=None, target_transform=None,
142
+ loader=default_loader):
143
+ super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
144
+ transform=transform1,
145
+ target_transform=target_transform)
146
+ self.imgs = self.samples
147
+ self.transform2 = transform2
148
+
149
+ def set_stage(self, stage):
150
+ if stage == 'last':
151
+ self.transform = self.transform2
152
+
153
+ class ListFolder(Dataset):
154
+ def __init__(self, txt, transform):
155
+ with open(txt) as f:
156
+ imgpaths= f.readlines()
157
+ self.imgpaths = [x.strip() for x in imgpaths]
158
+ self.transform = transform
159
+
160
+ def __getitem__(self, idx):
161
+ path = self.imgpaths[idx]
162
+ image = Image.open(path)
163
+ return self.transform(image)
164
+
165
+ def __len__(self):
166
+ return len(self.imgpaths)
167
+
distributed.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import distributed as dist
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+
9
+ def get_rank():
10
+ if not dist.is_available():
11
+ return 0
12
+
13
+ if not dist.is_initialized():
14
+ return 0
15
+
16
+ return dist.get_rank()
17
+
18
+
19
+ def synchronize():
20
+ if not dist.is_available():
21
+ return
22
+
23
+ if not dist.is_initialized():
24
+ return
25
+
26
+ world_size = dist.get_world_size()
27
+
28
+ if world_size == 1:
29
+ return
30
+
31
+ dist.barrier()
32
+
33
+
34
+ def get_world_size():
35
+ if not dist.is_available():
36
+ return 1
37
+
38
+ if not dist.is_initialized():
39
+ return 1
40
+
41
+ return dist.get_world_size()
42
+
43
+
44
+ def reduce_sum(tensor):
45
+ if not dist.is_available():
46
+ return tensor
47
+
48
+ if not dist.is_initialized():
49
+ return tensor
50
+
51
+ tensor = tensor.clone()
52
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53
+
54
+ return tensor
55
+
56
+
57
+ def gather_grad(params):
58
+ world_size = get_world_size()
59
+
60
+ if world_size == 1:
61
+ return
62
+
63
+ for param in params:
64
+ if param.grad is not None:
65
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
+ param.grad.data.div_(world_size)
67
+
68
+
69
+ def all_gather(data):
70
+ world_size = get_world_size()
71
+
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ buffer = pickle.dumps(data)
76
+ storage = torch.ByteStorage.from_buffer(buffer)
77
+ tensor = torch.ByteTensor(storage).to('cuda')
78
+
79
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
+ dist.all_gather(size_list, local_size)
82
+ size_list = [int(size.item()) for size in size_list]
83
+ max_size = max(size_list)
84
+
85
+ tensor_list = []
86
+ for _ in size_list:
87
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
+
89
+ if local_size != max_size:
90
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
+ tensor = torch.cat((tensor, padding), 0)
92
+
93
+ dist.all_gather(tensor_list, tensor)
94
+
95
+ data_list = []
96
+
97
+ for size, tensor in zip(size_list, tensor_list):
98
+ buffer = tensor.cpu().numpy().tobytes()[:size]
99
+ data_list.append(pickle.loads(buffer))
100
+
101
+ return data_list
102
+
103
+
104
+ def reduce_loss_dict(loss_dict):
105
+ world_size = get_world_size()
106
+
107
+ if world_size < 2:
108
+ return loss_dict
109
+
110
+ with torch.no_grad():
111
+ keys = []
112
+ losses = []
113
+
114
+ for k in sorted(loss_dict.keys()):
115
+ keys.append(k)
116
+ losses.append(loss_dict[k])
117
+
118
+ losses = torch.stack(losses, 0)
119
+ dist.reduce(losses, dst=0)
120
+
121
+ if dist.get_rank() == 0:
122
+ losses /= world_size
123
+
124
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
125
+
126
+ return reduced_losses
model.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import math
3
+ import random
4
+ import functools
5
+ import operator
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+
12
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
13
+ n_latent = 11
14
+
15
+
16
+ channels = {
17
+ 4: 512,
18
+ 8: 512,
19
+ 16: 512,
20
+ 32: 512,
21
+ 64: 256,
22
+ 128: 128,
23
+ 256: 64,
24
+ 512: 32,
25
+ 1024: 16,
26
+ }
27
+
28
+ class LambdaLR():
29
+ def __init__(self, n_epochs, offset, decay_start_epoch):
30
+ assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
31
+ self.n_epochs = n_epochs
32
+ self.offset = offset
33
+ self.decay_start_epoch = decay_start_epoch
34
+
35
+ def step(self, epoch):
36
+ return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)
37
+
38
+
39
+ class PixelNorm(nn.Module):
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ def forward(self, input):
44
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
45
+
46
+ def make_kernel(k):
47
+ k = torch.tensor(k, dtype=torch.float32)
48
+
49
+ if k.ndim == 1:
50
+ k = k[None, :] * k[:, None]
51
+
52
+ k /= k.sum()
53
+
54
+ return k
55
+
56
+ class Upsample(nn.Module):
57
+ def __init__(self, kernel, factor=2):
58
+ super().__init__()
59
+
60
+ self.factor = factor
61
+ kernel = make_kernel(kernel) * (factor ** 2)
62
+ self.register_buffer('kernel', kernel)
63
+
64
+ p = kernel.shape[0] - factor
65
+
66
+ pad0 = (p + 1) // 2 + factor - 1
67
+ pad1 = p // 2
68
+
69
+ self.pad = (pad0, pad1)
70
+
71
+ def forward(self, input):
72
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
73
+
74
+ return out
75
+
76
+
77
+ class Downsample(nn.Module):
78
+ def __init__(self, kernel, factor=2):
79
+ super().__init__()
80
+
81
+ self.factor = factor
82
+ kernel = make_kernel(kernel)
83
+ self.register_buffer('kernel', kernel)
84
+
85
+ p = kernel.shape[0] - factor
86
+
87
+ pad0 = (p + 1) // 2
88
+ pad1 = p // 2
89
+
90
+ self.pad = (pad0, pad1)
91
+
92
+ def forward(self, input):
93
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
94
+
95
+ return out
96
+
97
+
98
+ class Blur(nn.Module):
99
+ def __init__(self, kernel, pad, upsample_factor=1):
100
+ super().__init__()
101
+
102
+ kernel = make_kernel(kernel)
103
+
104
+ if upsample_factor > 1:
105
+ kernel = kernel * (upsample_factor ** 2)
106
+
107
+ self.register_buffer('kernel', kernel)
108
+
109
+ self.pad = pad
110
+
111
+ def forward(self, input):
112
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
113
+
114
+ return out
115
+
116
+
117
+ class EqualConv2d(nn.Module):
118
+ def __init__(
119
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
120
+ ):
121
+ super().__init__()
122
+
123
+ self.weight = nn.Parameter(
124
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
125
+ )
126
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
127
+
128
+ self.stride = stride
129
+ self.padding = padding
130
+
131
+ if bias:
132
+ self.bias = nn.Parameter(torch.zeros(out_channel))
133
+
134
+ else:
135
+ self.bias = None
136
+
137
+ def forward(self, input):
138
+ out = F.conv2d(
139
+ input,
140
+ self.weight * self.scale,
141
+ bias=self.bias,
142
+ stride=self.stride,
143
+ padding=self.padding,
144
+ )
145
+
146
+ return out
147
+
148
+ def __repr__(self):
149
+ return (
150
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
151
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
152
+ )
153
+
154
+
155
+ class EqualLinear(nn.Module):
156
+ def __init__(
157
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
158
+ ):
159
+ super().__init__()
160
+
161
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
162
+
163
+ if bias:
164
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
165
+
166
+ else:
167
+ self.bias = None
168
+
169
+ self.activation = activation
170
+
171
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
172
+ self.lr_mul = lr_mul
173
+
174
+ def forward(self, input):
175
+ bias = self.bias*self.lr_mul if self.bias is not None else None
176
+ if self.activation:
177
+ out = F.linear(input, self.weight * self.scale)
178
+ out = fused_leaky_relu(out, bias)
179
+
180
+ else:
181
+ out = F.linear(
182
+ input, self.weight * self.scale, bias=bias
183
+ )
184
+
185
+ return out
186
+
187
+ def __repr__(self):
188
+ return (
189
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
190
+ )
191
+
192
+
193
+ class ScaledLeakyReLU(nn.Module):
194
+ def __init__(self, negative_slope=0.2):
195
+ super().__init__()
196
+
197
+ self.negative_slope = negative_slope
198
+
199
+ def forward(self, input):
200
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
201
+
202
+ return out * math.sqrt(2)
203
+
204
+
205
+ class ModulatedConv2d(nn.Module):
206
+ def __init__(
207
+ self,
208
+ in_channel,
209
+ out_channel,
210
+ kernel_size,
211
+ style_dim,
212
+ use_style=True,
213
+ demodulate=True,
214
+ upsample=False,
215
+ downsample=False,
216
+ blur_kernel=[1, 3, 3, 1],
217
+ ):
218
+ super().__init__()
219
+
220
+ self.eps = 1e-8
221
+ self.kernel_size = kernel_size
222
+ self.in_channel = in_channel
223
+ self.out_channel = out_channel
224
+ self.upsample = upsample
225
+ self.downsample = downsample
226
+ self.use_style = use_style
227
+
228
+ if upsample:
229
+ factor = 2
230
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
231
+ pad0 = (p + 1) // 2 + factor - 1
232
+ pad1 = p // 2 + 1
233
+
234
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
235
+
236
+ if downsample:
237
+ factor = 2
238
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
239
+ pad0 = (p + 1) // 2
240
+ pad1 = p // 2
241
+
242
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
243
+
244
+ fan_in = in_channel * kernel_size ** 2
245
+ self.scale = 1 / math.sqrt(fan_in)
246
+ self.padding = kernel_size // 2
247
+
248
+ self.weight = nn.Parameter(
249
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
250
+ )
251
+
252
+ if use_style:
253
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
254
+ else:
255
+ self.modulation = nn.Parameter(torch.Tensor(1, 1, in_channel, 1, 1).fill_(1))
256
+
257
+ self.demodulate = demodulate
258
+
259
+ def __repr__(self):
260
+ return (
261
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
262
+ f'upsample={self.upsample}, downsample={self.downsample})'
263
+ )
264
+
265
+ def forward(self, input, style):
266
+ batch, in_channel, height, width = input.shape
267
+
268
+ if self.use_style:
269
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
270
+ weight = self.scale * self.weight * style
271
+ else:
272
+ weight = self.scale * self.weight.expand(batch,-1,-1,-1,-1) * self.modulation
273
+
274
+ if self.demodulate:
275
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
276
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
277
+
278
+ weight = weight.view(
279
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
280
+ )
281
+
282
+ if self.upsample:
283
+ input = input.view(1, batch * in_channel, height, width)
284
+ weight = weight.view(
285
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
286
+ )
287
+ weight = weight.transpose(1, 2).reshape(
288
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
289
+ )
290
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
291
+ _, _, height, width = out.shape
292
+ out = out.view(batch, self.out_channel, height, width)
293
+ out = self.blur(out)
294
+
295
+ elif self.downsample:
296
+ input = self.blur(input)
297
+ _, _, height, width = input.shape
298
+ input = input.view(1, batch * in_channel, height, width)
299
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
300
+ _, _, height, width = out.shape
301
+ out = out.view(batch, self.out_channel, height, width)
302
+
303
+ else:
304
+ input = input.view(1, batch * in_channel, height, width)
305
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
306
+ _, _, height, width = out.shape
307
+ out = out.view(batch, self.out_channel, height, width)
308
+
309
+ return out
310
+
311
+
312
+ class NoiseInjection(nn.Module):
313
+ def __init__(self):
314
+ super().__init__()
315
+
316
+ self.weight = nn.Parameter(torch.zeros(1))
317
+
318
+ def forward(self, image, noise=None):
319
+ if noise is None:
320
+ batch, _, height, width = image.shape
321
+ noise = image.new_empty(batch, 1, height, width).normal_()
322
+
323
+ return image + self.weight * noise
324
+
325
+
326
+ class ConstantInput(nn.Module):
327
+ def __init__(self, style_dim):
328
+ super().__init__()
329
+
330
+ self.input = nn.Parameter(torch.randn(1, style_dim))
331
+
332
+ def forward(self, input):
333
+ batch = input.shape[0]
334
+ out = self.input.repeat(batch, n_latent)
335
+
336
+ return out
337
+
338
+
339
+ class StyledConv(nn.Module):
340
+ def __init__(
341
+ self,
342
+ in_channel,
343
+ out_channel,
344
+ kernel_size,
345
+ style_dim,
346
+ use_style=True,
347
+ upsample=False,
348
+ downsample=False,
349
+ blur_kernel=[1, 3, 3, 1],
350
+ demodulate=True,
351
+ ):
352
+ super().__init__()
353
+ self.use_style = use_style
354
+
355
+ self.conv = ModulatedConv2d(
356
+ in_channel,
357
+ out_channel,
358
+ kernel_size,
359
+ style_dim,
360
+ use_style=use_style,
361
+ upsample=upsample,
362
+ downsample=downsample,
363
+ blur_kernel=blur_kernel,
364
+ demodulate=demodulate,
365
+ )
366
+
367
+ #if use_style:
368
+ # self.noise = NoiseInjection()
369
+ #else:
370
+ # self.noise = None
371
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
372
+ # self.activate = ScaledLeakyReLU(0.2)
373
+ self.activate = FusedLeakyReLU(out_channel)
374
+
375
+ def forward(self, input, style=None, noise=None):
376
+ out = self.conv(input, style)
377
+ #if self.use_style:
378
+ # out = self.noise(out, noise=noise)
379
+ # out = out + self.bias
380
+ out = self.activate(out)
381
+
382
+ return out
383
+
384
+
385
+ class StyledResBlock(nn.Module):
386
+ def __init__(self, in_channel, style_dim, blur_kernel=[1, 3, 3, 1], demodulate=True):
387
+ super().__init__()
388
+
389
+ self.conv1 = StyledConv(in_channel, in_channel, 3, style_dim, upsample=False, blur_kernel=blur_kernel, demodulate=demodulate)
390
+ self.conv2 = StyledConv(in_channel, in_channel, 3, style_dim, upsample=False, blur_kernel=blur_kernel, demodulate=demodulate)
391
+
392
+ def forward(self, input, style):
393
+ out = self.conv1(input, style)
394
+ out = self.conv2(out, style)
395
+ out = (out + input) / math.sqrt(2)
396
+
397
+ return out
398
+
399
+ class ToRGB(nn.Module):
400
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
401
+ super().__init__()
402
+
403
+ if upsample:
404
+ self.upsample = Upsample(blur_kernel)
405
+
406
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
407
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
408
+
409
+ def forward(self, input, style, skip=None):
410
+ out = self.conv(input, style)
411
+ out = out + self.bias
412
+
413
+ if skip is not None:
414
+ skip = self.upsample(skip)
415
+
416
+ out = out + skip
417
+
418
+ return out
419
+
420
+
421
+ class Generator(nn.Module):
422
+ def __init__(
423
+ self,
424
+ size,
425
+ num_down,
426
+ latent_dim,
427
+ n_mlp,
428
+ n_res,
429
+ channel_multiplier=1,
430
+ blur_kernel=[1, 3, 3, 1],
431
+ lr_mlp=0.01,
432
+ ):
433
+ super().__init__()
434
+ self.size = size
435
+
436
+ style_dim = 512
437
+
438
+ mapping = [EqualLinear(latent_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu')]
439
+ for i in range(n_mlp-1):
440
+ mapping.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'))
441
+
442
+ self.mapping = nn.Sequential(*mapping)
443
+
444
+ self.encoder = Encoder(size, latent_dim, num_down, n_res, channel_multiplier)
445
+
446
+ self.log_size = int(math.log(size, 2)) #7
447
+ in_log_size = self.log_size - num_down #7-2 or 7-3
448
+ in_size = 2 ** in_log_size
449
+
450
+ in_channel = channels[in_size]
451
+ self.adain_bottleneck = nn.ModuleList()
452
+ for i in range(n_res):
453
+ self.adain_bottleneck.append(StyledResBlock(in_channel, style_dim))
454
+
455
+ self.conv1 = StyledConv(in_channel, in_channel, 3, style_dim, blur_kernel=blur_kernel)
456
+ self.to_rgb1 = ToRGB(in_channel, style_dim, upsample=False)
457
+
458
+ self.num_layers = (self.log_size - in_log_size) * 2 + 1 #7
459
+
460
+ self.convs = nn.ModuleList()
461
+ self.upsamples = nn.ModuleList()
462
+ self.to_rgbs = nn.ModuleList()
463
+ #self.noises = nn.Module()
464
+
465
+
466
+ #for layer_idx in range(self.num_layers):
467
+ # res = (layer_idx + (in_log_size*2+1)) // 2 #2,3,3,5 ... -> 4,5,5,6 ...
468
+ # shape = [1, 1, 2 ** res, 2 ** res]
469
+ # self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
470
+
471
+ for i in range(in_log_size+1, self.log_size + 1):
472
+ out_channel = channels[2 ** i]
473
+
474
+ self.convs.append(
475
+ StyledConv(
476
+ in_channel,
477
+ out_channel,
478
+ 3,
479
+ style_dim,
480
+ upsample=True,
481
+ blur_kernel=blur_kernel,
482
+ )
483
+ )
484
+
485
+ self.convs.append(
486
+ StyledConv(
487
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
488
+ )
489
+ )
490
+
491
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
492
+
493
+ in_channel = out_channel
494
+
495
+ def style_encode(self, input):
496
+ return self.encoder(input)[1]
497
+
498
+ def encode(self, input):
499
+ return self.encoder(input)
500
+
501
+ def forward(self, input, z=None):
502
+ content, style = self.encode(input)
503
+ if z is None:
504
+ out = self.decode(content, style)
505
+ else:
506
+ out = self.decode(content, z)
507
+
508
+ return out, content, style
509
+
510
+ def decode(self, input, styles, use_mapping=True):
511
+ if use_mapping:
512
+ styles = self.mapping(styles)
513
+ #styles = styles.repeat(1, n_latent).view(styles.size(0), n_latent, -1)
514
+ out = input
515
+ i = 0
516
+ for conv in self.adain_bottleneck:
517
+ out = conv(out, styles)
518
+ i += 1
519
+
520
+ out = self.conv1(out, styles, noise=None)
521
+ skip = self.to_rgb1(out, styles)
522
+ i += 2
523
+
524
+ for conv1, conv2, to_rgb in zip(
525
+ self.convs[::2], self.convs[1::2], self.to_rgbs
526
+ ):
527
+ out = conv1(out, styles, noise=None)
528
+ out = conv2(out, styles, noise=None)
529
+ skip = to_rgb(out, styles, skip)
530
+
531
+ i += 3
532
+
533
+ image = skip
534
+ return image
535
+
536
+ class ConvLayer(nn.Sequential):
537
+ def __init__(
538
+ self,
539
+ in_channel,
540
+ out_channel,
541
+ kernel_size,
542
+ downsample=False,
543
+ blur_kernel=[1, 3, 3, 1],
544
+ bias=True,
545
+ activate=True,
546
+ ):
547
+ layers = []
548
+
549
+ if downsample:
550
+ factor = 2
551
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
552
+ pad0 = (p + 1) // 2
553
+ pad1 = p // 2
554
+
555
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
556
+
557
+ stride = 2
558
+ self.padding = 0
559
+
560
+ else:
561
+ stride = 1
562
+ self.padding = kernel_size // 2
563
+
564
+ layers.append(
565
+ EqualConv2d(
566
+ in_channel,
567
+ out_channel,
568
+ kernel_size,
569
+ padding=self.padding,
570
+ stride=stride,
571
+ bias=bias and not activate,
572
+ )
573
+ )
574
+
575
+ if activate:
576
+ if bias:
577
+ layers.append(FusedLeakyReLU(out_channel))
578
+
579
+ else:
580
+ layers.append(ScaledLeakyReLU(0.2))
581
+
582
+ super().__init__(*layers)
583
+
584
+ class InResBlock(nn.Module):
585
+ def __init__(self, in_channel, blur_kernel=[1, 3, 3, 1]):
586
+ super().__init__()
587
+
588
+ self.conv1 = StyledConv(in_channel, in_channel, 3, None, blur_kernel=blur_kernel, demodulate=True, use_style=False)
589
+ self.conv2 = StyledConv(in_channel, in_channel, 3, None, blur_kernel=blur_kernel, demodulate=True, use_style=False)
590
+
591
+ def forward(self, input):
592
+ out = self.conv1(input, None)
593
+ out = self.conv2(out, None)
594
+ out = (out + input) / math.sqrt(2)
595
+
596
+ return out
597
+
598
+ class ResBlock(nn.Module):
599
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True):
600
+ super().__init__()
601
+
602
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
603
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample)
604
+
605
+ if downsample or in_channel != out_channel:
606
+ self.skip = ConvLayer(
607
+ in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
608
+ )
609
+ else:
610
+ self.skip = None
611
+
612
+ def forward(self, input):
613
+ out = self.conv1(input)
614
+ out = self.conv2(out)
615
+
616
+ if self.skip is None:
617
+ skip = input
618
+ else:
619
+ skip = self.skip(input)
620
+ out = (out + skip) / math.sqrt(2)
621
+
622
+ return out
623
+
624
+ class Discriminator(nn.Module):
625
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
626
+ super().__init__()
627
+ self.size = size
628
+ l_branch = self.make_net_(32)
629
+ l_branch += [ConvLayer(channels[32], 1, 1, activate=False)]
630
+ self.l_branch = nn.Sequential(*l_branch)
631
+
632
+
633
+ g_branch = self.make_net_(8)
634
+ self.g_branch = nn.Sequential(*g_branch)
635
+ self.g_adv = ConvLayer(channels[8], 1, 1, activate=False)
636
+
637
+ self.g_std = nn.Sequential(ConvLayer(channels[8], channels[4], 3, downsample=True),
638
+ nn.Flatten(),
639
+ EqualLinear(channels[4] * 4 * 4, 128, activation='fused_lrelu'),
640
+ )
641
+ self.g_final = EqualLinear(128, 1, activation=False)
642
+
643
+
644
+ def make_net_(self, out_size):
645
+ size = self.size
646
+ convs = [ConvLayer(3, channels[size], 1)]
647
+ log_size = int(math.log(size, 2))
648
+ out_log_size = int(math.log(out_size, 2))
649
+ in_channel = channels[size]
650
+
651
+ for i in range(log_size, out_log_size, -1):
652
+ out_channel = channels[2 ** (i - 1)]
653
+ convs.append(ResBlock(in_channel, out_channel))
654
+ in_channel = out_channel
655
+
656
+ return convs
657
+
658
+ def forward(self, x):
659
+ l_adv = self.l_branch(x)
660
+
661
+ g_act = self.g_branch(x)
662
+ g_adv = self.g_adv(g_act)
663
+
664
+ output = self.g_std(g_act)
665
+ g_stddev = torch.sqrt(output.var(0, keepdim=True, unbiased=False) + 1e-8).repeat(x.size(0),1)
666
+ g_std = self.g_final(g_stddev)
667
+ return [l_adv, g_adv, g_std]
668
+
669
+
670
+
671
+ class Encoder(nn.Module):
672
+ def __init__(self, size, latent_dim, num_down, n_res, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
673
+ super().__init__()
674
+ stem = [ConvLayer(3, channels[size], 1)]
675
+ log_size = int(math.log(size, 2))
676
+ in_channel = channels[size]
677
+
678
+ for i in range(log_size, log_size-num_down, -1):
679
+ out_channel = channels[2 ** (i - 1)]
680
+ stem.append(ResBlock(in_channel, out_channel, downsample=True))
681
+ in_channel = out_channel
682
+ stem += [ResBlock(in_channel, in_channel, downsample=False) for i in range(n_res)]
683
+ self.stem = nn.Sequential(*stem)
684
+
685
+ self.content = nn.Sequential(
686
+ ConvLayer(in_channel, in_channel, 1),
687
+ ConvLayer(in_channel, in_channel, 1)
688
+ )
689
+ style = []
690
+ for i in range(log_size-num_down, 2, -1):
691
+ out_channel = channels[2 ** (i - 1)]
692
+ style.append(ConvLayer(in_channel, out_channel, 3, downsample=True))
693
+ in_channel = out_channel
694
+ style += [
695
+ nn.Flatten(),
696
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
697
+ EqualLinear(channels[4], latent_dim),
698
+ ]
699
+ self.style = nn.Sequential(*style)
700
+
701
+
702
+ def forward(self, input):
703
+ act = self.stem(input)
704
+ content = self.content(act)
705
+ style = self.style(act)
706
+ return content, style
707
+
708
+ class StyleEncoder(nn.Module):
709
+ def __init__(self, size, style_dim, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
710
+ super().__init__()
711
+ convs = [ConvLayer(3, channels[size], 1)]
712
+
713
+ log_size = int(math.log(size, 2))
714
+
715
+ in_channel = channels[size]
716
+ num_down = 6
717
+
718
+ for i in range(log_size, log_size-num_down, -1):
719
+ w = 2 ** (i - 1)
720
+ out_channel = channels[w]
721
+ convs.append(ConvLayer(in_channel, out_channel, 3, downsample=True))
722
+ in_channel = out_channel
723
+
724
+ convs += [
725
+ nn.Flatten(),
726
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), EqualLinear(channels[4], style_dim),
727
+ ]
728
+ self.convs = nn.Sequential(*convs)
729
+
730
+ def forward(self, input):
731
+ style = self.convs(input)
732
+ return style.view(input.size(0), -1)
733
+
734
+ class LatDiscriminator(nn.Module):
735
+ def __init__(self, style_dim):
736
+ super().__init__()
737
+
738
+ fc = [EqualLinear(style_dim, 256, activation='fused_lrelu')]
739
+ for i in range(3):
740
+ fc += [EqualLinear(256, 256, activation='fused_lrelu')]
741
+ fc += [FCMinibatchStd(256, 256)]
742
+ fc += [EqualLinear(256, 1)]
743
+ self.fc = nn.Sequential(*fc)
744
+
745
+ def forward(self, input):
746
+ return [self.fc(input), ]
747
+
748
+ class FCMinibatchStd(nn.Module):
749
+ def __init__(self, in_channel, out_channel):
750
+ super().__init__()
751
+ self.fc = EqualLinear(in_channel+1, out_channel, activation='fused_lrelu')
752
+
753
+ def forward(self, out):
754
+ stddev = torch.sqrt(out.var(0, unbiased=False) + 1e-8).mean().view(1,1).repeat(out.size(0), 1)
755
+ out = torch.cat([out, stddev], 1)
756
+ out = self.fc(out)
757
+ return out
op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
op/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (262 Bytes). View file
 
op/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (217 Bytes). View file
 
op/__pycache__/fused_act.cpython-310.pyc ADDED
Binary file (1.32 kB). View file
 
op/__pycache__/fused_act.cpython-38.pyc ADDED
Binary file (1.29 kB). View file
 
op/__pycache__/upfirdn2d.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
op/__pycache__/upfirdn2d.cpython-38.pyc ADDED
Binary file (1.44 kB). View file
 
op/fused_act.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ class FusedLeakyReLU(nn.Module):
12
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
13
+ super().__init__()
14
+
15
+ self.bias = nn.Parameter(torch.zeros(channel))
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, input):
20
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
+
22
+
23
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
24
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
25
+ if input.ndim == 3:
26
+ return (
27
+ F.leaky_relu(
28
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
29
+ )
30
+ * scale
31
+ )
32
+ else:
33
+ return (
34
+ F.leaky_relu(
35
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
36
+ )
37
+ * scale
38
+ )
39
+
op/upfirdn2d.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
59
+
60
+ return out.view(-1, channel, out_h, out_w)
predict.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ GANsNRoses: Selfie to Anime https://github.com/mchong6/GANsNRoses"""
2
+ import os
3
+ import tempfile
4
+ from base64 import b64encode
5
+
6
+ import cv2
7
+ import dlib
8
+ import kornia.augmentation as K
9
+ import moviepy.video.io.ImageSequenceClip
10
+ import numpy as np
11
+ import scipy
12
+ import torch
13
+ from aubio import source, tempo
14
+ from cog import BasePredictor, File, Input, Path
15
+ from PIL import Image
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.utils import data
19
+ from torchvision import transforms, utils
20
+ from tqdm import tqdm
21
+
22
+ from model import *
23
+ from util import *
24
+
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+
28
+ class Predictor(BasePredictor):
29
+ def setup(self):
30
+ """Load the model into memory to make running multiple predictions efficient"""
31
+
32
+ # params
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ def predict(
36
+ self,
37
+ inpath: Path = Input(description="Input image or short video", default=None),
38
+ ) -> Path:
39
+
40
+ # get input file
41
+ inpath = str(inpath)
42
+
43
+ # model setup
44
+ latent_dim = 8
45
+ n_mlp = 5
46
+ num_down = 3
47
+
48
+ G_A2B = (
49
+ Generator(
50
+ 256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=0.01, n_res=1
51
+ )
52
+ .to(self.device)
53
+ .eval()
54
+ )
55
+ ckpt = torch.load("GNR_checkpoint.pt", map_location=self.device)
56
+ G_A2B.load_state_dict(ckpt["G_A2B_ema"])
57
+
58
+ test_transform = transforms.Compose(
59
+ [
60
+ transforms.Resize((256, 256)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(
63
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True
64
+ ),
65
+ ]
66
+ )
67
+
68
+ if "mp4" in inpath: # video
69
+ print(f"*** Processing video input: {inpath} ***")
70
+
71
+ # use normal mode for demo purposes (see original repo for other modes)
72
+ mode = "normal"
73
+
74
+ # Frame numbers and length of output video
75
+ start_frame = 0
76
+ end_frame = None
77
+ frame_num = 0
78
+ mp4_fps = 30
79
+ faces = None
80
+ smoothing_sec = 0.7
81
+ eig_dir_idx = 1 # first eig isnt good so we skip it
82
+
83
+ frames = []
84
+ reader = cv2.VideoCapture(inpath)
85
+ num_frames = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
86
+
87
+ all_latents = torch.randn([8, latent_dim]).to(self.device)
88
+ in_latent = all_latents
89
+
90
+ # Face detector
91
+ face_detector = dlib.get_frontal_face_detector()
92
+
93
+ assert start_frame < num_frames - 1
94
+ end_frame = end_frame if end_frame else num_frames
95
+
96
+ while reader.isOpened():
97
+ _, image = reader.read()
98
+ if image is None:
99
+ break
100
+
101
+ if frame_num < start_frame:
102
+ continue
103
+ # Image size
104
+ height, width = image.shape[:2]
105
+
106
+ # 2. Detect with dlib
107
+ if faces is None:
108
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
109
+ faces = face_detector(gray, 1)
110
+ if len(faces):
111
+ # For now only take biggest face
112
+ face = faces[0]
113
+
114
+ # --- Prediction ---------------------------------------------------
115
+ # Face crop with dlib and bounding box scale enlargement
116
+ x, y, size = get_boundingbox(face, width, height)
117
+ cropped_face = image[y : y + size, x : x + size]
118
+ cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)
119
+ cropped_face = Image.fromarray(cropped_face)
120
+ frame = test_transform(cropped_face).unsqueeze(0).to(self.device)
121
+
122
+ with torch.no_grad():
123
+ A2B_content, A2B_style = G_A2B.encode(frame)
124
+
125
+ in_latent = all_latents
126
+
127
+ fake_A2B = G_A2B.decode(A2B_content.repeat(8, 1, 1, 1), in_latent)
128
+
129
+ fake_A2B = torch.cat([fake_A2B[:4], frame, fake_A2B[4:]], 0)
130
+
131
+ fake_A2B = utils.make_grid(
132
+ fake_A2B.cpu(), normalize=True, range=(-1, 1), nrow=3
133
+ )
134
+
135
+ # concatenate original image top
136
+ fake_A2B = fake_A2B.permute(1, 2, 0).cpu().numpy()
137
+ frames.append(fake_A2B * 255)
138
+
139
+ frame_num += 1
140
+
141
+ clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(
142
+ frames, fps=mp4_fps
143
+ )
144
+
145
+ # save to temporary file. hack to make sure ffmpeg works
146
+ output_path = Path(tempfile.mkdtemp()) / "output.mp4"
147
+ clip.write_videofile(str(output_path))
148
+ print(f'saving to {output_path}')
149
+
150
+ return output_path
151
+
152
+ # else, just process the image
153
+ print(f"*** Processing image input: {inpath} ***")
154
+ num_styles = 5
155
+ style = torch.randn([num_styles, latent_dim]).to(self.device)
156
+
157
+ # read input image
158
+ image = cv2.imread(inpath)
159
+ height, width = image.shape[:2]
160
+
161
+ # Detect with dlib
162
+ face_detector = dlib.get_frontal_face_detector()
163
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
164
+ # grab first face
165
+ face = face_detector(gray, 1)[0]
166
+
167
+ # Face crop with dlib and bounding box scale enlargement
168
+ x, y, size = get_boundingbox(face, width, height)
169
+ cropped_face = image[y : y + size, x : x + size]
170
+ cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)
171
+ cropped_face = Image.fromarray(cropped_face)
172
+
173
+ real_A = cropped_face
174
+ real_A = test_transform(real_A).unsqueeze(0).to(self.device)
175
+
176
+ with torch.no_grad():
177
+ A2B_content, _ = G_A2B.encode(real_A)
178
+ fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles, 1, 1, 1), style)
179
+ A2B = torch.cat([real_A, fake_A2B], 0)
180
+
181
+ # create and save output
182
+ output = utils.make_grid(A2B.cpu(), normalize=True, range=(-1, 1), nrow=10)
183
+ output_path = Path(tempfile.mkdtemp()) / "output.png"
184
+ torchvision.utils.save_image(output, output_path)
185
+ print(f'saving to {output_path}')
186
+
187
+ return output_path
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ gdown
3
+ kornia
4
+ scipy
5
+ opencv-python
6
+ moviepy
7
+ lpips
8
+ ninja
9
+ gradio
10
+ torchvision
train.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import random
4
+ import os
5
+ from util import *
6
+ import numpy as np
7
+ import torch
8
+ torch.backends.cudnn.benchmark = True
9
+ from torch import nn, autograd
10
+ from torch import optim
11
+ from torch.nn import functional as F
12
+ from torch.utils import data
13
+ import torch.distributed as dist
14
+
15
+ from torchvision import transforms, utils
16
+ from tqdm import tqdm
17
+ from torch.optim import lr_scheduler
18
+ import copy
19
+ import kornia.augmentation as K
20
+ import kornia
21
+ import lpips
22
+
23
+ from model import *
24
+ from dataset import ImageFolder
25
+ from distributed import (
26
+ get_rank,
27
+ synchronize,
28
+ reduce_loss_dict,
29
+ reduce_sum,
30
+ get_world_size,
31
+ )
32
+
33
+ mse_criterion = nn.MSELoss()
34
+
35
+
36
+ def test(args, genA2B, genB2A, testA_loader, testB_loader, name, step):
37
+ testA_loader = iter(testA_loader)
38
+ testB_loader = iter(testB_loader)
39
+ with torch.no_grad():
40
+ test_sample_num = 16
41
+
42
+ genA2B.eval(), genB2A.eval()
43
+ A2B = []
44
+ B2A = []
45
+ for i in range(test_sample_num):
46
+ real_A = testA_loader.next()
47
+ real_B = testB_loader.next()
48
+
49
+ real_A, real_B = real_A.cuda(), real_B.cuda()
50
+
51
+ A2B_content, A2B_style = genA2B.encode(real_A)
52
+ B2A_content, B2A_style = genB2A.encode(real_B)
53
+
54
+ if i % 2 == 0:
55
+ A2B_mod1 = torch.randn([1, args.latent_dim]).cuda()
56
+ B2A_mod1 = torch.randn([1, args.latent_dim]).cuda()
57
+ A2B_mod2 = torch.randn([1, args.latent_dim]).cuda()
58
+ B2A_mod2 = torch.randn([1, args.latent_dim]).cuda()
59
+
60
+ fake_B2B, _, _ = genA2B(real_B)
61
+ fake_A2A, _, _ = genB2A(real_A)
62
+
63
+ colsA = [real_A, fake_A2A]
64
+ colsB = [real_B, fake_B2B]
65
+
66
+ fake_A2B_1 = genA2B.decode(A2B_content, A2B_mod1)
67
+ fake_B2A_1 = genB2A.decode(B2A_content, B2A_mod1)
68
+
69
+ fake_A2B_2 = genA2B.decode(A2B_content, A2B_mod2)
70
+ fake_B2A_2 = genB2A.decode(B2A_content, B2A_mod2)
71
+
72
+ fake_A2B_3 = genA2B.decode(A2B_content, B2A_style)
73
+ fake_B2A_3 = genB2A.decode(B2A_content, A2B_style)
74
+
75
+ colsA += [fake_A2B_3, fake_A2B_1, fake_A2B_2]
76
+ colsB += [fake_B2A_3, fake_B2A_1, fake_B2A_2]
77
+
78
+ fake_A2B2A, _, _ = genB2A(fake_A2B_3, A2B_style)
79
+ fake_B2A2B, _, _ = genA2B(fake_B2A_3, B2A_style)
80
+ colsA.append(fake_A2B2A)
81
+ colsB.append(fake_B2A2B)
82
+
83
+ fake_A2B2A, _, _ = genB2A(fake_A2B_1, A2B_style)
84
+ fake_B2A2B, _, _ = genA2B(fake_B2A_1, B2A_style)
85
+ colsA.append(fake_A2B2A)
86
+ colsB.append(fake_B2A2B)
87
+
88
+ fake_A2B2A, _, _ = genB2A(fake_A2B_2, A2B_style)
89
+ fake_B2A2B, _, _ = genA2B(fake_B2A_2, B2A_style)
90
+ colsA.append(fake_A2B2A)
91
+ colsB.append(fake_B2A2B)
92
+
93
+ fake_A2B2A, _, _ = genB2A(fake_A2B_1)
94
+ fake_B2A2B, _, _ = genA2B(fake_B2A_1)
95
+ colsA.append(fake_A2B2A)
96
+ colsB.append(fake_B2A2B)
97
+
98
+ colsA = torch.cat(colsA, 2).detach().cpu()
99
+ colsB = torch.cat(colsB, 2).detach().cpu()
100
+
101
+ A2B.append(colsA)
102
+ B2A.append(colsB)
103
+ A2B = torch.cat(A2B, 0)
104
+ B2A = torch.cat(B2A, 0)
105
+
106
+ utils.save_image(A2B, f'{im_path}/{name}_A2B_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16)
107
+ utils.save_image(B2A, f'{im_path}/{name}_B2A_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16)
108
+
109
+ genA2B.train(), genB2A.train()
110
+
111
+
112
+ def train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device):
113
+ G_A2B.train(), G_B2A.train(), D_A.train(), D_B.train()
114
+ trainA_loader = sample_data(trainA_loader)
115
+ trainB_loader = sample_data(trainB_loader)
116
+ G_scheduler = lr_scheduler.StepLR(G_optim, step_size=100000, gamma=0.5)
117
+ D_scheduler = lr_scheduler.StepLR(D_optim, step_size=100000, gamma=0.5)
118
+
119
+ pbar = range(args.iter)
120
+
121
+ if get_rank() == 0:
122
+ pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.1)
123
+
124
+ loss_dict = {}
125
+ mean_path_length_A2B = 0
126
+ mean_path_length_B2A = 0
127
+
128
+ if args.distributed:
129
+ G_A2B_module = G_A2B.module
130
+ G_B2A_module = G_B2A.module
131
+ D_A_module = D_A.module
132
+ D_B_module = D_B.module
133
+ D_L_module = D_L.module
134
+
135
+ else:
136
+ G_A2B_module = G_A2B
137
+ G_B2A_module = G_B2A
138
+ D_A_module = D_A
139
+ D_B_module = D_B
140
+ D_L_module = D_L
141
+
142
+ for idx in pbar:
143
+ i = idx + args.start_iter
144
+
145
+ if i > args.iter:
146
+ print('Done!')
147
+ break
148
+
149
+ ori_A = next(trainA_loader)
150
+ ori_B = next(trainB_loader)
151
+ if isinstance(ori_A, list):
152
+ ori_A = ori_A[0]
153
+ if isinstance(ori_B, list):
154
+ ori_B = ori_B[0]
155
+
156
+ ori_A = ori_A.to(device)
157
+ ori_B = ori_B.to(device)
158
+ aug_A = aug(ori_A)
159
+ aug_B = aug(ori_B)
160
+ A = aug(ori_A[[np.random.randint(args.batch)]].expand_as(ori_A))
161
+ B = aug(ori_B[[np.random.randint(args.batch)]].expand_as(ori_B))
162
+
163
+ if i % args.d_reg_every == 0:
164
+ aug_A.requires_grad = True
165
+ aug_B.requires_grad = True
166
+
167
+ A2B_content, A2B_style = G_A2B.encode(A)
168
+ B2A_content, B2A_style = G_B2A.encode(B)
169
+
170
+ # get new style
171
+ aug_A2B_style = G_B2A.style_encode(aug_B)
172
+ aug_B2A_style = G_A2B.style_encode(aug_A)
173
+ rand_A2B_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_()
174
+ rand_B2A_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_()
175
+
176
+ # styles
177
+ idx = torch.randperm(2*args.batch)
178
+ input_A2B_style = torch.cat([rand_A2B_style, aug_A2B_style], 0)[idx][:args.batch]
179
+
180
+ idx = torch.randperm(2*args.batch)
181
+ input_B2A_style = torch.cat([rand_B2A_style, aug_B2A_style], 0)[idx][:args.batch]
182
+
183
+ fake_A2B = G_A2B.decode(A2B_content, input_A2B_style)
184
+ fake_B2A = G_B2A.decode(B2A_content, input_B2A_style)
185
+
186
+
187
+ # train disc
188
+ real_A_logit = D_A(aug_A)
189
+ real_B_logit = D_B(aug_B)
190
+ real_L_logit1 = D_L(rand_A2B_style)
191
+ real_L_logit2 = D_L(rand_B2A_style)
192
+
193
+ fake_B_logit = D_B(fake_A2B.detach())
194
+ fake_A_logit = D_A(fake_B2A.detach())
195
+ fake_L_logit1 = D_L(aug_A2B_style.detach())
196
+ fake_L_logit2 = D_L(aug_B2A_style.detach())
197
+
198
+ # global loss
199
+ D_loss = d_logistic_loss(real_A_logit, fake_A_logit) +\
200
+ d_logistic_loss(real_B_logit, fake_B_logit) +\
201
+ d_logistic_loss(real_L_logit1, fake_L_logit1) +\
202
+ d_logistic_loss(real_L_logit2, fake_L_logit2)
203
+
204
+ loss_dict['D_adv'] = D_loss
205
+
206
+ if i % args.d_reg_every == 0:
207
+ r1_A_loss = d_r1_loss(real_A_logit, aug_A)
208
+ r1_B_loss = d_r1_loss(real_B_logit, aug_B)
209
+ r1_L_loss = d_r1_loss(real_L_logit1, rand_A2B_style) + d_r1_loss(real_L_logit2, rand_B2A_style)
210
+ r1_loss = r1_A_loss + r1_B_loss + r1_L_loss
211
+ D_r1_loss = (args.r1 / 2 * r1_loss * args.d_reg_every)
212
+ D_loss += D_r1_loss
213
+
214
+ D_optim.zero_grad()
215
+ D_loss.backward()
216
+ D_optim.step()
217
+
218
+ #Generator
219
+ # adv loss
220
+ fake_B_logit = D_B(fake_A2B)
221
+ fake_A_logit = D_A(fake_B2A)
222
+ fake_L_logit1 = D_L(aug_A2B_style)
223
+ fake_L_logit2 = D_L(aug_B2A_style)
224
+
225
+ lambda_adv = (1, 1, 1)
226
+ G_adv_loss = 1 * (g_nonsaturating_loss(fake_A_logit, lambda_adv) +\
227
+ g_nonsaturating_loss(fake_B_logit, lambda_adv) +\
228
+ 2*g_nonsaturating_loss(fake_L_logit1, (1,)) +\
229
+ 2*g_nonsaturating_loss(fake_L_logit2, (1,)))
230
+
231
+ # style consis loss
232
+ G_con_loss = 50 * (A2B_style.var(0, unbiased=False).sum() + B2A_style.var(0, unbiased=False).sum())
233
+
234
+ # cycle recon
235
+ A2B2A_content, A2B2A_style = G_B2A.encode(fake_A2B)
236
+ B2A2B_content, B2A2B_style = G_A2B.encode(fake_B2A)
237
+ fake_A2B2A = G_B2A.decode(A2B2A_content, shuffle_batch(A2B_style))
238
+ fake_B2A2B = G_A2B.decode(B2A2B_content, shuffle_batch(B2A_style))
239
+
240
+ G_cycle_loss = 20 * (F.mse_loss(fake_A2B2A, A) + F.mse_loss(fake_B2A2B, B))
241
+ lpips_loss = 10 * (lpips_fn(fake_A2B2A, A).mean() + lpips_fn(fake_B2A2B, B).mean()) #10 for anime
242
+
243
+ # style reconstruction
244
+ G_style_loss = 5 * (mse_criterion(A2B2A_style, input_A2B_style) +\
245
+ mse_criterion(B2A2B_style, input_B2A_style))
246
+
247
+
248
+ G_loss = G_adv_loss + G_cycle_loss + G_con_loss + lpips_loss + G_style_loss
249
+
250
+ loss_dict['G_adv'] = G_adv_loss
251
+ loss_dict['G_con'] = G_con_loss
252
+ loss_dict['G_cycle'] = G_cycle_loss
253
+ loss_dict['lpips'] = lpips_loss
254
+
255
+ G_optim.zero_grad()
256
+ G_loss.backward()
257
+ G_optim.step()
258
+
259
+ G_scheduler.step()
260
+ D_scheduler.step()
261
+
262
+ accumulate(G_A2B_ema, G_A2B_module)
263
+ accumulate(G_B2A_ema, G_B2A_module)
264
+
265
+ loss_reduced = reduce_loss_dict(loss_dict)
266
+ D_adv_loss_val = loss_reduced['D_adv'].mean().item()
267
+
268
+ G_adv_loss_val = loss_reduced['G_adv'].mean().item()
269
+ G_cycle_loss_val = loss_reduced['G_cycle'].mean().item()
270
+ G_con_loss_val = loss_reduced['G_con'].mean().item()
271
+ lpips_val = loss_reduced['lpips'].mean().item()
272
+
273
+ if get_rank() == 0:
274
+ pbar.set_description(
275
+ (
276
+ f'Dadv: {D_adv_loss_val:.2f}; lpips: {lpips_val:.2f} '
277
+ f'Gadv: {G_adv_loss_val:.2f}; Gcycle: {G_cycle_loss_val:.2f}; GMS: {G_con_loss_val:.2f} {G_style_loss.item():.2f}'
278
+ )
279
+ )
280
+
281
+ if i % 1000 == 0:
282
+ with torch.no_grad():
283
+ test(args, G_A2B, G_B2A, testA_loader, testB_loader, 'normal', i)
284
+ test(args, G_A2B_ema, G_B2A_ema, testA_loader, testB_loader, 'ema', i)
285
+
286
+ if (i+1) % 2000 == 0:
287
+ torch.save(
288
+ {
289
+ 'G_A2B': G_A2B_module.state_dict(),
290
+ 'G_B2A': G_B2A_module.state_dict(),
291
+ 'G_A2B_ema': G_A2B_ema.state_dict(),
292
+ 'G_B2A_ema': G_B2A_ema.state_dict(),
293
+ 'D_A': D_A_module.state_dict(),
294
+ 'D_B': D_B_module.state_dict(),
295
+ 'D_L': D_L_module.state_dict(),
296
+ 'G_optim': G_optim.state_dict(),
297
+ 'D_optim': D_optim.state_dict(),
298
+ 'iter': i,
299
+ },
300
+ os.path.join(model_path, 'ck.pt'),
301
+ )
302
+
303
+
304
+ if __name__ == '__main__':
305
+ device = 'cuda'
306
+
307
+ parser = argparse.ArgumentParser()
308
+
309
+ parser.add_argument('--iter', type=int, default=300000)
310
+ parser.add_argument('--batch', type=int, default=4)
311
+ parser.add_argument('--n_sample', type=int, default=64)
312
+ parser.add_argument('--size', type=int, default=256)
313
+ parser.add_argument('--r1', type=float, default=10)
314
+ parser.add_argument('--lambda_cycle', type=int, default=1)
315
+ parser.add_argument('--path_regularize', type=float, default=2)
316
+ parser.add_argument('--path_batch_shrink', type=int, default=2)
317
+ parser.add_argument('--d_reg_every', type=int, default=16)
318
+ parser.add_argument('--g_reg_every', type=int, default=4)
319
+ parser.add_argument('--mixing', type=float, default=0.9)
320
+ parser.add_argument('--ckpt', type=str, default=None)
321
+ parser.add_argument('--lr', type=float, default=2e-3)
322
+ parser.add_argument('--local_rank', type=int, default=0)
323
+ parser.add_argument('--num_down', type=int, default=3)
324
+ parser.add_argument('--name', type=str, required=True)
325
+ parser.add_argument('--d_path', type=str, required=True)
326
+ parser.add_argument('--latent_dim', type=int, default=8)
327
+ parser.add_argument('--lr_mlp', type=float, default=0.01)
328
+ parser.add_argument('--n_res', type=int, default=1)
329
+
330
+ args = parser.parse_args()
331
+
332
+ n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
333
+ args.distributed = False
334
+
335
+ if args.distributed:
336
+ torch.cuda.set_device(args.local_rank)
337
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
338
+ synchronize()
339
+
340
+ save_path = f'./{args.name}'
341
+ im_path = os.path.join(save_path, 'sample')
342
+ model_path = os.path.join(save_path, 'checkpoint')
343
+ os.makedirs(im_path, exist_ok=True)
344
+ os.makedirs(model_path, exist_ok=True)
345
+
346
+ args.n_mlp = 5
347
+
348
+ args.start_iter = 0
349
+
350
+ G_A2B = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device)
351
+ D_A = Discriminator(args.size).to(device)
352
+ G_B2A = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device)
353
+ D_B = Discriminator(args.size).to(device)
354
+ D_L = LatDiscriminator(args.latent_dim).to(device)
355
+ lpips_fn = lpips.LPIPS(net='vgg').to(device)
356
+
357
+ G_A2B_ema = copy.deepcopy(G_A2B).to(device).eval()
358
+ G_B2A_ema = copy.deepcopy(G_B2A).to(device).eval()
359
+
360
+ g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
361
+ d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
362
+
363
+ G_optim = optim.Adam( list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=args.lr, betas=(0, 0.99))
364
+ D_optim = optim.Adam(
365
+ list(D_L.parameters()) + list(D_A.parameters()) + list(D_B.parameters()),
366
+ lr=args.lr, betas=(0**d_reg_ratio, 0.99**d_reg_ratio))
367
+
368
+ if args.ckpt is not None:
369
+ ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
370
+
371
+ try:
372
+ ckpt_name = os.path.basename(args.ckpt)
373
+ args.start_iter = int(os.path.splitext(ckpt_name)[0])
374
+
375
+ except ValueError:
376
+ pass
377
+
378
+ G_A2B.load_state_dict(ckpt['G_A2B'])
379
+ G_B2A.load_state_dict(ckpt['G_B2A'])
380
+ G_A2B_ema.load_state_dict(ckpt['G_A2B_ema'])
381
+ G_B2A_ema.load_state_dict(ckpt['G_B2A_ema'])
382
+ D_A.load_state_dict(ckpt['D_A'])
383
+ D_B.load_state_dict(ckpt['D_B'])
384
+ D_L.load_state_dict(ckpt['D_L'])
385
+
386
+ G_optim.load_state_dict(ckpt['G_optim'])
387
+ D_optim.load_state_dict(ckpt['D_optim'])
388
+ args.start_iter = ckpt['iter']
389
+
390
+ if args.distributed:
391
+ G_A2B = nn.parallel.DistributedDataParallel(
392
+ G_A2B,
393
+ device_ids=[args.local_rank],
394
+ output_device=args.local_rank,
395
+ broadcast_buffers=False,
396
+ )
397
+
398
+ D_A = nn.parallel.DistributedDataParallel(
399
+ D_A,
400
+ device_ids=[args.local_rank],
401
+ output_device=args.local_rank,
402
+ broadcast_buffers=False,
403
+ )
404
+
405
+ G_B2A = nn.parallel.DistributedDataParallel(
406
+ G_B2A,
407
+ device_ids=[args.local_rank],
408
+ output_device=args.local_rank,
409
+ broadcast_buffers=False,
410
+ )
411
+
412
+ D_B = nn.parallel.DistributedDataParallel(
413
+ D_B,
414
+ device_ids=[args.local_rank],
415
+ output_device=args.local_rank,
416
+ broadcast_buffers=False,
417
+ )
418
+ D_L = nn.parallel.DistributedDataParallel(
419
+ D_L,
420
+ device_ids=[args.local_rank],
421
+ output_device=args.local_rank,
422
+ broadcast_buffers=False,
423
+ )
424
+ train_transform = transforms.Compose([
425
+ transforms.ToTensor(),
426
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
427
+ ])
428
+
429
+ test_transform = transforms.Compose([
430
+ transforms.Resize((args.size, args.size)),
431
+ transforms.ToTensor(),
432
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
433
+ ])
434
+
435
+ aug = nn.Sequential(
436
+ K.RandomAffine(degrees=(-20,20), scale=(0.8, 1.2), translate=(0.1, 0.1), shear=0.15),
437
+ kornia.geometry.transform.Resize(256+30),
438
+ K.RandomCrop((256,256)),
439
+ K.RandomHorizontalFlip(),
440
+ )
441
+
442
+
443
+ d_path = args.d_path
444
+ trainA = ImageFolder(os.path.join(d_path, 'trainA'), train_transform)
445
+ trainB = ImageFolder(os.path.join(d_path, 'trainB'), train_transform)
446
+ testA = ImageFolder(os.path.join(d_path, 'testA'), test_transform)
447
+ testB = ImageFolder(os.path.join(d_path, 'testB'), test_transform)
448
+
449
+ trainA_loader = data.DataLoader(trainA, batch_size=args.batch,
450
+ sampler=data_sampler(trainA, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5)
451
+ trainB_loader = data.DataLoader(trainB, batch_size=args.batch,
452
+ sampler=data_sampler(trainB, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5)
453
+
454
+ testA_loader = data.DataLoader(testA, batch_size=1, shuffle=False)
455
+ testB_loader = data.DataLoader(testB, batch_size=1, shuffle=False)
456
+
457
+
458
+ train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device)
util.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils import data
4
+ from torch import nn, autograd
5
+ import os
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ google_drive_paths = {
10
+ "GNR_checkpoint.pt": "https://drive.google.com/uc?id=1IMIVke4WDaGayUa7vk_xVw1uqIHikGtC",
11
+ "GNR_checkpoint_new.pt": "https://drive.google.com/uc?id=1PQ_SRLfFsXO_9z_OW5H9gKhhmIMn7H-p",
12
+ }
13
+
14
+ def ensure_checkpoint_exists(model_weights_filename):
15
+ if not os.path.isfile(model_weights_filename) and (
16
+ model_weights_filename in google_drive_paths
17
+ ):
18
+ gdrive_url = google_drive_paths[model_weights_filename]
19
+ try:
20
+ from gdown import download as drive_download
21
+
22
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
23
+ except ModuleNotFoundError:
24
+ print(
25
+ "gdown module not found.",
26
+ "pip3 install gdown or, manually download the checkpoint file:",
27
+ gdrive_url
28
+ )
29
+
30
+ if not os.path.isfile(model_weights_filename) and (
31
+ model_weights_filename not in google_drive_paths
32
+ ):
33
+ print(
34
+ model_weights_filename,
35
+ " not found, you may need to manually download the model weights."
36
+ )
37
+
38
+ def shuffle_batch(x):
39
+ return x[torch.randperm(x.size(0))]
40
+
41
+ def data_sampler(dataset, shuffle, distributed):
42
+ if distributed:
43
+ return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
44
+
45
+ if shuffle:
46
+ return data.RandomSampler(dataset)
47
+
48
+ else:
49
+ return data.SequentialSampler(dataset)
50
+
51
+
52
+ def accumulate(model1, model2, decay=0.999):
53
+ par1 = dict(model1.named_parameters())
54
+ par2 = dict(model2.named_parameters())
55
+
56
+ for k in par1.keys():
57
+ par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
58
+
59
+
60
+ def sample_data(loader):
61
+ while True:
62
+ for batch in loader:
63
+ yield batch
64
+
65
+
66
+ def d_logistic_loss(real_pred, fake_pred):
67
+ loss = 0
68
+ for real, fake in zip(real_pred, fake_pred):
69
+ real_loss = F.softplus(-real)
70
+ fake_loss = F.softplus(fake)
71
+ loss += real_loss.mean() + fake_loss.mean()
72
+
73
+ return loss
74
+
75
+
76
+ def d_r1_loss(real_pred, real_img):
77
+ grad_penalty = 0
78
+ for real in real_pred:
79
+ grad_real, = autograd.grad(
80
+ outputs=real.mean(), inputs=real_img, create_graph=True, only_inputs=True
81
+ )
82
+ grad_penalty += grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
83
+
84
+ return grad_penalty
85
+
86
+
87
+ def g_nonsaturating_loss(fake_pred, weights):
88
+ loss = 0
89
+ for fake, weight in zip(fake_pred, weights):
90
+ loss += weight*F.softplus(-fake).mean()
91
+
92
+ return loss / len(fake_pred)
93
+
94
+ def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
95
+ # image is [3,h,w] or [1,3,h,w] tensor [0,1]
96
+ if image.is_cuda:
97
+ image = image.cpu()
98
+ if size is not None and image.size(-1) != size:
99
+ image = F.interpolate(image, size=(size,size), mode=mode)
100
+ if image.dim() == 4:
101
+ image = image[0]
102
+ image = image.permute(1, 2, 0).detach().numpy()
103
+ plt.figure()
104
+ plt.title(title)
105
+ plt.axis('off')
106
+ plt.imshow(image)
107
+
108
+ def normalize(x):
109
+ return ((x+1)/2).clamp(0,1)
110
+
111
+ def get_boundingbox(face, width, height, scale=1.3, minsize=None):
112
+ """
113
+ Expects a dlib face to generate a quadratic bounding box.
114
+ :param face: dlib face class
115
+ :param width: frame width
116
+ :param height: frame height
117
+ :param scale: bounding box size multiplier to get a bigger face region
118
+ :param minsize: set minimum bounding box size
119
+ :return: x, y, bounding_box_size in opencv form
120
+ """
121
+ x1 = face.left()
122
+ y1 = face.top()
123
+ x2 = face.right()
124
+ y2 = face.bottom()
125
+ size_bb = int(max(x2 - x1, y2 - y1) * scale)
126
+ if minsize:
127
+ if size_bb < minsize:
128
+ size_bb = minsize
129
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
130
+
131
+ # Check for out of bounds, x-y top left corner
132
+ x1 = max(int(center_x - size_bb // 2), 0)
133
+ y1 = max(int(center_y - size_bb // 2), 0)
134
+ # Check for too big bb size for given x, y
135
+ size_bb = min(width - x1, size_bb)
136
+ size_bb = min(height - y1, size_bb)
137
+
138
+ return x1, y1, size_bb
139
+
140
+
141
+ def preprocess_image(image, cuda=True):
142
+ """
143
+ Preprocesses the image such that it can be fed into our network.
144
+ During this process we envoke PIL to cast it into a PIL image.
145
+ :param image: numpy image in opencv form (i.e., BGR and of shape
146
+ :return: pytorch tensor of shape [1, 3, image_size, image_size], not
147
+ necessarily casted to cuda
148
+ """
149
+ # Revert from BGR
150
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
151
+ # Preprocess using the preprocessing function used during training and
152
+ # casting it to PIL image
153
+ preprocess = xception_default_data_transforms['test']
154
+ preprocessed_image = preprocess(pil_image.fromarray(image))
155
+ # Add first dimension as the network expects a batch
156
+ preprocessed_image = preprocessed_image.unsqueeze(0)
157
+ if cuda:
158
+ preprocessed_image = preprocessed_image.cuda()
159
+ return preprocessed_image
160
+
161
+ def truncate(x, truncation, mean_style):
162
+ return truncation*x + (1-truncation)*mean_style