Diggz10 commited on
Commit
3cf3525
·
verified ·
1 Parent(s): 6348ece

Upload 6 files

Browse files
training/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
training/augment.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+ import scipy.signal
11
+ import torch
12
+ from torch_utils import persistence
13
+ from torch_utils import misc
14
+ from torch_utils.ops import upfirdn2d
15
+ from torch_utils.ops import grid_sample_gradfix
16
+ from torch_utils.ops import conv2d_gradfix
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Coefficients of various wavelet decomposition low-pass filters.
20
+
21
+ wavelets = {
22
+ 'haar': [0.7071067811865476, 0.7071067811865476],
23
+ 'db1': [0.7071067811865476, 0.7071067811865476],
24
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
27
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
28
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
29
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
30
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
31
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
32
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
33
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
34
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
35
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
36
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
37
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
38
+ }
39
+
40
+ #----------------------------------------------------------------------------
41
+ # Helpers for constructing transformation matrices.
42
+
43
+ def matrix(*rows, device=None):
44
+ assert all(len(row) == len(rows[0]) for row in rows)
45
+ elems = [x for row in rows for x in row]
46
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
47
+ if len(ref) == 0:
48
+ return misc.constant(np.asarray(rows), device=device)
49
+ assert device is None or device == ref[0].device
50
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
51
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
52
+
53
+ def translate2d(tx, ty, **kwargs):
54
+ return matrix(
55
+ [1, 0, tx],
56
+ [0, 1, ty],
57
+ [0, 0, 1],
58
+ **kwargs)
59
+
60
+ def translate3d(tx, ty, tz, **kwargs):
61
+ return matrix(
62
+ [1, 0, 0, tx],
63
+ [0, 1, 0, ty],
64
+ [0, 0, 1, tz],
65
+ [0, 0, 0, 1],
66
+ **kwargs)
67
+
68
+ def scale2d(sx, sy, **kwargs):
69
+ return matrix(
70
+ [sx, 0, 0],
71
+ [0, sy, 0],
72
+ [0, 0, 1],
73
+ **kwargs)
74
+
75
+ def scale3d(sx, sy, sz, **kwargs):
76
+ return matrix(
77
+ [sx, 0, 0, 0],
78
+ [0, sy, 0, 0],
79
+ [0, 0, sz, 0],
80
+ [0, 0, 0, 1],
81
+ **kwargs)
82
+
83
+ def rotate2d(theta, **kwargs):
84
+ return matrix(
85
+ [torch.cos(theta), torch.sin(-theta), 0],
86
+ [torch.sin(theta), torch.cos(theta), 0],
87
+ [0, 0, 1],
88
+ **kwargs)
89
+
90
+ def rotate3d(v, theta, **kwargs):
91
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
+ return matrix(
94
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
+ [0, 0, 0, 1],
98
+ **kwargs)
99
+
100
+ def translate2d_inv(tx, ty, **kwargs):
101
+ return translate2d(-tx, -ty, **kwargs)
102
+
103
+ def scale2d_inv(sx, sy, **kwargs):
104
+ return scale2d(1 / sx, 1 / sy, **kwargs)
105
+
106
+ def rotate2d_inv(theta, **kwargs):
107
+ return rotate2d(-theta, **kwargs)
108
+
109
+ #----------------------------------------------------------------------------
110
+ # Versatile image augmentation pipeline from the paper
111
+ # "Training Generative Adversarial Networks with Limited Data".
112
+ #
113
+ # All augmentations are disabled by default; individual augmentations can
114
+ # be enabled by setting their probability multipliers to 1.
115
+
116
+ @persistence.persistent_class
117
+ class AugmentPipe(torch.nn.Module):
118
+ def __init__(self,
119
+ xflip=0, rotate90=0, xint=0, xint_max=0.125,
120
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
121
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
122
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
123
+ noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
124
+ ):
125
+ super().__init__()
126
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
127
+
128
+ # Pixel blitting.
129
+ self.xflip = float(xflip) # Probability multiplier for x-flip.
130
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
131
+ self.xint = float(xint) # Probability multiplier for integer translation.
132
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
133
+
134
+ # General geometric transformations.
135
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
136
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
137
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
138
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
139
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
140
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
141
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
142
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
143
+
144
+ # Color transformations.
145
+ self.brightness = float(brightness) # Probability multiplier for brightness.
146
+ self.contrast = float(contrast) # Probability multiplier for contrast.
147
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
148
+ self.hue = float(hue) # Probability multiplier for hue rotation.
149
+ self.saturation = float(saturation) # Probability multiplier for saturation.
150
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
151
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
152
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
153
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
154
+
155
+ # Image-space filtering.
156
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
157
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
158
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
159
+
160
+ # Image-space corruptions.
161
+ self.noise = float(noise) # Probability multiplier for additive RGB noise.
162
+ self.cutout = float(cutout) # Probability multiplier for cutout.
163
+ self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
164
+ self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
165
+
166
+ # Setup orthogonal lowpass filter for geometric augmentations.
167
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
168
+
169
+ # Construct filter bank for image-space filtering.
170
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
171
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
172
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
173
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
174
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
175
+ for i in range(1, Hz_fbank.shape[0]):
176
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
177
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
178
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
179
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
180
+
181
+ def forward(self, images, debug_percentile=None):
182
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
183
+ batch_size, num_channels, height, width = images.shape
184
+ device = images.device
185
+ if debug_percentile is not None:
186
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
187
+
188
+ # -------------------------------------
189
+ # Select parameters for pixel blitting.
190
+ # -------------------------------------
191
+
192
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
193
+ I_3 = torch.eye(3, device=device)
194
+ G_inv = I_3
195
+
196
+ # Apply x-flip with probability (xflip * strength).
197
+ if self.xflip > 0:
198
+ i = torch.floor(torch.rand([batch_size], device=device) * 2)
199
+ i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
200
+ if debug_percentile is not None:
201
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
202
+ G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
203
+
204
+ # Apply 90 degree rotations with probability (rotate90 * strength).
205
+ if self.rotate90 > 0:
206
+ i = torch.floor(torch.rand([batch_size], device=device) * 4)
207
+ i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
208
+ if debug_percentile is not None:
209
+ i = torch.full_like(i, torch.floor(debug_percentile * 4))
210
+ G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
211
+
212
+ # Apply integer translation with probability (xint * strength).
213
+ if self.xint > 0:
214
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
215
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
216
+ if debug_percentile is not None:
217
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
218
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
219
+
220
+ # --------------------------------------------------------
221
+ # Select parameters for general geometric transformations.
222
+ # --------------------------------------------------------
223
+
224
+ # Apply isotropic scaling with probability (scale * strength).
225
+ if self.scale > 0:
226
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
227
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
228
+ if debug_percentile is not None:
229
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
230
+ G_inv = G_inv @ scale2d_inv(s, s)
231
+
232
+ # Apply pre-rotation with probability p_rot.
233
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
234
+ if self.rotate > 0:
235
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
236
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
237
+ if debug_percentile is not None:
238
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
239
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
240
+
241
+ # Apply anisotropic scaling with probability (aniso * strength).
242
+ if self.aniso > 0:
243
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
244
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
245
+ if debug_percentile is not None:
246
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
247
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
248
+
249
+ # Apply post-rotation with probability p_rot.
250
+ if self.rotate > 0:
251
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
252
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
253
+ if debug_percentile is not None:
254
+ theta = torch.zeros_like(theta)
255
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
256
+
257
+ # Apply fractional translation with probability (xfrac * strength).
258
+ if self.xfrac > 0:
259
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
260
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
261
+ if debug_percentile is not None:
262
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
263
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
264
+
265
+ # ----------------------------------
266
+ # Execute geometric transformations.
267
+ # ----------------------------------
268
+
269
+ # Execute if the transform is not identity.
270
+ if G_inv is not I_3:
271
+
272
+ # Calculate padding.
273
+ cx = (width - 1) / 2
274
+ cy = (height - 1) / 2
275
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
276
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
277
+ Hz_pad = self.Hz_geom.shape[0] // 4
278
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
279
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
280
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
281
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
282
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
283
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
284
+
285
+ # Pad image and adjust origin.
286
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
287
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
288
+
289
+ # Upsample.
290
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
291
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
292
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
293
+
294
+ # Execute transformation.
295
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
296
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
297
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
298
+ images = grid_sample_gradfix.grid_sample(images, grid)
299
+
300
+ # Downsample and crop.
301
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
302
+
303
+ # --------------------------------------------
304
+ # Select parameters for color transformations.
305
+ # --------------------------------------------
306
+
307
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
308
+ I_4 = torch.eye(4, device=device)
309
+ C = I_4
310
+
311
+ # Apply brightness with probability (brightness * strength).
312
+ if self.brightness > 0:
313
+ b = torch.randn([batch_size], device=device) * self.brightness_std
314
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
315
+ if debug_percentile is not None:
316
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
317
+ C = translate3d(b, b, b) @ C
318
+
319
+ # Apply contrast with probability (contrast * strength).
320
+ if self.contrast > 0:
321
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
322
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
323
+ if debug_percentile is not None:
324
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
325
+ C = scale3d(c, c, c) @ C
326
+
327
+ # Apply luma flip with probability (lumaflip * strength).
328
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
329
+ if self.lumaflip > 0:
330
+ i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
331
+ i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
332
+ if debug_percentile is not None:
333
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
334
+ C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
335
+
336
+ # Apply hue rotation with probability (hue * strength).
337
+ if self.hue > 0 and num_channels > 1:
338
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
339
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
340
+ if debug_percentile is not None:
341
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
342
+ C = rotate3d(v, theta) @ C # Rotate around v.
343
+
344
+ # Apply saturation with probability (saturation * strength).
345
+ if self.saturation > 0 and num_channels > 1:
346
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
347
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
348
+ if debug_percentile is not None:
349
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
350
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
351
+
352
+ # ------------------------------
353
+ # Execute color transformations.
354
+ # ------------------------------
355
+
356
+ # Execute if the transform is not identity.
357
+ if C is not I_4:
358
+ images = images.reshape([batch_size, num_channels, height * width])
359
+ if num_channels == 3:
360
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
361
+ elif num_channels == 1:
362
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
363
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
364
+ else:
365
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
366
+ images = images.reshape([batch_size, num_channels, height, width])
367
+
368
+ # ----------------------
369
+ # Image-space filtering.
370
+ # ----------------------
371
+
372
+ if self.imgfilter > 0:
373
+ num_bands = self.Hz_fbank.shape[0]
374
+ assert len(self.imgfilter_bands) == num_bands
375
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
376
+
377
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
378
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
379
+ for i, band_strength in enumerate(self.imgfilter_bands):
380
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
381
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
382
+ if debug_percentile is not None:
383
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
384
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
385
+ t[:, i] = t_i # Replace i'th element.
386
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
387
+ g = g * t # Accumulate into global gain.
388
+
389
+ # Construct combined amplification filter.
390
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
391
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
392
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
393
+
394
+ # Apply filter.
395
+ p = self.Hz_fbank.shape[1] // 2
396
+ images = images.reshape([1, batch_size * num_channels, height, width])
397
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
398
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
399
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
400
+ images = images.reshape([batch_size, num_channels, height, width])
401
+
402
+ # ------------------------
403
+ # Image-space corruptions.
404
+ # ------------------------
405
+
406
+ # Apply additive RGB noise with probability (noise * strength).
407
+ if self.noise > 0:
408
+ sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
409
+ sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
410
+ if debug_percentile is not None:
411
+ sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
412
+ images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
413
+
414
+ # Apply cutout with probability (cutout * strength).
415
+ if self.cutout > 0:
416
+ size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
417
+ size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
418
+ center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
419
+ if debug_percentile is not None:
420
+ size = torch.full_like(size, self.cutout_size)
421
+ center = torch.full_like(center, debug_percentile)
422
+ coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
423
+ coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
424
+ mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
425
+ mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
426
+ mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
427
+ images = images * mask
428
+
429
+ return images
430
+
431
+ #----------------------------------------------------------------------------
training/dataset.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import zipfile
12
+ import PIL.Image
13
+ import json
14
+ import torch
15
+ import dnnlib
16
+
17
+ try:
18
+ import pyspng
19
+ except ImportError:
20
+ pyspng = None
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ class Dataset(torch.utils.data.Dataset):
25
+ def __init__(self,
26
+ name, # Name of the dataset.
27
+ raw_shape, # Shape of the raw image data (NCHW).
28
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
29
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
30
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
31
+ random_seed = 0, # Random seed to use when applying max_size.
32
+ ):
33
+ self._name = name
34
+ self._raw_shape = list(raw_shape)
35
+ self._use_labels = use_labels
36
+ self._raw_labels = None
37
+ self._label_shape = None
38
+
39
+ # Apply max_size.
40
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
41
+ if (max_size is not None) and (self._raw_idx.size > max_size):
42
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
43
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
44
+
45
+ # Apply xflip.
46
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
47
+ if xflip:
48
+ self._raw_idx = np.tile(self._raw_idx, 2)
49
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
50
+
51
+ def _get_raw_labels(self):
52
+ if self._raw_labels is None:
53
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
54
+ if self._raw_labels is None:
55
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
56
+ assert isinstance(self._raw_labels, np.ndarray)
57
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
58
+ assert self._raw_labels.dtype in [np.float32, np.int64]
59
+ if self._raw_labels.dtype == np.int64:
60
+ assert self._raw_labels.ndim == 1
61
+ assert np.all(self._raw_labels >= 0)
62
+ return self._raw_labels
63
+
64
+ def close(self): # to be overridden by subclass
65
+ pass
66
+
67
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
68
+ raise NotImplementedError
69
+
70
+ def _load_raw_labels(self): # to be overridden by subclass
71
+ raise NotImplementedError
72
+
73
+ def __getstate__(self):
74
+ return dict(self.__dict__, _raw_labels=None)
75
+
76
+ def __del__(self):
77
+ try:
78
+ self.close()
79
+ except:
80
+ pass
81
+
82
+ def __len__(self):
83
+ return self._raw_idx.size
84
+
85
+ def __getitem__(self, idx):
86
+ image = self._load_raw_image(self._raw_idx[idx])
87
+ assert isinstance(image, np.ndarray)
88
+ assert list(image.shape) == self.image_shape
89
+ assert image.dtype == np.uint8
90
+ if self._xflip[idx]:
91
+ assert image.ndim == 3 # CHW
92
+ image = image[:, :, ::-1]
93
+ return image.copy(), self.get_label(idx)
94
+
95
+ def get_label(self, idx):
96
+ label = self._get_raw_labels()[self._raw_idx[idx]]
97
+ if label.dtype == np.int64:
98
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
99
+ onehot[label] = 1
100
+ label = onehot
101
+ return label.copy()
102
+
103
+ def get_details(self, idx):
104
+ d = dnnlib.EasyDict()
105
+ d.raw_idx = int(self._raw_idx[idx])
106
+ d.xflip = (int(self._xflip[idx]) != 0)
107
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
108
+ return d
109
+
110
+ @property
111
+ def name(self):
112
+ return self._name
113
+
114
+ @property
115
+ def image_shape(self):
116
+ return list(self._raw_shape[1:])
117
+
118
+ @property
119
+ def num_channels(self):
120
+ assert len(self.image_shape) == 3 # CHW
121
+ return self.image_shape[0]
122
+
123
+ @property
124
+ def resolution(self):
125
+ assert len(self.image_shape) == 3 # CHW
126
+ assert self.image_shape[1] == self.image_shape[2]
127
+ return self.image_shape[1]
128
+
129
+ @property
130
+ def label_shape(self):
131
+ if self._label_shape is None:
132
+ raw_labels = self._get_raw_labels()
133
+ if raw_labels.dtype == np.int64:
134
+ self._label_shape = [int(np.max(raw_labels)) + 1]
135
+ else:
136
+ self._label_shape = raw_labels.shape[1:]
137
+ return list(self._label_shape)
138
+
139
+ @property
140
+ def label_dim(self):
141
+ assert len(self.label_shape) == 1
142
+ return self.label_shape[0]
143
+
144
+ @property
145
+ def has_labels(self):
146
+ return any(x != 0 for x in self.label_shape)
147
+
148
+ @property
149
+ def has_onehot_labels(self):
150
+ return self._get_raw_labels().dtype == np.int64
151
+
152
+ #----------------------------------------------------------------------------
153
+
154
+ class ImageFolderDataset(Dataset):
155
+ def __init__(self,
156
+ path, # Path to directory or zip.
157
+ resolution = None, # Ensure specific resolution, None = highest available.
158
+ **super_kwargs, # Additional arguments for the Dataset base class.
159
+ ):
160
+ self._path = path
161
+ self._zipfile = None
162
+
163
+ if os.path.isdir(self._path):
164
+ self._type = 'dir'
165
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
166
+ elif self._file_ext(self._path) == '.zip':
167
+ self._type = 'zip'
168
+ self._all_fnames = set(self._get_zipfile().namelist())
169
+ else:
170
+ raise IOError('Path must point to a directory or zip')
171
+
172
+ PIL.Image.init()
173
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
174
+ if len(self._image_fnames) == 0:
175
+ raise IOError('No image files found in the specified path')
176
+
177
+ name = os.path.splitext(os.path.basename(self._path))[0]
178
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
179
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
180
+ raise IOError('Image files do not match the specified resolution')
181
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
182
+
183
+ @staticmethod
184
+ def _file_ext(fname):
185
+ return os.path.splitext(fname)[1].lower()
186
+
187
+ def _get_zipfile(self):
188
+ assert self._type == 'zip'
189
+ if self._zipfile is None:
190
+ self._zipfile = zipfile.ZipFile(self._path)
191
+ return self._zipfile
192
+
193
+ def _open_file(self, fname):
194
+ if self._type == 'dir':
195
+ return open(os.path.join(self._path, fname), 'rb')
196
+ if self._type == 'zip':
197
+ return self._get_zipfile().open(fname, 'r')
198
+ return None
199
+
200
+ def close(self):
201
+ try:
202
+ if self._zipfile is not None:
203
+ self._zipfile.close()
204
+ finally:
205
+ self._zipfile = None
206
+
207
+ def __getstate__(self):
208
+ return dict(super().__getstate__(), _zipfile=None)
209
+
210
+ def _load_raw_image(self, raw_idx):
211
+ fname = self._image_fnames[raw_idx]
212
+ with self._open_file(fname) as f:
213
+ if pyspng is not None and self._file_ext(fname) == '.png':
214
+ image = pyspng.load(f.read())
215
+ else:
216
+ image = np.array(PIL.Image.open(f))
217
+ if image.ndim == 2:
218
+ image = image[:, :, np.newaxis] # HW => HWC
219
+ image = image.transpose(2, 0, 1) # HWC => CHW
220
+ return image
221
+
222
+ def _load_raw_labels(self):
223
+ fname = 'dataset.json'
224
+ if fname not in self._all_fnames:
225
+ return None
226
+ with self._open_file(fname) as f:
227
+ labels = json.load(f)['labels']
228
+ if labels is None:
229
+ return None
230
+ labels = dict(labels)
231
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
232
+ labels = np.array(labels)
233
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
234
+ return labels
235
+
236
+ #----------------------------------------------------------------------------
training/loss.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch_utils import training_stats
12
+ from torch_utils import misc
13
+ from torch_utils.ops import conv2d_gradfix
14
+
15
+ #----------------------------------------------------------------------------
16
+
17
+ class Loss:
18
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
19
+ raise NotImplementedError()
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ class StyleGAN2Loss(Loss):
24
+ def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
25
+ super().__init__()
26
+ self.device = device
27
+ self.G_mapping = G_mapping
28
+ self.G_synthesis = G_synthesis
29
+ self.D = D
30
+ self.augment_pipe = augment_pipe
31
+ self.style_mixing_prob = style_mixing_prob
32
+ self.r1_gamma = r1_gamma
33
+ self.pl_batch_shrink = pl_batch_shrink
34
+ self.pl_decay = pl_decay
35
+ self.pl_weight = pl_weight
36
+ self.pl_mean = torch.zeros([], device=device)
37
+
38
+ def run_G(self, z, c, sync):
39
+ with misc.ddp_sync(self.G_mapping, sync):
40
+ ws = self.G_mapping(z, c)
41
+ if self.style_mixing_prob > 0:
42
+ with torch.autograd.profiler.record_function('style_mixing'):
43
+ cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
44
+ cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
45
+ ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
46
+ with misc.ddp_sync(self.G_synthesis, sync):
47
+ img = self.G_synthesis(ws)
48
+ return img, ws
49
+
50
+ def run_D(self, img, c, sync):
51
+ if self.augment_pipe is not None:
52
+ img = self.augment_pipe(img)
53
+ with misc.ddp_sync(self.D, sync):
54
+ logits = self.D(img, c)
55
+ return logits
56
+
57
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
58
+ assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
59
+ do_Gmain = (phase in ['Gmain', 'Gboth'])
60
+ do_Dmain = (phase in ['Dmain', 'Dboth'])
61
+ do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
62
+ do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
63
+
64
+ # Gmain: Maximize logits for generated images.
65
+ if do_Gmain:
66
+ with torch.autograd.profiler.record_function('Gmain_forward'):
67
+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
68
+ gen_logits = self.run_D(gen_img, gen_c, sync=False)
69
+ training_stats.report('Loss/scores/fake', gen_logits)
70
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
71
+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
72
+ training_stats.report('Loss/G/loss', loss_Gmain)
73
+ with torch.autograd.profiler.record_function('Gmain_backward'):
74
+ loss_Gmain.mean().mul(gain).backward()
75
+
76
+ # Gpl: Apply path length regularization.
77
+ if do_Gpl:
78
+ with torch.autograd.profiler.record_function('Gpl_forward'):
79
+ batch_size = gen_z.shape[0] // self.pl_batch_shrink
80
+ gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
81
+ pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
82
+ with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
83
+ pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
84
+ pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
85
+ pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
86
+ self.pl_mean.copy_(pl_mean.detach())
87
+ pl_penalty = (pl_lengths - pl_mean).square()
88
+ training_stats.report('Loss/pl_penalty', pl_penalty)
89
+ loss_Gpl = pl_penalty * self.pl_weight
90
+ training_stats.report('Loss/G/reg', loss_Gpl)
91
+ with torch.autograd.profiler.record_function('Gpl_backward'):
92
+ (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
93
+
94
+ # Dmain: Minimize logits for generated images.
95
+ loss_Dgen = 0
96
+ if do_Dmain:
97
+ with torch.autograd.profiler.record_function('Dgen_forward'):
98
+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
99
+ gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
100
+ training_stats.report('Loss/scores/fake', gen_logits)
101
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
102
+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
103
+ with torch.autograd.profiler.record_function('Dgen_backward'):
104
+ loss_Dgen.mean().mul(gain).backward()
105
+
106
+ # Dmain: Maximize logits for real images.
107
+ # Dr1: Apply R1 regularization.
108
+ if do_Dmain or do_Dr1:
109
+ name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
110
+ with torch.autograd.profiler.record_function(name + '_forward'):
111
+ real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
112
+ real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
113
+ training_stats.report('Loss/scores/real', real_logits)
114
+ training_stats.report('Loss/signs/real', real_logits.sign())
115
+
116
+ loss_Dreal = 0
117
+ if do_Dmain:
118
+ loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
119
+ training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
120
+
121
+ loss_Dr1 = 0
122
+ if do_Dr1:
123
+ with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
124
+ r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
125
+ r1_penalty = r1_grads.square().sum([1,2,3])
126
+ loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
127
+ training_stats.report('Loss/r1_penalty', r1_penalty)
128
+ training_stats.report('Loss/D/reg', loss_Dr1)
129
+
130
+ with torch.autograd.profiler.record_function(name + '_backward'):
131
+ (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
132
+
133
+ #----------------------------------------------------------------------------
training/networks.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch_utils import misc
12
+ from torch_utils import persistence
13
+ from torch_utils.ops import conv2d_resample
14
+ from torch_utils.ops import upfirdn2d
15
+ from torch_utils.ops import bias_act
16
+ from torch_utils.ops import fma
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ @misc.profiled_function
21
+ def normalize_2nd_moment(x, dim=1, eps=1e-8):
22
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ @misc.profiled_function
27
+ def modulated_conv2d(
28
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
29
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
30
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
31
+ noise = None, # Optional noise tensor to add to the output activations.
32
+ up = 1, # Integer upsampling factor.
33
+ down = 1, # Integer downsampling factor.
34
+ padding = 0, # Padding with respect to the upsampled image.
35
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
36
+ demodulate = True, # Apply weight demodulation?
37
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
38
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
39
+ ):
40
+ batch_size = x.shape[0]
41
+ out_channels, in_channels, kh, kw = weight.shape
42
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
43
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
44
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
45
+
46
+ # Pre-normalize inputs to avoid FP16 overflow.
47
+ if x.dtype == torch.float16 and demodulate:
48
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
49
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
50
+
51
+ # Calculate per-sample weights and demodulation coefficients.
52
+ w = None
53
+ dcoefs = None
54
+ if demodulate or fused_modconv:
55
+ w = weight.unsqueeze(0) # [NOIkk]
56
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
57
+ if demodulate:
58
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
59
+ if demodulate and fused_modconv:
60
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
61
+
62
+ # Execute by scaling the activations before and after the convolution.
63
+ if not fused_modconv:
64
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
65
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
66
+ if demodulate and noise is not None:
67
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
68
+ elif demodulate:
69
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
70
+ elif noise is not None:
71
+ x = x.add_(noise.to(x.dtype))
72
+ return x
73
+
74
+ # Execute as one fused op using grouped convolution.
75
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
76
+ batch_size = int(batch_size)
77
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
78
+ x = x.reshape(1, -1, *x.shape[2:])
79
+ w = w.reshape(-1, in_channels, kh, kw)
80
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
81
+ x = x.reshape(batch_size, -1, *x.shape[2:])
82
+ if noise is not None:
83
+ x = x.add_(noise)
84
+ return x
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ @persistence.persistent_class
89
+ class FullyConnectedLayer(torch.nn.Module):
90
+ def __init__(self,
91
+ in_features, # Number of input features.
92
+ out_features, # Number of output features.
93
+ bias = True, # Apply additive bias before the activation function?
94
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
95
+ lr_multiplier = 1, # Learning rate multiplier.
96
+ bias_init = 0, # Initial value for the additive bias.
97
+ ):
98
+ super().__init__()
99
+ self.activation = activation
100
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
101
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
102
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
103
+ self.bias_gain = lr_multiplier
104
+
105
+ def forward(self, x):
106
+ w = self.weight.to(x.dtype) * self.weight_gain
107
+ b = self.bias
108
+ if b is not None:
109
+ b = b.to(x.dtype)
110
+ if self.bias_gain != 1:
111
+ b = b * self.bias_gain
112
+
113
+ if self.activation == 'linear' and b is not None:
114
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
115
+ else:
116
+ x = x.matmul(w.t())
117
+ x = bias_act.bias_act(x, b, act=self.activation)
118
+ return x
119
+
120
+ #----------------------------------------------------------------------------
121
+
122
+ @persistence.persistent_class
123
+ class Conv2dLayer(torch.nn.Module):
124
+ def __init__(self,
125
+ in_channels, # Number of input channels.
126
+ out_channels, # Number of output channels.
127
+ kernel_size, # Width and height of the convolution kernel.
128
+ bias = True, # Apply additive bias before the activation function?
129
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
130
+ up = 1, # Integer upsampling factor.
131
+ down = 1, # Integer downsampling factor.
132
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
133
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
134
+ channels_last = False, # Expect the input to have memory_format=channels_last?
135
+ trainable = True, # Update the weights of this layer during training?
136
+ ):
137
+ super().__init__()
138
+ self.activation = activation
139
+ self.up = up
140
+ self.down = down
141
+ self.conv_clamp = conv_clamp
142
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
143
+ self.padding = kernel_size // 2
144
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
145
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
146
+
147
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
148
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
149
+ bias = torch.zeros([out_channels]) if bias else None
150
+ if trainable:
151
+ self.weight = torch.nn.Parameter(weight)
152
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
153
+ else:
154
+ self.register_buffer('weight', weight)
155
+ if bias is not None:
156
+ self.register_buffer('bias', bias)
157
+ else:
158
+ self.bias = None
159
+
160
+ def forward(self, x, gain=1):
161
+ w = self.weight * self.weight_gain
162
+ b = self.bias.to(x.dtype) if self.bias is not None else None
163
+ flip_weight = (self.up == 1) # slightly faster
164
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
165
+
166
+ act_gain = self.act_gain * gain
167
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
168
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
169
+ return x
170
+
171
+ #----------------------------------------------------------------------------
172
+
173
+ @persistence.persistent_class
174
+ class MappingNetwork(torch.nn.Module):
175
+ def __init__(self,
176
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
177
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
178
+ w_dim, # Intermediate latent (W) dimensionality.
179
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
180
+ num_layers = 8, # Number of mapping layers.
181
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
182
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
183
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
184
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
185
+ w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
186
+ ):
187
+ super().__init__()
188
+ self.z_dim = z_dim
189
+ self.c_dim = c_dim
190
+ self.w_dim = w_dim
191
+ self.num_ws = num_ws
192
+ self.num_layers = num_layers
193
+ self.w_avg_beta = w_avg_beta
194
+
195
+ if embed_features is None:
196
+ embed_features = w_dim
197
+ if c_dim == 0:
198
+ embed_features = 0
199
+ if layer_features is None:
200
+ layer_features = w_dim
201
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
202
+
203
+ if c_dim > 0:
204
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
205
+ for idx in range(num_layers):
206
+ in_features = features_list[idx]
207
+ out_features = features_list[idx + 1]
208
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
209
+ setattr(self, f'fc{idx}', layer)
210
+
211
+ if num_ws is not None and w_avg_beta is not None:
212
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
213
+
214
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
215
+ # Embed, normalize, and concat inputs.
216
+ x = None
217
+ with torch.autograd.profiler.record_function('input'):
218
+ if self.z_dim > 0:
219
+ misc.assert_shape(z, [None, self.z_dim])
220
+ x = normalize_2nd_moment(z.to(torch.float32))
221
+ if self.c_dim > 0:
222
+ misc.assert_shape(c, [None, self.c_dim])
223
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
224
+ x = torch.cat([x, y], dim=1) if x is not None else y
225
+
226
+ # Main layers.
227
+ for idx in range(self.num_layers):
228
+ layer = getattr(self, f'fc{idx}')
229
+ x = layer(x)
230
+
231
+ # Update moving average of W.
232
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
233
+ with torch.autograd.profiler.record_function('update_w_avg'):
234
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
235
+
236
+ # Broadcast.
237
+ if self.num_ws is not None:
238
+ with torch.autograd.profiler.record_function('broadcast'):
239
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
240
+
241
+ # Apply truncation.
242
+ if truncation_psi != 1:
243
+ with torch.autograd.profiler.record_function('truncate'):
244
+ assert self.w_avg_beta is not None
245
+ if self.num_ws is None or truncation_cutoff is None:
246
+ x = self.w_avg.lerp(x, truncation_psi)
247
+ else:
248
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
249
+ return x
250
+
251
+ #----------------------------------------------------------------------------
252
+
253
+ @persistence.persistent_class
254
+ class SynthesisLayer(torch.nn.Module):
255
+ def __init__(self,
256
+ in_channels, # Number of input channels.
257
+ out_channels, # Number of output channels.
258
+ w_dim, # Intermediate latent (W) dimensionality.
259
+ resolution, # Resolution of this layer.
260
+ kernel_size = 3, # Convolution kernel size.
261
+ up = 1, # Integer upsampling factor.
262
+ use_noise = True, # Enable noise input?
263
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
264
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
265
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
266
+ channels_last = False, # Use channels_last format for the weights?
267
+ ):
268
+ super().__init__()
269
+ self.resolution = resolution
270
+ self.up = up
271
+ self.use_noise = use_noise
272
+ self.activation = activation
273
+ self.conv_clamp = conv_clamp
274
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
275
+ self.padding = kernel_size // 2
276
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
277
+
278
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
279
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
280
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
281
+ if use_noise:
282
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
283
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
284
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
285
+
286
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
287
+ assert noise_mode in ['random', 'const', 'none']
288
+ in_resolution = self.resolution // self.up
289
+ misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
290
+ styles = self.affine(w)
291
+
292
+ noise = None
293
+ if self.use_noise and noise_mode == 'random':
294
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
295
+ if self.use_noise and noise_mode == 'const':
296
+ noise = self.noise_const * self.noise_strength
297
+
298
+ flip_weight = (self.up == 1) # slightly faster
299
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
300
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
301
+
302
+ act_gain = self.act_gain * gain
303
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
304
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
305
+ return x
306
+
307
+ #----------------------------------------------------------------------------
308
+
309
+ @persistence.persistent_class
310
+ class ToRGBLayer(torch.nn.Module):
311
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
312
+ super().__init__()
313
+ self.conv_clamp = conv_clamp
314
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
315
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
316
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
317
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
318
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
319
+
320
+ def forward(self, x, w, fused_modconv=True):
321
+ styles = self.affine(w) * self.weight_gain
322
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
323
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
324
+ return x
325
+
326
+ #----------------------------------------------------------------------------
327
+
328
+ @persistence.persistent_class
329
+ class SynthesisBlock(torch.nn.Module):
330
+ def __init__(self,
331
+ in_channels, # Number of input channels, 0 = first block.
332
+ out_channels, # Number of output channels.
333
+ w_dim, # Intermediate latent (W) dimensionality.
334
+ resolution, # Resolution of this block.
335
+ img_channels, # Number of output color channels.
336
+ is_last, # Is this the last block?
337
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
338
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
339
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
340
+ use_fp16 = False, # Use FP16 for this block?
341
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
342
+ **layer_kwargs, # Arguments for SynthesisLayer.
343
+ ):
344
+ assert architecture in ['orig', 'skip', 'resnet']
345
+ super().__init__()
346
+ self.in_channels = in_channels
347
+ self.w_dim = w_dim
348
+ self.resolution = resolution
349
+ self.img_channels = img_channels
350
+ self.is_last = is_last
351
+ self.architecture = architecture
352
+ self.use_fp16 = use_fp16
353
+ self.channels_last = (use_fp16 and fp16_channels_last)
354
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
355
+ self.num_conv = 0
356
+ self.num_torgb = 0
357
+
358
+ if in_channels == 0:
359
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
360
+
361
+ if in_channels != 0:
362
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
363
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
364
+ self.num_conv += 1
365
+
366
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
367
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
368
+ self.num_conv += 1
369
+
370
+ if is_last or architecture == 'skip':
371
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
372
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
373
+ self.num_torgb += 1
374
+
375
+ if in_channels != 0 and architecture == 'resnet':
376
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
377
+ resample_filter=resample_filter, channels_last=self.channels_last)
378
+
379
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
380
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
381
+ w_iter = iter(ws.unbind(dim=1))
382
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
383
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
384
+ if fused_modconv is None:
385
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
386
+ fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
387
+
388
+ # Input.
389
+ if self.in_channels == 0:
390
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
391
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
392
+ else:
393
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
394
+ x = x.to(dtype=dtype, memory_format=memory_format)
395
+
396
+ # Main layers.
397
+ if self.in_channels == 0:
398
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
399
+ elif self.architecture == 'resnet':
400
+ y = self.skip(x, gain=np.sqrt(0.5))
401
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
402
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
403
+ x = y.add_(x)
404
+ else:
405
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
406
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
407
+
408
+ # ToRGB.
409
+ if img is not None:
410
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
411
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
412
+ if self.is_last or self.architecture == 'skip':
413
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
414
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
415
+ img = img.add_(y) if img is not None else y
416
+
417
+ assert x.dtype == dtype
418
+ assert img is None or img.dtype == torch.float32
419
+ return x, img
420
+
421
+ #----------------------------------------------------------------------------
422
+
423
+ @persistence.persistent_class
424
+ class SynthesisNetwork(torch.nn.Module):
425
+ def __init__(self,
426
+ w_dim, # Intermediate latent (W) dimensionality.
427
+ img_resolution, # Output image resolution.
428
+ img_channels, # Number of color channels.
429
+ channel_base = 32768, # Overall multiplier for the number of channels.
430
+ channel_max = 512, # Maximum number of channels in any layer.
431
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
432
+ **block_kwargs, # Arguments for SynthesisBlock.
433
+ ):
434
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
435
+ super().__init__()
436
+ self.w_dim = w_dim
437
+ self.img_resolution = img_resolution
438
+ self.img_resolution_log2 = int(np.log2(img_resolution))
439
+ self.img_channels = img_channels
440
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
441
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
442
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
443
+
444
+ self.num_ws = 0
445
+ for res in self.block_resolutions:
446
+ in_channels = channels_dict[res // 2] if res > 4 else 0
447
+ out_channels = channels_dict[res]
448
+ use_fp16 = (res >= fp16_resolution)
449
+ is_last = (res == self.img_resolution)
450
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
451
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
452
+ self.num_ws += block.num_conv
453
+ if is_last:
454
+ self.num_ws += block.num_torgb
455
+ setattr(self, f'b{res}', block)
456
+
457
+ def forward(self, ws, **block_kwargs):
458
+ block_ws = []
459
+ with torch.autograd.profiler.record_function('split_ws'):
460
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
461
+ ws = ws.to(torch.float32)
462
+ w_idx = 0
463
+ for res in self.block_resolutions:
464
+ block = getattr(self, f'b{res}')
465
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
466
+ w_idx += block.num_conv
467
+
468
+ x = img = None
469
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
470
+ block = getattr(self, f'b{res}')
471
+ x, img = block(x, img, cur_ws, **block_kwargs)
472
+ return img
473
+
474
+ #----------------------------------------------------------------------------
475
+
476
+ @persistence.persistent_class
477
+ class Generator(torch.nn.Module):
478
+ def __init__(self,
479
+ z_dim, # Input latent (Z) dimensionality.
480
+ c_dim, # Conditioning label (C) dimensionality.
481
+ w_dim, # Intermediate latent (W) dimensionality.
482
+ img_resolution, # Output resolution.
483
+ img_channels, # Number of output color channels.
484
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
485
+ synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
486
+ ):
487
+ super().__init__()
488
+ self.z_dim = z_dim
489
+ self.c_dim = c_dim
490
+ self.w_dim = w_dim
491
+ self.img_resolution = img_resolution
492
+ self.img_channels = img_channels
493
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
494
+ self.num_ws = self.synthesis.num_ws
495
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
496
+
497
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
498
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
499
+ img = self.synthesis(ws, **synthesis_kwargs)
500
+ return img
501
+
502
+ #----------------------------------------------------------------------------
503
+
504
+ @persistence.persistent_class
505
+ class DiscriminatorBlock(torch.nn.Module):
506
+ def __init__(self,
507
+ in_channels, # Number of input channels, 0 = first block.
508
+ tmp_channels, # Number of intermediate channels.
509
+ out_channels, # Number of output channels.
510
+ resolution, # Resolution of this block.
511
+ img_channels, # Number of input color channels.
512
+ first_layer_idx, # Index of the first layer.
513
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
514
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
515
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
516
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
517
+ use_fp16 = False, # Use FP16 for this block?
518
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
519
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
520
+ ):
521
+ assert in_channels in [0, tmp_channels]
522
+ assert architecture in ['orig', 'skip', 'resnet']
523
+ super().__init__()
524
+ self.in_channels = in_channels
525
+ self.resolution = resolution
526
+ self.img_channels = img_channels
527
+ self.first_layer_idx = first_layer_idx
528
+ self.architecture = architecture
529
+ self.use_fp16 = use_fp16
530
+ self.channels_last = (use_fp16 and fp16_channels_last)
531
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
532
+
533
+ self.num_layers = 0
534
+ def trainable_gen():
535
+ while True:
536
+ layer_idx = self.first_layer_idx + self.num_layers
537
+ trainable = (layer_idx >= freeze_layers)
538
+ self.num_layers += 1
539
+ yield trainable
540
+ trainable_iter = trainable_gen()
541
+
542
+ if in_channels == 0 or architecture == 'skip':
543
+ self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
544
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
545
+
546
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
547
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
548
+
549
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
550
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
551
+
552
+ if architecture == 'resnet':
553
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
554
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
555
+
556
+ def forward(self, x, img, force_fp32=False):
557
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
558
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
559
+
560
+ # Input.
561
+ if x is not None:
562
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
563
+ x = x.to(dtype=dtype, memory_format=memory_format)
564
+
565
+ # FromRGB.
566
+ if self.in_channels == 0 or self.architecture == 'skip':
567
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
568
+ img = img.to(dtype=dtype, memory_format=memory_format)
569
+ y = self.fromrgb(img)
570
+ x = x + y if x is not None else y
571
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
572
+
573
+ # Main layers.
574
+ if self.architecture == 'resnet':
575
+ y = self.skip(x, gain=np.sqrt(0.5))
576
+ x = self.conv0(x)
577
+ x = self.conv1(x, gain=np.sqrt(0.5))
578
+ x = y.add_(x)
579
+ else:
580
+ x = self.conv0(x)
581
+ x = self.conv1(x)
582
+
583
+ assert x.dtype == dtype
584
+ return x, img
585
+
586
+ #----------------------------------------------------------------------------
587
+
588
+ @persistence.persistent_class
589
+ class MinibatchStdLayer(torch.nn.Module):
590
+ def __init__(self, group_size, num_channels=1):
591
+ super().__init__()
592
+ self.group_size = group_size
593
+ self.num_channels = num_channels
594
+
595
+ def forward(self, x):
596
+ N, C, H, W = x.shape
597
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
598
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
599
+ F = self.num_channels
600
+ c = C // F
601
+
602
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
603
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
604
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
605
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
606
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
607
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
608
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
609
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
610
+ return x
611
+
612
+ #----------------------------------------------------------------------------
613
+
614
+ @persistence.persistent_class
615
+ class DiscriminatorEpilogue(torch.nn.Module):
616
+ def __init__(self,
617
+ in_channels, # Number of input channels.
618
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
619
+ resolution, # Resolution of this block.
620
+ img_channels, # Number of input color channels.
621
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
622
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
623
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
624
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
625
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
626
+ ):
627
+ assert architecture in ['orig', 'skip', 'resnet']
628
+ super().__init__()
629
+ self.in_channels = in_channels
630
+ self.cmap_dim = cmap_dim
631
+ self.resolution = resolution
632
+ self.img_channels = img_channels
633
+ self.architecture = architecture
634
+
635
+ if architecture == 'skip':
636
+ self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
637
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
638
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
639
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
640
+ self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
641
+
642
+ def forward(self, x, img, cmap, force_fp32=False):
643
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
644
+ _ = force_fp32 # unused
645
+ dtype = torch.float32
646
+ memory_format = torch.contiguous_format
647
+
648
+ # FromRGB.
649
+ x = x.to(dtype=dtype, memory_format=memory_format)
650
+ if self.architecture == 'skip':
651
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
652
+ img = img.to(dtype=dtype, memory_format=memory_format)
653
+ x = x + self.fromrgb(img)
654
+
655
+ # Main layers.
656
+ if self.mbstd is not None:
657
+ x = self.mbstd(x)
658
+ x = self.conv(x)
659
+ x = self.fc(x.flatten(1))
660
+ x = self.out(x)
661
+
662
+ # Conditioning.
663
+ if self.cmap_dim > 0:
664
+ misc.assert_shape(cmap, [None, self.cmap_dim])
665
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
666
+
667
+ assert x.dtype == dtype
668
+ return x
669
+
670
+ #----------------------------------------------------------------------------
671
+
672
+ @persistence.persistent_class
673
+ class Discriminator(torch.nn.Module):
674
+ def __init__(self,
675
+ c_dim, # Conditioning label (C) dimensionality.
676
+ img_resolution, # Input resolution.
677
+ img_channels, # Number of input color channels.
678
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
679
+ channel_base = 32768, # Overall multiplier for the number of channels.
680
+ channel_max = 512, # Maximum number of channels in any layer.
681
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
682
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
683
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
684
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
685
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
686
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
687
+ ):
688
+ super().__init__()
689
+ self.c_dim = c_dim
690
+ self.img_resolution = img_resolution
691
+ self.img_resolution_log2 = int(np.log2(img_resolution))
692
+ self.img_channels = img_channels
693
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
694
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
695
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
696
+
697
+ if cmap_dim is None:
698
+ cmap_dim = channels_dict[4]
699
+ if c_dim == 0:
700
+ cmap_dim = 0
701
+
702
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
703
+ cur_layer_idx = 0
704
+ for res in self.block_resolutions:
705
+ in_channels = channels_dict[res] if res < img_resolution else 0
706
+ tmp_channels = channels_dict[res]
707
+ out_channels = channels_dict[res // 2]
708
+ use_fp16 = (res >= fp16_resolution)
709
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
710
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
711
+ setattr(self, f'b{res}', block)
712
+ cur_layer_idx += block.num_layers
713
+ if c_dim > 0:
714
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
715
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
716
+
717
+ def forward(self, img, c, **block_kwargs):
718
+ x = None
719
+ for res in self.block_resolutions:
720
+ block = getattr(self, f'b{res}')
721
+ x, img = block(x, img, **block_kwargs)
722
+
723
+ cmap = None
724
+ if self.c_dim > 0:
725
+ cmap = self.mapping(None, c)
726
+ x = self.b4(x, img, cmap)
727
+ return x
728
+
729
+ #----------------------------------------------------------------------------
training/training_loop.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import copy
12
+ import json
13
+ import pickle
14
+ import psutil
15
+ import PIL.Image
16
+ import numpy as np
17
+ import torch
18
+ import dnnlib
19
+ from torch_utils import misc
20
+ from torch_utils import training_stats
21
+ from torch_utils.ops import conv2d_gradfix
22
+ from torch_utils.ops import grid_sample_gradfix
23
+
24
+ import legacy
25
+ from metrics import metric_main
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def setup_snapshot_image_grid(training_set, random_seed=0):
30
+ rnd = np.random.RandomState(random_seed)
31
+ gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
32
+ gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
33
+
34
+ # No labels => show random subset of training samples.
35
+ if not training_set.has_labels:
36
+ all_indices = list(range(len(training_set)))
37
+ rnd.shuffle(all_indices)
38
+ grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
39
+
40
+ else:
41
+ # Group training samples by label.
42
+ label_groups = dict() # label => [idx, ...]
43
+ for idx in range(len(training_set)):
44
+ label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
45
+ if label not in label_groups:
46
+ label_groups[label] = []
47
+ label_groups[label].append(idx)
48
+
49
+ # Reorder.
50
+ label_order = sorted(label_groups.keys())
51
+ for label in label_order:
52
+ rnd.shuffle(label_groups[label])
53
+
54
+ # Organize into grid.
55
+ grid_indices = []
56
+ for y in range(gh):
57
+ label = label_order[y % len(label_order)]
58
+ indices = label_groups[label]
59
+ grid_indices += [indices[x % len(indices)] for x in range(gw)]
60
+ label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
61
+
62
+ # Load data.
63
+ images, labels = zip(*[training_set[i] for i in grid_indices])
64
+ return (gw, gh), np.stack(images), np.stack(labels)
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def save_image_grid(img, fname, drange, grid_size):
69
+ lo, hi = drange
70
+ img = np.asarray(img, dtype=np.float32)
71
+ img = (img - lo) * (255 / (hi - lo))
72
+ img = np.rint(img).clip(0, 255).astype(np.uint8)
73
+
74
+ gw, gh = grid_size
75
+ _N, C, H, W = img.shape
76
+ img = img.reshape(gh, gw, C, H, W)
77
+ img = img.transpose(0, 3, 1, 4, 2)
78
+ img = img.reshape(gh * H, gw * W, C)
79
+
80
+ assert C in [1, 3]
81
+ if C == 1:
82
+ PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
83
+ if C == 3:
84
+ PIL.Image.fromarray(img, 'RGB').save(fname)
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def training_loop(
89
+ run_dir = '.', # Output directory.
90
+ training_set_kwargs = {}, # Options for training set.
91
+ data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
92
+ G_kwargs = {}, # Options for generator network.
93
+ D_kwargs = {}, # Options for discriminator network.
94
+ G_opt_kwargs = {}, # Options for generator optimizer.
95
+ D_opt_kwargs = {}, # Options for discriminator optimizer.
96
+ augment_kwargs = None, # Options for augmentation pipeline. None = disable.
97
+ loss_kwargs = {}, # Options for loss function.
98
+ metrics = [], # Metrics to evaluate during training.
99
+ random_seed = 0, # Global random seed.
100
+ num_gpus = 1, # Number of GPUs participating in the training.
101
+ rank = 0, # Rank of the current process in [0, num_gpus[.
102
+ batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
103
+ batch_gpu = 4, # Number of samples processed at a time by one GPU.
104
+ ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
105
+ ema_rampup = None, # EMA ramp-up coefficient.
106
+ G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization.
107
+ D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
108
+ augment_p = 0, # Initial value of augmentation probability.
109
+ ada_target = None, # ADA target value. None = fixed p.
110
+ ada_interval = 4, # How often to perform ADA adjustment?
111
+ ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
112
+ total_kimg = 25000, # Total length of the training, measured in thousands of real images.
113
+ kimg_per_tick = 4, # Progress snapshot interval.
114
+ image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
115
+ network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
116
+ resume_pkl = None, # Network pickle to resume training from.
117
+ cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
118
+ allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
119
+ abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
120
+ progress_fn = None, # Callback function for updating training progress. Called for all ranks.
121
+ ):
122
+ # Initialize.
123
+ start_time = time.time()
124
+ device = torch.device('cuda', rank)
125
+ np.random.seed(random_seed * num_gpus + rank)
126
+ torch.manual_seed(random_seed * num_gpus + rank)
127
+ torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
128
+ torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
129
+ torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
130
+ conv2d_gradfix.enabled = True # Improves training speed.
131
+ grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
132
+
133
+ # Load training set.
134
+ if rank == 0:
135
+ print('Loading training set...')
136
+ training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
137
+ training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
138
+ training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
139
+ if rank == 0:
140
+ print()
141
+ print('Num images: ', len(training_set))
142
+ print('Image shape:', training_set.image_shape)
143
+ print('Label shape:', training_set.label_shape)
144
+ print()
145
+
146
+ # Construct networks.
147
+ if rank == 0:
148
+ print('Constructing networks...')
149
+ common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
150
+ G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
151
+ D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
152
+ G_ema = copy.deepcopy(G).eval()
153
+
154
+ # Resume from existing pickle.
155
+ if (resume_pkl is not None) and (rank == 0):
156
+ print(f'Resuming from "{resume_pkl}"')
157
+ with dnnlib.util.open_url(resume_pkl) as f:
158
+ resume_data = legacy.load_network_pkl(f)
159
+ for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
160
+ misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
161
+
162
+ # Print network summary tables.
163
+ if rank == 0:
164
+ z = torch.empty([batch_gpu, G.z_dim], device=device)
165
+ c = torch.empty([batch_gpu, G.c_dim], device=device)
166
+ img = misc.print_module_summary(G, [z, c])
167
+ misc.print_module_summary(D, [img, c])
168
+
169
+ # Setup augmentation.
170
+ if rank == 0:
171
+ print('Setting up augmentation...')
172
+ augment_pipe = None
173
+ ada_stats = None
174
+ if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
175
+ augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
176
+ augment_pipe.p.copy_(torch.as_tensor(augment_p))
177
+ if ada_target is not None:
178
+ ada_stats = training_stats.Collector(regex='Loss/signs/real')
179
+
180
+ # Distribute across GPUs.
181
+ if rank == 0:
182
+ print(f'Distributing across {num_gpus} GPUs...')
183
+ ddp_modules = dict()
184
+ for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
185
+ if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
186
+ module.requires_grad_(True)
187
+ module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
188
+ module.requires_grad_(False)
189
+ if name is not None:
190
+ ddp_modules[name] = module
191
+
192
+ # Setup training phases.
193
+ if rank == 0:
194
+ print('Setting up training phases...')
195
+ loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss
196
+ phases = []
197
+ for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
198
+ if reg_interval is None:
199
+ opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
200
+ phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
201
+ else: # Lazy regularization.
202
+ mb_ratio = reg_interval / (reg_interval + 1)
203
+ opt_kwargs = dnnlib.EasyDict(opt_kwargs)
204
+ opt_kwargs.lr = opt_kwargs.lr * mb_ratio
205
+ opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
206
+ opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
207
+ phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
208
+ phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
209
+ for phase in phases:
210
+ phase.start_event = None
211
+ phase.end_event = None
212
+ if rank == 0:
213
+ phase.start_event = torch.cuda.Event(enable_timing=True)
214
+ phase.end_event = torch.cuda.Event(enable_timing=True)
215
+
216
+ # Export sample images.
217
+ grid_size = None
218
+ grid_z = None
219
+ grid_c = None
220
+ if rank == 0:
221
+ print('Exporting sample images...')
222
+ grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
223
+ save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
224
+ grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
225
+ grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
226
+ images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
227
+ save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
228
+
229
+ # Initialize logs.
230
+ if rank == 0:
231
+ print('Initializing logs...')
232
+ stats_collector = training_stats.Collector(regex='.*')
233
+ stats_metrics = dict()
234
+ stats_jsonl = None
235
+ stats_tfevents = None
236
+ if rank == 0:
237
+ stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
238
+ try:
239
+ import torch.utils.tensorboard as tensorboard
240
+ stats_tfevents = tensorboard.SummaryWriter(run_dir)
241
+ except ImportError as err:
242
+ print('Skipping tfevents export:', err)
243
+
244
+ # Train.
245
+ if rank == 0:
246
+ print(f'Training for {total_kimg} kimg...')
247
+ print()
248
+ cur_nimg = 0
249
+ cur_tick = 0
250
+ tick_start_nimg = cur_nimg
251
+ tick_start_time = time.time()
252
+ maintenance_time = tick_start_time - start_time
253
+ batch_idx = 0
254
+ if progress_fn is not None:
255
+ progress_fn(0, total_kimg)
256
+ while True:
257
+
258
+ # Fetch training data.
259
+ with torch.autograd.profiler.record_function('data_fetch'):
260
+ phase_real_img, phase_real_c = next(training_set_iterator)
261
+ phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
262
+ phase_real_c = phase_real_c.to(device).split(batch_gpu)
263
+ all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
264
+ all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
265
+ all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
266
+ all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
267
+ all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
268
+
269
+ # Execute training phases.
270
+ for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
271
+ if batch_idx % phase.interval != 0:
272
+ continue
273
+
274
+ # Initialize gradient accumulation.
275
+ if phase.start_event is not None:
276
+ phase.start_event.record(torch.cuda.current_stream(device))
277
+ phase.opt.zero_grad(set_to_none=True)
278
+ phase.module.requires_grad_(True)
279
+
280
+ # Accumulate gradients over multiple rounds.
281
+ for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)):
282
+ sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
283
+ gain = phase.interval
284
+ loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain)
285
+
286
+ # Update weights.
287
+ phase.module.requires_grad_(False)
288
+ with torch.autograd.profiler.record_function(phase.name + '_opt'):
289
+ for param in phase.module.parameters():
290
+ if param.grad is not None:
291
+ misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
292
+ phase.opt.step()
293
+ if phase.end_event is not None:
294
+ phase.end_event.record(torch.cuda.current_stream(device))
295
+
296
+ # Update G_ema.
297
+ with torch.autograd.profiler.record_function('Gema'):
298
+ ema_nimg = ema_kimg * 1000
299
+ if ema_rampup is not None:
300
+ ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
301
+ ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
302
+ for p_ema, p in zip(G_ema.parameters(), G.parameters()):
303
+ p_ema.copy_(p.lerp(p_ema, ema_beta))
304
+ for b_ema, b in zip(G_ema.buffers(), G.buffers()):
305
+ b_ema.copy_(b)
306
+
307
+ # Update state.
308
+ cur_nimg += batch_size
309
+ batch_idx += 1
310
+
311
+ # Execute ADA heuristic.
312
+ if (ada_stats is not None) and (batch_idx % ada_interval == 0):
313
+ ada_stats.update()
314
+ adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
315
+ augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
316
+
317
+ # Perform maintenance tasks once per tick.
318
+ done = (cur_nimg >= total_kimg * 1000)
319
+ if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
320
+ continue
321
+
322
+ # Print status line, accumulating the same information in stats_collector.
323
+ tick_end_time = time.time()
324
+ fields = []
325
+ fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
326
+ fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
327
+ fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
328
+ fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
329
+ fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
330
+ fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
331
+ fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
332
+ fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
333
+ torch.cuda.reset_peak_memory_stats()
334
+ fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
335
+ training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
336
+ training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
337
+ if rank == 0:
338
+ print(' '.join(fields))
339
+
340
+ # Check for abort.
341
+ if (not done) and (abort_fn is not None) and abort_fn():
342
+ done = True
343
+ if rank == 0:
344
+ print()
345
+ print('Aborting...')
346
+
347
+ # Save image snapshot.
348
+ if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
349
+ images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
350
+ save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
351
+
352
+ # Save network snapshot.
353
+ snapshot_pkl = None
354
+ snapshot_data = None
355
+ if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
356
+ snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
357
+ for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
358
+ if module is not None:
359
+ if num_gpus > 1:
360
+ misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
361
+ module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
362
+ snapshot_data[name] = module
363
+ del module # conserve memory
364
+ snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
365
+ if rank == 0:
366
+ with open(snapshot_pkl, 'wb') as f:
367
+ pickle.dump(snapshot_data, f)
368
+
369
+ # Evaluate metrics.
370
+ if (snapshot_data is not None) and (len(metrics) > 0):
371
+ if rank == 0:
372
+ print('Evaluating metrics...')
373
+ for metric in metrics:
374
+ result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
375
+ dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
376
+ if rank == 0:
377
+ metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
378
+ stats_metrics.update(result_dict.results)
379
+ del snapshot_data # conserve memory
380
+
381
+ # Collect statistics.
382
+ for phase in phases:
383
+ value = []
384
+ if (phase.start_event is not None) and (phase.end_event is not None):
385
+ phase.end_event.synchronize()
386
+ value = phase.start_event.elapsed_time(phase.end_event)
387
+ training_stats.report0('Timing/' + phase.name, value)
388
+ stats_collector.update()
389
+ stats_dict = stats_collector.as_dict()
390
+
391
+ # Update logs.
392
+ timestamp = time.time()
393
+ if stats_jsonl is not None:
394
+ fields = dict(stats_dict, timestamp=timestamp)
395
+ stats_jsonl.write(json.dumps(fields) + '\n')
396
+ stats_jsonl.flush()
397
+ if stats_tfevents is not None:
398
+ global_step = int(cur_nimg / 1e3)
399
+ walltime = timestamp - start_time
400
+ for name, value in stats_dict.items():
401
+ stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
402
+ for name, value in stats_metrics.items():
403
+ stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
404
+ stats_tfevents.flush()
405
+ if progress_fn is not None:
406
+ progress_fn(cur_nimg // 1000, total_kimg)
407
+
408
+ # Update state.
409
+ cur_tick += 1
410
+ tick_start_nimg = cur_nimg
411
+ tick_start_time = time.time()
412
+ maintenance_time = tick_start_time - tick_end_time
413
+ if done:
414
+ break
415
+
416
+ # Done.
417
+ if rank == 0:
418
+ print()
419
+ print('Exiting...')
420
+
421
+ #----------------------------------------------------------------------------