Upload 6 files
Browse files- training/__init__.py +9 -0
- training/augment.py +431 -0
- training/dataset.py +236 -0
- training/loss.py +133 -0
- training/networks.py +729 -0
- training/training_loop.py +421 -0
training/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
# empty
|
training/augment.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import scipy.signal
|
| 11 |
+
import torch
|
| 12 |
+
from torch_utils import persistence
|
| 13 |
+
from torch_utils import misc
|
| 14 |
+
from torch_utils.ops import upfirdn2d
|
| 15 |
+
from torch_utils.ops import grid_sample_gradfix
|
| 16 |
+
from torch_utils.ops import conv2d_gradfix
|
| 17 |
+
|
| 18 |
+
#----------------------------------------------------------------------------
|
| 19 |
+
# Coefficients of various wavelet decomposition low-pass filters.
|
| 20 |
+
|
| 21 |
+
wavelets = {
|
| 22 |
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
| 23 |
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
| 24 |
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
| 25 |
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
| 26 |
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
| 27 |
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
| 28 |
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
| 29 |
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
| 30 |
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
| 31 |
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
| 32 |
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
| 33 |
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
| 34 |
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
| 35 |
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
| 36 |
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
| 37 |
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
#----------------------------------------------------------------------------
|
| 41 |
+
# Helpers for constructing transformation matrices.
|
| 42 |
+
|
| 43 |
+
def matrix(*rows, device=None):
|
| 44 |
+
assert all(len(row) == len(rows[0]) for row in rows)
|
| 45 |
+
elems = [x for row in rows for x in row]
|
| 46 |
+
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
| 47 |
+
if len(ref) == 0:
|
| 48 |
+
return misc.constant(np.asarray(rows), device=device)
|
| 49 |
+
assert device is None or device == ref[0].device
|
| 50 |
+
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
| 51 |
+
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
| 52 |
+
|
| 53 |
+
def translate2d(tx, ty, **kwargs):
|
| 54 |
+
return matrix(
|
| 55 |
+
[1, 0, tx],
|
| 56 |
+
[0, 1, ty],
|
| 57 |
+
[0, 0, 1],
|
| 58 |
+
**kwargs)
|
| 59 |
+
|
| 60 |
+
def translate3d(tx, ty, tz, **kwargs):
|
| 61 |
+
return matrix(
|
| 62 |
+
[1, 0, 0, tx],
|
| 63 |
+
[0, 1, 0, ty],
|
| 64 |
+
[0, 0, 1, tz],
|
| 65 |
+
[0, 0, 0, 1],
|
| 66 |
+
**kwargs)
|
| 67 |
+
|
| 68 |
+
def scale2d(sx, sy, **kwargs):
|
| 69 |
+
return matrix(
|
| 70 |
+
[sx, 0, 0],
|
| 71 |
+
[0, sy, 0],
|
| 72 |
+
[0, 0, 1],
|
| 73 |
+
**kwargs)
|
| 74 |
+
|
| 75 |
+
def scale3d(sx, sy, sz, **kwargs):
|
| 76 |
+
return matrix(
|
| 77 |
+
[sx, 0, 0, 0],
|
| 78 |
+
[0, sy, 0, 0],
|
| 79 |
+
[0, 0, sz, 0],
|
| 80 |
+
[0, 0, 0, 1],
|
| 81 |
+
**kwargs)
|
| 82 |
+
|
| 83 |
+
def rotate2d(theta, **kwargs):
|
| 84 |
+
return matrix(
|
| 85 |
+
[torch.cos(theta), torch.sin(-theta), 0],
|
| 86 |
+
[torch.sin(theta), torch.cos(theta), 0],
|
| 87 |
+
[0, 0, 1],
|
| 88 |
+
**kwargs)
|
| 89 |
+
|
| 90 |
+
def rotate3d(v, theta, **kwargs):
|
| 91 |
+
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
| 92 |
+
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
| 93 |
+
return matrix(
|
| 94 |
+
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
| 95 |
+
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
| 96 |
+
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
| 97 |
+
[0, 0, 0, 1],
|
| 98 |
+
**kwargs)
|
| 99 |
+
|
| 100 |
+
def translate2d_inv(tx, ty, **kwargs):
|
| 101 |
+
return translate2d(-tx, -ty, **kwargs)
|
| 102 |
+
|
| 103 |
+
def scale2d_inv(sx, sy, **kwargs):
|
| 104 |
+
return scale2d(1 / sx, 1 / sy, **kwargs)
|
| 105 |
+
|
| 106 |
+
def rotate2d_inv(theta, **kwargs):
|
| 107 |
+
return rotate2d(-theta, **kwargs)
|
| 108 |
+
|
| 109 |
+
#----------------------------------------------------------------------------
|
| 110 |
+
# Versatile image augmentation pipeline from the paper
|
| 111 |
+
# "Training Generative Adversarial Networks with Limited Data".
|
| 112 |
+
#
|
| 113 |
+
# All augmentations are disabled by default; individual augmentations can
|
| 114 |
+
# be enabled by setting their probability multipliers to 1.
|
| 115 |
+
|
| 116 |
+
@persistence.persistent_class
|
| 117 |
+
class AugmentPipe(torch.nn.Module):
|
| 118 |
+
def __init__(self,
|
| 119 |
+
xflip=0, rotate90=0, xint=0, xint_max=0.125,
|
| 120 |
+
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
|
| 121 |
+
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
|
| 122 |
+
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
|
| 123 |
+
noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
|
| 127 |
+
|
| 128 |
+
# Pixel blitting.
|
| 129 |
+
self.xflip = float(xflip) # Probability multiplier for x-flip.
|
| 130 |
+
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
|
| 131 |
+
self.xint = float(xint) # Probability multiplier for integer translation.
|
| 132 |
+
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
|
| 133 |
+
|
| 134 |
+
# General geometric transformations.
|
| 135 |
+
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
| 136 |
+
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
|
| 137 |
+
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
| 138 |
+
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
|
| 139 |
+
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
| 140 |
+
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
|
| 141 |
+
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
| 142 |
+
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
|
| 143 |
+
|
| 144 |
+
# Color transformations.
|
| 145 |
+
self.brightness = float(brightness) # Probability multiplier for brightness.
|
| 146 |
+
self.contrast = float(contrast) # Probability multiplier for contrast.
|
| 147 |
+
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
| 148 |
+
self.hue = float(hue) # Probability multiplier for hue rotation.
|
| 149 |
+
self.saturation = float(saturation) # Probability multiplier for saturation.
|
| 150 |
+
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
| 151 |
+
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
| 152 |
+
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
| 153 |
+
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
| 154 |
+
|
| 155 |
+
# Image-space filtering.
|
| 156 |
+
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
|
| 157 |
+
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
|
| 158 |
+
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
|
| 159 |
+
|
| 160 |
+
# Image-space corruptions.
|
| 161 |
+
self.noise = float(noise) # Probability multiplier for additive RGB noise.
|
| 162 |
+
self.cutout = float(cutout) # Probability multiplier for cutout.
|
| 163 |
+
self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
|
| 164 |
+
self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
|
| 165 |
+
|
| 166 |
+
# Setup orthogonal lowpass filter for geometric augmentations.
|
| 167 |
+
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
|
| 168 |
+
|
| 169 |
+
# Construct filter bank for image-space filtering.
|
| 170 |
+
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
| 171 |
+
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
| 172 |
+
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
| 173 |
+
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
| 174 |
+
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
| 175 |
+
for i in range(1, Hz_fbank.shape[0]):
|
| 176 |
+
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
| 177 |
+
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
| 178 |
+
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
| 179 |
+
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
|
| 180 |
+
|
| 181 |
+
def forward(self, images, debug_percentile=None):
|
| 182 |
+
assert isinstance(images, torch.Tensor) and images.ndim == 4
|
| 183 |
+
batch_size, num_channels, height, width = images.shape
|
| 184 |
+
device = images.device
|
| 185 |
+
if debug_percentile is not None:
|
| 186 |
+
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
|
| 187 |
+
|
| 188 |
+
# -------------------------------------
|
| 189 |
+
# Select parameters for pixel blitting.
|
| 190 |
+
# -------------------------------------
|
| 191 |
+
|
| 192 |
+
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
|
| 193 |
+
I_3 = torch.eye(3, device=device)
|
| 194 |
+
G_inv = I_3
|
| 195 |
+
|
| 196 |
+
# Apply x-flip with probability (xflip * strength).
|
| 197 |
+
if self.xflip > 0:
|
| 198 |
+
i = torch.floor(torch.rand([batch_size], device=device) * 2)
|
| 199 |
+
i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
|
| 200 |
+
if debug_percentile is not None:
|
| 201 |
+
i = torch.full_like(i, torch.floor(debug_percentile * 2))
|
| 202 |
+
G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
|
| 203 |
+
|
| 204 |
+
# Apply 90 degree rotations with probability (rotate90 * strength).
|
| 205 |
+
if self.rotate90 > 0:
|
| 206 |
+
i = torch.floor(torch.rand([batch_size], device=device) * 4)
|
| 207 |
+
i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
|
| 208 |
+
if debug_percentile is not None:
|
| 209 |
+
i = torch.full_like(i, torch.floor(debug_percentile * 4))
|
| 210 |
+
G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
|
| 211 |
+
|
| 212 |
+
# Apply integer translation with probability (xint * strength).
|
| 213 |
+
if self.xint > 0:
|
| 214 |
+
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
|
| 215 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
|
| 216 |
+
if debug_percentile is not None:
|
| 217 |
+
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
|
| 218 |
+
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
|
| 219 |
+
|
| 220 |
+
# --------------------------------------------------------
|
| 221 |
+
# Select parameters for general geometric transformations.
|
| 222 |
+
# --------------------------------------------------------
|
| 223 |
+
|
| 224 |
+
# Apply isotropic scaling with probability (scale * strength).
|
| 225 |
+
if self.scale > 0:
|
| 226 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
|
| 227 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
|
| 228 |
+
if debug_percentile is not None:
|
| 229 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
|
| 230 |
+
G_inv = G_inv @ scale2d_inv(s, s)
|
| 231 |
+
|
| 232 |
+
# Apply pre-rotation with probability p_rot.
|
| 233 |
+
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
|
| 234 |
+
if self.rotate > 0:
|
| 235 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
| 236 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
| 237 |
+
if debug_percentile is not None:
|
| 238 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
|
| 239 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
|
| 240 |
+
|
| 241 |
+
# Apply anisotropic scaling with probability (aniso * strength).
|
| 242 |
+
if self.aniso > 0:
|
| 243 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
|
| 244 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
|
| 245 |
+
if debug_percentile is not None:
|
| 246 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
|
| 247 |
+
G_inv = G_inv @ scale2d_inv(s, 1 / s)
|
| 248 |
+
|
| 249 |
+
# Apply post-rotation with probability p_rot.
|
| 250 |
+
if self.rotate > 0:
|
| 251 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
| 252 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
| 253 |
+
if debug_percentile is not None:
|
| 254 |
+
theta = torch.zeros_like(theta)
|
| 255 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
|
| 256 |
+
|
| 257 |
+
# Apply fractional translation with probability (xfrac * strength).
|
| 258 |
+
if self.xfrac > 0:
|
| 259 |
+
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
|
| 260 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
|
| 261 |
+
if debug_percentile is not None:
|
| 262 |
+
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
|
| 263 |
+
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
|
| 264 |
+
|
| 265 |
+
# ----------------------------------
|
| 266 |
+
# Execute geometric transformations.
|
| 267 |
+
# ----------------------------------
|
| 268 |
+
|
| 269 |
+
# Execute if the transform is not identity.
|
| 270 |
+
if G_inv is not I_3:
|
| 271 |
+
|
| 272 |
+
# Calculate padding.
|
| 273 |
+
cx = (width - 1) / 2
|
| 274 |
+
cy = (height - 1) / 2
|
| 275 |
+
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
| 276 |
+
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
| 277 |
+
Hz_pad = self.Hz_geom.shape[0] // 4
|
| 278 |
+
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
| 279 |
+
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
| 280 |
+
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
| 281 |
+
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
| 282 |
+
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
|
| 283 |
+
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
| 284 |
+
|
| 285 |
+
# Pad image and adjust origin.
|
| 286 |
+
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
| 287 |
+
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
| 288 |
+
|
| 289 |
+
# Upsample.
|
| 290 |
+
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
|
| 291 |
+
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
| 292 |
+
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
| 293 |
+
|
| 294 |
+
# Execute transformation.
|
| 295 |
+
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
|
| 296 |
+
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
| 297 |
+
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
| 298 |
+
images = grid_sample_gradfix.grid_sample(images, grid)
|
| 299 |
+
|
| 300 |
+
# Downsample and crop.
|
| 301 |
+
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
|
| 302 |
+
|
| 303 |
+
# --------------------------------------------
|
| 304 |
+
# Select parameters for color transformations.
|
| 305 |
+
# --------------------------------------------
|
| 306 |
+
|
| 307 |
+
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
|
| 308 |
+
I_4 = torch.eye(4, device=device)
|
| 309 |
+
C = I_4
|
| 310 |
+
|
| 311 |
+
# Apply brightness with probability (brightness * strength).
|
| 312 |
+
if self.brightness > 0:
|
| 313 |
+
b = torch.randn([batch_size], device=device) * self.brightness_std
|
| 314 |
+
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
|
| 315 |
+
if debug_percentile is not None:
|
| 316 |
+
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
|
| 317 |
+
C = translate3d(b, b, b) @ C
|
| 318 |
+
|
| 319 |
+
# Apply contrast with probability (contrast * strength).
|
| 320 |
+
if self.contrast > 0:
|
| 321 |
+
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
|
| 322 |
+
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
|
| 323 |
+
if debug_percentile is not None:
|
| 324 |
+
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
|
| 325 |
+
C = scale3d(c, c, c) @ C
|
| 326 |
+
|
| 327 |
+
# Apply luma flip with probability (lumaflip * strength).
|
| 328 |
+
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
|
| 329 |
+
if self.lumaflip > 0:
|
| 330 |
+
i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
|
| 331 |
+
i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
|
| 332 |
+
if debug_percentile is not None:
|
| 333 |
+
i = torch.full_like(i, torch.floor(debug_percentile * 2))
|
| 334 |
+
C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
|
| 335 |
+
|
| 336 |
+
# Apply hue rotation with probability (hue * strength).
|
| 337 |
+
if self.hue > 0 and num_channels > 1:
|
| 338 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
|
| 339 |
+
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
|
| 340 |
+
if debug_percentile is not None:
|
| 341 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
|
| 342 |
+
C = rotate3d(v, theta) @ C # Rotate around v.
|
| 343 |
+
|
| 344 |
+
# Apply saturation with probability (saturation * strength).
|
| 345 |
+
if self.saturation > 0 and num_channels > 1:
|
| 346 |
+
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
|
| 347 |
+
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
|
| 348 |
+
if debug_percentile is not None:
|
| 349 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
|
| 350 |
+
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
|
| 351 |
+
|
| 352 |
+
# ------------------------------
|
| 353 |
+
# Execute color transformations.
|
| 354 |
+
# ------------------------------
|
| 355 |
+
|
| 356 |
+
# Execute if the transform is not identity.
|
| 357 |
+
if C is not I_4:
|
| 358 |
+
images = images.reshape([batch_size, num_channels, height * width])
|
| 359 |
+
if num_channels == 3:
|
| 360 |
+
images = C[:, :3, :3] @ images + C[:, :3, 3:]
|
| 361 |
+
elif num_channels == 1:
|
| 362 |
+
C = C[:, :3, :].mean(dim=1, keepdims=True)
|
| 363 |
+
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
|
| 364 |
+
else:
|
| 365 |
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
| 366 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
| 367 |
+
|
| 368 |
+
# ----------------------
|
| 369 |
+
# Image-space filtering.
|
| 370 |
+
# ----------------------
|
| 371 |
+
|
| 372 |
+
if self.imgfilter > 0:
|
| 373 |
+
num_bands = self.Hz_fbank.shape[0]
|
| 374 |
+
assert len(self.imgfilter_bands) == num_bands
|
| 375 |
+
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
|
| 376 |
+
|
| 377 |
+
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
| 378 |
+
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
|
| 379 |
+
for i, band_strength in enumerate(self.imgfilter_bands):
|
| 380 |
+
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
|
| 381 |
+
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
|
| 382 |
+
if debug_percentile is not None:
|
| 383 |
+
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
|
| 384 |
+
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
|
| 385 |
+
t[:, i] = t_i # Replace i'th element.
|
| 386 |
+
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
| 387 |
+
g = g * t # Accumulate into global gain.
|
| 388 |
+
|
| 389 |
+
# Construct combined amplification filter.
|
| 390 |
+
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
| 391 |
+
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
| 392 |
+
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
| 393 |
+
|
| 394 |
+
# Apply filter.
|
| 395 |
+
p = self.Hz_fbank.shape[1] // 2
|
| 396 |
+
images = images.reshape([1, batch_size * num_channels, height, width])
|
| 397 |
+
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
| 398 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
| 399 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
| 400 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
| 401 |
+
|
| 402 |
+
# ------------------------
|
| 403 |
+
# Image-space corruptions.
|
| 404 |
+
# ------------------------
|
| 405 |
+
|
| 406 |
+
# Apply additive RGB noise with probability (noise * strength).
|
| 407 |
+
if self.noise > 0:
|
| 408 |
+
sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
|
| 409 |
+
sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
|
| 410 |
+
if debug_percentile is not None:
|
| 411 |
+
sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
|
| 412 |
+
images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
|
| 413 |
+
|
| 414 |
+
# Apply cutout with probability (cutout * strength).
|
| 415 |
+
if self.cutout > 0:
|
| 416 |
+
size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
|
| 417 |
+
size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
|
| 418 |
+
center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
|
| 419 |
+
if debug_percentile is not None:
|
| 420 |
+
size = torch.full_like(size, self.cutout_size)
|
| 421 |
+
center = torch.full_like(center, debug_percentile)
|
| 422 |
+
coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
|
| 423 |
+
coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
|
| 424 |
+
mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
|
| 425 |
+
mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
|
| 426 |
+
mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
|
| 427 |
+
images = images * mask
|
| 428 |
+
|
| 429 |
+
return images
|
| 430 |
+
|
| 431 |
+
#----------------------------------------------------------------------------
|
training/dataset.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
import zipfile
|
| 12 |
+
import PIL.Image
|
| 13 |
+
import json
|
| 14 |
+
import torch
|
| 15 |
+
import dnnlib
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import pyspng
|
| 19 |
+
except ImportError:
|
| 20 |
+
pyspng = None
|
| 21 |
+
|
| 22 |
+
#----------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
class Dataset(torch.utils.data.Dataset):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
name, # Name of the dataset.
|
| 27 |
+
raw_shape, # Shape of the raw image data (NCHW).
|
| 28 |
+
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
|
| 29 |
+
use_labels = False, # Enable conditioning labels? False = label dimension is zero.
|
| 30 |
+
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
|
| 31 |
+
random_seed = 0, # Random seed to use when applying max_size.
|
| 32 |
+
):
|
| 33 |
+
self._name = name
|
| 34 |
+
self._raw_shape = list(raw_shape)
|
| 35 |
+
self._use_labels = use_labels
|
| 36 |
+
self._raw_labels = None
|
| 37 |
+
self._label_shape = None
|
| 38 |
+
|
| 39 |
+
# Apply max_size.
|
| 40 |
+
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
|
| 41 |
+
if (max_size is not None) and (self._raw_idx.size > max_size):
|
| 42 |
+
np.random.RandomState(random_seed).shuffle(self._raw_idx)
|
| 43 |
+
self._raw_idx = np.sort(self._raw_idx[:max_size])
|
| 44 |
+
|
| 45 |
+
# Apply xflip.
|
| 46 |
+
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
|
| 47 |
+
if xflip:
|
| 48 |
+
self._raw_idx = np.tile(self._raw_idx, 2)
|
| 49 |
+
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
| 50 |
+
|
| 51 |
+
def _get_raw_labels(self):
|
| 52 |
+
if self._raw_labels is None:
|
| 53 |
+
self._raw_labels = self._load_raw_labels() if self._use_labels else None
|
| 54 |
+
if self._raw_labels is None:
|
| 55 |
+
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
|
| 56 |
+
assert isinstance(self._raw_labels, np.ndarray)
|
| 57 |
+
assert self._raw_labels.shape[0] == self._raw_shape[0]
|
| 58 |
+
assert self._raw_labels.dtype in [np.float32, np.int64]
|
| 59 |
+
if self._raw_labels.dtype == np.int64:
|
| 60 |
+
assert self._raw_labels.ndim == 1
|
| 61 |
+
assert np.all(self._raw_labels >= 0)
|
| 62 |
+
return self._raw_labels
|
| 63 |
+
|
| 64 |
+
def close(self): # to be overridden by subclass
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
def _load_raw_image(self, raw_idx): # to be overridden by subclass
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
def _load_raw_labels(self): # to be overridden by subclass
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
def __getstate__(self):
|
| 74 |
+
return dict(self.__dict__, _raw_labels=None)
|
| 75 |
+
|
| 76 |
+
def __del__(self):
|
| 77 |
+
try:
|
| 78 |
+
self.close()
|
| 79 |
+
except:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return self._raw_idx.size
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx):
|
| 86 |
+
image = self._load_raw_image(self._raw_idx[idx])
|
| 87 |
+
assert isinstance(image, np.ndarray)
|
| 88 |
+
assert list(image.shape) == self.image_shape
|
| 89 |
+
assert image.dtype == np.uint8
|
| 90 |
+
if self._xflip[idx]:
|
| 91 |
+
assert image.ndim == 3 # CHW
|
| 92 |
+
image = image[:, :, ::-1]
|
| 93 |
+
return image.copy(), self.get_label(idx)
|
| 94 |
+
|
| 95 |
+
def get_label(self, idx):
|
| 96 |
+
label = self._get_raw_labels()[self._raw_idx[idx]]
|
| 97 |
+
if label.dtype == np.int64:
|
| 98 |
+
onehot = np.zeros(self.label_shape, dtype=np.float32)
|
| 99 |
+
onehot[label] = 1
|
| 100 |
+
label = onehot
|
| 101 |
+
return label.copy()
|
| 102 |
+
|
| 103 |
+
def get_details(self, idx):
|
| 104 |
+
d = dnnlib.EasyDict()
|
| 105 |
+
d.raw_idx = int(self._raw_idx[idx])
|
| 106 |
+
d.xflip = (int(self._xflip[idx]) != 0)
|
| 107 |
+
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
|
| 108 |
+
return d
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def name(self):
|
| 112 |
+
return self._name
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def image_shape(self):
|
| 116 |
+
return list(self._raw_shape[1:])
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def num_channels(self):
|
| 120 |
+
assert len(self.image_shape) == 3 # CHW
|
| 121 |
+
return self.image_shape[0]
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def resolution(self):
|
| 125 |
+
assert len(self.image_shape) == 3 # CHW
|
| 126 |
+
assert self.image_shape[1] == self.image_shape[2]
|
| 127 |
+
return self.image_shape[1]
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def label_shape(self):
|
| 131 |
+
if self._label_shape is None:
|
| 132 |
+
raw_labels = self._get_raw_labels()
|
| 133 |
+
if raw_labels.dtype == np.int64:
|
| 134 |
+
self._label_shape = [int(np.max(raw_labels)) + 1]
|
| 135 |
+
else:
|
| 136 |
+
self._label_shape = raw_labels.shape[1:]
|
| 137 |
+
return list(self._label_shape)
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def label_dim(self):
|
| 141 |
+
assert len(self.label_shape) == 1
|
| 142 |
+
return self.label_shape[0]
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def has_labels(self):
|
| 146 |
+
return any(x != 0 for x in self.label_shape)
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def has_onehot_labels(self):
|
| 150 |
+
return self._get_raw_labels().dtype == np.int64
|
| 151 |
+
|
| 152 |
+
#----------------------------------------------------------------------------
|
| 153 |
+
|
| 154 |
+
class ImageFolderDataset(Dataset):
|
| 155 |
+
def __init__(self,
|
| 156 |
+
path, # Path to directory or zip.
|
| 157 |
+
resolution = None, # Ensure specific resolution, None = highest available.
|
| 158 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
| 159 |
+
):
|
| 160 |
+
self._path = path
|
| 161 |
+
self._zipfile = None
|
| 162 |
+
|
| 163 |
+
if os.path.isdir(self._path):
|
| 164 |
+
self._type = 'dir'
|
| 165 |
+
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
|
| 166 |
+
elif self._file_ext(self._path) == '.zip':
|
| 167 |
+
self._type = 'zip'
|
| 168 |
+
self._all_fnames = set(self._get_zipfile().namelist())
|
| 169 |
+
else:
|
| 170 |
+
raise IOError('Path must point to a directory or zip')
|
| 171 |
+
|
| 172 |
+
PIL.Image.init()
|
| 173 |
+
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
|
| 174 |
+
if len(self._image_fnames) == 0:
|
| 175 |
+
raise IOError('No image files found in the specified path')
|
| 176 |
+
|
| 177 |
+
name = os.path.splitext(os.path.basename(self._path))[0]
|
| 178 |
+
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
|
| 179 |
+
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
|
| 180 |
+
raise IOError('Image files do not match the specified resolution')
|
| 181 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _file_ext(fname):
|
| 185 |
+
return os.path.splitext(fname)[1].lower()
|
| 186 |
+
|
| 187 |
+
def _get_zipfile(self):
|
| 188 |
+
assert self._type == 'zip'
|
| 189 |
+
if self._zipfile is None:
|
| 190 |
+
self._zipfile = zipfile.ZipFile(self._path)
|
| 191 |
+
return self._zipfile
|
| 192 |
+
|
| 193 |
+
def _open_file(self, fname):
|
| 194 |
+
if self._type == 'dir':
|
| 195 |
+
return open(os.path.join(self._path, fname), 'rb')
|
| 196 |
+
if self._type == 'zip':
|
| 197 |
+
return self._get_zipfile().open(fname, 'r')
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
def close(self):
|
| 201 |
+
try:
|
| 202 |
+
if self._zipfile is not None:
|
| 203 |
+
self._zipfile.close()
|
| 204 |
+
finally:
|
| 205 |
+
self._zipfile = None
|
| 206 |
+
|
| 207 |
+
def __getstate__(self):
|
| 208 |
+
return dict(super().__getstate__(), _zipfile=None)
|
| 209 |
+
|
| 210 |
+
def _load_raw_image(self, raw_idx):
|
| 211 |
+
fname = self._image_fnames[raw_idx]
|
| 212 |
+
with self._open_file(fname) as f:
|
| 213 |
+
if pyspng is not None and self._file_ext(fname) == '.png':
|
| 214 |
+
image = pyspng.load(f.read())
|
| 215 |
+
else:
|
| 216 |
+
image = np.array(PIL.Image.open(f))
|
| 217 |
+
if image.ndim == 2:
|
| 218 |
+
image = image[:, :, np.newaxis] # HW => HWC
|
| 219 |
+
image = image.transpose(2, 0, 1) # HWC => CHW
|
| 220 |
+
return image
|
| 221 |
+
|
| 222 |
+
def _load_raw_labels(self):
|
| 223 |
+
fname = 'dataset.json'
|
| 224 |
+
if fname not in self._all_fnames:
|
| 225 |
+
return None
|
| 226 |
+
with self._open_file(fname) as f:
|
| 227 |
+
labels = json.load(f)['labels']
|
| 228 |
+
if labels is None:
|
| 229 |
+
return None
|
| 230 |
+
labels = dict(labels)
|
| 231 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
|
| 232 |
+
labels = np.array(labels)
|
| 233 |
+
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 234 |
+
return labels
|
| 235 |
+
|
| 236 |
+
#----------------------------------------------------------------------------
|
training/loss.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch_utils import training_stats
|
| 12 |
+
from torch_utils import misc
|
| 13 |
+
from torch_utils.ops import conv2d_gradfix
|
| 14 |
+
|
| 15 |
+
#----------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
class Loss:
|
| 18 |
+
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
|
| 19 |
+
raise NotImplementedError()
|
| 20 |
+
|
| 21 |
+
#----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class StyleGAN2Loss(Loss):
|
| 24 |
+
def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.device = device
|
| 27 |
+
self.G_mapping = G_mapping
|
| 28 |
+
self.G_synthesis = G_synthesis
|
| 29 |
+
self.D = D
|
| 30 |
+
self.augment_pipe = augment_pipe
|
| 31 |
+
self.style_mixing_prob = style_mixing_prob
|
| 32 |
+
self.r1_gamma = r1_gamma
|
| 33 |
+
self.pl_batch_shrink = pl_batch_shrink
|
| 34 |
+
self.pl_decay = pl_decay
|
| 35 |
+
self.pl_weight = pl_weight
|
| 36 |
+
self.pl_mean = torch.zeros([], device=device)
|
| 37 |
+
|
| 38 |
+
def run_G(self, z, c, sync):
|
| 39 |
+
with misc.ddp_sync(self.G_mapping, sync):
|
| 40 |
+
ws = self.G_mapping(z, c)
|
| 41 |
+
if self.style_mixing_prob > 0:
|
| 42 |
+
with torch.autograd.profiler.record_function('style_mixing'):
|
| 43 |
+
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
|
| 44 |
+
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
|
| 45 |
+
ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
|
| 46 |
+
with misc.ddp_sync(self.G_synthesis, sync):
|
| 47 |
+
img = self.G_synthesis(ws)
|
| 48 |
+
return img, ws
|
| 49 |
+
|
| 50 |
+
def run_D(self, img, c, sync):
|
| 51 |
+
if self.augment_pipe is not None:
|
| 52 |
+
img = self.augment_pipe(img)
|
| 53 |
+
with misc.ddp_sync(self.D, sync):
|
| 54 |
+
logits = self.D(img, c)
|
| 55 |
+
return logits
|
| 56 |
+
|
| 57 |
+
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
|
| 58 |
+
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
|
| 59 |
+
do_Gmain = (phase in ['Gmain', 'Gboth'])
|
| 60 |
+
do_Dmain = (phase in ['Dmain', 'Dboth'])
|
| 61 |
+
do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
|
| 62 |
+
do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
|
| 63 |
+
|
| 64 |
+
# Gmain: Maximize logits for generated images.
|
| 65 |
+
if do_Gmain:
|
| 66 |
+
with torch.autograd.profiler.record_function('Gmain_forward'):
|
| 67 |
+
gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
|
| 68 |
+
gen_logits = self.run_D(gen_img, gen_c, sync=False)
|
| 69 |
+
training_stats.report('Loss/scores/fake', gen_logits)
|
| 70 |
+
training_stats.report('Loss/signs/fake', gen_logits.sign())
|
| 71 |
+
loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
|
| 72 |
+
training_stats.report('Loss/G/loss', loss_Gmain)
|
| 73 |
+
with torch.autograd.profiler.record_function('Gmain_backward'):
|
| 74 |
+
loss_Gmain.mean().mul(gain).backward()
|
| 75 |
+
|
| 76 |
+
# Gpl: Apply path length regularization.
|
| 77 |
+
if do_Gpl:
|
| 78 |
+
with torch.autograd.profiler.record_function('Gpl_forward'):
|
| 79 |
+
batch_size = gen_z.shape[0] // self.pl_batch_shrink
|
| 80 |
+
gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
|
| 81 |
+
pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
|
| 82 |
+
with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
|
| 83 |
+
pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
|
| 84 |
+
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
|
| 85 |
+
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
|
| 86 |
+
self.pl_mean.copy_(pl_mean.detach())
|
| 87 |
+
pl_penalty = (pl_lengths - pl_mean).square()
|
| 88 |
+
training_stats.report('Loss/pl_penalty', pl_penalty)
|
| 89 |
+
loss_Gpl = pl_penalty * self.pl_weight
|
| 90 |
+
training_stats.report('Loss/G/reg', loss_Gpl)
|
| 91 |
+
with torch.autograd.profiler.record_function('Gpl_backward'):
|
| 92 |
+
(gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
|
| 93 |
+
|
| 94 |
+
# Dmain: Minimize logits for generated images.
|
| 95 |
+
loss_Dgen = 0
|
| 96 |
+
if do_Dmain:
|
| 97 |
+
with torch.autograd.profiler.record_function('Dgen_forward'):
|
| 98 |
+
gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
|
| 99 |
+
gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
|
| 100 |
+
training_stats.report('Loss/scores/fake', gen_logits)
|
| 101 |
+
training_stats.report('Loss/signs/fake', gen_logits.sign())
|
| 102 |
+
loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
|
| 103 |
+
with torch.autograd.profiler.record_function('Dgen_backward'):
|
| 104 |
+
loss_Dgen.mean().mul(gain).backward()
|
| 105 |
+
|
| 106 |
+
# Dmain: Maximize logits for real images.
|
| 107 |
+
# Dr1: Apply R1 regularization.
|
| 108 |
+
if do_Dmain or do_Dr1:
|
| 109 |
+
name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
|
| 110 |
+
with torch.autograd.profiler.record_function(name + '_forward'):
|
| 111 |
+
real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
|
| 112 |
+
real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
|
| 113 |
+
training_stats.report('Loss/scores/real', real_logits)
|
| 114 |
+
training_stats.report('Loss/signs/real', real_logits.sign())
|
| 115 |
+
|
| 116 |
+
loss_Dreal = 0
|
| 117 |
+
if do_Dmain:
|
| 118 |
+
loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
|
| 119 |
+
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
|
| 120 |
+
|
| 121 |
+
loss_Dr1 = 0
|
| 122 |
+
if do_Dr1:
|
| 123 |
+
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
|
| 124 |
+
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
|
| 125 |
+
r1_penalty = r1_grads.square().sum([1,2,3])
|
| 126 |
+
loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
|
| 127 |
+
training_stats.report('Loss/r1_penalty', r1_penalty)
|
| 128 |
+
training_stats.report('Loss/D/reg', loss_Dr1)
|
| 129 |
+
|
| 130 |
+
with torch.autograd.profiler.record_function(name + '_backward'):
|
| 131 |
+
(real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
|
| 132 |
+
|
| 133 |
+
#----------------------------------------------------------------------------
|
training/networks.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch_utils import misc
|
| 12 |
+
from torch_utils import persistence
|
| 13 |
+
from torch_utils.ops import conv2d_resample
|
| 14 |
+
from torch_utils.ops import upfirdn2d
|
| 15 |
+
from torch_utils.ops import bias_act
|
| 16 |
+
from torch_utils.ops import fma
|
| 17 |
+
|
| 18 |
+
#----------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
@misc.profiled_function
|
| 21 |
+
def normalize_2nd_moment(x, dim=1, eps=1e-8):
|
| 22 |
+
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
| 23 |
+
|
| 24 |
+
#----------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
@misc.profiled_function
|
| 27 |
+
def modulated_conv2d(
|
| 28 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
| 29 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
| 30 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
| 31 |
+
noise = None, # Optional noise tensor to add to the output activations.
|
| 32 |
+
up = 1, # Integer upsampling factor.
|
| 33 |
+
down = 1, # Integer downsampling factor.
|
| 34 |
+
padding = 0, # Padding with respect to the upsampled image.
|
| 35 |
+
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
| 36 |
+
demodulate = True, # Apply weight demodulation?
|
| 37 |
+
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
| 38 |
+
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
| 39 |
+
):
|
| 40 |
+
batch_size = x.shape[0]
|
| 41 |
+
out_channels, in_channels, kh, kw = weight.shape
|
| 42 |
+
misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
|
| 43 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
| 44 |
+
misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
|
| 45 |
+
|
| 46 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
| 47 |
+
if x.dtype == torch.float16 and demodulate:
|
| 48 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
|
| 49 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
| 50 |
+
|
| 51 |
+
# Calculate per-sample weights and demodulation coefficients.
|
| 52 |
+
w = None
|
| 53 |
+
dcoefs = None
|
| 54 |
+
if demodulate or fused_modconv:
|
| 55 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
| 56 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
| 57 |
+
if demodulate:
|
| 58 |
+
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
| 59 |
+
if demodulate and fused_modconv:
|
| 60 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
| 61 |
+
|
| 62 |
+
# Execute by scaling the activations before and after the convolution.
|
| 63 |
+
if not fused_modconv:
|
| 64 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
| 65 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
|
| 66 |
+
if demodulate and noise is not None:
|
| 67 |
+
x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
| 68 |
+
elif demodulate:
|
| 69 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
| 70 |
+
elif noise is not None:
|
| 71 |
+
x = x.add_(noise.to(x.dtype))
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
# Execute as one fused op using grouped convolution.
|
| 75 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
| 76 |
+
batch_size = int(batch_size)
|
| 77 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None])
|
| 78 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
| 79 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
| 80 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
|
| 81 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
| 82 |
+
if noise is not None:
|
| 83 |
+
x = x.add_(noise)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
#----------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
@persistence.persistent_class
|
| 89 |
+
class FullyConnectedLayer(torch.nn.Module):
|
| 90 |
+
def __init__(self,
|
| 91 |
+
in_features, # Number of input features.
|
| 92 |
+
out_features, # Number of output features.
|
| 93 |
+
bias = True, # Apply additive bias before the activation function?
|
| 94 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
| 95 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
| 96 |
+
bias_init = 0, # Initial value for the additive bias.
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.activation = activation
|
| 100 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
| 101 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
| 102 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
| 103 |
+
self.bias_gain = lr_multiplier
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
w = self.weight.to(x.dtype) * self.weight_gain
|
| 107 |
+
b = self.bias
|
| 108 |
+
if b is not None:
|
| 109 |
+
b = b.to(x.dtype)
|
| 110 |
+
if self.bias_gain != 1:
|
| 111 |
+
b = b * self.bias_gain
|
| 112 |
+
|
| 113 |
+
if self.activation == 'linear' and b is not None:
|
| 114 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
| 115 |
+
else:
|
| 116 |
+
x = x.matmul(w.t())
|
| 117 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
#----------------------------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
@persistence.persistent_class
|
| 123 |
+
class Conv2dLayer(torch.nn.Module):
|
| 124 |
+
def __init__(self,
|
| 125 |
+
in_channels, # Number of input channels.
|
| 126 |
+
out_channels, # Number of output channels.
|
| 127 |
+
kernel_size, # Width and height of the convolution kernel.
|
| 128 |
+
bias = True, # Apply additive bias before the activation function?
|
| 129 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
| 130 |
+
up = 1, # Integer upsampling factor.
|
| 131 |
+
down = 1, # Integer downsampling factor.
|
| 132 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 133 |
+
conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
|
| 134 |
+
channels_last = False, # Expect the input to have memory_format=channels_last?
|
| 135 |
+
trainable = True, # Update the weights of this layer during training?
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.activation = activation
|
| 139 |
+
self.up = up
|
| 140 |
+
self.down = down
|
| 141 |
+
self.conv_clamp = conv_clamp
|
| 142 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 143 |
+
self.padding = kernel_size // 2
|
| 144 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
| 145 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
| 146 |
+
|
| 147 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 148 |
+
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
|
| 149 |
+
bias = torch.zeros([out_channels]) if bias else None
|
| 150 |
+
if trainable:
|
| 151 |
+
self.weight = torch.nn.Parameter(weight)
|
| 152 |
+
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
| 153 |
+
else:
|
| 154 |
+
self.register_buffer('weight', weight)
|
| 155 |
+
if bias is not None:
|
| 156 |
+
self.register_buffer('bias', bias)
|
| 157 |
+
else:
|
| 158 |
+
self.bias = None
|
| 159 |
+
|
| 160 |
+
def forward(self, x, gain=1):
|
| 161 |
+
w = self.weight * self.weight_gain
|
| 162 |
+
b = self.bias.to(x.dtype) if self.bias is not None else None
|
| 163 |
+
flip_weight = (self.up == 1) # slightly faster
|
| 164 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
|
| 165 |
+
|
| 166 |
+
act_gain = self.act_gain * gain
|
| 167 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 168 |
+
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
#----------------------------------------------------------------------------
|
| 172 |
+
|
| 173 |
+
@persistence.persistent_class
|
| 174 |
+
class MappingNetwork(torch.nn.Module):
|
| 175 |
+
def __init__(self,
|
| 176 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
| 177 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
| 178 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 179 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
| 180 |
+
num_layers = 8, # Number of mapping layers.
|
| 181 |
+
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
|
| 182 |
+
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
| 183 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 184 |
+
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
| 185 |
+
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.z_dim = z_dim
|
| 189 |
+
self.c_dim = c_dim
|
| 190 |
+
self.w_dim = w_dim
|
| 191 |
+
self.num_ws = num_ws
|
| 192 |
+
self.num_layers = num_layers
|
| 193 |
+
self.w_avg_beta = w_avg_beta
|
| 194 |
+
|
| 195 |
+
if embed_features is None:
|
| 196 |
+
embed_features = w_dim
|
| 197 |
+
if c_dim == 0:
|
| 198 |
+
embed_features = 0
|
| 199 |
+
if layer_features is None:
|
| 200 |
+
layer_features = w_dim
|
| 201 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
| 202 |
+
|
| 203 |
+
if c_dim > 0:
|
| 204 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
| 205 |
+
for idx in range(num_layers):
|
| 206 |
+
in_features = features_list[idx]
|
| 207 |
+
out_features = features_list[idx + 1]
|
| 208 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
| 209 |
+
setattr(self, f'fc{idx}', layer)
|
| 210 |
+
|
| 211 |
+
if num_ws is not None and w_avg_beta is not None:
|
| 212 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
| 213 |
+
|
| 214 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
| 215 |
+
# Embed, normalize, and concat inputs.
|
| 216 |
+
x = None
|
| 217 |
+
with torch.autograd.profiler.record_function('input'):
|
| 218 |
+
if self.z_dim > 0:
|
| 219 |
+
misc.assert_shape(z, [None, self.z_dim])
|
| 220 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
| 221 |
+
if self.c_dim > 0:
|
| 222 |
+
misc.assert_shape(c, [None, self.c_dim])
|
| 223 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
| 224 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
| 225 |
+
|
| 226 |
+
# Main layers.
|
| 227 |
+
for idx in range(self.num_layers):
|
| 228 |
+
layer = getattr(self, f'fc{idx}')
|
| 229 |
+
x = layer(x)
|
| 230 |
+
|
| 231 |
+
# Update moving average of W.
|
| 232 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
| 233 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
| 234 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
| 235 |
+
|
| 236 |
+
# Broadcast.
|
| 237 |
+
if self.num_ws is not None:
|
| 238 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
| 239 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
| 240 |
+
|
| 241 |
+
# Apply truncation.
|
| 242 |
+
if truncation_psi != 1:
|
| 243 |
+
with torch.autograd.profiler.record_function('truncate'):
|
| 244 |
+
assert self.w_avg_beta is not None
|
| 245 |
+
if self.num_ws is None or truncation_cutoff is None:
|
| 246 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
| 247 |
+
else:
|
| 248 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
| 249 |
+
return x
|
| 250 |
+
|
| 251 |
+
#----------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
@persistence.persistent_class
|
| 254 |
+
class SynthesisLayer(torch.nn.Module):
|
| 255 |
+
def __init__(self,
|
| 256 |
+
in_channels, # Number of input channels.
|
| 257 |
+
out_channels, # Number of output channels.
|
| 258 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 259 |
+
resolution, # Resolution of this layer.
|
| 260 |
+
kernel_size = 3, # Convolution kernel size.
|
| 261 |
+
up = 1, # Integer upsampling factor.
|
| 262 |
+
use_noise = True, # Enable noise input?
|
| 263 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 264 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 265 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 266 |
+
channels_last = False, # Use channels_last format for the weights?
|
| 267 |
+
):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.resolution = resolution
|
| 270 |
+
self.up = up
|
| 271 |
+
self.use_noise = use_noise
|
| 272 |
+
self.activation = activation
|
| 273 |
+
self.conv_clamp = conv_clamp
|
| 274 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 275 |
+
self.padding = kernel_size // 2
|
| 276 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
| 277 |
+
|
| 278 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
| 279 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 280 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
| 281 |
+
if use_noise:
|
| 282 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
| 283 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
| 284 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 285 |
+
|
| 286 |
+
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
|
| 287 |
+
assert noise_mode in ['random', 'const', 'none']
|
| 288 |
+
in_resolution = self.resolution // self.up
|
| 289 |
+
misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
|
| 290 |
+
styles = self.affine(w)
|
| 291 |
+
|
| 292 |
+
noise = None
|
| 293 |
+
if self.use_noise and noise_mode == 'random':
|
| 294 |
+
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
|
| 295 |
+
if self.use_noise and noise_mode == 'const':
|
| 296 |
+
noise = self.noise_const * self.noise_strength
|
| 297 |
+
|
| 298 |
+
flip_weight = (self.up == 1) # slightly faster
|
| 299 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
| 300 |
+
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
|
| 301 |
+
|
| 302 |
+
act_gain = self.act_gain * gain
|
| 303 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 304 |
+
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
|
| 305 |
+
return x
|
| 306 |
+
|
| 307 |
+
#----------------------------------------------------------------------------
|
| 308 |
+
|
| 309 |
+
@persistence.persistent_class
|
| 310 |
+
class ToRGBLayer(torch.nn.Module):
|
| 311 |
+
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.conv_clamp = conv_clamp
|
| 314 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
| 315 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 316 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
| 317 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 318 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
| 319 |
+
|
| 320 |
+
def forward(self, x, w, fused_modconv=True):
|
| 321 |
+
styles = self.affine(w) * self.weight_gain
|
| 322 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
|
| 323 |
+
x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
#----------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
@persistence.persistent_class
|
| 329 |
+
class SynthesisBlock(torch.nn.Module):
|
| 330 |
+
def __init__(self,
|
| 331 |
+
in_channels, # Number of input channels, 0 = first block.
|
| 332 |
+
out_channels, # Number of output channels.
|
| 333 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 334 |
+
resolution, # Resolution of this block.
|
| 335 |
+
img_channels, # Number of output color channels.
|
| 336 |
+
is_last, # Is this the last block?
|
| 337 |
+
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
|
| 338 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 339 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 340 |
+
use_fp16 = False, # Use FP16 for this block?
|
| 341 |
+
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
| 342 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
| 343 |
+
):
|
| 344 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.in_channels = in_channels
|
| 347 |
+
self.w_dim = w_dim
|
| 348 |
+
self.resolution = resolution
|
| 349 |
+
self.img_channels = img_channels
|
| 350 |
+
self.is_last = is_last
|
| 351 |
+
self.architecture = architecture
|
| 352 |
+
self.use_fp16 = use_fp16
|
| 353 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
| 354 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 355 |
+
self.num_conv = 0
|
| 356 |
+
self.num_torgb = 0
|
| 357 |
+
|
| 358 |
+
if in_channels == 0:
|
| 359 |
+
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
| 360 |
+
|
| 361 |
+
if in_channels != 0:
|
| 362 |
+
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
|
| 363 |
+
resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
| 364 |
+
self.num_conv += 1
|
| 365 |
+
|
| 366 |
+
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
|
| 367 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
| 368 |
+
self.num_conv += 1
|
| 369 |
+
|
| 370 |
+
if is_last or architecture == 'skip':
|
| 371 |
+
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
|
| 372 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last)
|
| 373 |
+
self.num_torgb += 1
|
| 374 |
+
|
| 375 |
+
if in_channels != 0 and architecture == 'resnet':
|
| 376 |
+
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
| 377 |
+
resample_filter=resample_filter, channels_last=self.channels_last)
|
| 378 |
+
|
| 379 |
+
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
|
| 380 |
+
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
|
| 381 |
+
w_iter = iter(ws.unbind(dim=1))
|
| 382 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
| 383 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
| 384 |
+
if fused_modconv is None:
|
| 385 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
| 386 |
+
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
|
| 387 |
+
|
| 388 |
+
# Input.
|
| 389 |
+
if self.in_channels == 0:
|
| 390 |
+
x = self.const.to(dtype=dtype, memory_format=memory_format)
|
| 391 |
+
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
|
| 392 |
+
else:
|
| 393 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
|
| 394 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 395 |
+
|
| 396 |
+
# Main layers.
|
| 397 |
+
if self.in_channels == 0:
|
| 398 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 399 |
+
elif self.architecture == 'resnet':
|
| 400 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
| 401 |
+
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 402 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
|
| 403 |
+
x = y.add_(x)
|
| 404 |
+
else:
|
| 405 |
+
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 406 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
| 407 |
+
|
| 408 |
+
# ToRGB.
|
| 409 |
+
if img is not None:
|
| 410 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
|
| 411 |
+
img = upfirdn2d.upsample2d(img, self.resample_filter)
|
| 412 |
+
if self.is_last or self.architecture == 'skip':
|
| 413 |
+
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
|
| 414 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
| 415 |
+
img = img.add_(y) if img is not None else y
|
| 416 |
+
|
| 417 |
+
assert x.dtype == dtype
|
| 418 |
+
assert img is None or img.dtype == torch.float32
|
| 419 |
+
return x, img
|
| 420 |
+
|
| 421 |
+
#----------------------------------------------------------------------------
|
| 422 |
+
|
| 423 |
+
@persistence.persistent_class
|
| 424 |
+
class SynthesisNetwork(torch.nn.Module):
|
| 425 |
+
def __init__(self,
|
| 426 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 427 |
+
img_resolution, # Output image resolution.
|
| 428 |
+
img_channels, # Number of color channels.
|
| 429 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
| 430 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
| 431 |
+
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
|
| 432 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
| 433 |
+
):
|
| 434 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
| 435 |
+
super().__init__()
|
| 436 |
+
self.w_dim = w_dim
|
| 437 |
+
self.img_resolution = img_resolution
|
| 438 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 439 |
+
self.img_channels = img_channels
|
| 440 |
+
self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
|
| 441 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
| 442 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 443 |
+
|
| 444 |
+
self.num_ws = 0
|
| 445 |
+
for res in self.block_resolutions:
|
| 446 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
| 447 |
+
out_channels = channels_dict[res]
|
| 448 |
+
use_fp16 = (res >= fp16_resolution)
|
| 449 |
+
is_last = (res == self.img_resolution)
|
| 450 |
+
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
| 451 |
+
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
| 452 |
+
self.num_ws += block.num_conv
|
| 453 |
+
if is_last:
|
| 454 |
+
self.num_ws += block.num_torgb
|
| 455 |
+
setattr(self, f'b{res}', block)
|
| 456 |
+
|
| 457 |
+
def forward(self, ws, **block_kwargs):
|
| 458 |
+
block_ws = []
|
| 459 |
+
with torch.autograd.profiler.record_function('split_ws'):
|
| 460 |
+
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
|
| 461 |
+
ws = ws.to(torch.float32)
|
| 462 |
+
w_idx = 0
|
| 463 |
+
for res in self.block_resolutions:
|
| 464 |
+
block = getattr(self, f'b{res}')
|
| 465 |
+
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
|
| 466 |
+
w_idx += block.num_conv
|
| 467 |
+
|
| 468 |
+
x = img = None
|
| 469 |
+
for res, cur_ws in zip(self.block_resolutions, block_ws):
|
| 470 |
+
block = getattr(self, f'b{res}')
|
| 471 |
+
x, img = block(x, img, cur_ws, **block_kwargs)
|
| 472 |
+
return img
|
| 473 |
+
|
| 474 |
+
#----------------------------------------------------------------------------
|
| 475 |
+
|
| 476 |
+
@persistence.persistent_class
|
| 477 |
+
class Generator(torch.nn.Module):
|
| 478 |
+
def __init__(self,
|
| 479 |
+
z_dim, # Input latent (Z) dimensionality.
|
| 480 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 481 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 482 |
+
img_resolution, # Output resolution.
|
| 483 |
+
img_channels, # Number of output color channels.
|
| 484 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
| 485 |
+
synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
|
| 486 |
+
):
|
| 487 |
+
super().__init__()
|
| 488 |
+
self.z_dim = z_dim
|
| 489 |
+
self.c_dim = c_dim
|
| 490 |
+
self.w_dim = w_dim
|
| 491 |
+
self.img_resolution = img_resolution
|
| 492 |
+
self.img_channels = img_channels
|
| 493 |
+
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
|
| 494 |
+
self.num_ws = self.synthesis.num_ws
|
| 495 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
| 496 |
+
|
| 497 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
|
| 498 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
| 499 |
+
img = self.synthesis(ws, **synthesis_kwargs)
|
| 500 |
+
return img
|
| 501 |
+
|
| 502 |
+
#----------------------------------------------------------------------------
|
| 503 |
+
|
| 504 |
+
@persistence.persistent_class
|
| 505 |
+
class DiscriminatorBlock(torch.nn.Module):
|
| 506 |
+
def __init__(self,
|
| 507 |
+
in_channels, # Number of input channels, 0 = first block.
|
| 508 |
+
tmp_channels, # Number of intermediate channels.
|
| 509 |
+
out_channels, # Number of output channels.
|
| 510 |
+
resolution, # Resolution of this block.
|
| 511 |
+
img_channels, # Number of input color channels.
|
| 512 |
+
first_layer_idx, # Index of the first layer.
|
| 513 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 514 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 515 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
| 516 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 517 |
+
use_fp16 = False, # Use FP16 for this block?
|
| 518 |
+
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
| 519 |
+
freeze_layers = 0, # Freeze-D: Number of layers to freeze.
|
| 520 |
+
):
|
| 521 |
+
assert in_channels in [0, tmp_channels]
|
| 522 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.in_channels = in_channels
|
| 525 |
+
self.resolution = resolution
|
| 526 |
+
self.img_channels = img_channels
|
| 527 |
+
self.first_layer_idx = first_layer_idx
|
| 528 |
+
self.architecture = architecture
|
| 529 |
+
self.use_fp16 = use_fp16
|
| 530 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
| 531 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
| 532 |
+
|
| 533 |
+
self.num_layers = 0
|
| 534 |
+
def trainable_gen():
|
| 535 |
+
while True:
|
| 536 |
+
layer_idx = self.first_layer_idx + self.num_layers
|
| 537 |
+
trainable = (layer_idx >= freeze_layers)
|
| 538 |
+
self.num_layers += 1
|
| 539 |
+
yield trainable
|
| 540 |
+
trainable_iter = trainable_gen()
|
| 541 |
+
|
| 542 |
+
if in_channels == 0 or architecture == 'skip':
|
| 543 |
+
self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
|
| 544 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
|
| 545 |
+
|
| 546 |
+
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
|
| 547 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
|
| 548 |
+
|
| 549 |
+
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
|
| 550 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
|
| 551 |
+
|
| 552 |
+
if architecture == 'resnet':
|
| 553 |
+
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
|
| 554 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
|
| 555 |
+
|
| 556 |
+
def forward(self, x, img, force_fp32=False):
|
| 557 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
| 558 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
| 559 |
+
|
| 560 |
+
# Input.
|
| 561 |
+
if x is not None:
|
| 562 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
|
| 563 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 564 |
+
|
| 565 |
+
# FromRGB.
|
| 566 |
+
if self.in_channels == 0 or self.architecture == 'skip':
|
| 567 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
|
| 568 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
| 569 |
+
y = self.fromrgb(img)
|
| 570 |
+
x = x + y if x is not None else y
|
| 571 |
+
img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
|
| 572 |
+
|
| 573 |
+
# Main layers.
|
| 574 |
+
if self.architecture == 'resnet':
|
| 575 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
| 576 |
+
x = self.conv0(x)
|
| 577 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
| 578 |
+
x = y.add_(x)
|
| 579 |
+
else:
|
| 580 |
+
x = self.conv0(x)
|
| 581 |
+
x = self.conv1(x)
|
| 582 |
+
|
| 583 |
+
assert x.dtype == dtype
|
| 584 |
+
return x, img
|
| 585 |
+
|
| 586 |
+
#----------------------------------------------------------------------------
|
| 587 |
+
|
| 588 |
+
@persistence.persistent_class
|
| 589 |
+
class MinibatchStdLayer(torch.nn.Module):
|
| 590 |
+
def __init__(self, group_size, num_channels=1):
|
| 591 |
+
super().__init__()
|
| 592 |
+
self.group_size = group_size
|
| 593 |
+
self.num_channels = num_channels
|
| 594 |
+
|
| 595 |
+
def forward(self, x):
|
| 596 |
+
N, C, H, W = x.shape
|
| 597 |
+
with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 598 |
+
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
|
| 599 |
+
F = self.num_channels
|
| 600 |
+
c = C // F
|
| 601 |
+
|
| 602 |
+
y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
| 603 |
+
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
| 604 |
+
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
| 605 |
+
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
| 606 |
+
y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
|
| 607 |
+
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
| 608 |
+
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
| 609 |
+
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
| 610 |
+
return x
|
| 611 |
+
|
| 612 |
+
#----------------------------------------------------------------------------
|
| 613 |
+
|
| 614 |
+
@persistence.persistent_class
|
| 615 |
+
class DiscriminatorEpilogue(torch.nn.Module):
|
| 616 |
+
def __init__(self,
|
| 617 |
+
in_channels, # Number of input channels.
|
| 618 |
+
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
| 619 |
+
resolution, # Resolution of this block.
|
| 620 |
+
img_channels, # Number of input color channels.
|
| 621 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 622 |
+
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
| 623 |
+
mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
| 624 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 625 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 626 |
+
):
|
| 627 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 628 |
+
super().__init__()
|
| 629 |
+
self.in_channels = in_channels
|
| 630 |
+
self.cmap_dim = cmap_dim
|
| 631 |
+
self.resolution = resolution
|
| 632 |
+
self.img_channels = img_channels
|
| 633 |
+
self.architecture = architecture
|
| 634 |
+
|
| 635 |
+
if architecture == 'skip':
|
| 636 |
+
self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
|
| 637 |
+
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
| 638 |
+
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
|
| 639 |
+
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
|
| 640 |
+
self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
|
| 641 |
+
|
| 642 |
+
def forward(self, x, img, cmap, force_fp32=False):
|
| 643 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
|
| 644 |
+
_ = force_fp32 # unused
|
| 645 |
+
dtype = torch.float32
|
| 646 |
+
memory_format = torch.contiguous_format
|
| 647 |
+
|
| 648 |
+
# FromRGB.
|
| 649 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 650 |
+
if self.architecture == 'skip':
|
| 651 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
|
| 652 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
| 653 |
+
x = x + self.fromrgb(img)
|
| 654 |
+
|
| 655 |
+
# Main layers.
|
| 656 |
+
if self.mbstd is not None:
|
| 657 |
+
x = self.mbstd(x)
|
| 658 |
+
x = self.conv(x)
|
| 659 |
+
x = self.fc(x.flatten(1))
|
| 660 |
+
x = self.out(x)
|
| 661 |
+
|
| 662 |
+
# Conditioning.
|
| 663 |
+
if self.cmap_dim > 0:
|
| 664 |
+
misc.assert_shape(cmap, [None, self.cmap_dim])
|
| 665 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
| 666 |
+
|
| 667 |
+
assert x.dtype == dtype
|
| 668 |
+
return x
|
| 669 |
+
|
| 670 |
+
#----------------------------------------------------------------------------
|
| 671 |
+
|
| 672 |
+
@persistence.persistent_class
|
| 673 |
+
class Discriminator(torch.nn.Module):
|
| 674 |
+
def __init__(self,
|
| 675 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 676 |
+
img_resolution, # Input resolution.
|
| 677 |
+
img_channels, # Number of input color channels.
|
| 678 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 679 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
| 680 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
| 681 |
+
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
|
| 682 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 683 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
| 684 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
| 685 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
| 686 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
| 687 |
+
):
|
| 688 |
+
super().__init__()
|
| 689 |
+
self.c_dim = c_dim
|
| 690 |
+
self.img_resolution = img_resolution
|
| 691 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 692 |
+
self.img_channels = img_channels
|
| 693 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
| 694 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
| 695 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 696 |
+
|
| 697 |
+
if cmap_dim is None:
|
| 698 |
+
cmap_dim = channels_dict[4]
|
| 699 |
+
if c_dim == 0:
|
| 700 |
+
cmap_dim = 0
|
| 701 |
+
|
| 702 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
| 703 |
+
cur_layer_idx = 0
|
| 704 |
+
for res in self.block_resolutions:
|
| 705 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
| 706 |
+
tmp_channels = channels_dict[res]
|
| 707 |
+
out_channels = channels_dict[res // 2]
|
| 708 |
+
use_fp16 = (res >= fp16_resolution)
|
| 709 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
| 710 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
| 711 |
+
setattr(self, f'b{res}', block)
|
| 712 |
+
cur_layer_idx += block.num_layers
|
| 713 |
+
if c_dim > 0:
|
| 714 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
| 715 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
| 716 |
+
|
| 717 |
+
def forward(self, img, c, **block_kwargs):
|
| 718 |
+
x = None
|
| 719 |
+
for res in self.block_resolutions:
|
| 720 |
+
block = getattr(self, f'b{res}')
|
| 721 |
+
x, img = block(x, img, **block_kwargs)
|
| 722 |
+
|
| 723 |
+
cmap = None
|
| 724 |
+
if self.c_dim > 0:
|
| 725 |
+
cmap = self.mapping(None, c)
|
| 726 |
+
x = self.b4(x, img, cmap)
|
| 727 |
+
return x
|
| 728 |
+
|
| 729 |
+
#----------------------------------------------------------------------------
|
training/training_loop.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import copy
|
| 12 |
+
import json
|
| 13 |
+
import pickle
|
| 14 |
+
import psutil
|
| 15 |
+
import PIL.Image
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import dnnlib
|
| 19 |
+
from torch_utils import misc
|
| 20 |
+
from torch_utils import training_stats
|
| 21 |
+
from torch_utils.ops import conv2d_gradfix
|
| 22 |
+
from torch_utils.ops import grid_sample_gradfix
|
| 23 |
+
|
| 24 |
+
import legacy
|
| 25 |
+
from metrics import metric_main
|
| 26 |
+
|
| 27 |
+
#----------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
def setup_snapshot_image_grid(training_set, random_seed=0):
|
| 30 |
+
rnd = np.random.RandomState(random_seed)
|
| 31 |
+
gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
|
| 32 |
+
gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
|
| 33 |
+
|
| 34 |
+
# No labels => show random subset of training samples.
|
| 35 |
+
if not training_set.has_labels:
|
| 36 |
+
all_indices = list(range(len(training_set)))
|
| 37 |
+
rnd.shuffle(all_indices)
|
| 38 |
+
grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
|
| 39 |
+
|
| 40 |
+
else:
|
| 41 |
+
# Group training samples by label.
|
| 42 |
+
label_groups = dict() # label => [idx, ...]
|
| 43 |
+
for idx in range(len(training_set)):
|
| 44 |
+
label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
|
| 45 |
+
if label not in label_groups:
|
| 46 |
+
label_groups[label] = []
|
| 47 |
+
label_groups[label].append(idx)
|
| 48 |
+
|
| 49 |
+
# Reorder.
|
| 50 |
+
label_order = sorted(label_groups.keys())
|
| 51 |
+
for label in label_order:
|
| 52 |
+
rnd.shuffle(label_groups[label])
|
| 53 |
+
|
| 54 |
+
# Organize into grid.
|
| 55 |
+
grid_indices = []
|
| 56 |
+
for y in range(gh):
|
| 57 |
+
label = label_order[y % len(label_order)]
|
| 58 |
+
indices = label_groups[label]
|
| 59 |
+
grid_indices += [indices[x % len(indices)] for x in range(gw)]
|
| 60 |
+
label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
|
| 61 |
+
|
| 62 |
+
# Load data.
|
| 63 |
+
images, labels = zip(*[training_set[i] for i in grid_indices])
|
| 64 |
+
return (gw, gh), np.stack(images), np.stack(labels)
|
| 65 |
+
|
| 66 |
+
#----------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def save_image_grid(img, fname, drange, grid_size):
|
| 69 |
+
lo, hi = drange
|
| 70 |
+
img = np.asarray(img, dtype=np.float32)
|
| 71 |
+
img = (img - lo) * (255 / (hi - lo))
|
| 72 |
+
img = np.rint(img).clip(0, 255).astype(np.uint8)
|
| 73 |
+
|
| 74 |
+
gw, gh = grid_size
|
| 75 |
+
_N, C, H, W = img.shape
|
| 76 |
+
img = img.reshape(gh, gw, C, H, W)
|
| 77 |
+
img = img.transpose(0, 3, 1, 4, 2)
|
| 78 |
+
img = img.reshape(gh * H, gw * W, C)
|
| 79 |
+
|
| 80 |
+
assert C in [1, 3]
|
| 81 |
+
if C == 1:
|
| 82 |
+
PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
|
| 83 |
+
if C == 3:
|
| 84 |
+
PIL.Image.fromarray(img, 'RGB').save(fname)
|
| 85 |
+
|
| 86 |
+
#----------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def training_loop(
|
| 89 |
+
run_dir = '.', # Output directory.
|
| 90 |
+
training_set_kwargs = {}, # Options for training set.
|
| 91 |
+
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
|
| 92 |
+
G_kwargs = {}, # Options for generator network.
|
| 93 |
+
D_kwargs = {}, # Options for discriminator network.
|
| 94 |
+
G_opt_kwargs = {}, # Options for generator optimizer.
|
| 95 |
+
D_opt_kwargs = {}, # Options for discriminator optimizer.
|
| 96 |
+
augment_kwargs = None, # Options for augmentation pipeline. None = disable.
|
| 97 |
+
loss_kwargs = {}, # Options for loss function.
|
| 98 |
+
metrics = [], # Metrics to evaluate during training.
|
| 99 |
+
random_seed = 0, # Global random seed.
|
| 100 |
+
num_gpus = 1, # Number of GPUs participating in the training.
|
| 101 |
+
rank = 0, # Rank of the current process in [0, num_gpus[.
|
| 102 |
+
batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
|
| 103 |
+
batch_gpu = 4, # Number of samples processed at a time by one GPU.
|
| 104 |
+
ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
|
| 105 |
+
ema_rampup = None, # EMA ramp-up coefficient.
|
| 106 |
+
G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization.
|
| 107 |
+
D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
|
| 108 |
+
augment_p = 0, # Initial value of augmentation probability.
|
| 109 |
+
ada_target = None, # ADA target value. None = fixed p.
|
| 110 |
+
ada_interval = 4, # How often to perform ADA adjustment?
|
| 111 |
+
ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
|
| 112 |
+
total_kimg = 25000, # Total length of the training, measured in thousands of real images.
|
| 113 |
+
kimg_per_tick = 4, # Progress snapshot interval.
|
| 114 |
+
image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
|
| 115 |
+
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
|
| 116 |
+
resume_pkl = None, # Network pickle to resume training from.
|
| 117 |
+
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
|
| 118 |
+
allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
|
| 119 |
+
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
|
| 120 |
+
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
|
| 121 |
+
):
|
| 122 |
+
# Initialize.
|
| 123 |
+
start_time = time.time()
|
| 124 |
+
device = torch.device('cuda', rank)
|
| 125 |
+
np.random.seed(random_seed * num_gpus + rank)
|
| 126 |
+
torch.manual_seed(random_seed * num_gpus + rank)
|
| 127 |
+
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
|
| 128 |
+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
|
| 129 |
+
torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
|
| 130 |
+
conv2d_gradfix.enabled = True # Improves training speed.
|
| 131 |
+
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
|
| 132 |
+
|
| 133 |
+
# Load training set.
|
| 134 |
+
if rank == 0:
|
| 135 |
+
print('Loading training set...')
|
| 136 |
+
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
|
| 137 |
+
training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
|
| 138 |
+
training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
|
| 139 |
+
if rank == 0:
|
| 140 |
+
print()
|
| 141 |
+
print('Num images: ', len(training_set))
|
| 142 |
+
print('Image shape:', training_set.image_shape)
|
| 143 |
+
print('Label shape:', training_set.label_shape)
|
| 144 |
+
print()
|
| 145 |
+
|
| 146 |
+
# Construct networks.
|
| 147 |
+
if rank == 0:
|
| 148 |
+
print('Constructing networks...')
|
| 149 |
+
common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
|
| 150 |
+
G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
|
| 151 |
+
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
|
| 152 |
+
G_ema = copy.deepcopy(G).eval()
|
| 153 |
+
|
| 154 |
+
# Resume from existing pickle.
|
| 155 |
+
if (resume_pkl is not None) and (rank == 0):
|
| 156 |
+
print(f'Resuming from "{resume_pkl}"')
|
| 157 |
+
with dnnlib.util.open_url(resume_pkl) as f:
|
| 158 |
+
resume_data = legacy.load_network_pkl(f)
|
| 159 |
+
for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
|
| 160 |
+
misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
|
| 161 |
+
|
| 162 |
+
# Print network summary tables.
|
| 163 |
+
if rank == 0:
|
| 164 |
+
z = torch.empty([batch_gpu, G.z_dim], device=device)
|
| 165 |
+
c = torch.empty([batch_gpu, G.c_dim], device=device)
|
| 166 |
+
img = misc.print_module_summary(G, [z, c])
|
| 167 |
+
misc.print_module_summary(D, [img, c])
|
| 168 |
+
|
| 169 |
+
# Setup augmentation.
|
| 170 |
+
if rank == 0:
|
| 171 |
+
print('Setting up augmentation...')
|
| 172 |
+
augment_pipe = None
|
| 173 |
+
ada_stats = None
|
| 174 |
+
if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
|
| 175 |
+
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
|
| 176 |
+
augment_pipe.p.copy_(torch.as_tensor(augment_p))
|
| 177 |
+
if ada_target is not None:
|
| 178 |
+
ada_stats = training_stats.Collector(regex='Loss/signs/real')
|
| 179 |
+
|
| 180 |
+
# Distribute across GPUs.
|
| 181 |
+
if rank == 0:
|
| 182 |
+
print(f'Distributing across {num_gpus} GPUs...')
|
| 183 |
+
ddp_modules = dict()
|
| 184 |
+
for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
|
| 185 |
+
if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
|
| 186 |
+
module.requires_grad_(True)
|
| 187 |
+
module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
|
| 188 |
+
module.requires_grad_(False)
|
| 189 |
+
if name is not None:
|
| 190 |
+
ddp_modules[name] = module
|
| 191 |
+
|
| 192 |
+
# Setup training phases.
|
| 193 |
+
if rank == 0:
|
| 194 |
+
print('Setting up training phases...')
|
| 195 |
+
loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss
|
| 196 |
+
phases = []
|
| 197 |
+
for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
|
| 198 |
+
if reg_interval is None:
|
| 199 |
+
opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
|
| 200 |
+
phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
|
| 201 |
+
else: # Lazy regularization.
|
| 202 |
+
mb_ratio = reg_interval / (reg_interval + 1)
|
| 203 |
+
opt_kwargs = dnnlib.EasyDict(opt_kwargs)
|
| 204 |
+
opt_kwargs.lr = opt_kwargs.lr * mb_ratio
|
| 205 |
+
opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
|
| 206 |
+
opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
|
| 207 |
+
phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
|
| 208 |
+
phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
|
| 209 |
+
for phase in phases:
|
| 210 |
+
phase.start_event = None
|
| 211 |
+
phase.end_event = None
|
| 212 |
+
if rank == 0:
|
| 213 |
+
phase.start_event = torch.cuda.Event(enable_timing=True)
|
| 214 |
+
phase.end_event = torch.cuda.Event(enable_timing=True)
|
| 215 |
+
|
| 216 |
+
# Export sample images.
|
| 217 |
+
grid_size = None
|
| 218 |
+
grid_z = None
|
| 219 |
+
grid_c = None
|
| 220 |
+
if rank == 0:
|
| 221 |
+
print('Exporting sample images...')
|
| 222 |
+
grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
|
| 223 |
+
save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
|
| 224 |
+
grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
|
| 225 |
+
grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
|
| 226 |
+
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
|
| 227 |
+
save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
|
| 228 |
+
|
| 229 |
+
# Initialize logs.
|
| 230 |
+
if rank == 0:
|
| 231 |
+
print('Initializing logs...')
|
| 232 |
+
stats_collector = training_stats.Collector(regex='.*')
|
| 233 |
+
stats_metrics = dict()
|
| 234 |
+
stats_jsonl = None
|
| 235 |
+
stats_tfevents = None
|
| 236 |
+
if rank == 0:
|
| 237 |
+
stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
|
| 238 |
+
try:
|
| 239 |
+
import torch.utils.tensorboard as tensorboard
|
| 240 |
+
stats_tfevents = tensorboard.SummaryWriter(run_dir)
|
| 241 |
+
except ImportError as err:
|
| 242 |
+
print('Skipping tfevents export:', err)
|
| 243 |
+
|
| 244 |
+
# Train.
|
| 245 |
+
if rank == 0:
|
| 246 |
+
print(f'Training for {total_kimg} kimg...')
|
| 247 |
+
print()
|
| 248 |
+
cur_nimg = 0
|
| 249 |
+
cur_tick = 0
|
| 250 |
+
tick_start_nimg = cur_nimg
|
| 251 |
+
tick_start_time = time.time()
|
| 252 |
+
maintenance_time = tick_start_time - start_time
|
| 253 |
+
batch_idx = 0
|
| 254 |
+
if progress_fn is not None:
|
| 255 |
+
progress_fn(0, total_kimg)
|
| 256 |
+
while True:
|
| 257 |
+
|
| 258 |
+
# Fetch training data.
|
| 259 |
+
with torch.autograd.profiler.record_function('data_fetch'):
|
| 260 |
+
phase_real_img, phase_real_c = next(training_set_iterator)
|
| 261 |
+
phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
|
| 262 |
+
phase_real_c = phase_real_c.to(device).split(batch_gpu)
|
| 263 |
+
all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
|
| 264 |
+
all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
|
| 265 |
+
all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
|
| 266 |
+
all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
|
| 267 |
+
all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
|
| 268 |
+
|
| 269 |
+
# Execute training phases.
|
| 270 |
+
for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
|
| 271 |
+
if batch_idx % phase.interval != 0:
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
# Initialize gradient accumulation.
|
| 275 |
+
if phase.start_event is not None:
|
| 276 |
+
phase.start_event.record(torch.cuda.current_stream(device))
|
| 277 |
+
phase.opt.zero_grad(set_to_none=True)
|
| 278 |
+
phase.module.requires_grad_(True)
|
| 279 |
+
|
| 280 |
+
# Accumulate gradients over multiple rounds.
|
| 281 |
+
for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)):
|
| 282 |
+
sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
|
| 283 |
+
gain = phase.interval
|
| 284 |
+
loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain)
|
| 285 |
+
|
| 286 |
+
# Update weights.
|
| 287 |
+
phase.module.requires_grad_(False)
|
| 288 |
+
with torch.autograd.profiler.record_function(phase.name + '_opt'):
|
| 289 |
+
for param in phase.module.parameters():
|
| 290 |
+
if param.grad is not None:
|
| 291 |
+
misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
|
| 292 |
+
phase.opt.step()
|
| 293 |
+
if phase.end_event is not None:
|
| 294 |
+
phase.end_event.record(torch.cuda.current_stream(device))
|
| 295 |
+
|
| 296 |
+
# Update G_ema.
|
| 297 |
+
with torch.autograd.profiler.record_function('Gema'):
|
| 298 |
+
ema_nimg = ema_kimg * 1000
|
| 299 |
+
if ema_rampup is not None:
|
| 300 |
+
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
|
| 301 |
+
ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
|
| 302 |
+
for p_ema, p in zip(G_ema.parameters(), G.parameters()):
|
| 303 |
+
p_ema.copy_(p.lerp(p_ema, ema_beta))
|
| 304 |
+
for b_ema, b in zip(G_ema.buffers(), G.buffers()):
|
| 305 |
+
b_ema.copy_(b)
|
| 306 |
+
|
| 307 |
+
# Update state.
|
| 308 |
+
cur_nimg += batch_size
|
| 309 |
+
batch_idx += 1
|
| 310 |
+
|
| 311 |
+
# Execute ADA heuristic.
|
| 312 |
+
if (ada_stats is not None) and (batch_idx % ada_interval == 0):
|
| 313 |
+
ada_stats.update()
|
| 314 |
+
adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
|
| 315 |
+
augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
|
| 316 |
+
|
| 317 |
+
# Perform maintenance tasks once per tick.
|
| 318 |
+
done = (cur_nimg >= total_kimg * 1000)
|
| 319 |
+
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
|
| 320 |
+
continue
|
| 321 |
+
|
| 322 |
+
# Print status line, accumulating the same information in stats_collector.
|
| 323 |
+
tick_end_time = time.time()
|
| 324 |
+
fields = []
|
| 325 |
+
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
|
| 326 |
+
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
|
| 327 |
+
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
|
| 328 |
+
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
|
| 329 |
+
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
|
| 330 |
+
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
|
| 331 |
+
fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
|
| 332 |
+
fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
|
| 333 |
+
torch.cuda.reset_peak_memory_stats()
|
| 334 |
+
fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
|
| 335 |
+
training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
|
| 336 |
+
training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
|
| 337 |
+
if rank == 0:
|
| 338 |
+
print(' '.join(fields))
|
| 339 |
+
|
| 340 |
+
# Check for abort.
|
| 341 |
+
if (not done) and (abort_fn is not None) and abort_fn():
|
| 342 |
+
done = True
|
| 343 |
+
if rank == 0:
|
| 344 |
+
print()
|
| 345 |
+
print('Aborting...')
|
| 346 |
+
|
| 347 |
+
# Save image snapshot.
|
| 348 |
+
if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
|
| 349 |
+
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
|
| 350 |
+
save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
|
| 351 |
+
|
| 352 |
+
# Save network snapshot.
|
| 353 |
+
snapshot_pkl = None
|
| 354 |
+
snapshot_data = None
|
| 355 |
+
if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
|
| 356 |
+
snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
|
| 357 |
+
for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
|
| 358 |
+
if module is not None:
|
| 359 |
+
if num_gpus > 1:
|
| 360 |
+
misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
|
| 361 |
+
module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
|
| 362 |
+
snapshot_data[name] = module
|
| 363 |
+
del module # conserve memory
|
| 364 |
+
snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
|
| 365 |
+
if rank == 0:
|
| 366 |
+
with open(snapshot_pkl, 'wb') as f:
|
| 367 |
+
pickle.dump(snapshot_data, f)
|
| 368 |
+
|
| 369 |
+
# Evaluate metrics.
|
| 370 |
+
if (snapshot_data is not None) and (len(metrics) > 0):
|
| 371 |
+
if rank == 0:
|
| 372 |
+
print('Evaluating metrics...')
|
| 373 |
+
for metric in metrics:
|
| 374 |
+
result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
|
| 375 |
+
dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
|
| 376 |
+
if rank == 0:
|
| 377 |
+
metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
|
| 378 |
+
stats_metrics.update(result_dict.results)
|
| 379 |
+
del snapshot_data # conserve memory
|
| 380 |
+
|
| 381 |
+
# Collect statistics.
|
| 382 |
+
for phase in phases:
|
| 383 |
+
value = []
|
| 384 |
+
if (phase.start_event is not None) and (phase.end_event is not None):
|
| 385 |
+
phase.end_event.synchronize()
|
| 386 |
+
value = phase.start_event.elapsed_time(phase.end_event)
|
| 387 |
+
training_stats.report0('Timing/' + phase.name, value)
|
| 388 |
+
stats_collector.update()
|
| 389 |
+
stats_dict = stats_collector.as_dict()
|
| 390 |
+
|
| 391 |
+
# Update logs.
|
| 392 |
+
timestamp = time.time()
|
| 393 |
+
if stats_jsonl is not None:
|
| 394 |
+
fields = dict(stats_dict, timestamp=timestamp)
|
| 395 |
+
stats_jsonl.write(json.dumps(fields) + '\n')
|
| 396 |
+
stats_jsonl.flush()
|
| 397 |
+
if stats_tfevents is not None:
|
| 398 |
+
global_step = int(cur_nimg / 1e3)
|
| 399 |
+
walltime = timestamp - start_time
|
| 400 |
+
for name, value in stats_dict.items():
|
| 401 |
+
stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
|
| 402 |
+
for name, value in stats_metrics.items():
|
| 403 |
+
stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
|
| 404 |
+
stats_tfevents.flush()
|
| 405 |
+
if progress_fn is not None:
|
| 406 |
+
progress_fn(cur_nimg // 1000, total_kimg)
|
| 407 |
+
|
| 408 |
+
# Update state.
|
| 409 |
+
cur_tick += 1
|
| 410 |
+
tick_start_nimg = cur_nimg
|
| 411 |
+
tick_start_time = time.time()
|
| 412 |
+
maintenance_time = tick_start_time - tick_end_time
|
| 413 |
+
if done:
|
| 414 |
+
break
|
| 415 |
+
|
| 416 |
+
# Done.
|
| 417 |
+
if rank == 0:
|
| 418 |
+
print()
|
| 419 |
+
print('Exiting...')
|
| 420 |
+
|
| 421 |
+
#----------------------------------------------------------------------------
|