diff --git a/README.md b/README.md
index 154df8298fab5ecf322016157858e08cd1bccbe1..8b6dcd8543131edd1dfa1348dd7286c176ab9fd3 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,172 @@
----
-license: apache-2.0
----
+## Diffusion-GAN — Official PyTorch implementation
+
+**Diffusion-GAN: Training GANs with Diffusion**
+Zhendong Wang, Huangjie Zheng, Pengcheng He, Weizhu Chen and Mingyuan Zhou
+https://arxiv.org/abs/2206.02262
+
+Abstract: *For stable training of generative adversarial networks (GANs), injecting instance
+noise into the input of the discriminator is considered as a theoretically sound
+solution, which, however, has not yet delivered on its promise in practice. This
+paper introduces Diffusion-GAN that employs a Gaussian mixture distribution,
+defined over all the diffusion steps of a forward diffusion chain, to inject instance
+noise. A random sample from the mixture, which is diffused from an observed
+or generated data, is fed as the input to the discriminator. The generator is
+updated by backpropagating its gradient through the forward diffusion chain,
+whose length is adaptively adjusted to control the maximum noise-to-data ratio
+allowed at each training step. Theoretical analysis verifies the soundness of the
+proposed Diffusion-GAN, which provides model- and domain-agnostic differentiable
+augmentation. A rich set of experiments on diverse datasets show that DiffusionGAN can
+provide stable and data-efficient GAN training, bringing consistent
+performance improvement over strong GAN baselines for synthesizing photorealistic images.*
+
+[](https://paperswithcode.com/sota/image-generation-on-celeba-64x64?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-stl-10?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-lsun-bedroom-256-x-256?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-afhq-wild?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-afhq-cat?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-afhq-dog?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-lsun-churches-256-x-256?p=diffusion-gan-training-gans-with-diffusion)
+[](https://paperswithcode.com/sota/image-generation-on-ffhq-1024-x-1024?p=diffusion-gan-training-gans-with-diffusion)
+
+## ToDos
+- [x] Initial code release
+- [x] Providing pretrained models
+
+## Build your Diffusion-GAN
+Here, we explain how to train general GANs with diffusion. We provide two ways:
+a. plug-in as simple as a data augmentation method;
+b. training GANs on diffusion chains with a timestep-dependent discriminator.
+Currently, we didn't find significant empirical differences of the two approaches,
+while the second approach has stronger theoretical guarantees. We suspect when advanced timestep-dependent structure is applied in the discriminator,
+the second approach could become better, and we left that for future study.
+
+### Simple Plug-in
+* Design a proper diffusion process based on the ```diffusion.py``` file
+* Apply diffusion on the inputs of discriminators,
+```logits = Discriminator(Diffusion(gen/real_images))```
+* Add adaptiveness of diffusion into your training iterations
+```
+if update_diffusion: # batch_idx % ada_interval == 0
+ adjust = np.sign(sign(Discriminator(real_images)) - ada_target) * C # C = (batch_size * ada_interval) / (ada_kimg * 1000)
+ diffusion.p = (diffusion.p + adjust).clip(min=0., max=1.)
+ diffusion.update_T()
+```
+
+### Full Version
+* Add diffusion timestep `t` as an input for discriminators `logits = Discriminator(images, t)`.
+You may need some modifications in your discriminator architecture.
+* The other steps are the same as Simple Plug-in. Note that since discriminator depends on timesteps,
+you need to collect `t`.
+```
+diffused_images, t = Diffusion(images)
+logits = Discrimnator(diffused_images, t)
+```
+
+## Train our Diffusion-GAN
+
+### Requirements
+* 64-bit Python 3.7 and PyTorch 1.7.1/1.8.1. See [https://pytorch.org/](https://pytorch.org/) for PyTorch install instructions.
+* CUDA toolkit 11.0 or later.
+* Python libraries: `pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3`.
+
+### Data Preparation
+
+In our paper, we trained our model on [CIFAR-10 (32 x 32)](https://www.cs.toronto.edu/~kriz/cifar.html), [STL-10 (64 x 64)](https://cs.stanford.edu/~acoates/stl10/),
+[LSUN (256 x 256)](https://github.com/fyu/lsun), [AFHQ (512 x 512)](https://github.com/clovaai/stargan-v2) and [FFHQ (1024 x 1024)](https://github.com/NVlabs/ffhq-dataset).
+You can download the datasets we used in our paper at their respective websites.
+To prepare the dataset at the respective resolution, run for example
+```.bash
+python dataset_tool.py --source=~/downloads/lsun/raw/bedroom_lmdb --dest=~/datasets/lsun_bedroom200k.zip \
+ --transform=center-crop --width=256 --height=256 --max_images=200000
+
+python dataset_tool.py --source=~/downloads/lsun/raw/church_lmdb --dest=~/datasets/lsun_church200k.zip \
+ --transform=center-crop-wide --width=256 --height=256 --max_images=200000
+```
+
+### Training
+
+We show the training commands that we used below. In most cases, the training commands are similar, so below we use CIFAR-10 dataset
+as an example:
+
+For Diffusion-GAN,
+```.bash
+python train.py --outdir=training-runs --data="~/cifar10.zip" --gpus=4 --cfg cifar --kimg 50000 --aug no --target 0.6 --noise_sd 0.05 --ts_dist priority
+```
+For Diffusion-ProjectedGAN
+```.bash
+python train.py --outdir=training-runs --data="~/cifar10.zip" --gpus=4 --batch 64 --batch-gpu=16 --cfg fastgan --kimg 50000 --target 0.45 --d_pos first --noise_sd 0.5
+```
+For Diffusion-InsGen
+```.bash
+python train.py --outdir=training-runs --data="~/afhq-wild.zip" --gpus=8 --cfg paper512 --kimg 25000
+```
+
+We follows the `config` setting from [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada-pytorchhttps://github.com/NVlabs/stylegan2-ada-pytorch)
+and refer to them for more details. The other major hyperparameters are listed and discussed below:
+* `--target` the discriminator target, which balances the level of diffusion intensity.
+* `--aug` domain-specific image augmentation, such as ADA and Differentiable Augmentation, which is used for evaluate complementariness with diffusion.
+* `--noise_sd` diffusion noise standard deviation, which is set as 0.05 in our case.
+* ` --ts_dist` t sampling distribution, $\pi(t)$ in paper.
+
+We evaluated two `t` sampling distribution `['priority', 'uniform']`,
+where `'priority'` denotes the Equation (11) in paper and `'uniform'` denotes random sampling. In most cases, `priority` works slightly better, while in some cases, such as FFHQ,
+`'uniform'` is better.
+
+## Sampling and Evaluation with our checkpoints
+We report the FIDs of our Diffusion-GAN below and provide the trained checkpoints in the ``./checkpoints`` folder:
+
+| Model | Dataset | Resolution | FID |
+|:---------------------------:|:------------:|:----------:|:-----:|
+| Diffusion-StyleGAN2 | CIFAR-10 | 32x32 | 3.19 |
+| Diffusion-StyleGAN2 | CelebA | 64x64 | 1.69 |
+| Diffusion-StyleGAN2 | STL-10 | 64x64 | 11.53 |
+| Diffusion-StyleGAN2 | LSUN-Bedroom | 256x256 | 3.65 |
+| Diffusion-StyleGAN2 | LSUN-Church | 256x256 | 3.17 |
+| Diffusion-StyleGAN2 | FFHQ | 1024x1024 | 2.83 |
+| Diffusion-ProjectedGAN | CIFAR-10 | 32x32 | 2.54 |
+| Diffusion-ProjectedGAN | STL-10 | 64x64 | 6.91 |
+| Diffusion-ProjectedGAN | LSUN-Bedroom | 256x256 | 1.43 |
+| Diffusion-ProjectedGAN | LSUN-Church | 256x256 | 1.85 |
+| Diffusion-InsGen | AFHQ-Cat | 512x512 | 2.40 |
+| Diffusion-InsGen | AFHQ-Dog | 512x512 | 4.83 |
+| Diffusion-InsGen | AFHQ-Wild | 512x512 | 1.51 |
+
+
+To generate samples, run the following commands:
+
+```.bash
+# Generate FFHQ with pretrained Diffusion-StyleGAN2
+python generate.py --outdir=out --seeds=1-100 \
+ --network=https://tsciencescu.blob.core.windows.net/projectshzheng/DiffusionGAN/diffusion-stylegan2-ffhq.pkl
+
+# Generate LSUN-Church with pretrained Diffusion-ProjectedGAN
+python gen_images.py --outdir=out --seeds=1-100 \
+ --network=https://tsciencescu.blob.core.windows.net/projectshzheng/DiffusionGAN/diffusion-projectedgan-lsun-church.pkl
+```
+
+The checkpoints can be replaced with any pre-trained Diffusion-GAN checkpoint path downloaded from the table above.
+
+
+Similarly, the metrics can be calculated with the following commands:
+
+```.bash
+# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
+python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \
+ --network=https://tsciencescu.blob.core.windows.net/projectshzheng/DiffusionGAN/diffusion-stylegan2-ffhq.pkl
+```
+
+## Citation
+
+```
+@article{wang2022diffusiongan,
+ title = {Diffusion-GAN: Training GANs with Diffusion},
+ author = {Wang, Zhendong and Zheng, Huangjie and He, Pengcheng and Chen, Weizhu and Zhou, Mingyuan},
+ journal = {arXiv preprint arXiv:2206.02262},
+ year = {2022},
+ url = {https://arxiv.org/abs/2206.02262}
+}
+```
+
+## Acknowledgements
+
+Our code builds upon the awesome [StyleGAN2-ADA repo](https://github.com/NVlabs/stylegan2-ada-pytorch), [InsGen repo](https://github.com/genforce/insgen) and [ProjectedGAN repo](https://github.com/autonomousvision/projected_gan), respectively by Karras et al, Ceyuan Yang et al and Axel Sauer et al.
diff --git a/checkpoints/diffusion-insgen-afhqcat.pkl b/checkpoints/diffusion-insgen-afhqcat.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..7ef4fd8d456a1b8d04062207c3e1b8caf82a348b
--- /dev/null
+++ b/checkpoints/diffusion-insgen-afhqcat.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c92c46b87bbaafc8fb914fb781b1c315d70a1bba99b99b54af7541e3669ca2f
+size 365039489
diff --git a/checkpoints/diffusion-insgen-afhqdog.pkl b/checkpoints/diffusion-insgen-afhqdog.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..d2de2de4e30b4aaba09228e293952b4b9215db37
--- /dev/null
+++ b/checkpoints/diffusion-insgen-afhqdog.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0b1617c0af01a89795654337ca0b2510598c9f0c507760a9ddca63599f42039
+size 365039489
diff --git a/checkpoints/diffusion-insgen-afhqwild.pkl b/checkpoints/diffusion-insgen-afhqwild.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..a0237c5c1409b19803ebd64ce2c77a4bf2ddbd18
--- /dev/null
+++ b/checkpoints/diffusion-insgen-afhqwild.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7efc94c615be9cf76f3cde438bc8e832e397d421a2bbeb40e71918efd60a8e65
+size 365039490
diff --git a/checkpoints/diffusion-projectedgan-cifar10.pkl b/checkpoints/diffusion-projectedgan-cifar10.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..396d3813236df54559eebfa73fc8915974e9475d
--- /dev/null
+++ b/checkpoints/diffusion-projectedgan-cifar10.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3406ca783404806d7a8ee1b1daf9cf7936f143e94a2fa4a54057ed8c662679e0
+size 1788846251
diff --git a/checkpoints/diffusion-projectedgan-lsun-bedroom.pkl b/checkpoints/diffusion-projectedgan-lsun-bedroom.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..0e2db3acaecee0f637ba7287bef7b4dff020bef6
--- /dev/null
+++ b/checkpoints/diffusion-projectedgan-lsun-bedroom.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98aaa6cbfc5cd115fc0afab4bfb8507f0bc0289b4422024f60ab321ae94f5938
+size 1788705080
diff --git a/checkpoints/diffusion-projectedgan-lsun-church.pkl b/checkpoints/diffusion-projectedgan-lsun-church.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..cb178d8afffcadca98dc5615001b235626b55611
--- /dev/null
+++ b/checkpoints/diffusion-projectedgan-lsun-church.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:573cfb3eefdd78ea1e41dbc2c03effca8f43cb8a794af3b29227894d8a9a0c83
+size 1788704999
diff --git a/checkpoints/diffusion-projectedgan-stl10.pkl b/checkpoints/diffusion-projectedgan-stl10.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..14a3f09146a0655c0a62715409fe316c6be7874b
--- /dev/null
+++ b/checkpoints/diffusion-projectedgan-stl10.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:902cbf4b2282dfa8a2ea5a326a08a6001a93eb766e3f7dbfe1ce55e9109b6d7e
+size 1788846259
diff --git a/checkpoints/diffusion-stylegan2-celeba64.pkl b/checkpoints/diffusion-stylegan2-celeba64.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..716a49f226071d23f53287d72045bc8280649968
--- /dev/null
+++ b/checkpoints/diffusion-stylegan2-celeba64.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:921e72e290870affb879bc90fe5334e8cb6d5f90ff486e4d6b540036a5606745
+size 319333518
diff --git a/checkpoints/diffusion-stylegan2-cifar10.pkl b/checkpoints/diffusion-stylegan2-cifar10.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..37efe58dd187875bd4ccc0ad53ddb060da9f27fa
--- /dev/null
+++ b/checkpoints/diffusion-stylegan2-cifar10.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b828b7cb95c13688256497f7789ecb8f9dc556df32aa9be8815f9ab0e0ffe6a
+size 252092418
diff --git a/checkpoints/diffusion-stylegan2-ffhq.pkl b/checkpoints/diffusion-stylegan2-ffhq.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..2e9cbc2a688d2652b30da52f6a27e0ecc73fd148
--- /dev/null
+++ b/checkpoints/diffusion-stylegan2-ffhq.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d425b3b85dbd7b79bdde5e8366a8da4dd6bd8bac0e5a1bb7dd543d86ced685a
+size 391116089
diff --git a/checkpoints/diffusion-stylegan2-lsun-bedroom.pkl b/checkpoints/diffusion-stylegan2-lsun-bedroom.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..3dc45dde2eb14525f84aecf3a4d38e1d92d908a2
--- /dev/null
+++ b/checkpoints/diffusion-stylegan2-lsun-bedroom.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86805c3573922b686718957aac81729fc8e181d52796f6727b48719e87bbd7e0
+size 305245901
diff --git a/checkpoints/diffusion-stylegan2-lsun-church.pkl b/checkpoints/diffusion-stylegan2-lsun-church.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..3d7e4ff956055068b8e4038c6fbd785fd7cbffd5
--- /dev/null
+++ b/checkpoints/diffusion-stylegan2-lsun-church.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90dbb7f40beeeac921764a283473e97f3ce06d984b4e394bbdc898f30e3ddf9c
+size 305242727
diff --git a/checkpoints/diffusion-stylegan2-stl10.pkl b/checkpoints/diffusion-stylegan2-stl10.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..9e50066cbcf1d197e444cf6651aa4e34de6eaab6
--- /dev/null
+++ b/checkpoints/diffusion-stylegan2-stl10.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:490017d05b39f3249c5a771bc1035bb2a697b0dba9adc820a081a51cf5fad0e1
+size 319325822
diff --git a/diffusion-insgen/calc_metrics.py b/diffusion-insgen/calc_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..03e828195a096f6f78da241b700c16f56327bdb8
--- /dev/null
+++ b/diffusion-insgen/calc_metrics.py
@@ -0,0 +1,190 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Calculate quality metrics for previous training run or pretrained network pickle."""
+
+import os
+import click
+import json
+import tempfile
+import copy
+import torch
+import dnnlib
+
+import legacy
+from metrics import metric_main
+from metrics import metric_utils
+from torch_utils import training_stats
+from torch_utils import custom_ops
+from torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def subprocess_fn(rank, args, temp_dir):
+ dnnlib.util.Logger(should_flush=True)
+
+ # Init torch.distributed.
+ if args.num_gpus > 1:
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
+ else:
+ init_method = f'file://{init_file}'
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
+
+ # Init torch_utils.
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
+ if rank != 0 or not args.verbose:
+ custom_ops.verbosity = 'none'
+
+ # Print network summary.
+ device = torch.device('cuda', rank)
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
+ if rank == 0 and args.verbose:
+ z = torch.empty([1, G.z_dim], device=device)
+ c = torch.empty([1, G.c_dim], device=device)
+ misc.print_module_summary(G, [z, c])
+
+ # Calculate each metric.
+ for metric in args.metrics:
+ if rank == 0 and args.verbose:
+ print(f'Calculating {metric}...')
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
+ if rank == 0:
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
+ if rank == 0 and args.verbose:
+ print()
+
+ # Done.
+ if rank == 0 and args.verbose:
+ print('Exiting...')
+
+#----------------------------------------------------------------------------
+
+class CommaSeparatedList(click.ParamType):
+ name = 'list'
+
+ def convert(self, value, param, ctx):
+ _ = param, ctx
+ if value is None or value.lower() == 'none' or value == '':
+ return []
+ return value.split(',')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
+@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
+@click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
+@click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
+@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
+@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
+
+def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
+ """Calculate quality metrics for previous training run or pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Previous training run: look up options automatically, save result to JSONL file.
+ python calc_metrics.py --metrics=pr50k3_full \\
+ --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
+
+ \b
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
+
+ Available metrics:
+
+ \b
+ ADA paper:
+ fid50k_full Frechet inception distance against the full dataset.
+ kid50k_full Kernel inception distance against the full dataset.
+ pr50k3_full Precision and recall againt the full dataset.
+ is50k Inception score for CIFAR-10.
+
+ \b
+ StyleGAN and StyleGAN2 papers:
+ fid50k Frechet inception distance against 50k real images.
+ kid50k Kernel inception distance against 50k real images.
+ pr50k3 Precision and recall against 50k real images.
+ ppl2_wend Perceptual path length in W at path endpoints against full image.
+ ppl_zfull Perceptual path length in Z for full paths against cropped image.
+ ppl_wfull Perceptual path length in W for full paths against cropped image.
+ ppl_zend Perceptual path length in Z at path endpoints against cropped image.
+ ppl_wend Perceptual path length in W at path endpoints against cropped image.
+ """
+ dnnlib.util.Logger(should_flush=True)
+
+ # Validate arguments.
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
+ if not args.num_gpus >= 1:
+ ctx.fail('--gpus must be at least 1')
+
+ # Load network.
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
+ ctx.fail('--network must point to a file or URL')
+ if args.verbose:
+ print(f'Loading network from "{network_pkl}"...')
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
+ network_dict = legacy.load_network_pkl(f)
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
+
+ # Initialize dataset options.
+ if data is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
+ elif network_dict['training_set_kwargs'] is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
+ else:
+ ctx.fail('Could not look up dataset options; please specify --data')
+
+ # Finalize dataset options.
+ args.dataset_kwargs.resolution = args.G.img_resolution
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
+ if mirror is not None:
+ args.dataset_kwargs.xflip = mirror
+
+ # Print dataset options.
+ if args.verbose:
+ print('Dataset options:')
+ print(json.dumps(args.dataset_kwargs, indent=2))
+
+ # Locate run dir.
+ args.run_dir = None
+ if os.path.isfile(network_pkl):
+ pkl_dir = os.path.dirname(network_pkl)
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
+ args.run_dir = pkl_dir
+
+ # Launch processes.
+ if args.verbose:
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn')
+ with tempfile.TemporaryDirectory() as temp_dir:
+ if args.num_gpus == 1:
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
+ else:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ calc_metrics() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/dataset_tool.py b/diffusion-insgen/dataset_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59e6292891c3896722965020af7c60056729f2d
--- /dev/null
+++ b/diffusion-insgen/dataset_tool.py
@@ -0,0 +1,444 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import functools
+import io
+import json
+import os
+import pickle
+import sys
+import tarfile
+import gzip
+import zipfile
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+
+import click
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+
+#----------------------------------------------------------------------------
+
+def error(msg):
+ print('Error: ' + msg)
+ sys.exit(1)
+
+#----------------------------------------------------------------------------
+
+def maybe_min(a: int, b: Optional[int]) -> int:
+ if b is not None:
+ return min(a, b)
+ return a
+
+#----------------------------------------------------------------------------
+
+def file_ext(name: Union[str, Path]) -> str:
+ return str(name).split('.')[-1]
+
+#----------------------------------------------------------------------------
+
+def is_image_ext(fname: Union[str, Path]) -> bool:
+ ext = file_ext(fname).lower()
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
+
+#----------------------------------------------------------------------------
+
+def open_image_folder(source_dir, *, max_images: Optional[int]):
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
+
+ # Load labels.
+ labels = {}
+ meta_fname = os.path.join(source_dir, 'dataset.json')
+ if os.path.isfile(meta_fname):
+ with open(meta_fname, 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = { x[0]: x[1] for x in labels }
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ for idx, fname in enumerate(input_images):
+ arch_fname = os.path.relpath(fname, source_dir)
+ arch_fname = arch_fname.replace('\\', '/')
+ img = np.array(PIL.Image.open(fname))
+ yield dict(img=img, label=labels.get(arch_fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_image_zip(source, *, max_images: Optional[int]):
+ with zipfile.ZipFile(source, mode='r') as z:
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
+
+ # Load labels.
+ labels = {}
+ if 'dataset.json' in z.namelist():
+ with z.open('dataset.json', 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = { x[0]: x[1] for x in labels }
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ with zipfile.ZipFile(source, mode='r') as z:
+ for idx, fname in enumerate(input_images):
+ with z.open(fname, 'r') as file:
+ img = PIL.Image.open(file) # type: ignore
+ img = np.array(img)
+ yield dict(img=img, label=labels.get(fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
+ import cv2 # pip install opencv-python
+ import lmdb # pip install lmdb # pylint: disable=import-error
+
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
+
+ def iterate_images():
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ for idx, (_key, value) in enumerate(txn.cursor()):
+ try:
+ try:
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
+ if img is None:
+ raise IOError('cv2.imdecode failed')
+ img = img[:, :, ::-1] # BGR => RGB
+ except IOError:
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
+ yield dict(img=img, label=None)
+ if idx >= max_idx-1:
+ break
+ except:
+ print(sys.exc_info()[1])
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_cifar10(tarball: str, *, max_images: Optional[int]):
+ images = []
+ labels = []
+
+ with tarfile.open(tarball, 'r:gz') as tar:
+ for batch in range(1, 6):
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
+ with tar.extractfile(member) as file:
+ data = pickle.load(file, encoding='latin1')
+ images.append(data['data'].reshape(-1, 3, 32, 32))
+ labels.append(data['labels'])
+
+ images = np.concatenate(images)
+ labels = np.concatenate(labels)
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_mnist(images_gz: str, *, max_images: Optional[int]):
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
+ assert labels_gz != images_gz
+ images = []
+ labels = []
+
+ with gzip.open(images_gz, 'rb') as f:
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
+ with gzip.open(labels_gz, 'rb') as f:
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
+
+ images = images.reshape(-1, 28, 28)
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def make_transform(
+ transform: Optional[str],
+ output_width: Optional[int],
+ output_height: Optional[int],
+ resize_filter: str
+) -> Callable[[np.ndarray], Optional[np.ndarray]]:
+ resample = { 'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS }[resize_filter]
+ def scale(width, height, img):
+ w = img.shape[1]
+ h = img.shape[0]
+ if width == w and height == h:
+ return img
+ img = PIL.Image.fromarray(img)
+ ww = width if width is not None else w
+ hh = height if height is not None else h
+ img = img.resize((ww, hh), resample)
+ return np.array(img)
+
+ def center_crop(width, height, img):
+ crop = np.min(img.shape[:2])
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
+ img = PIL.Image.fromarray(img, 'RGB')
+ img = img.resize((width, height), resample)
+ return np.array(img)
+
+ def center_crop_wide(width, height, img):
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
+ if img.shape[1] < width or ch < height:
+ return None
+
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
+ img = PIL.Image.fromarray(img, 'RGB')
+ img = img.resize((width, height), resample)
+ img = np.array(img)
+
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
+ return canvas
+
+ if transform is None:
+ return functools.partial(scale, output_width, output_height)
+ if transform == 'center-crop':
+ if (output_width is None) or (output_height is None):
+ error ('must specify --width and --height when using ' + transform + 'transform')
+ return functools.partial(center_crop, output_width, output_height)
+ if transform == 'center-crop-wide':
+ if (output_width is None) or (output_height is None):
+ error ('must specify --width and --height when using ' + transform + ' transform')
+ return functools.partial(center_crop_wide, output_width, output_height)
+ assert False, 'unknown transform'
+
+#----------------------------------------------------------------------------
+
+def open_dataset(source, *, max_images: Optional[int]):
+ if os.path.isdir(source):
+ if source.rstrip('/').endswith('_lmdb'):
+ return open_lmdb(source, max_images=max_images)
+ else:
+ return open_image_folder(source, max_images=max_images)
+ elif os.path.isfile(source):
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
+ return open_cifar10(source, max_images=max_images)
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
+ return open_mnist(source, max_images=max_images)
+ elif file_ext(source) == 'zip':
+ return open_image_zip(source, max_images=max_images)
+ else:
+ assert False, 'unknown archive type'
+ else:
+ error(f'Missing input file or directory: {source}')
+
+#----------------------------------------------------------------------------
+
+def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
+ dest_ext = file_ext(dest)
+
+ if dest_ext == 'zip':
+ if os.path.dirname(dest) != '':
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
+ zf.writestr(fname, data)
+ return '', zip_write_bytes, zf.close
+ else:
+ # If the output folder already exists, check that is is
+ # empty.
+ #
+ # Note: creating the output directory is not strictly
+ # necessary as folder_write_bytes() also mkdirs, but it's better
+ # to give an error message earlier in case the dest folder
+ # somehow cannot be created.
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
+ error('--dest folder must be empty')
+ os.makedirs(dest, exist_ok=True)
+
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
+ with open(fname, 'wb') as fout:
+ if isinstance(data, str):
+ data = data.encode('utf8')
+ fout.write(data)
+ return dest, folder_write_bytes, lambda: None
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
+@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
+@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
+@click.option('--resize-filter', help='Filter to use when resizing images for output resolution', type=click.Choice(['box', 'lanczos']), default='lanczos', show_default=True)
+@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
+@click.option('--width', help='Output width', type=int)
+@click.option('--height', help='Output height', type=int)
+def convert_dataset(
+ ctx: click.Context,
+ source: str,
+ dest: str,
+ max_images: Optional[int],
+ transform: Optional[str],
+ resize_filter: str,
+ width: Optional[int],
+ height: Optional[int]
+):
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
+
+ The input dataset format is guessed from the --source argument:
+
+ \b
+ --source *_lmdb/ Load LSUN dataset
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
+ --source path/ Recursively load all images from path/
+ --source dataset.zip Recursively load all images from dataset.zip
+
+ Specifying the output format and path:
+
+ \b
+ --dest /path/to/dir Save output files under /path/to/dir
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
+
+ The output dataset format can be either an image folder or an uncompressed zip archive.
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
+ offer better training performance on network file systems.
+
+ Images within the dataset archive will be stored as uncompressed PNG.
+ Uncompresed PNGs can be efficiently decoded in the training loop.
+
+ Class labels are stored in a file called 'dataset.json' that is stored at the
+ dataset root folder. This file has the following structure:
+
+ \b
+ {
+ "labels": [
+ ["00000/img00000000.png",6],
+ ["00000/img00000001.png",9],
+ ... repeated for every image in the datase
+ ["00049/img00049999.png",1]
+ ]
+ }
+
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
+ not containing class labels.
+
+ Image scale/crop and resolution requirements:
+
+ Output images must be square-shaped and they must all have the same power-of-two
+ dimensions.
+
+ To scale arbitrary input image size to a specific width and height, use the
+ --width and --height options. Output resolution will be either the original
+ input resolution (if --width/--height was not specified) or the one specified with
+ --width/height.
+
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
+ center crop transform on the input image. These options should be used with the
+ --width and --height options. For example:
+
+ \b
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
+ --transform=center-crop-wide --width 512 --height=384
+ """
+
+ PIL.Image.init() # type: ignore
+
+ if dest == '':
+ ctx.fail('--dest output filename or directory must not be an empty string')
+
+ num_files, input_iter = open_dataset(source, max_images=max_images)
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
+
+ transform_image = make_transform(transform, width, height, resize_filter)
+
+ dataset_attrs = None
+
+ labels = []
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
+ idx_str = f'{idx:08d}'
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
+
+ # Apply crop and resize.
+ img = transform_image(image['img'])
+
+ # Transform may drop images.
+ if img is None:
+ continue
+
+ # Error check to require uniform image attributes across
+ # the whole dataset.
+ channels = img.shape[2] if img.ndim == 3 else 1
+ cur_image_attrs = {
+ 'width': img.shape[1],
+ 'height': img.shape[0],
+ 'channels': channels
+ }
+ if dataset_attrs is None:
+ dataset_attrs = cur_image_attrs
+ width = dataset_attrs['width']
+ height = dataset_attrs['height']
+ if width != height:
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
+ if dataset_attrs['channels'] not in [1, 3]:
+ error('Input images must be stored as RGB or grayscale')
+ if width != 2 ** int(np.floor(np.log2(width))):
+ error('Image width/height after scale and crop are required to be power-of-two')
+ elif dataset_attrs != cur_image_attrs:
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
+
+ # Save the image as an uncompressed PNG.
+ img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
+ image_bits = io.BytesIO()
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
+
+ metadata = {
+ 'labels': labels if all(x is not None for x in labels) else None
+ }
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
+ close_dest()
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_dataset() # pylint: disable=no-value-for-parameter
diff --git a/diffusion-insgen/dnnlib/__init__.py b/diffusion-insgen/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f08cf36f11f9b0fd94c1b7caeadf69b98375b04
--- /dev/null
+++ b/diffusion-insgen/dnnlib/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/diffusion-insgen/dnnlib/util.py b/diffusion-insgen/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..76725336d01e75e1c68daa88be47f4fde0bbc63b
--- /dev/null
+++ b/diffusion-insgen/dnnlib/util.py
@@ -0,0 +1,477 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/diffusion-insgen/generate.py b/diffusion-insgen/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7f961931e4e2947a74e29289b0e354d789d7bdc
--- /dev/null
+++ b/diffusion-insgen/generate.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate images using pretrained network pickle."""
+
+import os
+import re
+from typing import List, Optional
+
+import click
+import dnnlib
+import numpy as np
+import PIL.Image
+import torch
+
+import legacy
+
+#----------------------------------------------------------------------------
+
+def num_range(s: str) -> List[int]:
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
+
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ m = range_re.match(s)
+ if m:
+ return list(range(int(m.group(1)), int(m.group(2))+1))
+ vals = s.split(',')
+ return [int(x) for x in vals]
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=num_range, help='List of random seeds')
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
+@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
+def generate_images(
+ ctx: click.Context,
+ network_pkl: str,
+ seeds: Optional[List[int]],
+ truncation_psi: float,
+ noise_mode: str,
+ outdir: str,
+ class_idx: Optional[int],
+ projected_w: Optional[str]
+):
+ """Generate images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Generate curated MetFaces images without truncation (Fig.10 left)
+ python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
+ python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
+ python generate.py --outdir=out --seeds=0-35 --class=1 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
+
+ \b
+ # Render an image from projected W
+ python generate.py --outdir=out --projected_w=projected_w.npz \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+ """
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ os.makedirs(outdir, exist_ok=True)
+
+ # Synthesize the result of a W projection.
+ if projected_w is not None:
+ if seeds is not None:
+ print ('warn: --seeds is ignored when using --projected-w')
+ print(f'Generating images from projected W "{projected_w}"')
+ ws = np.load(projected_w)['w']
+ ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
+ assert ws.shape[1:] == (G.num_ws, G.w_dim)
+ for idx, w in enumerate(ws):
+ img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
+ return
+
+ if seeds is None:
+ ctx.fail('--seeds option is required when not using --projected-w')
+
+ # Labels.
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ ctx.fail('Must specify class label with --class when using a conditional network')
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print ('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Generate images.
+ for seed_idx, seed in enumerate(seeds):
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
+
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/legacy.py b/diffusion-insgen/legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c48c8906477dd3e85d0dc065a486dc19aa6775d
--- /dev/null
+++ b/diffusion-insgen/legacy.py
@@ -0,0 +1,332 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import click
+import pickle
+import re
+import copy
+import numpy as np
+import torch
+import dnnlib
+from torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def load_network_pkl(f, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_G, tf_D, tf_Gs = data
+ G = convert_tf_generator(tf_G)
+ D = convert_tf_discriminator(tf_D)
+ G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=G, D=D, G_ema=G_ema)
+
+ # extract nn.module from ddp
+ for k, v in data.items():
+ if isinstance(v, _DDPNetworkStub):
+ data[k] = v._modules['module']
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ assert isinstance(data['G'], torch.nn.Module)
+ assert isinstance(data['D'], torch.nn.Module)
+ assert isinstance(data['G_ema'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['G', 'D', 'G_ema']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ if key.startswith('G'):
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
+ kwargs.synthesis_kwargs.num_fp16_res = 4
+ kwargs.synthesis_kwargs.conv_clamp = 256
+ if key.startswith('D'):
+ kwargs.num_fp16_res = 4
+ kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ return data
+
+#----------------------------------------------------------------------------
+
+class _DDPNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _TFNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _LegacyUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'torch.nn.parallel.distributed' and name == 'DistributedDataParallel':
+ return _DDPNetworkStub
+ if module == 'dnnlib.tflib.network' and name == 'Network':
+ return _TFNetworkStub
+ if module == 'training.augment':
+ return _TFNetworkStub
+ return super().find_class(module, name)
+
+#----------------------------------------------------------------------------
+
+def _collect_tf_params(tf_net):
+ # pylint: disable=protected-access
+ tf_params = dict()
+ def recurse(prefix, tf_net):
+ for name, value in tf_net.variables:
+ tf_params[prefix + name] = value
+ for name, comp in tf_net.components.items():
+ recurse(prefix + name + '/', comp)
+ recurse('', tf_net)
+ return tf_params
+
+#----------------------------------------------------------------------------
+
+def _populate_module_params(module, *patterns):
+ for name, tensor in misc.named_params_and_buffers(module):
+ found = False
+ value = None
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
+ match = re.fullmatch(pattern, name)
+ if match:
+ found = True
+ if value_fn is not None:
+ value = value_fn(*match.groups())
+ break
+ try:
+ assert found
+ if value is not None:
+ tensor.copy_(torch.from_numpy(np.array(value)))
+ except:
+ print(name, list(tensor.shape))
+ raise
+
+#----------------------------------------------------------------------------
+
+def convert_tf_generator(tf_G):
+ if tf_G.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_G.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None, none=None):
+ known_kwargs.add(tf_name)
+ val = tf_kwargs.get(tf_name, default)
+ return val if val is not None else none
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ z_dim = kwarg('latent_size', 512),
+ c_dim = kwarg('label_size', 0),
+ w_dim = kwarg('dlatent_size', 512),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 8),
+ embed_features = kwarg('label_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
+ ),
+ synthesis_kwargs = dnnlib.EasyDict(
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ architecture = kwarg('architecture', 'skip'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ use_noise = kwarg('use_noise', True),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('truncation_psi')
+ kwarg('truncation_cutoff')
+ kwarg('style_mixing_prob')
+ kwarg('structure')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_G)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
+ kwargs.synthesis.kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ _populate_module_params(G,
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'.*\.resample_filter', None,
+ )
+ return G
+
+#----------------------------------------------------------------------------
+
+def convert_tf_discriminator(tf_D):
+ if tf_D.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_D.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None):
+ known_kwargs.add(tf_name)
+ return tf_kwargs.get(tf_name, default)
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ c_dim = kwarg('label_size', 0),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ architecture = kwarg('architecture', 'resnet'),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ cmap_dim = kwarg('mapping_fmaps', None),
+ block_kwargs = dnnlib.EasyDict(
+ activation = kwarg('nonlinearity', 'lrelu'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ freeze_layers = kwarg('freeze_layers', 0),
+ ),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 0),
+ embed_features = kwarg('mapping_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
+ ),
+ epilogue_kwargs = dnnlib.EasyDict(
+ mbstd_group_size = kwarg('mbstd_group_size', None),
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('structure')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_D)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
+ kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ _populate_module_params(D,
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
+ r'.*\.resample_filter', None,
+ )
+ return D
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--source', help='Input pickle', required=True, metavar='PATH')
+@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
+@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
+def convert_network_pickle(source, dest, force_fp16):
+ """Convert legacy network pickle into the native PyTorch format.
+
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
+
+ Example:
+
+ \b
+ python legacy.py \\
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
+ --dest=stylegan2-cat-config-f.pkl
+ """
+ print(f'Loading "{source}"...')
+ with dnnlib.util.open_url(source) as f:
+ data = load_network_pkl(f, force_fp16=force_fp16)
+ print(f'Saving "{dest}"...')
+ with open(dest, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/__init__.py b/diffusion-insgen/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e1a5ba99e56a56ecaa14f7d4fa41777789c0cf
--- /dev/null
+++ b/diffusion-insgen/metrics/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/diffusion-insgen/metrics/frechet_inception_distance.py b/diffusion-insgen/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d38ec731b33e6f4f20cd4601e58d7e5ce2eaaa3
--- /dev/null
+++ b/diffusion-insgen/metrics/frechet_inception_distance.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
+
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/inception_score.py b/diffusion-insgen/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..3822c1435901a47e8c192b52cd3ed1ce5de67acd
--- /dev/null
+++ b/diffusion-insgen/metrics/inception_score.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/kernel_inception_distance.py b/diffusion-insgen/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ac978925b5cf810463ef8e8a6f0dcd3f9078e6d
--- /dev/null
+++ b/diffusion-insgen/metrics/kernel_inception_distance.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/metric_main.py b/diffusion-insgen/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..738804a6fbdba7bee3b0c68ca2fca4646527bc28
--- /dev/null
+++ b/diffusion-insgen/metrics/metric_main.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import time
+import json
+import torch
+import dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import precision_recall
+from . import perceptual_path_length
+from . import inception_score
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ results = _metric_dict[metric](opts)
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Primary metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def pr50k3_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
+
+@register_metric
+def ppl2_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
+ return dict(ppl2_wend=ppl)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
+
+#----------------------------------------------------------------------------
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+@register_metric
+def pr50k3(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
+
+@register_metric
+def ppl_zfull(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
+ return dict(ppl_zfull=ppl)
+
+@register_metric
+def ppl_wfull(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
+ return dict(ppl_wfull=ppl)
+
+@register_metric
+def ppl_zend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
+ return dict(ppl_zend=ppl)
+
+@register_metric
+def ppl_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
+ return dict(ppl_wend=ppl)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/metric_utils.py b/diffusion-insgen/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..16de1eae3ee79549412eff5313dcf26b5d7a4bb9
--- /dev/null
+++ b/diffusion-insgen/metrics/metric_utils.py
@@ -0,0 +1,275 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import time
+import hashlib
+import pickle
+import copy
+import uuid
+import numpy as np
+import torch
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items):
+ assert (self.num_items is None) or (cur_items <= self.num_items)
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return FeatureStats.load(cache_file)
+
+ # Initialize.
+ num_items = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images.to(opts.device), **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and load labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+
+ # Image generation func.
+ def run_generator(z, c):
+ img = G(z=z, c=c, **opts.G_kwargs)
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ return img
+
+ # JIT.
+ if jit:
+ z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
+ c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
+ run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
+
+ # Initialize.
+ stats = FeatureStats(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ images.append(run_generator(z, c))
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images, **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/perceptual_path_length.py b/diffusion-insgen/metrics/perceptual_path_length.py
new file mode 100644
index 0000000000000000000000000000000000000000..d070f45a04efed7e9492fddb85078be306753282
--- /dev/null
+++ b/diffusion-insgen/metrics/perceptual_path_length.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
+Architecture for Generative Adversarial Networks". Matches the original
+implementation by Karras et al. at
+https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
+
+import copy
+import numpy as np
+import torch
+import dnnlib
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+#----------------------------------------------------------------------------
+
+class PPLSampler(torch.nn.Module):
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__()
+ self.G = copy.deepcopy(G)
+ self.G_kwargs = G_kwargs
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.vgg16 = copy.deepcopy(vgg16)
+
+ def forward(self, c):
+ # Generate random latents and interpolation t-values.
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
+ else: # space == 'z'
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
+
+ # Randomize noise buffers.
+ for name, buf in self.G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Generate images.
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
+
+ # Center crop.
+ if self.crop:
+ assert img.shape[2] == img.shape[3]
+ c = img.shape[2] // 8
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample to 256x256.
+ factor = self.G.img_resolution // 256
+ if factor > 1:
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
+
+ # Scale dynamic range from [-1,1] to [0,255].
+ img = (img + 1) * (255 / 2)
+ if self.G.img_channels == 1:
+ img = img.repeat([1, 3, 1, 1])
+
+ # Evaluate differential LPIPS.
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
+ return dist
+
+#----------------------------------------------------------------------------
+
+def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
+
+ # Setup sampler.
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
+ sampler.eval().requires_grad_(False).to(opts.device)
+ if jit:
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
+ sampler = torch.jit.trace(sampler, [c], check_trace=False)
+
+ # Sampling loop.
+ dist = []
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ x = sampler(c)
+ for src in range(opts.num_gpus):
+ y = x.clone()
+ if opts.num_gpus > 1:
+ torch.distributed.broadcast(y, src=src)
+ dist.append(y)
+ progress.update(num_samples)
+
+ # Compute PPL.
+ if opts.rank != 0:
+ return float('nan')
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
+ lo = np.percentile(dist, 1, interpolation='lower')
+ hi = np.percentile(dist, 99, interpolation='higher')
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
+ return float(ppl)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/metrics/precision_recall.py b/diffusion-insgen/metrics/precision_recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..8200b7ef51963ae218e3b871de270a826bf10459
--- /dev/null
+++ b/diffusion-insgen/metrics/precision_recall.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper "Improved Precision and Recall
+Metric for Assessing Generative Models". Matches the original implementation
+by Kynkaanniemi et al. at
+https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
+
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
+ assert 0 <= rank < num_gpus
+ num_cols = col_features.shape[0]
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
+ dist_batches = []
+ for col_batch in col_batches[rank :: num_gpus]:
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
+ for src in range(num_gpus):
+ dist_broadcast = dist_batch.clone()
+ if num_gpus > 1:
+ torch.distributed.broadcast(dist_broadcast, src=src)
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
+
+#----------------------------------------------------------------------------
+
+def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
+ detector_kwargs = dict(return_features=True)
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
+
+ results = dict()
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
+ kth = []
+ for manifold_batch in manifold.split(row_batch_size):
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
+ kth = torch.cat(kth) if opts.rank == 0 else None
+ pred = []
+ for probes_batch in probes.split(row_batch_size):
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
+ return results['precision'], results['recall']
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/projector.py b/diffusion-insgen/projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..36041a08619a602304deb603a6769dbfed9437c8
--- /dev/null
+++ b/diffusion-insgen/projector.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Project given image to the latent space of pretrained network pickle."""
+
+import copy
+import os
+from time import perf_counter
+
+import click
+import imageio
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+
+import dnnlib
+import legacy
+
+def project(
+ G,
+ target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
+ *,
+ num_steps = 1000,
+ w_avg_samples = 10000,
+ initial_learning_rate = 0.1,
+ initial_noise_factor = 0.05,
+ lr_rampdown_length = 0.25,
+ lr_rampup_length = 0.05,
+ noise_ramp_length = 0.75,
+ regularize_noise_weight = 1e5,
+ verbose = False,
+ device: torch.device
+):
+ assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
+
+ def logprint(*args):
+ if verbose:
+ print(*args)
+
+ G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
+
+ # Compute w stats.
+ logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
+ w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
+ w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
+ w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
+
+ # Setup noise inputs.
+ noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
+
+ # Load VGG16 feature detector.
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
+ with dnnlib.util.open_url(url) as f:
+ vgg16 = torch.jit.load(f).eval().to(device)
+
+ # Features for target image.
+ target_images = target.unsqueeze(0).to(device).to(torch.float32)
+ if target_images.shape[2] > 256:
+ target_images = F.interpolate(target_images, size=(256, 256), mode='area')
+ target_features = vgg16(target_images, resize_images=False, return_lpips=True)
+
+ w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
+ w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
+ optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
+
+ # Init noise.
+ for buf in noise_bufs.values():
+ buf[:] = torch.randn_like(buf)
+ buf.requires_grad = True
+
+ for step in range(num_steps):
+ # Learning rate schedule.
+ t = step / num_steps
+ w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
+ lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
+ lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
+ lr = initial_learning_rate * lr_ramp
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+ # Synth images from opt_w.
+ w_noise = torch.randn_like(w_opt) * w_noise_scale
+ ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
+ synth_images = G.synthesis(ws, noise_mode='const')
+
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
+ synth_images = (synth_images + 1) * (255/2)
+ if synth_images.shape[2] > 256:
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
+
+ # Features for synth images.
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
+ dist = (target_features - synth_features).square().sum()
+
+ # Noise regularization.
+ reg_loss = 0.0
+ for v in noise_bufs.values():
+ noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
+ while True:
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
+ if noise.shape[2] <= 8:
+ break
+ noise = F.avg_pool2d(noise, kernel_size=2)
+ loss = dist + reg_loss * regularize_noise_weight
+
+ # Step
+ optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+ optimizer.step()
+ logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
+
+ # Save projected W for each optimization step.
+ w_out[step] = w_opt.detach()[0]
+
+ # Normalize noise.
+ with torch.no_grad():
+ for buf in noise_bufs.values():
+ buf -= buf.mean()
+ buf *= buf.square().mean().rsqrt()
+
+ return w_out.repeat([1, G.mapping.num_ws, 1])
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
+@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
+@click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
+@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
+@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
+def run_projection(
+ network_pkl: str,
+ target_fname: str,
+ outdir: str,
+ save_video: bool,
+ seed: int,
+ num_steps: int
+):
+ """Project given image to the latent space of pretrained network pickle.
+
+ Examples:
+
+ \b
+ python projector.py --outdir=out --target=~/mytargetimg.png \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
+ """
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ # Load networks.
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as fp:
+ G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
+
+ # Load target image.
+ target_pil = PIL.Image.open(target_fname).convert('RGB')
+ w, h = target_pil.size
+ s = min(w, h)
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
+ target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
+
+ # Optimize projection.
+ start_time = perf_counter()
+ projected_w_steps = project(
+ G,
+ target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
+ num_steps=num_steps,
+ device=device,
+ verbose=True
+ )
+ print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
+
+ # Render debug output: optional video and projected image and W vector.
+ os.makedirs(outdir, exist_ok=True)
+ if save_video:
+ video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
+ print (f'Saving optimization progress video "{outdir}/proj.mp4"')
+ for projected_w in projected_w_steps:
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
+ synth_image = (synth_image + 1) * (255/2)
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
+ video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
+ video.close()
+
+ # Save final projected frame and W vector.
+ target_pil.save(f'{outdir}/target.png')
+ projected_w = projected_w_steps[-1]
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
+ synth_image = (synth_image + 1) * (255/2)
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
+ PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
+ np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ run_projection() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/style_mixing.py b/diffusion-insgen/style_mixing.py
new file mode 100644
index 0000000000000000000000000000000000000000..c47bebbc44c0126b6fd00a55b8b487dc7b159653
--- /dev/null
+++ b/diffusion-insgen/style_mixing.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate style mixing image matrix using pretrained network pickle."""
+
+import os
+import re
+from typing import List
+
+import click
+import dnnlib
+import numpy as np
+import PIL.Image
+import torch
+
+import legacy
+
+#----------------------------------------------------------------------------
+
+def num_range(s: str) -> List[int]:
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
+
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ m = range_re.match(s)
+ if m:
+ return list(range(int(m.group(1)), int(m.group(2))+1))
+ vals = s.split(',')
+ return [int(x) for x in vals]
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
+@click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
+@click.option('--styles', 'col_styles', type=num_range, help='Style layer range', default='0-6', show_default=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--outdir', type=str, required=True)
+def generate_style_mix(
+ network_pkl: str,
+ row_seeds: List[int],
+ col_seeds: List[int],
+ col_styles: List[int],
+ truncation_psi: float,
+ noise_mode: str,
+ outdir: str
+):
+ """Generate images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+ """
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ os.makedirs(outdir, exist_ok=True)
+
+ print('Generating W vectors...')
+ all_seeds = list(set(row_seeds + col_seeds))
+ all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
+ all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
+ w_avg = G.mapping.w_avg
+ all_w = w_avg + (all_w - w_avg) * truncation_psi
+ w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
+
+ print('Generating images...')
+ all_images = G.synthesis(all_w, noise_mode=noise_mode)
+ all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
+ image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
+
+ print('Generating style-mixed images...')
+ for row_seed in row_seeds:
+ for col_seed in col_seeds:
+ w = w_dict[row_seed].clone()
+ w[col_styles] = w_dict[col_seed][col_styles]
+ image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
+ image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
+
+ print('Saving images...')
+ os.makedirs(outdir, exist_ok=True)
+ for (row_seed, col_seed), image in image_dict.items():
+ PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
+
+ print('Saving image grid...')
+ W = G.img_resolution
+ H = G.img_resolution
+ canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
+ for row_idx, row_seed in enumerate([0] + row_seeds):
+ for col_idx, col_seed in enumerate([0] + col_seeds):
+ if row_idx == 0 and col_idx == 0:
+ continue
+ key = (row_seed, col_seed)
+ if row_idx == 0:
+ key = (col_seed, col_seed)
+ if col_idx == 0:
+ key = (row_seed, row_seed)
+ canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
+ canvas.save(f'{outdir}/grid.png')
+
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_style_mix() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/__init__.py b/diffusion-insgen/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a93a14f64afc0dfdb2edcda0d98554bc168cbe
--- /dev/null
+++ b/diffusion-insgen/torch_utils/__init__.py
@@ -0,0 +1,2 @@
+
+# empty
diff --git a/diffusion-insgen/torch_utils/custom_ops.py b/diffusion-insgen/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb19b96bb02b0fb5a72b3323f462f13d0d887427
--- /dev/null
+++ b/diffusion-insgen/torch_utils/custom_ops.py
@@ -0,0 +1,119 @@
+
+import os
+import glob
+import torch
+import torch.utils.cpp_extension
+import importlib
+import hashlib
+import shutil
+from pathlib import Path
+
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Compile and load.
+ verbose_build = (verbosity == 'full')
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
+
+ # Compute a combined hash digest for all source files in the same
+ # custom op directory (usually .cu, .cpp, .py and .h files).
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
+
+ if not os.path.isdir(digest_build_dir):
+ os.makedirs(digest_build_dir, exist_ok=True)
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
+ if baton.try_acquire():
+ try:
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
+ finally:
+ baton.release()
+ else:
+ # Someone else is copying source files under the digest dir,
+ # wait until done and continue.
+ baton.wait()
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/misc.py b/diffusion-insgen/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..577ed974d792089963061e976fccd21f9ae282dc
--- /dev/null
+++ b/diffusion-insgen/torch_utils/misc.py
@@ -0,0 +1,260 @@
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to suppress known warnings in torch.jit.trace().
+
+class suppress_tracer_warnings(warnings.catch_warnings):
+ def __enter__(self):
+ super().__enter__()
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
+ return self
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
+
+import os
+
+def get_ckpt_path(run_dir):
+ return os.path.join(run_dir, f'network-snapshot.pkl')
diff --git a/diffusion-insgen/torch_utils/ops/__init__.py b/diffusion-insgen/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a93a14f64afc0dfdb2edcda0d98554bc168cbe
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/__init__.py
@@ -0,0 +1,2 @@
+
+# empty
diff --git a/diffusion-insgen/torch_utils/ops/bias_act.cpp b/diffusion-insgen/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,99 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/bias_act.cu b/diffusion-insgen/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/bias_act.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/bias_act.h b/diffusion-insgen/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/bias_act.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/bias_act.py b/diffusion-insgen/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0363594a7d54203631566552d38c1d1c351947
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/bias_act.py
@@ -0,0 +1,205 @@
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import dnnlib
+import traceback
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ _inited = True
+ sources = ['bias_act.cpp', 'bias_act.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/conv2d_gradfix.py b/diffusion-insgen/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e0a0b28f8bc83936c187b999c1082610dcfbb9f
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,172 @@
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import warnings
+import contextlib
+import torch
+from distutils.version import LooseVersion
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+old_version = LooseVersion(torch.__version__) < LooseVersion('1.11.0')
+
+@contextlib.contextmanager
+def no_weight_gradients():
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
+ return True
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
+ return False
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ if not transpose:
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+ else: # transpose
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ ctx.save_for_backward(input, weight, bias)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight, bias = ctx.saved_tensors
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
+ assert grad_input.shape == input.shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input, bias)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, bias):
+ if old_version:
+ op = torch._C._jit_get_operation(
+ 'aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
+ torch.backends.cudnn.allow_tf32]
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+ else:
+ bias_shape = bias.shape if (bias is not None) else None
+ empty_weight = torch.empty(weight_shape, dtype=input.dtype, layout=input.layout, device=input.device)
+ grad_weight = torch.ops.aten.convolution_backward(grad_output, input, empty_weight, bias_sizes=bias_shape, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[0,1,0])[1]
+ assert grad_weight.shape == weight_shape
+ ctx.save_for_backward(grad_output, input)
+ return grad_weight
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output.shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input.shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/conv2d_resample.py b/diffusion-insgen/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..10c081d218eee704f6c63b76e5bd0b508ca3f1e9
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,149 @@
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ w = w.flip([2, 3])
+
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
+ if out_channels <= 4 and groups == 1:
+ in_shape = x.shape
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
+ else:
+ x = x.to(memory_format=torch.contiguous_format)
+ w = w.to(memory_format=torch.contiguous_format)
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
+ return x.to(memory_format=torch.channels_last)
+
+ # Otherwise => execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/fma.py b/diffusion-insgen/torch_utils/ops/fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ab4e07052d72149fbd10dcd195e8d5cf0a22ca
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/fma.py
@@ -0,0 +1,53 @@
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/grid_sample_gradfix.py b/diffusion-insgen/torch_utils/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..07828ac9fd4e93b446f5ba678626dc815c46a883
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,77 @@
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import warnings
+import torch
+from distutils.version import LooseVersion
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ if not enabled:
+ return False
+ if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
+ return True
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
+ return False
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/upfirdn2d.cpp b/diffusion-insgen/torch_utils/ops/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/upfirdn2d.cu b/diffusion-insgen/torch_utils/ops/upfirdn2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,350 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ }
+ return spec;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/upfirdn2d.h b/diffusion-insgen/torch_utils/ops/upfirdn2d.h
new file mode 100644
index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/upfirdn2d.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct upfirdn2d_kernel_params
+{
+ const void* x;
+ const float* f;
+ void* y;
+
+ int2 up;
+ int2 down;
+ int2 pad0;
+ int flip;
+ float gain;
+
+ int4 inSize; // [width, height, channel, batch]
+ int4 inStride;
+ int2 filterSize; // [width, height]
+ int2 filterStride;
+ int4 outSize; // [width, height, channel, batch]
+ int4 outStride;
+ int sizeMinor;
+ int sizeMajor;
+
+ int loopMinor;
+ int loopMajor;
+ int loopX;
+ int launchMinor;
+ int launchMajor;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct upfirdn2d_kernel_spec
+{
+ void* kernel;
+ int tileOutW;
+ int tileOutH;
+ int loopMinor;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/ops/upfirdn2d.py b/diffusion-insgen/torch_utils/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6817174717fe0ef92d19904a68c6174f9ff24297
--- /dev/null
+++ b/diffusion-insgen/torch_utils/ops/upfirdn2d.py
@@ -0,0 +1,377 @@
+
+"""Custom PyTorch ops for efficient resampling of 2D images."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import traceback
+
+from .. import custom_ops
+from .. import misc
+from . import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+def _parse_scaling(scaling):
+ if isinstance(scaling, int):
+ scaling = [scaling, scaling]
+ assert isinstance(scaling, (list, tuple))
+ assert all(isinstance(x, int) for x in scaling)
+ sx, sy = scaling
+ assert sx >= 1 and sy >= 1
+ return sx, sy
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, int) for x in padding)
+ if len(padding) == 2:
+ padx, pady = padding
+ padding = [padx, padx, pady, pady]
+ padx0, padx1, pady0, pady1 = padding
+ return padx0, padx1, pady0, pady1
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ fw = f.shape[-1]
+ fh = f.shape[0]
+ with misc.suppress_tracer_warnings():
+ fw = int(fw)
+ fh = int(fh)
+ misc.assert_shape(f, [fh, fw][:f.ndim])
+ assert fw >= 1 and fh >= 1
+ return fw, fh
+
+#----------------------------------------------------------------------------
+
+def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
+
+ Args:
+ f: Torch tensor, numpy array, or python list of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable),
+ `[]` (impulse), or
+ `None` (identity).
+ device: Result device (default: cpu).
+ normalize: Normalize the filter so that it retains the magnitude
+ for constant input signal (DC)? (default: True).
+ flip_filter: Flip the filter? (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ separable: Return a separable filter? (default: select automatically).
+
+ Returns:
+ Float32 tensor of the shape
+ `[filter_height, filter_width]` (non-separable) or
+ `[filter_taps]` (separable).
+ """
+ # Validate.
+ if f is None:
+ f = 1
+ f = torch.as_tensor(f, dtype=torch.float32)
+ assert f.ndim in [0, 1, 2]
+ assert f.numel() > 0
+ if f.ndim == 0:
+ f = f[np.newaxis]
+
+ # Separable?
+ if separable is None:
+ separable = (f.ndim == 1 and f.numel() >= 8)
+ if f.ndim == 1 and not separable:
+ f = f.ger(f)
+ assert f.ndim == (1 if separable else 2)
+
+ # Apply normalize, flip, gain, and device.
+ if normalize:
+ f /= f.sum()
+ if flip_filter:
+ f = f.flip(list(range(f.ndim)))
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(device=device)
+ return f
+
+#----------------------------------------------------------------------------
+
+def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 4. Downsample the image by keeping every Nth pixel (`down`).
+
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ assert f.dtype == torch.float32 and not f.requires_grad
+ batch_size, num_channels, in_height, in_width = x.shape
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Upsample by inserting zeros.
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
+
+ # Pad or crop.
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
+
+ # Setup filter.
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(x.dtype)
+ if not flip_filter:
+ f = f.flip(list(range(f.ndim)))
+
+ # Convolve with the filter.
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
+ if f.ndim == 4:
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
+ else:
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
+
+ # Downsample by throwing away pixels.
+ x = x[:, :, ::downy, ::downx]
+ return x
+
+#----------------------------------------------------------------------------
+
+_upfirdn2d_cuda_cache = dict()
+
+def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
+ """
+ # Parse arguments.
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Lookup from cache.
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ if key in _upfirdn2d_cuda_cache:
+ return _upfirdn2d_cuda_cache[key]
+
+ # Forward op.
+ class Upfirdn2dCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ y = x
+ if f.ndim == 2:
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ else:
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
+ ctx.save_for_backward(f)
+ ctx.x_shape = x.shape
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ f, = ctx.saved_tensors
+ _, _, ih, iw = ctx.x_shape
+ _, _, oh, ow = dy.shape
+ fw, fh = _get_filter_size(f)
+ p = [
+ fw - padx0 - 1,
+ iw * upx - ow * downx + padx0 - upx + 1,
+ fh - pady0 - 1,
+ ih * upy - oh * downy + pady0 - upy + 1,
+ ]
+ dx = None
+ df = None
+
+ if ctx.needs_input_grad[0]:
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
+
+ assert not ctx.needs_input_grad[1]
+ return dx, df
+
+ # Add to cache.
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
+ return Upfirdn2dCuda
+
+#----------------------------------------------------------------------------
+
+def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape matches the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + fw // 2,
+ padx1 + (fw - 1) // 2,
+ pady0 + fh // 2,
+ pady1 + (fh - 1) // 2,
+ ]
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a multiple of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ upx, upy = _parse_scaling(up)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw + upx - 1) // 2,
+ padx1 + (fw - upx) // 2,
+ pady0 + (fh + upy - 1) // 2,
+ pady1 + (fh - upy) // 2,
+ ]
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a fraction of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the input. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw - downx + 1) // 2,
+ padx1 + (fw - downx) // 2,
+ pady0 + (fh - downy + 1) // 2,
+ pady1 + (fh - downy) // 2,
+ ]
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/persistence.py b/diffusion-insgen/torch_utils/persistence.py
new file mode 100644
index 0000000000000000000000000000000000000000..08921dddba38d2a7e6726d065ca15ec606026790
--- /dev/null
+++ b/diffusion-insgen/torch_utils/persistence.py
@@ -0,0 +1,244 @@
+
+"""Facilities for pickling Python code alongside other data.
+
+The pickled code is automatically imported into a separate Python module
+during unpickling. This way, any previously exported pickles will remain
+usable even if the original code is no longer available, or if the current
+version of the code is not consistent with what was originally pickled."""
+
+import sys
+import pickle
+import io
+import inspect
+import copy
+import uuid
+import types
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+_version = 6 # internal version number
+_decorators = set() # {decorator_class, ...}
+_import_hooks = [] # [hook_function, ...]
+_module_to_src_dict = dict() # {module: src, ...}
+_src_to_module_dict = dict() # {src: module, ...}
+
+#----------------------------------------------------------------------------
+
+def persistent_class(orig_class):
+ r"""Class decorator that extends a given class to save its source code
+ when pickled.
+
+ Example:
+
+ from torch_utils import persistence
+
+ @persistence.persistent_class
+ class MyNetwork(torch.nn.Module):
+ def __init__(self, num_inputs, num_outputs):
+ super().__init__()
+ self.fc = MyLayer(num_inputs, num_outputs)
+ ...
+
+ @persistence.persistent_class
+ class MyLayer(torch.nn.Module):
+ ...
+
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
+ source code alongside other internal state (e.g., parameters, buffers,
+ and submodules). This way, any previously exported pickle will remain
+ usable even if the class definitions have been modified or are no
+ longer available.
+
+ The decorator saves the source code of the entire Python module
+ containing the decorated class. It does *not* save the source code of
+ any imported modules. Thus, the imported modules must be available
+ during unpickling, also including `torch_utils.persistence` itself.
+
+ It is ok to call functions defined in the same module from the
+ decorated class. However, if the decorated class depends on other
+ classes defined in the same module, they must be decorated as well.
+ This is illustrated in the above example in the case of `MyLayer`.
+
+ It is also possible to employ the decorator just-in-time before
+ calling the constructor. For example:
+
+ cls = MyLayer
+ if want_to_make_it_persistent:
+ cls = persistence.persistent_class(cls)
+ layer = cls(num_inputs, num_outputs)
+
+ As an additional feature, the decorator also keeps track of the
+ arguments that were used to construct each instance of the decorated
+ class. The arguments can be queried via `obj.init_args` and
+ `obj.init_kwargs`, and they are automatically pickled alongside other
+ object state. A typical use case is to first unpickle a previous
+ instance of a persistent class, and then upgrade it to use the latest
+ version of the source code:
+
+ with open('old_pickle.pkl', 'rb') as f:
+ old_net = pickle.load(f)
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
+ """
+ assert isinstance(orig_class, type)
+ if is_persistent(orig_class):
+ return orig_class
+
+ assert orig_class.__module__ in sys.modules
+ orig_module = sys.modules[orig_class.__module__]
+ orig_module_src = _module_to_src(orig_module)
+
+ class Decorator(orig_class):
+ _orig_module_src = orig_module_src
+ _orig_class_name = orig_class.__name__
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._init_args = copy.deepcopy(args)
+ self._init_kwargs = copy.deepcopy(kwargs)
+ assert orig_class.__name__ in orig_module.__dict__
+ _check_pickleable(self.__reduce__())
+
+ @property
+ def init_args(self):
+ return copy.deepcopy(self._init_args)
+
+ @property
+ def init_kwargs(self):
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
+
+ def __reduce__(self):
+ fields = list(super().__reduce__())
+ fields += [None] * max(3 - len(fields), 0)
+ if fields[0] is not _reconstruct_persistent_obj:
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
+ fields[1] = (meta,) # reconstruct args
+ fields[2] = None # state dict
+ return tuple(fields)
+
+ Decorator.__name__ = orig_class.__name__
+ _decorators.add(Decorator)
+ return Decorator
+
+#----------------------------------------------------------------------------
+
+def is_persistent(obj):
+ r"""Test whether the given object or class is persistent, i.e.,
+ whether it will save its source code when pickled.
+ """
+ try:
+ if obj in _decorators:
+ return True
+ except TypeError:
+ pass
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
+
+#----------------------------------------------------------------------------
+
+def import_hook(hook):
+ r"""Register an import hook that is called whenever a persistent object
+ is being unpickled. A typical use case is to patch the pickled source
+ code to avoid errors and inconsistencies when the API of some imported
+ module has changed.
+
+ The hook should have the following signature:
+
+ hook(meta) -> modified meta
+
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
+
+ type: Type of the persistent object, e.g. `'class'`.
+ version: Internal version number of `torch_utils.persistence`.
+ module_src Original source code of the Python module.
+ class_name: Class name in the original Python module.
+ state: Internal state of the object.
+
+ Example:
+
+ @persistence.import_hook
+ def wreck_my_network(meta):
+ if meta.class_name == 'MyNetwork':
+ print('MyNetwork is being imported. I will wreck it!')
+ meta.module_src = meta.module_src.replace("True", "False")
+ return meta
+ """
+ assert callable(hook)
+ _import_hooks.append(hook)
+
+#----------------------------------------------------------------------------
+
+def _reconstruct_persistent_obj(meta):
+ r"""Hook that is called internally by the `pickle` module to unpickle
+ a persistent object.
+ """
+ meta = dnnlib.EasyDict(meta)
+ meta.state = dnnlib.EasyDict(meta.state)
+ for hook in _import_hooks:
+ meta = hook(meta)
+ assert meta is not None
+
+ assert meta.version == _version
+ module = _src_to_module(meta.module_src)
+
+ assert meta.type == 'class'
+ orig_class = module.__dict__[meta.class_name]
+ decorator_class = persistent_class(orig_class)
+ obj = decorator_class.__new__(decorator_class)
+
+ setstate = getattr(obj, '__setstate__', None)
+ if callable(setstate):
+ setstate(meta.state) # pylint: disable=not-callable
+ else:
+ obj.__dict__.update(meta.state)
+ return obj
+
+#----------------------------------------------------------------------------
+
+def _module_to_src(module):
+ r"""Query the source code of a given Python module.
+ """
+ src = _module_to_src_dict.get(module, None)
+ if src is None:
+ src = inspect.getsource(module)
+ _module_to_src_dict[module] = src
+ _src_to_module_dict[src] = module
+ return src
+
+def _src_to_module(src):
+ r"""Get or create a Python module for the given source code.
+ """
+ module = _src_to_module_dict.get(src, None)
+ if module is None:
+ module_name = "_imported_module_" + uuid.uuid4().hex
+ module = types.ModuleType(module_name)
+ sys.modules[module_name] = module
+ _module_to_src_dict[module] = src
+ _src_to_module_dict[src] = module
+ exec(src, module.__dict__) # pylint: disable=exec-used
+ return module
+
+#----------------------------------------------------------------------------
+
+def _check_pickleable(obj):
+ r"""Check that the given object is pickleable, raising an exception if
+ it is not. This function is expected to be considerably more efficient
+ than actually pickling the object.
+ """
+ def recurse(obj):
+ if isinstance(obj, (list, tuple, set)):
+ return [recurse(x) for x in obj]
+ if isinstance(obj, dict):
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
+ return None # Python primitive types are pickleable.
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
+ return None # NumPy arrays and PyTorch tensors are pickleable.
+ if is_persistent(obj):
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
+ return obj
+ with io.BytesIO() as f:
+ pickle.dump(recurse(obj), f)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/torch_utils/training_stats.py b/diffusion-insgen/torch_utils/training_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..feb4e54d1c628801ec6588596e247ced25b38442
--- /dev/null
+++ b/diffusion-insgen/torch_utils/training_stats.py
@@ -0,0 +1,261 @@
+
+"""Facilities for reporting and collecting training statistics across
+multiple processes and devices. The interface is designed to minimize
+synchronization overhead as well as the amount of boilerplate in user
+code."""
+
+import re
+import numpy as np
+import torch
+import dnnlib
+
+from . import misc
+
+#----------------------------------------------------------------------------
+
+_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
+_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
+_counter_dtype = torch.float64 # Data type to use for the internal counters.
+_rank = 0 # Rank of the current process.
+_sync_device = None # Device to use for multiprocess communication. None = single-process.
+_sync_called = False # Has _sync() been called yet?
+_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
+_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
+
+#----------------------------------------------------------------------------
+
+def init_multiprocessing(rank, sync_device):
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
+ across multiple processes.
+
+ This function must be called after
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
+ The call is not necessary if multi-process collection is not needed.
+
+ Args:
+ rank: Rank of the current process.
+ sync_device: PyTorch device to use for inter-process
+ communication, or None to disable multi-process
+ collection. Typically `torch.device('cuda', rank)`.
+ """
+ global _rank, _sync_device
+ assert not _sync_called
+ _rank = rank
+ _sync_device = sync_device
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def report(name, value):
+ r"""Broadcasts the given set of scalars to all interested instances of
+ `Collector`, across device and process boundaries.
+
+ This function is expected to be extremely cheap and can be safely
+ called from anywhere in the training loop, loss function, or inside a
+ `torch.nn.Module`.
+
+ Warning: The current implementation expects the set of unique names to
+ be consistent across processes. Please make sure that `report()` is
+ called at least once for each unique name by each process, and in the
+ same order. If a given process has no scalars to broadcast, it can do
+ `report(name, [])` (empty list).
+
+ Args:
+ name: Arbitrary string specifying the name of the statistic.
+ Averages are accumulated separately for each unique name.
+ value: Arbitrary set of scalars. Can be a list, tuple,
+ NumPy array, PyTorch tensor, or Python scalar.
+
+ Returns:
+ The same `value` that was passed in.
+ """
+ if name not in _counters:
+ _counters[name] = dict()
+
+ elems = torch.as_tensor(value)
+ if elems.numel() == 0:
+ return value
+
+ elems = elems.detach().flatten().to(_reduce_dtype)
+ moments = torch.stack([
+ torch.ones_like(elems).sum(),
+ elems.sum(),
+ elems.square().sum(),
+ ])
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
+ moments = moments.to(_counter_dtype)
+
+ device = moments.device
+ if device not in _counters[name]:
+ _counters[name][device] = torch.zeros_like(moments)
+ _counters[name][device].add_(moments)
+ return value
+
+#----------------------------------------------------------------------------
+
+def report0(name, value):
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
+ but ignores any scalars provided by the other processes.
+ See `report()` for further details.
+ """
+ report(name, value if _rank == 0 else [])
+ return value
+
+#----------------------------------------------------------------------------
+
+class Collector:
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
+ computes their long-term averages (mean and standard deviation) over
+ user-defined periods of time.
+
+ The averages are first collected into internal counters that are not
+ directly visible to the user. They are then copied to the user-visible
+ state as a result of calling `update()` and can then be queried using
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
+ internal counters for the next round, so that the user-visible state
+ effectively reflects averages collected between the last two calls to
+ `update()`.
+
+ Args:
+ regex: Regular expression defining which statistics to
+ collect. The default is to collect everything.
+ keep_previous: Whether to retain the previous averages if no
+ scalars were collected on a given round
+ (default: True).
+ """
+ def __init__(self, regex='.*', keep_previous=True):
+ self._regex = re.compile(regex)
+ self._keep_previous = keep_previous
+ self._cumulative = dict()
+ self._moments = dict()
+ self.update()
+ self._moments.clear()
+
+ def names(self):
+ r"""Returns the names of all statistics broadcasted so far that
+ match the regular expression specified at construction time.
+ """
+ return [name for name in _counters if self._regex.fullmatch(name)]
+
+ def update(self):
+ r"""Copies current values of the internal counters to the
+ user-visible state and resets them for the next round.
+
+ If `keep_previous=True` was specified at construction time, the
+ operation is skipped for statistics that have received no scalars
+ since the last update, retaining their previous averages.
+
+ This method performs a number of GPU-to-CPU transfers and one
+ `torch.distributed.all_reduce()`. It is intended to be called
+ periodically in the main training loop, typically once every
+ N training steps.
+ """
+ if not self._keep_previous:
+ self._moments.clear()
+ for name, cumulative in _sync(self.names()):
+ if name not in self._cumulative:
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ delta = cumulative - self._cumulative[name]
+ self._cumulative[name].copy_(cumulative)
+ if float(delta[0]) != 0:
+ self._moments[name] = delta
+
+ def _get_delta(self, name):
+ r"""Returns the raw moments that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ assert self._regex.fullmatch(name)
+ if name not in self._moments:
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ return self._moments[name]
+
+ def num(self, name):
+ r"""Returns the number of scalars that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ return int(delta[0])
+
+ def mean(self, name):
+ r"""Returns the mean of the scalars that were accumulated for the
+ given statistic between the last two calls to `update()`, or NaN if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0:
+ return float('nan')
+ return float(delta[1] / delta[0])
+
+ def std(self, name):
+ r"""Returns the standard deviation of the scalars that were
+ accumulated for the given statistic between the last two calls to
+ `update()`, or NaN if no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
+ return float('nan')
+ if int(delta[0]) == 1:
+ return float(0)
+ mean = float(delta[1] / delta[0])
+ raw_var = float(delta[2] / delta[0])
+ return np.sqrt(max(raw_var - np.square(mean), 0))
+
+ def as_dict(self):
+ r"""Returns the averages accumulated between the last two calls to
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
+
+ dnnlib.EasyDict(
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
+ ...
+ )
+ """
+ stats = dnnlib.EasyDict()
+ for name in self.names():
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
+ return stats
+
+ def __getitem__(self, name):
+ r"""Convenience getter.
+ `collector[name]` is a synonym for `collector.mean(name)`.
+ """
+ return self.mean(name)
+
+#----------------------------------------------------------------------------
+
+def _sync(names):
+ r"""Synchronize the global cumulative counters across devices and
+ processes. Called internally by `Collector.update()`.
+ """
+ if len(names) == 0:
+ return []
+ global _sync_called
+ _sync_called = True
+
+ # Collect deltas within current rank.
+ deltas = []
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
+ for name in names:
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
+ for counter in _counters[name].values():
+ delta.add_(counter.to(device))
+ counter.copy_(torch.zeros_like(counter))
+ deltas.append(delta)
+ deltas = torch.stack(deltas)
+
+ # Sum deltas across ranks.
+ if _sync_device is not None:
+ torch.distributed.all_reduce(deltas)
+
+ # Update cumulative values.
+ deltas = deltas.cpu()
+ for idx, name in enumerate(names):
+ if name not in _cumulative:
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ _cumulative[name].add_(deltas[idx])
+
+ # Return name-value pairs.
+ return [(name, _cumulative[name]) for name in names]
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/train.py b/diffusion-insgen/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d314761a9825a0ed6bbc89915292a7e44ddb8cb2
--- /dev/null
+++ b/diffusion-insgen/train.py
@@ -0,0 +1,605 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Train a GAN using the techniques described in the paper
+"Training Generative Adversarial Networks with Limited Data"."""
+
+import os
+import click
+import re
+import json
+import tempfile
+import torch
+import dnnlib
+
+from training import training_loop
+from metrics import metric_main
+from torch_utils import training_stats
+from torch_utils import custom_ops
+
+#----------------------------------------------------------------------------
+
+class UserError(Exception):
+ pass
+
+#----------------------------------------------------------------------------
+
+def setup_training_loop_kwargs(
+ # General options (not included in desc).
+ gpus = None, # Number of GPUs: , default = 1 gpu
+ snap = None, # Snapshot interval: , default = 50 ticks
+ metrics = None, # List of metric names: [], ['fid50k_full'] (default), ...
+ seed = None, # Random seed: , default = 0
+
+ # Dataset.
+ data = None, # Training dataset (required):
+ cond = None, # Train conditional model based on dataset labels: , default = False
+ subset = None, # Train with only N images: , default = all
+ mirror = None, # Augment dataset with x-flips: , default = False
+
+ # Base config.
+ cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
+ gamma = None, # Override R1 gamma:
+ kimg = None, # Override training duration:
+ batch = None, # Override batch size:
+
+ # Discriminator augmentation.
+ aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
+ p = None, # Specify p for 'fixed' (required):
+ target = None, # Override ADA target for 'ada': , default = depends on aug
+
+ # Transfer learning.
+ resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', ,
+ freezed = None, # Freeze-D: , default = 0 discriminator layers
+
+ # Performance options (not included in desc).
+ fp32 = None, # Disable mixed-precision training: , default = False
+ nhwc = None, # Use NHWC memory format with FP16: , default = False
+ allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: , default = False
+ nobench = None, # Disable cuDNN benchmarking: , default = False
+ workers = None, # Override number of DataLoader workers: , default = 3
+ # InsGen related options
+ no_insgen = False, # Disable insgen for training: , default = False
+ rqs = None, # Size of real image queue: , default = 5% * len(dataset)
+ fqs = None, # Size of fake image queue: , default = 5% * len(dataset)
+ no_cl_on_g = False, # Disable fake instance discrimination for generator: , default = False
+ ada_linear = False, # Whether to linearly increase the strength of ADA: , default = False
+
+ # Added
+ exp = None,
+ daug = 'ADA',
+
+ # Adaptive Diffusion config.
+ beta_schedule = None,
+ beta_start = None,
+ beta_end = None,
+ t_min = None,
+ t_max = None,
+ noise_sd = None,
+ ts_dist = None,
+ ada_maxp = None,
+):
+ args = dnnlib.EasyDict()
+
+ # ------------------------------------------
+ # General options: gpus, snap, metrics, seed
+ # ------------------------------------------
+
+ if gpus is None:
+ gpus = 1
+ assert isinstance(gpus, int)
+ if not (gpus >= 1 and gpus & (gpus - 1) == 0):
+ raise UserError('--gpus must be a power of two')
+ args.num_gpus = gpus
+
+ if snap is None:
+ snap = 50
+ assert isinstance(snap, int)
+ if snap < 1:
+ raise UserError('--snap must be at least 1')
+ args.image_snapshot_ticks = snap
+ args.network_snapshot_ticks = snap
+
+ if metrics is None:
+ metrics = ['fid50k_full']
+ assert isinstance(metrics, list)
+ if not all(metric_main.is_valid_metric(metric) for metric in metrics):
+ raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
+ args.metrics = metrics
+
+ if seed is None:
+ seed = 0
+ assert isinstance(seed, int)
+ args.random_seed = seed
+
+ # -----------------------------------
+ # Dataset: data, cond, subset, mirror
+ # -----------------------------------
+
+ assert data is not None
+ assert isinstance(data, str)
+ args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
+ args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
+ try:
+ training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
+ args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
+ args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
+ args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
+ desc = training_set.name
+ del training_set # conserve memory
+ except IOError as err:
+ raise UserError(f'--data: {err}')
+
+ if exp is not None:
+ desc += f'-{exp}'
+
+ if cond is None:
+ cond = False
+ assert isinstance(cond, bool)
+ if cond:
+ if not args.training_set_kwargs.use_labels:
+ raise UserError('--cond=True requires labels specified in dataset.json')
+ desc += '-cond'
+ else:
+ args.training_set_kwargs.use_labels = False
+
+ if subset is not None:
+ assert isinstance(subset, int)
+ if not 1 <= subset <= args.training_set_kwargs.max_size:
+ raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}')
+ desc += f'-subset{subset}'
+ if subset < args.training_set_kwargs.max_size:
+ args.training_set_kwargs.max_size = subset
+ args.training_set_kwargs.random_seed = args.random_seed
+
+ if mirror is None:
+ mirror = False
+ assert isinstance(mirror, bool)
+ if mirror:
+ desc += '-mirror'
+ args.training_set_kwargs.xflip = True
+
+ # ------------------------------------
+ # Base config: cfg, gamma, kimg, batch
+ # ------------------------------------
+
+ if cfg is None:
+ cfg = 'auto'
+ assert isinstance(cfg, str)
+ desc += f'-{cfg}'
+
+ cfg_specs = {
+ 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
+ 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
+ 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
+ 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
+ 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
+ 'cifar': dict(ref_gpus=4, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
+ }
+
+ assert cfg in cfg_specs
+ spec = dnnlib.EasyDict(cfg_specs[cfg])
+ if cfg == 'auto':
+ desc += f'{gpus:d}'
+ spec.ref_gpus = gpus
+ res = args.training_set_kwargs.resolution
+ spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
+ spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
+ spec.fmaps = 1 if res >= 512 else 0.5
+ spec.lrate = 0.002 if res >= 1024 else 0.0025
+ spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
+ spec.ema = spec.mb * 10 / 32
+
+ args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
+ args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
+ args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
+ args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
+ args.G_kwargs.mapping_kwargs.num_layers = spec.map
+ args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
+ args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
+ args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
+ args.D_kwargs.mapping_kwargs.num_layers = 0 # align with tensorflow implementation of ADA
+
+ args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
+ args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
+ args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)
+
+ args.total_kimg = spec.kimg
+ args.batch_size = spec.mb
+ args.batch_gpu = spec.mb // spec.ref_gpus
+ args.ema_kimg = spec.ema
+ args.ema_rampup = spec.ramp
+
+ if cfg == 'cifar':
+ args.loss_kwargs.pl_weight = 0 # disable path length regularization
+ args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
+ args.D_kwargs.architecture = 'orig' # disable residual skip connections
+
+ if gamma is not None:
+ assert isinstance(gamma, float)
+ if not gamma >= 0:
+ raise UserError('--gamma must be non-negative')
+ desc += f'-gamma{gamma:g}'
+ args.loss_kwargs.r1_gamma = gamma
+
+ if kimg is not None:
+ assert isinstance(kimg, int)
+ if not kimg >= 1:
+ raise UserError('--kimg must be at least 1')
+ desc += f'-kimg{kimg:d}'
+ args.total_kimg = kimg
+
+ if batch is not None:
+ assert isinstance(batch, int)
+ if not (batch >= 1 and batch % gpus == 0):
+ raise UserError('--batch must be at least 1 and divisible by --gpus')
+ desc += f'-batch{batch}'
+ args.batch_size = batch
+ args.batch_gpu = batch // gpus
+
+ # ---------------------------------------------------
+ # Discriminator augmentation: aug, p, target, augpipe
+ # ---------------------------------------------------
+
+ if aug is None:
+ aug = 'ada'
+ else:
+ assert isinstance(aug, str)
+ desc += f'-{aug}'
+
+ if aug == 'ada':
+ args.ada_target = 0.6
+
+ elif aug == 'noaug':
+ pass
+
+ elif aug == 'fixed':
+ if p is None:
+ raise UserError(f'--aug={aug} requires specifying --p')
+
+ else:
+ raise UserError(f'--aug={aug} not supported')
+
+ if p is not None:
+ assert isinstance(p, float)
+ if aug != 'fixed':
+ raise UserError('--p can only be specified with --aug=fixed')
+ if not 0 <= p <= 1:
+ raise UserError('--p must be between 0 and 1')
+ desc += f'-p{p:g}'
+ args.augment_p = p
+
+ if target is not None:
+ assert isinstance(target, float)
+ if aug != 'ada':
+ raise UserError('--target can only be specified with --aug=ada')
+ if not 0 <= target <= 1:
+ raise UserError('--target must be between 0 and 1')
+ desc += f'-target{target:g}'
+ args.ada_target = target
+
+ diffusion_specs = dict(beta_schedule=beta_schedule, beta_start=beta_start, beta_end=beta_end,
+ t_min=t_min, t_max=t_max, noise_std=noise_sd,
+ aug=daug, ada_maxp=ada_maxp, ts_dist=ts_dist)
+
+ desc += f"-ts_dist-{ts_dist}"
+ if aug != 'noaug':
+ args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **diffusion_specs)
+
+ # ----------------------------------
+ # Transfer learning: resume, freezed
+ # ----------------------------------
+
+ resume_specs = {
+ 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
+ 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
+ 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
+ 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
+ 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
+ }
+
+ assert resume is None or isinstance(resume, str)
+ if resume is None:
+ resume = 'noresume'
+ elif resume == 'noresume':
+ desc += '-noresume'
+ elif resume in resume_specs:
+ desc += f'-resume{resume}'
+ args.resume_pkl = resume_specs[resume] # predefined url
+ else:
+ desc += '-resumecustom'
+ args.resume_pkl = resume # custom path or url
+
+ if resume != 'noresume':
+ args.ada_kimg = 100 # make ADA react faster at the beginning
+ args.ema_rampup = None # disable EMA rampup
+ args.ada_kimg = 100
+
+ if freezed is not None:
+ assert isinstance(freezed, int)
+ if not freezed >= 0:
+ raise UserError('--freezed must be non-negative')
+ desc += f'-freezed{freezed:d}'
+ args.D_kwargs.block_kwargs.freeze_layers = freezed
+
+ # -------------------------------------------------
+ # Performance options: fp32, nhwc, nobench, workers
+ # -------------------------------------------------
+
+ if fp32 is None:
+ fp32 = False
+ assert isinstance(fp32, bool)
+ if fp32:
+ args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
+ args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
+
+ if nhwc is None:
+ nhwc = False
+ assert isinstance(nhwc, bool)
+ if nhwc:
+ args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True
+
+ if nobench is None:
+ nobench = False
+ assert isinstance(nobench, bool)
+ if nobench:
+ args.cudnn_benchmark = False
+
+ if allow_tf32 is None:
+ allow_tf32 = False
+ assert isinstance(allow_tf32, bool)
+ if allow_tf32:
+ args.allow_tf32 = True
+
+ if workers is not None:
+ assert isinstance(workers, int)
+ if not workers >= 1:
+ raise UserError('--workers must be at least 1')
+ args.data_loader_kwargs.num_workers = workers
+
+ # ----------------------------------------------------
+ # InsGen: contrastive_head, no_cl_on_g, cl_loss_weight
+ # ----------------------------------------------------
+ use_insgen = True
+ if no_insgen is not None:
+ assert isinstance(no_insgen, bool)
+ use_insgen = not no_insgen
+
+ if use_insgen:
+ # Overwrite class name of loss function
+ args.loss_kwargs.class_name = 'training.contrastive_loss.StyleGAN2LossCL'
+
+ args.DHead_kwargs = dnnlib.EasyDict(class_name='training.contrastive_head.CLHead', inplanes=512, temperature=0.2, momentum=0.999, queue_size=-1)
+ args.GHead_kwargs = dnnlib.EasyDict(class_name='training.contrastive_head.CLHead', inplanes=512, temperature=0.2, momentum=0.999, queue_size=-1)
+ # Default queue size is 0.05 * len(dataset)
+ default_queue_size = int(0.05 * args.training_set_kwargs.max_size)
+ if args.training_set_kwargs.xflip:
+ default_queue_size *= 2
+ args.DHead_kwargs.queue_size = default_queue_size if rqs is None else rqs
+ args.GHead_kwargs.queue_size = default_queue_size if fqs is None else fqs
+
+ if no_cl_on_g is not None:
+ assert isinstance(no_cl_on_g, bool)
+ args.no_cl_on_g = no_cl_on_g
+ if ada_linear is not None:
+ assert isinstance(ada_linear, bool)
+ args.ada_linear = ada_linear
+ # Default loss weight for real instance discrimination, fake instance discrimination and fake instance discrimination on g
+ args.cl_loss_weight = dnnlib.EasyDict(lw_real_cl=1.0, lw_fake_cl=1.0, lw_fake_cl_on_g=0.1)
+ else:
+ args.DHead_kwargs = None
+ args.GHead_kwargs = None
+
+ return desc, args
+
+#----------------------------------------------------------------------------
+
+def subprocess_fn(rank, args, temp_dir):
+ dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)
+
+ # Init torch.distributed.
+ if args.num_gpus > 1:
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
+ else:
+ init_method = f'file://{init_file}'
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
+
+ # Init torch_utils.
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
+ if rank != 0:
+ custom_ops.verbosity = 'none'
+
+ # Execute training loop.
+ training_loop.training_loop(rank=rank, **args)
+
+#----------------------------------------------------------------------------
+
+class CommaSeparatedList(click.ParamType):
+ name = 'list'
+
+ def convert(self, value, param, ctx):
+ _ = param, ctx
+ if value is None or value.lower() == 'none' or value == '':
+ return []
+ return value.split(',')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+
+# General options.
+@click.option('--outdir', help='Where to save the results', required=True, metavar='DIR')
+@click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT')
+@click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT')
+@click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
+@click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT')
+@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
+@click.option('--exp', help='exp id', type=str)
+
+# Dataset.
+@click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True)
+@click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL')
+@click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT')
+@click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL', default=1)
+
+# Base config.
+@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
+@click.option('--gamma', help='Override R1 gamma', type=float)
+@click.option('--kimg', help='Override training duration', type=int, metavar='INT')
+@click.option('--batch', help='Override batch size', type=int, metavar='INT')
+
+# Discriminator augmentation.
+@click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
+@click.option('--daug', help='Augmentation mode [default: ada]', type=click.Choice(['NO', 'ADA', 'DIFF']), default='ADA')
+@click.option('--p', help='Augmentation probability for --aug=fixed', type=float)
+
+# Adaptive diffusion config.
+@click.option('--beta_schedule', help='Forward diffusion beta schedule (we use linear always)', type=str, default='linear')
+@click.option('--beta_start', help='Forward diffusion process beta_start', type=float, default=1e-4)
+@click.option('--beta_end', help='Forward diffusion process beta_end', type=float, default=2e-2)
+@click.option('--t_min', help='Minimum # of timesteps for adaptively modification', type=int, default=10)
+@click.option('--t_max', help='Maximum # of timesteps for adaptively modification', type=int, default=500)
+@click.option('--noise_sd', help='Diffusion noise standard deviation', type=float, default=0.05)
+@click.option('--ts_dist', help='Diffusion t sampling way', type=click.Choice(['priority', 'uniform']), default='uniform')
+@click.option('--target', help='Discriminator target value', type=float, default=0.6)
+
+# Transfer learning.
+@click.option('--resume', help='Resume training [default: noresume]', metavar='PKL')
+@click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT')
+
+# Performance options.
+@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
+@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
+@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
+@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
+@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
+
+# InsGen related options.
+@click.option('--no_insgen', help='Disable InsGen back to ADA [default: False]', type=bool, metavar='BOOL')
+@click.option('--rqs', help='Size of real image queue [default: 5% * len(dataset)]', type=int, metavar='INT')
+@click.option('--fqs', help='Size of fake image queue [default: 5% * len(dataset)]', type=int, metavar='INT')
+@click.option('--no_cl_on_g', help='Disable fake instance discrimination for generator [default: False]', type=bool, metavar='BOOL')
+@click.option('--ada_linear', help='Whether to linearly increase the strength of ADA [default: False]', type=bool, metavar='BOOL')
+
+
+def main(ctx, outdir, dry_run, **config_kwargs):
+ """Train a GAN using the techniques described in the paper
+ "Training Generative Adversarial Networks with Limited Data".
+
+ Examples:
+
+ \b
+ # Train with custom dataset using 1 GPU.
+ python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
+
+ \b
+ # Train class-conditional CIFAR-10 using 2 GPUs.
+ python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
+ --gpus=2 --cfg=cifar --cond=1
+
+ \b
+ # Transfer learn MetFaces from FFHQ using 4 GPUs.
+ python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
+ --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
+
+ \b
+ # Reproduce original StyleGAN2 config F.
+ python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
+ --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
+
+ \b
+ Base configs (--cfg):
+ auto Automatically select reasonable defaults based on resolution
+ and GPU count. Good starting point for new datasets.
+ stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
+ paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
+ paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
+ paper1024 Reproduce results for MetFaces at 1024x1024.
+ cifar Reproduce results for CIFAR-10 at 32x32.
+
+ \b
+ Transfer learning source networks (--resume):
+ ffhq256 FFHQ trained at 256x256 resolution.
+ ffhq512 FFHQ trained at 512x512 resolution.
+ ffhq1024 FFHQ trained at 1024x1024 resolution.
+ celebahq256 CelebA-HQ trained at 256x256 resolution.
+ lsundog256 LSUN Dog trained at 256x256 resolution.
+ Custom network pickle.
+ """
+ dnnlib.util.Logger(should_flush=True)
+
+ # Setup training options.
+ try:
+ run_desc, args = setup_training_loop_kwargs(**config_kwargs)
+ except UserError as err:
+ ctx.fail(err)
+
+ # Pick output directory.
+ prev_run_dirs = []
+ if os.path.isdir(outdir):
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
+
+ matching_dirs = [re.fullmatch(r'\d{5}' + f'-{run_desc}', x) for x in prev_run_dirs if
+ re.fullmatch(r'\d{5}' + f'-{run_desc}', x) is not None]
+ if len(matching_dirs) > 0: # expect unique desc, continue in this directory
+ assert len(matching_dirs) == 1, f'Multiple directories found for resuming: {matching_dirs}'
+ run_dir = os.path.join(outdir, matching_dirs[0].group())
+ else: # fallback to standard
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
+ cur_run_id = max(prev_run_ids, default=-1) + 1
+ run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
+ assert not os.path.exists(run_dir)
+ args.run_dir = run_dir
+
+ # Print options.
+ print()
+ print('Training options:')
+ print(json.dumps(args, indent=2))
+ print()
+ print(f'Output directory: {args.run_dir}')
+ print(f'Training data: {args.training_set_kwargs.path}')
+ print(f'Training duration: {args.total_kimg} kimg')
+ print(f'Number of GPUs: {args.num_gpus}')
+ print(f'Number of images: {args.training_set_kwargs.max_size}')
+ print(f'Image resolution: {args.training_set_kwargs.resolution}')
+ print(f'Conditional model: {args.training_set_kwargs.use_labels}')
+ print(f'Dataset x-flips: {args.training_set_kwargs.xflip}')
+ print()
+
+ # Dry run?
+ if dry_run:
+ print('Dry run; exiting.')
+ return
+
+ # Create output directory.
+ print('Creating output directory...')
+ os.makedirs(args.run_dir, exist_ok=True)
+ with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f:
+ json.dump(args, f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn')
+ with tempfile.TemporaryDirectory() as temp_dir:
+ if args.num_gpus == 1:
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
+ else:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/__init__.py b/diffusion-insgen/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e1a5ba99e56a56ecaa14f7d4fa41777789c0cf
--- /dev/null
+++ b/diffusion-insgen/training/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/diffusion-insgen/training/adaaug.py b/diffusion-insgen/training/adaaug.py
new file mode 100644
index 0000000000000000000000000000000000000000..163f439c7cd51b78b7eb8a259abe5e3baad439e9
--- /dev/null
+++ b/diffusion-insgen/training/adaaug.py
@@ -0,0 +1,449 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import scipy.signal
+import torch
+from torch_utils import persistence
+from torch_utils import misc
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import grid_sample_gradfix
+from torch_utils.ops import conv2d_gradfix
+
+
+def AdaAugment(p=0.25, spec='bgc'):
+ return ADA(p=p, **augpipe_specs[spec])
+
+#----------------------------------------------------------------------------
+# Coefficients of various wavelet decomposition low-pass filters.
+
+augpipe_specs = {
+ 'blit': dict(xflip=1, rotate90=1, xint=1),
+ 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
+ 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
+ 'filter': dict(imgfilter=1),
+ 'noise': dict(noise=1),
+ 'cutout': dict(cutout=1),
+ 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
+ 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
+ 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
+ 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
+ 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
+}
+
+wavelets = {
+ 'haar': [0.7071067811865476, 0.7071067811865476],
+ 'db1': [0.7071067811865476, 0.7071067811865476],
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
+}
+
+#----------------------------------------------------------------------------
+# Helpers for constructing transformation matrices.
+
+def matrix(*rows, device=None):
+ assert all(len(row) == len(rows[0]) for row in rows)
+ elems = [x for row in rows for x in row]
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
+ if len(ref) == 0:
+ return misc.constant(np.asarray(rows), device=device)
+ assert device is None or device == ref[0].device
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
+
+def translate2d(tx, ty, **kwargs):
+ return matrix(
+ [1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1],
+ **kwargs)
+
+def translate3d(tx, ty, tz, **kwargs):
+ return matrix(
+ [1, 0, 0, tx],
+ [0, 1, 0, ty],
+ [0, 0, 1, tz],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def scale2d(sx, sy, **kwargs):
+ return matrix(
+ [sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1],
+ **kwargs)
+
+def scale3d(sx, sy, sz, **kwargs):
+ return matrix(
+ [sx, 0, 0, 0],
+ [0, sy, 0, 0],
+ [0, 0, sz, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def rotate2d(theta, **kwargs):
+ return matrix(
+ [torch.cos(theta), torch.sin(-theta), 0],
+ [torch.sin(theta), torch.cos(theta), 0],
+ [0, 0, 1],
+ **kwargs)
+
+def rotate3d(v, theta, **kwargs):
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
+ return matrix(
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def translate2d_inv(tx, ty, **kwargs):
+ return translate2d(-tx, -ty, **kwargs)
+
+def scale2d_inv(sx, sy, **kwargs):
+ return scale2d(1 / sx, 1 / sy, **kwargs)
+
+def rotate2d_inv(theta, **kwargs):
+ return rotate2d(-theta, **kwargs)
+
+#----------------------------------------------------------------------------
+# Versatile image augmentation pipeline from the paper
+# "Training Generative Adversarial Networks with Limited Data".
+#
+# All augmentations are disabled by default; individual augmentations can
+# be enabled by setting their probability multipliers to 1.
+
+@persistence.persistent_class
+class ADA(torch.nn.Module):
+ def __init__(self, p=0.6,
+ xflip=0, rotate90=0, xint=0, xint_max=0.125,
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
+ noise=0, cutout=0, noise_std=0.1, cutout_size=0.5
+ ):
+ super().__init__()
+ self.p = torch.tensor(p) # Overall multiplier for augmentation probability.
+
+ # Pixel blitting.
+ self.xflip = float(xflip) # Probability multiplier for x-flip.
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
+ self.xint = float(xint) # Probability multiplier for integer translation.
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
+
+ # General geometric transformations.
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
+
+ # Color transformations.
+ self.brightness = float(brightness) # Probability multiplier for brightness.
+ self.contrast = float(contrast) # Probability multiplier for contrast.
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
+ self.hue = float(hue) # Probability multiplier for hue rotation.
+ self.saturation = float(saturation) # Probability multiplier for saturation.
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
+
+ # Image-space filtering.
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
+
+ # Image-space corruptions.
+ self.noise = float(noise) # Probability multiplier for additive RGB noise.
+ self.cutout = float(cutout) # Probability multiplier for cutout.
+ self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
+ self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
+
+ # Setup orthogonal lowpass filter for geometric augmentations.
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
+
+ # Construct filter bank for image-space filtering.
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
+ for i in range(1, Hz_fbank.shape[0]):
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
+
+ def forward(self, images, debug_percentile=None):
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
+ batch_size, num_channels, height, width = images.shape
+ device = images.device
+ if debug_percentile is not None:
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
+
+ # -------------------------------------
+ # Select parameters for pixel blitting.
+ # -------------------------------------
+
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
+ I_3 = torch.eye(3, device=device)
+ G_inv = I_3
+
+ # Apply x-flip with probability (xflip * strength).
+ if self.xflip > 0:
+ i = torch.floor(torch.rand([batch_size], device=device) * 2)
+ i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
+ G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
+
+ # Apply 90 degree rotations with probability (rotate90 * strength).
+ if self.rotate90 > 0:
+ i = torch.floor(torch.rand([batch_size], device=device) * 4)
+ i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 4))
+ G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
+
+ # Apply integer translation with probability (xint * strength).
+ if self.xint > 0:
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
+
+ # --------------------------------------------------------
+ # Select parameters for general geometric transformations.
+ # --------------------------------------------------------
+
+ # Apply isotropic scaling with probability (scale * strength).
+ if self.scale > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
+ G_inv = G_inv @ scale2d_inv(s, s)
+
+ # Apply pre-rotation with probability p_rot.
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
+
+ # Apply anisotropic scaling with probability (aniso * strength).
+ if self.aniso > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
+
+ # Apply post-rotation with probability p_rot.
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.zeros_like(theta)
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
+
+ # Apply fractional translation with probability (xfrac * strength).
+ if self.xfrac > 0:
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
+
+ # ----------------------------------
+ # Execute geometric transformations.
+ # ----------------------------------
+
+ # Execute if the transform is not identity.
+ if G_inv is not I_3:
+
+ # Calculate padding.
+ cx = (width - 1) / 2
+ cy = (height - 1) / 2
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
+ Hz_pad = self.Hz_geom.shape[0] // 4
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
+
+ # Pad image and adjust origin.
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
+
+ # Upsample.
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
+
+ # Execute transformation.
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
+ images = grid_sample_gradfix.grid_sample(images, grid)
+
+ # Downsample and crop.
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
+
+ # --------------------------------------------
+ # Select parameters for color transformations.
+ # --------------------------------------------
+
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
+ I_4 = torch.eye(4, device=device)
+ C = I_4
+
+ # Apply brightness with probability (brightness * strength).
+ if self.brightness > 0:
+ b = torch.randn([batch_size], device=device) * self.brightness_std
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
+ if debug_percentile is not None:
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
+ C = translate3d(b, b, b) @ C
+
+ # Apply contrast with probability (contrast * strength).
+ if self.contrast > 0:
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
+ if debug_percentile is not None:
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
+ C = scale3d(c, c, c) @ C
+
+ # Apply luma flip with probability (lumaflip * strength).
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
+ if self.lumaflip > 0:
+ i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
+ i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
+ C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
+
+ # Apply hue rotation with probability (hue * strength).
+ if self.hue > 0 and num_channels > 1:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
+ C = rotate3d(v, theta) @ C # Rotate around v.
+
+ # Apply saturation with probability (saturation * strength).
+ if self.saturation > 0 and num_channels > 1:
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
+
+ # ------------------------------
+ # Execute color transformations.
+ # ------------------------------
+
+ # Execute if the transform is not identity.
+ if C is not I_4:
+ images = images.reshape([batch_size, num_channels, height * width])
+ if num_channels == 3:
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
+ elif num_channels == 1:
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
+ else:
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ----------------------
+ # Image-space filtering.
+ # ----------------------
+
+ if self.imgfilter > 0:
+ num_bands = self.Hz_fbank.shape[0]
+ assert len(self.imgfilter_bands) == num_bands
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
+
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
+ for i, band_strength in enumerate(self.imgfilter_bands):
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
+ if debug_percentile is not None:
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
+ t[:, i] = t_i # Replace i'th element.
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
+ g = g * t # Accumulate into global gain.
+
+ # Construct combined amplification filter.
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
+
+ # Apply filter.
+ p = self.Hz_fbank.shape[1] // 2
+ images = images.reshape([1, batch_size * num_channels, height, width])
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ------------------------
+ # Image-space corruptions.
+ # ------------------------
+
+ # Apply additive RGB noise with probability (noise * strength).
+ if self.noise > 0:
+ sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
+ sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
+ if debug_percentile is not None:
+ sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
+ images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
+
+ # Apply cutout with probability (cutout * strength).
+ if self.cutout > 0:
+ size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
+ size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
+ center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
+ if debug_percentile is not None:
+ size = torch.full_like(size, self.cutout_size)
+ center = torch.full_like(center, debug_percentile)
+ coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
+ coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
+ mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
+ mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
+ mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
+ images = images * mask
+
+ return images
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/augment.py b/diffusion-insgen/training/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..477f1ef1ee5b1c40d6609baf20320db97c527696
--- /dev/null
+++ b/diffusion-insgen/training/augment.py
@@ -0,0 +1,209 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import scipy.signal
+import torch
+from torch_utils import persistence
+from torch_utils import misc
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import grid_sample_gradfix
+from torch_utils.ops import conv2d_gradfix
+
+from training.diffaug import DiffAugment
+from training.adaaug import AdaAugment
+
+#----------------------------------------------------------------------------
+# Helpers for doing defusion process.
+
+
+def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps):
+ def sigmoid(x):
+ return 1 / (np.exp(-x) + 1)
+
+ def continuous_t_beta(t, T):
+ b_max = 5.
+ b_min = 0.1
+ alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
+ return 1 - alpha
+
+ if beta_schedule == "continuous_t":
+ betas = continuous_t_beta(np.arange(1, num_diffusion_timesteps+1), num_diffusion_timesteps)
+ elif beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "sigmoid":
+ betas = np.linspace(-6, 6, num_diffusion_timesteps)
+ betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
+ elif beta_schedule == 'cosine':
+ """
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ """
+ s = 0.008
+ steps = num_diffusion_timesteps + 1
+ x = np.linspace(0, steps, steps)
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
+ return betas_clipped
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type='gauss', noise_std=1.0):
+ if noise_type == 'gauss':
+ noise = torch.randn_like(x_0, device=x_0.device) * noise_std
+ elif noise_type == 'bernoulli':
+ noise = (torch.bernoulli(torch.ones_like(x_0) * 0.5) * 2 - 1.) * noise_std
+ else:
+ raise NotImplementedError(noise_type)
+ alphas_t_sqrt = alphas_bar_sqrt[t].view(-1, 1, 1, 1)
+ one_minus_alphas_bar_t_sqrt = one_minus_alphas_bar_sqrt[t].view(-1, 1, 1, 1)
+ x_t = alphas_t_sqrt * x_0 + one_minus_alphas_bar_t_sqrt * noise
+ return x_t
+
+
+def q_sample_c(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type='gauss', noise_std=1.0):
+ batch_size, num_channels, _, _ = x_0.shape
+ if noise_type == 'gauss':
+ noise = torch.randn_like(x_0, device=x_0.device) * noise_std
+ elif noise_type == 'bernoulli':
+ noise = (torch.bernoulli(torch.ones_like(x_0) * 0.5) * 2 - 1.) * noise_std
+ else:
+ raise NotImplementedError(noise_type)
+ alphas_t_sqrt = alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1)
+ one_minus_alphas_bar_t_sqrt = one_minus_alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1)
+ x_t = alphas_t_sqrt * x_0 + one_minus_alphas_bar_t_sqrt * noise
+ return x_t
+
+
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return x
+
+
+@persistence.persistent_class
+class AugmentPipe(torch.nn.Module):
+ def __init__(self,
+ beta_schedule='linear', beta_start=1e-4, beta_end=2e-2, t_min=10, t_max=1000,
+ noise_std=0.05, aug='NO', ada_maxp=None, ts_dist='priority', update_beta=True,
+ ):
+ super().__init__()
+ self.p = 0.0 # Overall multiplier for augmentation probability.
+ self.aug_type = aug
+ self.ada_maxp = ada_maxp
+ self.noise_type = self.base_noise_type = 'gauss'
+ self.beta_schedule = beta_schedule
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.t_min = t_min
+ self.t_max = t_max
+ self.t_add = int(t_max - t_min)
+ self.ts_dist = ts_dist
+
+ # Image-space corruptions.
+ self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
+ self.noise_type = "gauss"
+ if aug == 'ADA':
+ self.aug = AdaAugment(p=0.0)
+ elif aug == 'DIFF':
+ self.aug = DiffAugment()
+ else:
+ self.aug = Identity()
+
+ self.update_beta = update_beta
+ if not update_beta:
+ self.set_diffusion_process(t_max, beta_schedule)
+ self.update_T()
+
+ def set_diffusion_process(self, t, beta_schedule):
+
+ betas = get_beta_schedule(
+ beta_schedule=beta_schedule,
+ beta_start=self.beta_start,
+ beta_end=self.beta_end,
+ num_diffusion_timesteps=t,
+ )
+
+ betas = self.betas = torch.from_numpy(betas).float()
+ self.num_timesteps = betas.shape[0]
+
+ alphas = self.alphas = 1.0 - betas
+ alphas_cumprod = torch.cat([torch.tensor([1.]), alphas.cumprod(dim=0)])
+ self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod)
+ self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod)
+
+ def update_T(self):
+ if self.aug_type == 'ADA':
+ _p = min(self.p, self.ada_maxp) if self.ada_maxp else self.p
+ self.aug.p.copy_(torch.tensor(_p))
+
+ t_adjust = round(self.p * self.t_add)
+ t = np.clip(int(self.t_min + t_adjust), a_min=self.t_min, a_max=self.t_max)
+
+ if self.update_beta:
+ if self.beta_schedule == 'linear_cosine':
+ if t >= 500:
+ self.set_diffusion_process(t, 'cosine')
+ else:
+ self.set_diffusion_process(t, 'linear')
+ else:
+ self.set_diffusion_process(t, self.beta_schedule)
+
+ # sampling t
+ self.t_epl = np.zeros(64, dtype=np.int)
+ diffusion_ind = 32
+ t_diffusion = np.zeros((diffusion_ind,)).astype(np.int)
+ if self.ts_dist == 'priority':
+ prob_t = np.arange(t) / np.arange(t).sum()
+ t_diffusion = np.random.choice(np.arange(1, t + 1), size=diffusion_ind, p=prob_t)
+ elif self.ts_dist == 'uniform':
+ t_diffusion = np.random.choice(np.arange(1, t + 1), size=diffusion_ind)
+ self.t_epl[:diffusion_ind] = t_diffusion
+
+ def forward(self, x_0):
+ x_0 = self.aug(x_0)
+ assert isinstance(x_0, torch.Tensor) and x_0.ndim == 4
+ batch_size, num_channels, height, width = x_0.shape
+ device = x_0.device
+
+ alphas_bar_sqrt = self.alphas_bar_sqrt.to(device)
+ one_minus_alphas_bar_sqrt = self.one_minus_alphas_bar_sqrt.to(device)
+
+ t = torch.from_numpy(np.random.choice(self.t_epl, size=batch_size, replace=True)).to(device)
+ x_t = q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t,
+ noise_type=self.noise_type,
+ noise_std=self.noise_std)
+ # x_t = self.aug(x_t)
+ return x_t, t.view(-1, 1)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/contrastive_head.py b/diffusion-insgen/training/contrastive_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e09367e3a201beadd0f45d6aede3d37d866f0964
--- /dev/null
+++ b/diffusion-insgen/training/contrastive_head.py
@@ -0,0 +1,203 @@
+# Code are mainly borrowed from the official implementation of MoCo (https://github.com/facebookresearch/moco)
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_utils import misc
+from torch_utils import persistence
+
+#----------------------------------------------------------------------------
+
+# Contrastive head
+@persistence.persistent_class
+class CLHead(torch.nn.Module):
+ def __init__(self,
+ inplanes = 256, # Number of input features
+ temperature = 0.2, # Temperature of logits
+ queue_size = 3500, # Number of stored negative samples
+ momentum = 0.999, # Momentum for updating network
+ ):
+ super().__init__()
+ self.inplanes = inplanes
+ self.temperature = temperature
+ self.queue_size = queue_size
+ self.m = momentum
+
+ self.mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128))
+ self.momentum_mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128))
+ self.momentum_mlp.requires_grad_(False)
+
+ for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+
+ # create the queue
+ self.register_buffer("queue", torch.randn(128, self.queue_size))
+ self.queue = nn.functional.normalize(self.queue, dim=0)
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
+
+
+ @torch.no_grad()
+ def _momentum_update_key_encoder(self):
+ """
+ Momentum update of the key encoder
+ """
+ for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()):
+ param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, keys):
+ # gather keys before updating queue
+ keys = concat_all_gather(keys)
+
+ batch_size = keys.shape[0]
+ keys = keys.T
+ ptr = int(self.queue_ptr)
+ if batch_size > self.queue_size:
+ self.queue[:, 0:] = keys[:, :self.queue_size]
+
+ elif ptr + batch_size > self.queue_size:
+ self.queue[:, ptr:] = keys[:, :self.queue_size - ptr]
+ self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[:, self.queue_size-ptr:]
+ self.queue_ptr[0] = batch_size - (self.queue_size - ptr)
+ else:
+ self.queue[:, ptr:ptr + batch_size] = keys
+ self.queue_ptr[0] = ptr + batch_size
+
+ @torch.no_grad()
+ def _batch_shuffle_ddp(self, x):
+ """
+ Batch shuffle, for making use of BatchNorm.
+ """
+
+ # If non-distributed now, return raw input directly.
+ # We have no idea the effect of disabling shuffle BN to MoCo.
+ # Thus, we recommand train InsGen with more than 1 GPU always.
+ if not torch.distributed.is_initialized():
+ return x, torch.arange(x.shape[0])
+
+ # gather from all gpus
+ device = x.device
+ batch_size_this = x.shape[0]
+ x_gather = concat_all_gather(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # random shuffle index
+ idx_shuffle = torch.randperm(batch_size_all).cuda(device)
+
+ # broadcast to all gpus
+ torch.distributed.broadcast(idx_shuffle, src=0)
+
+ # index for restoring
+ idx_unshuffle = torch.argsort(idx_shuffle)
+
+ # shuffled index for this gpu
+ gpu_idx = torch.distributed.get_rank()
+ idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this], idx_unshuffle
+
+ @torch.no_grad()
+ def _batch_unshuffle_ddp(self, x, idx_unshuffle):
+ """
+ Undo batch shuffle.
+ """
+ # If non-distributed now, return raw input directly.
+ # We have no idea the effect of disabling shuffle BN to MoCo.
+ # Thus, we recommand train InsGen with more than 1 GPU always.
+ if not torch.distributed.is_initialized():
+ return x
+
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = concat_all_gather(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # restored index for this gpu
+ gpu_idx = torch.distributed.get_rank()
+ idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this]
+
+
+ def forward(self, im_q, im_k, loss_only=False, update_q=False):
+ """
+ Input:
+ im_q: a batch of query images
+ im_k: a batch of key images
+ Output:
+ logits, targets
+ """
+ device = im_q.device
+ im_q = im_q.to(torch.float32)
+ im_k = im_k.to(torch.float32)
+ # compute query features
+ if im_q.ndim > 2:
+ im_q = im_q.mean([2,3])
+ q = self.mlp(im_q) # queries: NxC
+ q = nn.functional.normalize(q, dim=1)
+
+ # compute key features
+ with torch.no_grad(): # no gradient to keys
+ self._momentum_update_key_encoder() # update the key encoder
+ if im_k.ndim > 2:
+ im_k = im_k.mean([2,3])
+ # shuffle for making use of BN
+ im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
+ k = self.momentum_mlp(im_k) # keys: NxC
+ k = nn.functional.normalize(k, dim=1)
+
+ # undo shuffle
+ k = self._batch_unshuffle_ddp(k, idx_unshuffle)
+
+ # compute logits
+ # Einstein sum is more intuitive
+ # positive logits: Nx1
+ l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
+ # negative logits: NxK
+ l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
+
+ # logits: Nx(1+K)
+ logits = torch.cat([l_pos, l_neg], dim=1)
+
+ # apply temperature
+ logits /= self.temperature
+
+ # labels: positive key indicators
+ labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda(device)
+
+ # dequeue and enqueue
+ if not loss_only:
+ if update_q:
+ with torch.no_grad():
+ temp_im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q)
+ temp_q = self.momentum_mlp(temp_im_q)
+ temp_q = nn.functional.normalize(temp_q, dim=1)
+ temp_q = self._batch_unshuffle_ddp(temp_q, idx_unshuffle)
+ self._dequeue_and_enqueue(temp_q)
+ else:
+ self._dequeue_and_enqueue(k)
+
+ # calculate loss
+ loss = nn.functional.cross_entropy(logits, labels)
+
+ return loss
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/contrastive_loss.py b/diffusion-insgen/training/contrastive_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a072c3fb34b2fbae30a16003a0bb9ea2390de2f
--- /dev/null
+++ b/diffusion-insgen/training/contrastive_loss.py
@@ -0,0 +1,194 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import torch
+from torch_utils import training_stats
+from torch_utils import misc
+from torch_utils.ops import conv2d_gradfix
+
+from training.adaaug import AdaAugment
+
+#----------------------------------------------------------------------------
+
+class Loss:
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
+ raise NotImplementedError()
+
+#----------------------------------------------------------------------------
+
+class StyleGAN2LossCL(Loss):
+ def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
+ super().__init__()
+ self.device = device
+ self.G_mapping = G_mapping
+ self.G_synthesis = G_synthesis
+ self.D = D
+ self.augment_pipe = augment_pipe
+ self.style_mixing_prob = style_mixing_prob
+ self.r1_gamma = r1_gamma
+ self.pl_batch_shrink = pl_batch_shrink
+ self.pl_decay = pl_decay
+ self.pl_weight = pl_weight
+ self.pl_mean = torch.zeros([], device=device)
+ self.image_disturb = AdaAugment(p=0.2).to(device)
+
+ def run_G(self, z, c, sync):
+ with misc.ddp_sync(self.G_mapping, sync):
+ ws = self.G_mapping(z, c)
+ if self.style_mixing_prob > 0:
+ with torch.autograd.profiler.record_function('style_mixing'):
+ cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
+ cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
+ ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
+ with misc.ddp_sync(self.G_synthesis, sync):
+ img = self.G_synthesis(ws)
+ return img, ws
+
+ def run_D(self, img, c, sync):
+ if self.augment_pipe is not None:
+ img, t = self.augment_pipe(img)
+ with misc.ddp_sync(self.D, sync):
+ logits = self.D(img, c, t)
+ return logits
+
+ def run_cl(self, img, c, sync, contrastive_head, D_ema, loss_name='', loss_only=False, img1=None, update_q=False):
+ # contrastive loss fwd
+
+ # augmentation first via ada-aug
+ # assert(self.augment_pipe is not None)
+ img0 = self.image_disturb(img)
+ img1 = self.image_disturb(img) if img1 is None else self.image_disturb(img1)
+ batch_size, device = img.shape[0], img.device
+ # img0 = img
+ # img1 = img.clone() + torch.randn_like(img) * 0.02 if img1 is None else img1
+
+ # extract features for two views via D and momentum D
+ _, logits0 = self.D(img0, c, torch.zeros((batch_size, 1)).long().to(device), return_feats=True)
+ with torch.no_grad():
+ _, logits1 = D_ema(img1, c, torch.zeros((batch_size, 1)).long().to(device), return_feats=True)
+
+ # project features into the unit sphere and calculate contrastive loss
+ loss = contrastive_head(logits0, logits1, loss_only=loss_only, update_q=update_q)
+ training_stats.report('Loss/'+loss_name, loss)
+ return loss
+
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain, cl_phases=None, D_ema=None, lw_real_cl=1.0, lw_fake_cl=1.0, lw_fake_cl_on_g=1.0, g_fake_cl=False):
+ assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
+ do_Gmain = (phase in ['Gmain', 'Gboth'])
+ do_Dmain = (phase in ['Dmain', 'Dboth'])
+ do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
+ do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
+
+ # Gmain: Maximize logits for generated images.
+ if do_Gmain:
+ with torch.autograd.profiler.record_function('Gmain_forward'):
+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
+ gen_logits = self.run_D(gen_img, gen_c, sync=False)
+ training_stats.report('Loss/scores/fake', gen_logits)
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
+ training_stats.report('Loss/G/loss', loss_Gmain)
+
+ # Diversity generation loss from fake instance discrimination
+ if cl_phases.get('GHeadmain', None) is not None and g_fake_cl:
+ # when fake cl on g, no params in D encoder and head would be updated, including feature queue.
+ Gphase = cl_phases['GHeadmain']
+ Gphase.module.requires_grad_(False)
+ # fake_cl on g: gradients bp to generator
+ loss_Gmain = loss_Gmain + lw_fake_cl_on_g * self.run_cl(gen_img, gen_c, False, Gphase.module, D_ema, loss_name='G_cl_on_g', loss_only=True)
+
+ with torch.autograd.profiler.record_function('Gmain_backward'):
+ loss_Gmain.mean().mul(gain).backward()
+
+ if cl_phases.get('GHeadmain', None) is not None and g_fake_cl:
+ Gphase = cl_phases['GHeadmain']
+ Gphase.module.requires_grad_(True)
+
+ # Gpl: Apply path length regularization.
+ if do_Gpl:
+ with torch.autograd.profiler.record_function('Gpl_forward'):
+ batch_size = gen_z.shape[0] // self.pl_batch_shrink
+ gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
+ pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
+ with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
+ pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
+ pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
+ pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
+ self.pl_mean.copy_(pl_mean.detach())
+ pl_penalty = (pl_lengths - pl_mean).square()
+ training_stats.report('Loss/pl_penalty', pl_penalty)
+ loss_Gpl = pl_penalty * self.pl_weight
+ training_stats.report('Loss/G/reg', loss_Gpl)
+ with torch.autograd.profiler.record_function('Gpl_backward'):
+ (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
+
+ # Dmain: Minimize logits for generated images.
+ loss_Dgen = 0
+ if do_Dmain:
+ with torch.autograd.profiler.record_function('Dgen_forward'):
+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
+ gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
+ training_stats.report('Loss/scores/fake', gen_logits)
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
+ with torch.autograd.profiler.record_function('Dgen_backward'):
+ loss_Dgen.mean().mul(gain).backward()
+
+ # Dmain: Maximize logits for real images.
+ # Dr1: Apply R1 regularization.
+ if do_Dmain or do_Dr1:
+ name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
+ with torch.autograd.profiler.record_function(name + '_forward'):
+ real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
+ real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
+ training_stats.report('Loss/scores/real', real_logits)
+ training_stats.report('Loss/signs/real', real_logits.sign())
+
+ loss_Dreal = 0
+ if do_Dmain:
+ loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
+ training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
+ # Contrastive loss would be added to the normal binary cls loss of D
+ # real instance discrimination
+ if cl_phases.get('DHeadmain', None) is not None:
+ Dphase = cl_phases['DHeadmain']
+ Dphase.opt.zero_grad(set_to_none=True)
+ loss_Dreal = loss_Dreal + lw_real_cl * self.run_cl(real_img_tmp, real_c, sync, Dphase.module, D_ema, loss_name='D_cl')
+
+ # fake instance discrimination
+ if cl_phases.get('GHeadmain', None) is not None:
+ Gphase = cl_phases['GHeadmain']
+ Gphase.opt.zero_grad(set_to_none=True)
+ # noisy perturbation
+ with torch.no_grad():
+ delta_z = torch.randn(gen_z.shape, device=gen_z.device) * 0.15
+ noisy_gen_img, _ = self.run_G(gen_z + delta_z, gen_c, sync=False)
+ loss_Dreal = loss_Dreal + lw_fake_cl * self.run_cl(gen_img, gen_c, False, Gphase.module, D_ema, loss_name='G_cl', img1=noisy_gen_img, update_q=True)
+
+ loss_Dr1 = 0
+ if do_Dr1:
+ with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
+ r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
+ r1_penalty = r1_grads.square().sum([1,2,3])
+ loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
+ training_stats.report('Loss/r1_penalty', r1_penalty)
+ training_stats.report('Loss/D/reg', loss_Dr1)
+
+ with torch.autograd.profiler.record_function(name + '_backward'):
+ (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
+
+ # after backward of contrastive loss together with the original loss of D,
+ # manually call optim.step() to update the parameters of contrative head
+ if cl_phases.get('DHeadmain', None) is not None and do_Dmain:
+ Dphase.opt.step()
+
+ if cl_phases.get('GHeadmain', None) is not None and do_Dmain:
+ Gphase.opt.step()
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/dataset.py b/diffusion-insgen/training/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..82dcabadd46b5abda69b2030fb4dd7569133e2c0
--- /dev/null
+++ b/diffusion-insgen/training/dataset.py
@@ -0,0 +1,236 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import numpy as np
+import zipfile
+import PIL.Image
+import json
+import torch
+import dnnlib
+
+try:
+ import pyspng
+except ImportError:
+ pyspng = None
+
+#----------------------------------------------------------------------------
+
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self,
+ name, # Name of the dataset.
+ raw_shape, # Shape of the raw image data (NCHW).
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
+ random_seed = 0, # Random seed to use when applying max_size.
+ ):
+ self._name = name
+ self._raw_shape = list(raw_shape)
+ self._use_labels = use_labels
+ self._raw_labels = None
+ self._label_shape = None
+
+ # Apply max_size.
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
+ if (max_size is not None) and (self._raw_idx.size > max_size):
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
+
+ # Apply xflip.
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
+ if xflip:
+ self._raw_idx = np.tile(self._raw_idx, 2)
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
+
+ def _get_raw_labels(self):
+ if self._raw_labels is None:
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
+ if self._raw_labels is None:
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
+ assert isinstance(self._raw_labels, np.ndarray)
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
+ assert self._raw_labels.dtype in [np.float32, np.int64]
+ if self._raw_labels.dtype == np.int64:
+ assert self._raw_labels.ndim == 1
+ assert np.all(self._raw_labels >= 0)
+ return self._raw_labels
+
+ def close(self): # to be overridden by subclass
+ pass
+
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
+ raise NotImplementedError
+
+ def _load_raw_labels(self): # to be overridden by subclass
+ raise NotImplementedError
+
+ def __getstate__(self):
+ return dict(self.__dict__, _raw_labels=None)
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ pass
+
+ def __len__(self):
+ return self._raw_idx.size
+
+ def __getitem__(self, idx):
+ image = self._load_raw_image(self._raw_idx[idx])
+ assert isinstance(image, np.ndarray)
+ assert list(image.shape) == self.image_shape
+ assert image.dtype == np.uint8
+ if self._xflip[idx]:
+ assert image.ndim == 3 # CHW
+ image = image[:, :, ::-1]
+ return image.copy(), self.get_label(idx)
+
+ def get_label(self, idx):
+ label = self._get_raw_labels()[self._raw_idx[idx]]
+ if label.dtype == np.int64:
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
+ onehot[label] = 1
+ label = onehot
+ return label.copy()
+
+ def get_details(self, idx):
+ d = dnnlib.EasyDict()
+ d.raw_idx = int(self._raw_idx[idx])
+ d.xflip = (int(self._xflip[idx]) != 0)
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
+ return d
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def image_shape(self):
+ return list(self._raw_shape[1:])
+
+ @property
+ def num_channels(self):
+ assert len(self.image_shape) == 3 # CHW
+ return self.image_shape[0]
+
+ @property
+ def resolution(self):
+ assert len(self.image_shape) == 3 # CHW
+ assert self.image_shape[1] == self.image_shape[2]
+ return self.image_shape[1]
+
+ @property
+ def label_shape(self):
+ if self._label_shape is None:
+ raw_labels = self._get_raw_labels()
+ if raw_labels.dtype == np.int64:
+ self._label_shape = [int(np.max(raw_labels)) + 1]
+ else:
+ self._label_shape = raw_labels.shape[1:]
+ return list(self._label_shape)
+
+ @property
+ def label_dim(self):
+ assert len(self.label_shape) == 1
+ return self.label_shape[0]
+
+ @property
+ def has_labels(self):
+ return any(x != 0 for x in self.label_shape)
+
+ @property
+ def has_onehot_labels(self):
+ return self._get_raw_labels().dtype == np.int64
+
+#----------------------------------------------------------------------------
+
+class ImageFolderDataset(Dataset):
+ def __init__(self,
+ path, # Path to directory or zip.
+ resolution = None, # Ensure specific resolution, None = highest available.
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self._path = path
+ self._zipfile = None
+
+ if os.path.isdir(self._path):
+ self._type = 'dir'
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
+ elif self._file_ext(self._path) == '.zip':
+ self._type = 'zip'
+ self._all_fnames = set(self._get_zipfile().namelist())
+ else:
+ raise IOError('Path must point to a directory or zip')
+
+ PIL.Image.init()
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ name = os.path.splitext(os.path.basename(self._path))[0]
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
+ raise IOError('Image files do not match the specified resolution')
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _get_zipfile(self):
+ assert self._type == 'zip'
+ if self._zipfile is None:
+ self._zipfile = zipfile.ZipFile(self._path)
+ return self._zipfile
+
+ def _open_file(self, fname):
+ if self._type == 'dir':
+ return open(os.path.join(self._path, fname), 'rb')
+ if self._type == 'zip':
+ return self._get_zipfile().open(fname, 'r')
+ return None
+
+ def close(self):
+ try:
+ if self._zipfile is not None:
+ self._zipfile.close()
+ finally:
+ self._zipfile = None
+
+ def __getstate__(self):
+ return dict(super().__getstate__(), _zipfile=None)
+
+ def _load_raw_image(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ with self._open_file(fname) as f:
+ if pyspng is not None and self._file_ext(fname) == '.png':
+ image = pyspng.load(f.read())
+ else:
+ image = np.array(PIL.Image.open(f))
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+ return image
+
+ def _load_raw_labels(self):
+ fname = 'dataset.json'
+ if fname not in self._all_fnames:
+ return None
+ with self._open_file(fname) as f:
+ labels = json.load(f)['labels']
+ if labels is None:
+ return None
+ labels = dict(labels)
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
+ labels = np.array(labels)
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
+ return labels
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/diffaug.py b/diffusion-insgen/training/diffaug.py
new file mode 100644
index 0000000000000000000000000000000000000000..7876439a83e683789f2ab644e94c7c1c8d9282c5
--- /dev/null
+++ b/diffusion-insgen/training/diffaug.py
@@ -0,0 +1,92 @@
+# Differentiable Augmentation for Data-Efficient GAN Training
+# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
+# https://arxiv.org/pdf/2006.10738
+
+import torch
+import torch.nn.functional as F
+
+class DiffAugment(torch.nn.Module):
+ def __init__(self, policy='color,translation,cutout', channels_first=True):
+ super().__init__()
+ self.policy = policy
+ self.channels_first = channels_first
+
+ def forward(self, x):
+ if not self.channels_first:
+ x = x.permute(0, 3, 1, 2)
+ for p in self.policy.split(','):
+ for f in AUGMENT_FNS[p]:
+ x = f(x)
+ if not self.channels_first:
+ x = x.permute(0, 2, 3, 1)
+ x = x.contiguous()
+ return x
+
+# def DiffAugment(x, policy='', channels_first=True):
+# if policy:
+# if not channels_first:
+# x = x.permute(0, 3, 1, 2)
+# for p in policy.split(','):
+# for f in AUGMENT_FNS[p]:
+# x = f(x)
+# if not channels_first:
+# x = x.permute(0, 2, 3, 1)
+# x = x.contiguous()
+# return x
+
+
+def rand_brightness(x):
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
+ return x
+
+
+def rand_saturation(x):
+ x_mean = x.mean(dim=1, keepdim=True)
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
+ return x
+
+
+def rand_contrast(x):
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
+ return x
+
+
+def rand_translation(x, ratio=0.125):
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
+ )
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
+ return x
+
+
+def rand_cutout(x, ratio=0.2):
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
+ )
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
+ mask[grid_batch, grid_x, grid_y] = 0
+ x = x * mask.unsqueeze(1)
+ return x
+
+
+AUGMENT_FNS = {
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
+ 'translation': [rand_translation],
+ 'cutout': [rand_cutout],
+}
diff --git a/diffusion-insgen/training/loss.py b/diffusion-insgen/training/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b31d1f9e43493e9613149cf5498e6dfa30d461d2
--- /dev/null
+++ b/diffusion-insgen/training/loss.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import torch
+from torch_utils import training_stats
+from torch_utils import misc
+from torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+class Loss:
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
+ raise NotImplementedError()
+
+#----------------------------------------------------------------------------
+
+class StyleGAN2Loss(Loss):
+ def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
+ super().__init__()
+ self.device = device
+ self.G_mapping = G_mapping
+ self.G_synthesis = G_synthesis
+ self.D = D
+ self.augment_pipe = augment_pipe
+ self.style_mixing_prob = style_mixing_prob
+ self.r1_gamma = r1_gamma
+ self.pl_batch_shrink = pl_batch_shrink
+ self.pl_decay = pl_decay
+ self.pl_weight = pl_weight
+ self.pl_mean = torch.zeros([], device=device)
+
+ def run_G(self, z, c, sync):
+ with misc.ddp_sync(self.G_mapping, sync):
+ ws = self.G_mapping(z, c)
+ if self.style_mixing_prob > 0:
+ with torch.autograd.profiler.record_function('style_mixing'):
+ cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
+ cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
+ ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
+ with misc.ddp_sync(self.G_synthesis, sync):
+ img = self.G_synthesis(ws)
+ return img, ws
+
+ def run_D(self, img, c, sync):
+ if self.augment_pipe is not None:
+ img, t = self.augment_pipe(img)
+ with misc.ddp_sync(self.D, sync):
+ logits = self.D(img, c, t)
+ return logits
+
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain, **kwargs):
+ assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
+ do_Gmain = (phase in ['Gmain', 'Gboth'])
+ do_Dmain = (phase in ['Dmain', 'Dboth'])
+ do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
+ do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
+
+ # Gmain: Maximize logits for generated images.
+ if do_Gmain:
+ with torch.autograd.profiler.record_function('Gmain_forward'):
+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
+ gen_logits = self.run_D(gen_img, gen_c, sync=False)
+ training_stats.report('Loss/scores/fake', gen_logits)
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
+ training_stats.report('Loss/G/loss', loss_Gmain)
+ with torch.autograd.profiler.record_function('Gmain_backward'):
+ loss_Gmain.mean().mul(gain).backward()
+
+ # Gpl: Apply path length regularization.
+ if do_Gpl:
+ with torch.autograd.profiler.record_function('Gpl_forward'):
+ batch_size = gen_z.shape[0] // self.pl_batch_shrink
+ gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
+ pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
+ with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
+ pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
+ pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
+ pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
+ self.pl_mean.copy_(pl_mean.detach())
+ pl_penalty = (pl_lengths - pl_mean).square()
+ training_stats.report('Loss/pl_penalty', pl_penalty)
+ loss_Gpl = pl_penalty * self.pl_weight
+ training_stats.report('Loss/G/reg', loss_Gpl)
+ with torch.autograd.profiler.record_function('Gpl_backward'):
+ (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
+
+ # Dmain: Minimize logits for generated images.
+ loss_Dgen = 0
+ if do_Dmain:
+ with torch.autograd.profiler.record_function('Dgen_forward'):
+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
+ gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
+ training_stats.report('Loss/scores/fake', gen_logits)
+ training_stats.report('Loss/signs/fake', gen_logits.sign())
+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
+ with torch.autograd.profiler.record_function('Dgen_backward'):
+ loss_Dgen.mean().mul(gain).backward()
+
+ # Dmain: Maximize logits for real images.
+ # Dr1: Apply R1 regularization.
+ if do_Dmain or do_Dr1:
+ name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
+ with torch.autograd.profiler.record_function(name + '_forward'):
+ real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
+ real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
+ training_stats.report('Loss/scores/real', real_logits)
+ training_stats.report('Loss/signs/real', real_logits.sign())
+
+ loss_Dreal = 0
+ if do_Dmain:
+ loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
+ training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
+
+ loss_Dr1 = 0
+ if do_Dr1:
+ with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
+ r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
+ r1_penalty = r1_grads.square().sum([1,2,3])
+ loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
+ training_stats.report('Loss/r1_penalty', r1_penalty)
+ training_stats.report('Loss/D/reg', loss_Dr1)
+
+ with torch.autograd.profiler.record_function(name + '_backward'):
+ (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/networks.py b/diffusion-insgen/training/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..4053a8ccd582fc825a8f2e191d632ffbd23b0a8b
--- /dev/null
+++ b/diffusion-insgen/training/networks.py
@@ -0,0 +1,740 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import torch
+from torch_utils import misc
+from torch_utils import persistence
+from torch_utils.ops import conv2d_resample
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import bias_act
+from torch_utils.ops import fma
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
+ noise = None, # Optional noise tensor to add to the output activations.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ padding = 0, # Padding with respect to the upsampled image.
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
+ demodulate = True, # Apply weight demodulation?
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
+):
+ batch_size = x.shape[0]
+ out_channels, in_channels, kh, kw = weight.shape
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs to avoid FP16 overflow.
+ if x.dtype == torch.float16 and demodulate:
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
+
+ # Calculate per-sample weights and demodulation coefficients.
+ w = None
+ dcoefs = None
+ if demodulate or fused_modconv:
+ w = weight.unsqueeze(0) # [NOIkk]
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
+ if demodulate and fused_modconv:
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
+
+ # Execute by scaling the activations before and after the convolution.
+ if not fused_modconv:
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
+ if demodulate and noise is not None:
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
+ elif demodulate:
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ elif noise is not None:
+ x = x.add_(noise.to(x.dtype))
+ return x
+
+ # Execute as one fused op using grouped convolution.
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(batch_size)
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ if noise is not None:
+ x = x.add_(noise)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Conv2dLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
+ channels_last = False, # Expect the input to have memory_format=channels_last?
+ trainable = True, # Update the weights of this layer during training?
+ ):
+ super().__init__()
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ def forward(self, x, gain=1):
+ w = self.weight * self.weight_gain
+ b = self.bias.to(x.dtype) if self.bias is not None else None
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers = 8, # Number of mapping layers.
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this layer.
+ kernel_size = 3, # Convolution kernel size.
+ up = 1, # Integer upsampling factor.
+ use_noise = True, # Enable noise input?
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ channels_last = False, # Use channels_last format for the weights?
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.up = up
+ self.use_noise = use_noise
+ self.activation = activation
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ if use_noise:
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
+ assert noise_mode in ['random', 'const', 'none']
+ in_resolution = self.resolution // self.up
+ misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
+ styles = self.affine(w)
+
+ noise = None
+ if self.use_noise and noise_mode == 'random':
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
+ if self.use_noise and noise_mode == 'const':
+ noise = self.noise_const * self.noise_strength
+
+ flip_weight = (self.up == 1) # slightly faster
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class ToRGBLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
+ super().__init__()
+ self.conv_clamp = conv_clamp
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+
+ def forward(self, x, w, fused_modconv=True):
+ styles = self.affine(w) * self.weight_gain
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+
+ # ToRGB.
+ if img is not None:
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
+ **block_kwargs, # Arguments for SynthesisBlock.
+ ):
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
+ super().__init__()
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ self.num_ws = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res // 2] if res > 4 else 0
+ out_channels = channels_dict[res]
+ use_fp16 = (res >= fp16_resolution)
+ is_last = (res == self.img_resolution)
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
+ self.num_ws += block.num_conv
+ if is_last:
+ self.num_ws += block.num_torgb
+ setattr(self, f'b{res}', block)
+
+ def forward(self, ws, **block_kwargs):
+ block_ws = []
+ with torch.autograd.profiler.record_function('split_ws'):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32)
+ w_idx = 0
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
+ w_idx += block.num_conv
+
+ x = img = None
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, cur_ws, **block_kwargs)
+ return img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
+ img = self.synthesis(ws, **synthesis_kwargs)
+ return img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class DiscriminatorBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ tmp_channels, # Number of intermediate channels.
+ out_channels, # Number of output channels.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ first_layer_idx, # Index of the first layer.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
+ ):
+ assert in_channels in [0, tmp_channels]
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.first_layer_idx = first_layer_idx
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+
+ self.num_layers = 0
+ def trainable_gen():
+ while True:
+ layer_idx = self.first_layer_idx + self.num_layers
+ trainable = (layer_idx >= freeze_layers)
+ self.num_layers += 1
+ yield trainable
+ trainable_iter = trainable_gen()
+
+ if in_channels == 0 or architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ if architecture == 'resnet':
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, force_fp32=False):
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+
+ # Input.
+ if x is not None:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # FromRGB.
+ if self.in_channels == 0 or self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ y = self.fromrgb(img)
+ x = x + y if x is not None else y
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
+
+ # Main layers.
+ if self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x)
+ x = self.conv1(x, gain=np.sqrt(0.5))
+ x = y.add_(x)
+ else:
+ x = self.conv0(x)
+ x = self.conv1(x)
+
+ assert x.dtype == dtype
+ return x, img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class MinibatchStdLayer(torch.nn.Module):
+ def __init__(self, group_size, num_channels=1):
+ super().__init__()
+ self.group_size = group_size
+ self.num_channels = num_channels
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
+ F = self.num_channels
+ c = C // F
+
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class DiscriminatorEpilogue(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.cmap_dim = cmap_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ if architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
+ self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
+
+ def forward(self, x, img, cmap, force_fp32=False):
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
+ _ = force_fp32 # unused
+ dtype = torch.float32
+ memory_format = torch.contiguous_format
+
+ # FromRGB.
+ x = x.to(dtype=dtype, memory_format=memory_format)
+ if self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ x = x + self.fromrgb(img)
+
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ x = self.conv(x)
+ x = self.fc(x.flatten(1))
+ x = self.out(x)
+
+ # Conditioning.
+ if self.cmap_dim > 0:
+ misc.assert_shape(cmap, [None, self.cmap_dim])
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ assert x.dtype == dtype
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Discriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ t_dim = 1, # Diffusion timestep dimensionality
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ c_dim = c_dim + t_dim
+
+ self.c_dim = c_dim
+ self.t_dim = t_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, c, t, return_feats=False, **block_kwargs):
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ # Store features for InsGen
+ feats = x
+
+ cmap = None
+ if self.c_dim > 0:
+ c = torch.cat((c, t), dim=1) if c is not None else t
+ cmap = self.mapping(None, c)
+
+ out = self.b4(x, img, cmap)
+ if return_feats:
+ return out, feats
+ return out
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-insgen/training/training_loop.py b/diffusion-insgen/training/training_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a119e8c61b0e45098de6e7662f7955e73576b73
--- /dev/null
+++ b/diffusion-insgen/training/training_loop.py
@@ -0,0 +1,519 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import time
+import copy
+import json
+import pickle
+import psutil
+import PIL.Image
+import numpy as np
+import torch
+import dnnlib
+from torch_utils import misc
+from torch_utils import training_stats
+from torch_utils.ops import conv2d_gradfix
+from torch_utils.ops import grid_sample_gradfix
+
+import legacy
+from metrics import metric_main
+
+#----------------------------------------------------------------------------
+
+def setup_snapshot_image_grid(training_set, random_seed=0):
+ rnd = np.random.RandomState(random_seed)
+ gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
+ gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
+
+ # No labels => show random subset of training samples.
+ if not training_set.has_labels:
+ all_indices = list(range(len(training_set)))
+ rnd.shuffle(all_indices)
+ grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
+
+ else:
+ # Group training samples by label.
+ label_groups = dict() # label => [idx, ...]
+ for idx in range(len(training_set)):
+ label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
+ if label not in label_groups:
+ label_groups[label] = []
+ label_groups[label].append(idx)
+
+ # Reorder.
+ label_order = sorted(label_groups.keys())
+ for label in label_order:
+ rnd.shuffle(label_groups[label])
+
+ # Organize into grid.
+ grid_indices = []
+ for y in range(gh):
+ label = label_order[y % len(label_order)]
+ indices = label_groups[label]
+ grid_indices += [indices[x % len(indices)] for x in range(gw)]
+ label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
+
+ # Load data.
+ images, labels = zip(*[training_set[i] for i in grid_indices])
+ return (gw, gh), np.stack(images), np.stack(labels)
+
+#----------------------------------------------------------------------------
+
+def save_image_grid(img, fname, drange, grid_size):
+ lo, hi = drange
+ img = np.asarray(img, dtype=np.float32)
+ img = (img - lo) * (255 / (hi - lo))
+ img = np.rint(img).clip(0, 255).astype(np.uint8)
+
+ gw, gh = grid_size
+ _N, C, H, W = img.shape
+ img = img.reshape(gh, gw, C, H, W)
+ img = img.transpose(0, 3, 1, 4, 2)
+ img = img.reshape(gh * H, gw * W, C)
+
+ assert C in [1, 3]
+ if C == 1:
+ PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
+ if C == 3:
+ PIL.Image.fromarray(img, 'RGB').save(fname)
+
+#----------------------------------------------------------------------------
+
+def training_loop(
+ run_dir = '.', # Output directory.
+ training_set_kwargs = {}, # Options for training set.
+ data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
+ G_kwargs = {}, # Options for generator network.
+ D_kwargs = {}, # Options for discriminator network.
+ G_opt_kwargs = {}, # Options for generator optimizer.
+ D_opt_kwargs = {}, # Options for discriminator optimizer.
+ DHead_kwargs = None, # Options for real contrastive head.
+ GHead_kwargs = None, # Options for fake contrastive head.
+ no_cl_on_g = False, # Options for fake instance discrmination for generator.
+ cl_loss_weight = {}, # Options for multiple loss weights for InsGen.
+ augment_kwargs = None, # Options for augmentation pipeline. None = disable.
+ loss_kwargs = {}, # Options for loss function.
+ metrics = [], # Metrics to evaluate during training.
+ random_seed = 0, # Global random seed.
+ num_gpus = 1, # Number of GPUs participating in the training.
+ rank = 0, # Rank of the current process in [0, num_gpus[.
+ batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
+ batch_gpu = 4, # Number of samples processed at a time by one GPU.
+ ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
+ ema_rampup = None, # EMA ramp-up coefficient.
+ G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization.
+ D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
+ augment_p = 0, # Initial value of augmentation probability.
+ ada_target = None, # ADA target value. None = fixed p.
+ ada_interval = 4, # How often to perform ADA adjustment?
+ ada_kimg = 100, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
+ ada_linear = False, # Whether to linearly increase the strength of ADA.
+ total_kimg = 25000, # Total length of the training, measured in thousands of real images.
+ kimg_per_tick = 4, # Progress snapshot interval.
+ image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
+ network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
+ resume_pkl = None, # Network pickle to resume training from.
+ cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
+ allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
+ abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
+ progress_fn = None, # Callback function for updating training progress. Called for all ranks.
+):
+ # Initialize.
+ start_time = time.time()
+ device = torch.device('cuda', rank)
+ np.random.seed(random_seed * num_gpus + rank)
+ torch.manual_seed(random_seed * num_gpus + rank)
+ torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
+ torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
+ torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
+ conv2d_gradfix.enabled = True # Improves training speed.
+ grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
+ __CUR_NIMG__ = torch.tensor(0, dtype=torch.long, device=device)
+ __CUR_TICK__ = torch.tensor(0, dtype=torch.long, device=device)
+ __BATCH_IDX__ = torch.tensor(0, dtype=torch.long, device=device)
+ best_fid = 9999
+
+ # Load training set.
+ if rank == 0:
+ print('Loading training set...')
+ training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
+ training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
+ training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
+ if rank == 0:
+ print()
+ print('Num images: ', len(training_set))
+ print('Image shape:', training_set.image_shape)
+ print('Label shape:', training_set.label_shape)
+ print()
+
+ # Construct networks.
+ if rank == 0:
+ print('Constructing networks...')
+ common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
+ G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+ D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+ G_ema = copy.deepcopy(G).eval()
+
+ # Construct contrastive heads.
+ DHead = dnnlib.util.construct_class_by_name(**DHead_kwargs).train().to(device) if DHead_kwargs is not None else None
+ GHead = dnnlib.util.construct_class_by_name(**GHead_kwargs).train().to(device) if GHead_kwargs is not None else None
+ D_ema = copy.deepcopy(D).eval()
+
+ # Setup augmentation.
+ if rank == 0:
+ print('Setting up augmentation...')
+ augment_pipe = None
+ ada_stats = None
+ if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
+ augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(
+ device) # subclass of torch.nn.Module
+ augment_pipe.p = augment_p
+ if ada_target is not None:
+ ada_stats = training_stats.Collector(regex='Loss/signs/real')
+
+ # Check for existing checkpoint
+ ckpt_pkl = None
+ if os.path.isfile(misc.get_ckpt_path(run_dir)):
+ ckpt_pkl = resume_pkl = misc.get_ckpt_path(run_dir)
+
+ # Resume from existing pickle.
+ if (resume_pkl is not None) and (rank == 0):
+ print(f'Resuming from "{resume_pkl}"')
+ with dnnlib.util.open_url(resume_pkl) as f:
+ resume_data = legacy.load_network_pkl(f)
+ for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('D_ema', D_ema), ('DHead', DHead), ('GHead', GHead)]:
+ if module is None:
+ continue
+ misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
+
+ __CUR_NIMG__ = resume_data['progress']['cur_nimg'].to(device)
+ __CUR_TICK__ = resume_data['progress']['cur_tick'].to(device)
+ __BATCH_IDX__ = resume_data['progress']['batch_idx'].to(device)
+ best_fid = resume_data['progress']['best_fid'] # only needed for rank == 0
+ augment_pipe.p = float(resume_data['progress']['cur_p'][0])
+
+ del resume_data
+
+ # Print network summary tables.
+ if rank == 0:
+ z = torch.empty([batch_gpu, G.z_dim], device=device)
+ c = torch.empty([batch_gpu, G.c_dim], device=device)
+ img = misc.print_module_summary(G, [z, c])
+ t = torch.empty([batch_gpu, D.t_dim], device=device)
+ misc.print_module_summary(D, [img, c, t])
+
+ # Distribute across GPUs.
+ if rank == 0:
+ print(f'Distributing across {num_gpus} GPUs...')
+ ddp_modules = dict()
+ for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe), (None, D_ema)]:
+ if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
+ module.requires_grad_(True)
+ module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
+ module.requires_grad_(False)
+ if name is not None:
+ ddp_modules[name] = module
+
+ # Distribute Heads across GPUs.
+ if rank == 0:
+ print(f'Distributing Contrastive Heads across {num_gpus} GPUS...')
+ if num_gpus > 1:
+ if DHead is not None:
+ DHead = torch.nn.parallel.DistributedDataParallel(DHead, device_ids=[device], broadcast_buffers=True)
+ if GHead is not None:
+ GHead = torch.nn.parallel.DistributedDataParallel(GHead, device_ids=[device], broadcast_buffers=True)
+
+ # Setup training phases.
+ if rank == 0:
+ print('Setting up training phases...')
+ loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss
+ phases = []
+ for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
+ if reg_interval is None:
+ opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
+ phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
+ else: # Lazy regularization.
+ mb_ratio = reg_interval / (reg_interval + 1)
+ opt_kwargs = dnnlib.EasyDict(opt_kwargs)
+ opt_kwargs.lr = opt_kwargs.lr * mb_ratio
+ opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
+ opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
+ phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
+ phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
+ for phase in phases:
+ phase.start_event = None
+ phase.end_event = None
+ if rank == 0:
+ phase.start_event = torch.cuda.Event(enable_timing=True)
+ phase.end_event = torch.cuda.Event(enable_timing=True)
+
+ # Setup contrastive training phases.
+ if rank == 0:
+ print('Setting up contrastive training phases...')
+ cl_phases = dict()
+ for name, module, opt_kwargs, reg_interval in [('GHead', GHead, G_opt_kwargs, G_reg_interval), ('DHead', DHead, D_opt_kwargs, D_reg_interval)]:
+ if module is None:
+ continue
+ assert (reg_interval is not None)
+ # Lazy regularization.
+ mb_ratio = reg_interval / (reg_interval + 1)
+ opt_kwargs = dnnlib.EasyDict(opt_kwargs)
+ opt_kwargs.lr = opt_kwargs.lr * mb_ratio
+ opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
+ opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
+ cl_phases.update({name+'main': dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)})
+
+ # Export sample images.
+ grid_size = None
+ grid_z = None
+ grid_c = None
+ if rank == 0:
+ print('Exporting sample images...')
+ grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
+ save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
+ grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
+ grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
+ images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
+ save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
+
+ # Initialize logs.
+ if rank == 0:
+ print('Initializing logs...')
+ stats_collector = training_stats.Collector(regex='.*')
+ stats_metrics = dict()
+ stats_jsonl = None
+ stats_tfevents = None
+ if rank == 0:
+ stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
+ try:
+ import torch.utils.tensorboard as tensorboard
+ stats_tfevents = tensorboard.SummaryWriter(run_dir)
+ except ImportError as err:
+ print('Skipping tfevents export:', err)
+
+ # Train.
+ if rank == 0:
+ print(f'Training for {total_kimg} kimg...')
+ print()
+ if num_gpus > 1: # broadcast loaded states to all
+ torch.distributed.broadcast(__CUR_NIMG__, 0)
+ torch.distributed.broadcast(__CUR_TICK__, 0)
+ torch.distributed.broadcast(__BATCH_IDX__, 0)
+ torch.distributed.barrier() # ensure all processes received this info
+ cur_nimg = __CUR_NIMG__.item()
+ cur_tick = __CUR_TICK__.item()
+ tick_start_nimg = cur_nimg
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - start_time
+ batch_idx = 0
+ if progress_fn is not None:
+ progress_fn(0, total_kimg)
+ while True:
+
+ # Fetch training data.
+ with torch.autograd.profiler.record_function('data_fetch'):
+ phase_real_img, phase_real_c = next(training_set_iterator)
+ phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
+ phase_real_c = phase_real_c.to(device).split(batch_gpu)
+ all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
+ all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
+ all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
+ all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
+ all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
+
+ # Update D_ema
+ with torch.autograd.profiler.record_function('Dema'):
+ momentum = 0.999 if DHead_kwargs is None else DHead_kwargs.momentum
+ for p_ema, p in zip(D_ema.parameters(), D.parameters()):
+ p_ema.data = p_ema.data * momentum + p.data * (1. - momentum)
+
+ # Execute training phases.
+ for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
+ if batch_idx % phase.interval != 0:
+ continue
+
+ # Initialize gradient accumulation.
+ if phase.start_event is not None:
+ phase.start_event.record(torch.cuda.current_stream(device))
+ phase.opt.zero_grad(set_to_none=True)
+ phase.module.requires_grad_(True)
+
+ # Accumulate gradients over multiple rounds.
+ for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)):
+ sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
+ gain = phase.interval
+ loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain, cl_phases=cl_phases, D_ema=D_ema, g_fake_cl=not no_cl_on_g, **cl_loss_weight)
+
+ # Update weights.
+ phase.module.requires_grad_(False)
+ with torch.autograd.profiler.record_function(phase.name + '_opt'):
+ for param in phase.module.parameters():
+ if param.grad is not None:
+ misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
+ phase.opt.step()
+ if phase.end_event is not None:
+ phase.end_event.record(torch.cuda.current_stream(device))
+
+ # Update G_ema.
+ with torch.autograd.profiler.record_function('Gema'):
+ ema_nimg = ema_kimg * 1000
+ if ema_rampup is not None:
+ ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
+ ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
+ for p_ema, p in zip(G_ema.parameters(), G.parameters()):
+ p_ema.copy_(p.lerp(p_ema, ema_beta))
+ for b_ema, b in zip(G_ema.buffers(), G.buffers()):
+ b_ema.copy_(b)
+
+ # Update state.
+ cur_nimg += batch_size
+ batch_idx += 1
+
+ # Execute ADA heuristic.
+ if (ada_stats is not None) and (batch_idx % ada_interval == 0):
+ ada_stats.update()
+ adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
+ augment_pipe.p = (augment_pipe.p + adjust).clip(min=0., max=1.)
+ # augment_pipe.p = (augment_pipe.p + adjust).clip(min=0.)
+ augment_pipe.update_T()
+
+ # Perform maintenance tasks once per tick.
+ done = (cur_nimg >= total_kimg * 1000)
+ if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
+ continue
+
+ # Print status line, accumulating the same information in stats_collector.
+ tick_end_time = time.time()
+ fields = []
+ fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
+ fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
+ fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
+ fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
+ fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
+ fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
+ fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
+ fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
+ torch.cuda.reset_peak_memory_stats()
+ fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p) if augment_pipe is not None else 0):.3f}"]
+ fields += [f"T {training_stats.report0('Progress/augment_T', float(augment_pipe.num_timesteps) if augment_pipe is not None else 0)}"]
+ training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
+ training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
+ if rank == 0:
+ print(' '.join(fields))
+
+ # Check for abort.
+ if (not done) and (abort_fn is not None) and abort_fn():
+ done = True
+ if rank == 0:
+ print()
+ print('Aborting...')
+
+ # Save image snapshot.
+ if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
+ images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
+ save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
+
+ # Save network snapshot.
+ snapshot_pkl = None
+ snapshot_data = None
+ if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
+ snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
+ for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe), ('D_ema', D_ema), ('DHead', DHead), ('GHead', GHead)]:
+ if module is not None:
+ if num_gpus > 1:
+ misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
+ module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
+ snapshot_data[name] = module
+ del module # conserve memory
+
+ # Save Checkpoint if needed
+ if (rank == 0) and (network_snapshot_ticks is not None) and (
+ done or cur_tick % network_snapshot_ticks == 0):
+ snapshot_pkl = misc.get_ckpt_path(run_dir)
+ # save as tensors to avoid error for multi GPU
+ snapshot_data['progress'] = {
+ 'cur_nimg': torch.LongTensor([cur_nimg]),
+ 'cur_tick': torch.LongTensor([cur_tick]),
+ 'cur_p': torch.FloatTensor([augment_pipe.p]),
+ 'batch_idx': torch.LongTensor([batch_idx]),
+ 'best_fid': best_fid,
+ }
+ if hasattr(loss, 'pl_mean'):
+ snapshot_data['progress']['pl_mean'] = loss.pl_mean.cpu()
+
+ with open(snapshot_pkl, 'wb') as f:
+ pickle.dump(snapshot_data, f)
+
+ # Evaluate metrics.
+ if (snapshot_data is not None) and (len(metrics) > 0):
+ if rank == 0:
+ print('Evaluating metrics...')
+ for metric in metrics:
+ result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
+ dataset_kwargs=training_set_kwargs, num_gpus=num_gpus,
+ rank=rank, device=device)
+ if rank == 0:
+ metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
+ stats_metrics.update(result_dict.results)
+
+ # save best fid ckpt
+ snapshot_pkl = os.path.join(run_dir, f'best_model.pkl')
+ cur_nimg_txt = os.path.join(run_dir, f'best_nimg.txt')
+ if rank == 0:
+ if 'fid50k_full' in stats_metrics and stats_metrics['fid50k_full'] < best_fid:
+ best_fid = stats_metrics['fid50k_full']
+
+ with open(snapshot_pkl, 'wb') as f:
+ pickle.dump(snapshot_data, f)
+ # save curr iteration number (directly saving it to pkl leads to problems with multi GPU)
+ with open(cur_nimg_txt, 'w') as f:
+ f.write(f"nimg: {cur_nimg} best_fid: {best_fid}")
+ del snapshot_data # conserve memory
+
+ # Collect statistics.
+ for phase in phases:
+ value = []
+ if (phase.start_event is not None) and (phase.end_event is not None):
+ phase.end_event.synchronize()
+ value = phase.start_event.elapsed_time(phase.end_event)
+ training_stats.report0('Timing/' + phase.name, value)
+ stats_collector.update()
+ stats_dict = stats_collector.as_dict()
+
+ # Update logs.
+ timestamp = time.time()
+ if stats_jsonl is not None:
+ fields = dict(stats_dict, timestamp=timestamp)
+ stats_jsonl.write(json.dumps(fields) + '\n')
+ stats_jsonl.flush()
+ if stats_tfevents is not None:
+ global_step = int(cur_nimg / 1e3)
+ walltime = timestamp - start_time
+ for name, value in stats_dict.items():
+ stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
+ for name, value in stats_metrics.items():
+ stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
+ stats_tfevents.flush()
+ if progress_fn is not None:
+ progress_fn(cur_nimg // 1000, total_kimg)
+
+ # Update state.
+ cur_tick += 1
+ tick_start_nimg = cur_nimg
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - tick_end_time
+ if done:
+ break
+
+ # Done.
+ if rank == 0:
+ print()
+ print('Exiting...')
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/calc_metrics.py b/diffusion-projected-gan/calc_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b832631a25f2e33049eb3661e999bb06532b41
--- /dev/null
+++ b/diffusion-projected-gan/calc_metrics.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Calculate quality metrics for previous training run or pretrained network pickle."""
+
+import os
+import click
+import json
+import tempfile
+import copy
+import torch
+
+import dnnlib
+import legacy
+from metrics import metric_main
+from metrics import metric_utils
+from torch_utils import training_stats
+from torch_utils import custom_ops
+from torch_utils import misc
+from torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+def subprocess_fn(rank, args, temp_dir):
+ dnnlib.util.Logger(should_flush=True)
+
+ # Init torch.distributed.
+ if args.num_gpus > 1:
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
+ else:
+ init_method = f'file://{init_file}'
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
+
+ # Init torch_utils.
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
+ if rank != 0 or not args.verbose:
+ custom_ops.verbosity = 'none'
+
+ # Configure torch.
+ device = torch.device('cuda', rank)
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+ conv2d_gradfix.enabled = True
+
+ # Print network summary.
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
+ if rank == 0 and args.verbose:
+ z = torch.empty([1, G.z_dim], device=device)
+ c = torch.empty([1, G.c_dim], device=device)
+ misc.print_module_summary(G, [z, c])
+
+ # Calculate each metric.
+ for metric in args.metrics:
+ if rank == 0 and args.verbose:
+ print(f'Calculating {metric}...')
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress, snapshot_pkl=args.network_pkl)
+ if rank == 0:
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
+ if rank == 0 and args.verbose:
+ print()
+
+ # Done.
+ if rank == 0 and args.verbose:
+ print('Exiting...')
+
+#----------------------------------------------------------------------------
+
+def parse_comma_separated_list(s):
+ if isinstance(s, list):
+ return s
+ if s is None or s.lower() == 'none' or s == '':
+ return []
+ return s.split(',')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
+@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
+@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
+@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
+@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
+@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
+
+def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
+ """Calculate quality metrics for previous training run or pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Previous training run: look up options automatically, save result to JSONL file.
+ python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
+
+ \b
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
+
+ \b
+ Recommended metrics:
+ fid50k_full Frechet inception distance against the full dataset.
+ kid50k_full Kernel inception distance against the full dataset.
+ pr50k3_full Precision and recall againt the full dataset.
+ ppl2_wend Perceptual path length in W, endpoints, full image.
+ eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
+ eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
+ eqr50k Equivariance w.r.t. rotation (EQ-R).
+
+ \b
+ Legacy metrics:
+ fid50k Frechet inception distance against 50k real images.
+ kid50k Kernel inception distance against 50k real images.
+ pr50k3 Precision and recall against 50k real images.
+ is50k Inception score for CIFAR-10.
+ """
+ dnnlib.util.Logger(should_flush=True)
+
+ # Validate arguments.
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
+ if not args.num_gpus >= 1:
+ ctx.fail('--gpus must be at least 1')
+
+ # Load network.
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
+ ctx.fail('--network must point to a file or URL')
+ if args.verbose:
+ print(f'Loading network from "{network_pkl}"...')
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
+ network_dict = legacy.load_network_pkl(f)
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
+
+ # Initialize dataset options.
+ if data is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
+ elif network_dict['training_set_kwargs'] is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
+ else:
+ ctx.fail('Could not look up dataset options; please specify --data')
+
+ # Finalize dataset options.
+ args.dataset_kwargs.resolution = args.G.img_resolution
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
+ if mirror is not None:
+ args.dataset_kwargs.xflip = mirror
+
+ # Print dataset options.
+ if args.verbose:
+ print('Dataset options:')
+ print(json.dumps(args.dataset_kwargs, indent=2))
+
+ # Locate run dir.
+ args.run_dir = None
+ if os.path.isfile(network_pkl):
+ pkl_dir = os.path.dirname(network_pkl)
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
+ args.run_dir = pkl_dir
+
+ # Launch processes.
+ if args.verbose:
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn')
+ with tempfile.TemporaryDirectory() as temp_dir:
+ if args.num_gpus == 1:
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
+ else:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ calc_metrics() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/dataset_tool.py b/diffusion-projected-gan/dataset_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..8103795dbcbfb15b5e570825e750887ce87184a3
--- /dev/null
+++ b/diffusion-projected-gan/dataset_tool.py
@@ -0,0 +1,463 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Tool for creating ZIP/PNG based datasets."""
+
+import functools
+import gzip
+import io
+import json
+import os
+import pickle
+import re
+import sys
+import tarfile
+import zipfile
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+import imageio
+
+import click
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+
+#----------------------------------------------------------------------------
+
+def error(msg):
+ print('Error: ' + msg)
+ sys.exit(1)
+
+#----------------------------------------------------------------------------
+
+def parse_tuple(s: str) -> Tuple[int, int]:
+ '''Parse a 'M,N' or 'MxN' integer tuple.
+
+ Example:
+ '4x2' returns (4,2)
+ '0,1' returns (0,1)
+ '''
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
+ if m:
+ return (int(m.group(1)), int(m.group(2)))
+ raise ValueError(f'cannot parse tuple {s}')
+
+#----------------------------------------------------------------------------
+
+def maybe_min(a: int, b: Optional[int]) -> int:
+ if b is not None:
+ return min(a, b)
+ return a
+
+#----------------------------------------------------------------------------
+
+def file_ext(name: Union[str, Path]) -> str:
+ return str(name).split('.')[-1]
+
+#----------------------------------------------------------------------------
+
+def is_image_ext(fname: Union[str, Path]) -> bool:
+ ext = file_ext(fname).lower()
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
+
+#----------------------------------------------------------------------------
+
+def open_image_folder(source_dir, *, max_images: Optional[int]):
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
+
+ # Load labels.
+ labels = {}
+ meta_fname = os.path.join(source_dir, 'dataset.json')
+ if os.path.isfile(meta_fname):
+ with open(meta_fname, 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = { x[0]: x[1] for x in labels }
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ for idx, fname in enumerate(input_images):
+ arch_fname = os.path.relpath(fname, source_dir)
+ arch_fname = arch_fname.replace('\\', '/')
+ img = imageio.imread(fname)
+
+ yield dict(img=img, label=labels.get(arch_fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_image_zip(source, *, max_images: Optional[int]):
+ with zipfile.ZipFile(source, mode='r') as z:
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
+
+ # Load labels.
+ labels = {}
+ if 'dataset.json' in z.namelist():
+ with z.open('dataset.json', 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = { x[0]: x[1] for x in labels }
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ with zipfile.ZipFile(source, mode='r') as z:
+ for idx, fname in enumerate(input_images):
+ with z.open(fname, 'r') as file:
+ img = PIL.Image.open(file) # type: ignore
+ img = np.array(img)
+ yield dict(img=img, label=labels.get(fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
+ import cv2 # pip install opencv-python # pylint: disable=import-error
+ import lmdb # pip install lmdb # pylint: disable=import-error
+
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
+
+ def iterate_images():
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ for idx, (_key, value) in enumerate(txn.cursor()):
+ try:
+ try:
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
+ if img is None:
+ raise IOError('cv2.imdecode failed')
+ img = img[:, :, ::-1] # BGR => RGB
+ except IOError:
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
+ yield dict(img=img, label=None)
+ if idx >= max_idx-1:
+ break
+ except:
+ print(sys.exc_info()[1])
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_cifar10(tarball: str, *, max_images: Optional[int]):
+ images = []
+ labels = []
+
+ with tarfile.open(tarball, 'r:gz') as tar:
+ for batch in range(1, 6):
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
+ with tar.extractfile(member) as file:
+ data = pickle.load(file, encoding='latin1')
+ images.append(data['data'].reshape(-1, 3, 32, 32))
+ labels.append(data['labels'])
+
+ images = np.concatenate(images)
+ labels = np.concatenate(labels)
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_mnist(images_gz: str, *, max_images: Optional[int]):
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
+ assert labels_gz != images_gz
+ images = []
+ labels = []
+
+ with gzip.open(images_gz, 'rb') as f:
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
+ with gzip.open(labels_gz, 'rb') as f:
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
+
+ images = images.reshape(-1, 28, 28)
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def make_transform(
+ transform: Optional[str],
+ output_width: Optional[int],
+ output_height: Optional[int]
+) -> Callable[[np.ndarray], Optional[np.ndarray]]:
+ def scale(width, height, img):
+ w = img.shape[1]
+ h = img.shape[0]
+ if width == w and height == h:
+ return img
+ img = PIL.Image.fromarray(img)
+ ww = width if width is not None else w
+ hh = height if height is not None else h
+ img = img.resize((ww, hh), PIL.Image.LANCZOS)
+ return np.array(img)
+
+ def center_crop(width, height, img):
+ crop = np.min(img.shape[:2])
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
+
+ if len(img.shape) == 2:
+ img = PIL.Image.fromarray(img, 'L').convert('RGB')
+ else:
+ img = PIL.Image.fromarray(img, 'RGB')
+
+ img = img.resize((width, height), PIL.Image.LANCZOS)
+ return np.array(img)
+
+ def center_crop_wide(width, height, img):
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
+ if img.shape[1] < width or ch < height:
+ return None
+
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
+ img = PIL.Image.fromarray(img, 'RGB')
+ img = img.resize((width, height), PIL.Image.LANCZOS)
+ img = np.array(img)
+
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
+ return canvas
+
+ if transform is None:
+ return functools.partial(scale, output_width, output_height)
+ if transform == 'center-crop':
+ if (output_width is None) or (output_height is None):
+ error ('must specify --resolution=WxH when using ' + transform + 'transform')
+ return functools.partial(center_crop, output_width, output_height)
+ if transform == 'center-crop-wide':
+ if (output_width is None) or (output_height is None):
+ error ('must specify --resolution=WxH when using ' + transform + ' transform')
+ return functools.partial(center_crop_wide, output_width, output_height)
+ assert False, 'unknown transform'
+
+#----------------------------------------------------------------------------
+
+def open_dataset(source, *, max_images: Optional[int]):
+ if os.path.isdir(source):
+ if source.rstrip('/').endswith('_lmdb'):
+ return open_lmdb(source, max_images=max_images)
+ else:
+ return open_image_folder(source, max_images=max_images)
+ elif os.path.isfile(source):
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
+ return open_cifar10(source, max_images=max_images)
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
+ return open_mnist(source, max_images=max_images)
+ elif file_ext(source) == 'zip':
+ return open_image_zip(source, max_images=max_images)
+ else:
+ assert False, 'unknown archive type'
+ else:
+ error(f'Missing input file or directory: {source}')
+
+#----------------------------------------------------------------------------
+
+def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
+ dest_ext = file_ext(dest)
+
+ if dest_ext == 'zip':
+ if os.path.dirname(dest) != '':
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
+ zf.writestr(fname, data)
+ return '', zip_write_bytes, zf.close
+ else:
+ # If the output folder already exists, check that is is
+ # empty.
+ #
+ # Note: creating the output directory is not strictly
+ # necessary as folder_write_bytes() also mkdirs, but it's better
+ # to give an error message earlier in case the dest folder
+ # somehow cannot be created.
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
+ error('--dest folder must be empty')
+ os.makedirs(dest, exist_ok=True)
+
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
+ with open(fname, 'wb') as fout:
+ if isinstance(data, str):
+ data = data.encode('utf8')
+ fout.write(data)
+ return dest, folder_write_bytes, lambda: None
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
+@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
+@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
+@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
+@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
+def convert_dataset(
+ ctx: click.Context,
+ source: str,
+ dest: str,
+ max_images: Optional[int],
+ transform: Optional[str],
+ resolution: Optional[Tuple[int, int]]
+):
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
+
+ The input dataset format is guessed from the --source argument:
+
+ \b
+ --source *_lmdb/ Load LSUN dataset
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
+ --source path/ Recursively load all images from path/
+ --source dataset.zip Recursively load all images from dataset.zip
+
+ Specifying the output format and path:
+
+ \b
+ --dest /path/to/dir Save output files under /path/to/dir
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
+
+ The output dataset format can be either an image folder or an uncompressed zip archive.
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
+ offer better training performance on network file systems.
+
+ Images within the dataset archive will be stored as uncompressed PNG.
+ Uncompresed PNGs can be efficiently decoded in the training loop.
+
+ Class labels are stored in a file called 'dataset.json' that is stored at the
+ dataset root folder. This file has the following structure:
+
+ \b
+ {
+ "labels": [
+ ["00000/img00000000.png",6],
+ ["00000/img00000001.png",9],
+ ... repeated for every image in the datase
+ ["00049/img00049999.png",1]
+ ]
+ }
+
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
+ not containing class labels.
+
+ Image scale/crop and resolution requirements:
+
+ Output images must be square-shaped and they must all have the same power-of-two
+ dimensions.
+
+ To scale arbitrary input image size to a specific width and height, use the
+ --resolution option. Output resolution will be either the original
+ input resolution (if resolution was not specified) or the one specified with
+ --resolution option.
+
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
+ center crop transform on the input image. These options should be used with the
+ --resolution option. For example:
+
+ \b
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
+ --transform=center-crop-wide --resolution=512x384
+ """
+
+ PIL.Image.init() # type: ignore
+
+ if dest == '':
+ ctx.fail('--dest output filename or directory must not be an empty string')
+
+ num_files, input_iter = open_dataset(source, max_images=max_images)
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
+
+ if resolution is None: resolution = (None, None)
+ transform_image = make_transform(transform, *resolution)
+
+ dataset_attrs = None
+
+ labels = []
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
+ idx_str = f'{idx:08d}'
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
+
+ # Apply crop and resize.
+ img = transform_image(image['img'])
+
+ # Transform may drop images.
+ if img is None:
+ continue
+
+ # Error check to require uniform image attributes across
+ # the whole dataset.
+ channels = img.shape[2] if img.ndim == 3 else 1
+ cur_image_attrs = {
+ 'width': img.shape[1],
+ 'height': img.shape[0],
+ 'channels': channels
+ }
+ if dataset_attrs is None:
+ dataset_attrs = cur_image_attrs
+ width = dataset_attrs['width']
+ height = dataset_attrs['height']
+ if width != height:
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
+ if dataset_attrs['channels'] not in [1, 3]:
+ error('Input images must be stored as RGB or grayscale')
+ if width != 2 ** int(np.floor(np.log2(width))):
+ error('Image width/height after scale and crop are required to be power-of-two')
+ elif dataset_attrs != cur_image_attrs:
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
+
+ # Save the image as an uncompressed PNG.
+ img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
+ image_bits = io.BytesIO()
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
+
+ metadata = {
+ 'labels': labels if all(x is not None for x in labels) else None
+ }
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
+ close_dest()
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_dataset() # pylint: disable=no-value-for-parameter
diff --git a/diffusion-projected-gan/dnnlib/__init__.py b/diffusion-projected-gan/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7423bffe245d0ff3f32e8658aa67daae454e64e
--- /dev/null
+++ b/diffusion-projected-gan/dnnlib/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/diffusion-projected-gan/dnnlib/util.py b/diffusion-projected-gan/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bbdf3bd8fe1c138cd969d37dcc52190b45c4c16
--- /dev/null
+++ b/diffusion-projected-gan/dnnlib/util.py
@@ -0,0 +1,491 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/diffusion-projected-gan/environment.yml b/diffusion-projected-gan/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9979710d6cff562ed468dc9192106d51c0df8af1
--- /dev/null
+++ b/diffusion-projected-gan/environment.yml
@@ -0,0 +1,142 @@
+name: pg
+channels:
+ - anaconda
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=1_gnu
+ - absl-py=1.0.0=pyhd8ed1ab_0
+ - aiohttp=3.7.0=py39h07f9747_0
+ - async-timeout=3.0.1=py_1000
+ - attrs=21.2.0=pyhd8ed1ab_0
+ - blas=1.0=mkl
+ - blinker=1.4=py_1
+ - brotli=1.0.9=he6710b0_2
+ - brotlipy=0.7.0=py39h27cfd23_1003
+ - c-ares=1.18.1=h7f98852_0
+ - ca-certificates=2021.10.8=ha878542_0
+ - cachetools=4.2.4=pyhd8ed1ab_0
+ - certifi=2021.10.8=py39hf3d152e_1
+ - cffi=1.14.6=py39h400218f_0
+ - chardet=3.0.4=py39h079e4ff_1008
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - click=8.0.3=pyhd3eb1b0_0
+ - cryptography=35.0.0=py39hd23ed53_0
+ - cudatoolkit=11.1.74=h6bb024c_0
+ - cudnn=8.2.1.32=h86fa8c9_0
+ - cycler=0.10.0=py39h06a4308_0
+ - dataclasses=0.8=pyhc8e2a94_3
+ - dbus=1.13.18=hb2f20db_0
+ - dill=0.3.2=py_0
+ - expat=2.4.1=h2531618_2
+ - fontconfig=2.13.1=h6c09931_0
+ - fonttools=4.25.0=pyhd3eb1b0_0
+ - freetype=2.11.0=h70c0345_0
+ - future=0.18.2=py39hf3d152e_4
+ - glib=2.69.1=h5202010_0
+ - google-auth=2.3.3=pyh6c4a22f_0
+ - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
+ - grpcio=1.38.1=py39hff7568b_0
+ - gst-plugins-base=1.14.0=h8213a91_2
+ - gstreamer=1.14.0=h28cd5cc_2
+ - icu=58.2=he6710b0_3
+ - idna=3.3=pyhd3eb1b0_0
+ - imageio=2.9.0=pyhd3eb1b0_0
+ - importlib-metadata=4.8.2=py39hf3d152e_0
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - jpeg=9d=h7f8727e_0
+ - kiwisolver=1.3.1=py39h2531618_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.35.1=h7274673_9
+ - libblas=3.9.0=12_linux64_mkl
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=11.2.0=h1d223b6_11
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
+ - libgfortran4=7.5.0=ha8ba4b0_17
+ - libgomp=11.2.0=h1d223b6_11
+ - liblapack=3.9.0=12_linux64_mkl
+ - libpng=1.6.37=hbc83047_0
+ - libprotobuf=3.18.0=h780b84a_1
+ - libstdcxx-ng=11.2.0=he4da1e4_11
+ - libtiff=4.2.0=h85742a9_0
+ - libuuid=1.0.3=h7f8727e_2
+ - libuv=1.40.0=h7b6447c_0
+ - libwebp-base=1.2.0=h27cfd23_0
+ - libxcb=1.14=h7b6447c_0
+ - libxml2=2.9.12=h03d6c58_0
+ - lz4-c=1.9.3=h295c915_1
+ - magma=2.5.4=ha9b7cf9_2
+ - markdown=3.3.6=pyhd8ed1ab_0
+ - matplotlib=3.4.2=py39h06a4308_0
+ - matplotlib-base=3.4.2=py39hab158f2_0
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py39h7f8727e_0
+ - mkl_fft=1.3.1=py39hd3c417c_0
+ - mkl_random=1.2.2=py39h51133e4_0
+ - multidict=5.2.0=py39h3811e60_1
+ - munkres=1.1.4=py_0
+ - nccl=2.11.4.1=h97a9cb7_0
+ - ncurses=6.3=h7f8727e_2
+ - ninja=1.10.2=py39hd09550d_3
+ - numpy=1.21.2=py39h20f2e39_0
+ - numpy-base=1.21.2=py39h79a1101_0
+ - oauthlib=3.1.1=pyhd8ed1ab_0
+ - olefile=0.46=pyhd3eb1b0_0
+ - openjpeg=2.4.0=h3ad879b_0
+ - openssl=1.1.1l=h7f98852_0
+ - pcre=8.45=h295c915_0
+ - pillow=8.3.1=py39h2c7a002_0
+ - pip=21.2.4=py39h06a4308_0
+ - protobuf=3.18.0=py39he80948d_0
+ - psutil=5.8.0=py39h3811e60_1
+ - pyasn1=0.4.8=py_0
+ - pyasn1-modules=0.2.7=py_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyjwt=2.3.0=pyhd8ed1ab_0
+ - pyopenssl=21.0.0=pyhd3eb1b0_1
+ - pyparsing=3.0.4=pyhd3eb1b0_0
+ - pyqt=5.9.2=py39h2531618_6
+ - pysocks=1.7.1=py39h06a4308_0
+ - python=3.9.7=h12debd9_1
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
+ - python_abi=3.9=2_cp39
+ - pytorch=1.9.1=cuda111py39hb4a4491_3
+ - pytorch-gpu=1.9.1=cuda111py39h788eb59_3
+ - pyu2f=0.1.5=pyhd8ed1ab_0
+ - qt=5.9.7=h5867ecd_1
+ - readline=8.1=h27cfd23_0
+ - requests=2.26.0=pyhd3eb1b0_0
+ - requests-oauthlib=1.3.0=pyh9f0ad1d_0
+ - rsa=4.8=pyhd8ed1ab_0
+ - scipy=1.7.1=py39h292c36d_2
+ - setuptools=58.0.4=py39h06a4308_0
+ - sip=4.19.13=py39h2531618_0
+ - six=1.16.0=pyhd3eb1b0_0
+ - sleef=3.5.1=h9b69904_2
+ - sqlite=3.36.0=hc218d9a_0
+ - tensorboard=2.7.0=pyhd8ed1ab_0
+ - tensorboard-data-server=0.6.0=py39h95dcef6_1
+ - tensorboard-plugin-wit=1.8.0=pyh44b312d_0
+ - timm=0.4.12=pyhd8ed1ab_0
+ - tk=8.6.11=h1ccaba5_0
+ - torchvision=0.10.1=py39cuda111hcd06603_0_cuda
+ - tornado=6.1=py39h27cfd23_0
+ - tqdm=4.62.2=pyhd3eb1b0_1
+ - typing_extensions=3.10.0.2=pyh06a4308_0
+ - tzdata=2021e=hda174b7_0
+ - urllib3=1.26.7=pyhd3eb1b0_0
+ - werkzeug=2.0.1=pyhd8ed1ab_0
+ - wheel=0.37.0=pyhd3eb1b0_1
+ - xz=5.2.5=h7b6447c_0
+ - yarl=1.7.2=py39h3811e60_1
+ - zipp=3.6.0=pyhd8ed1ab_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.4.9=haebb681_0
+ - pip:
+ - glfw==2.2.0
+ - imageio-ffmpeg==0.4.3
+ - imgui==1.3.0
+ - pyopengl==3.1.5
+ - pyspng==0.1.0
diff --git a/diffusion-projected-gan/gen_images.py b/diffusion-projected-gan/gen_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6cab990e088817bfc4f49c54e0cb8382851c12c
--- /dev/null
+++ b/diffusion-projected-gan/gen_images.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate images using pretrained network pickle."""
+
+import os
+import re
+from typing import List, Optional, Tuple, Union
+
+import click
+import dnnlib
+import numpy as np
+import PIL.Image
+import torch
+
+import legacy
+
+#----------------------------------------------------------------------------
+
+def parse_range(s: Union[str, List]) -> List[int]:
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ '''
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ m = range_re.match(p)
+ if m:
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+#----------------------------------------------------------------------------
+
+def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
+ '''Parse a floating point 2-vector of syntax 'a,b'.
+
+ Example:
+ '0,1' returns (0,1)
+ '''
+ if isinstance(s, tuple): return s
+ parts = s.split(',')
+ if len(parts) == 2:
+ return (float(parts[0]), float(parts[1]))
+ raise ValueError(f'cannot parse 2-vector {s}')
+
+#----------------------------------------------------------------------------
+
+def make_transform(translate: Tuple[float,float], angle: float):
+ m = np.eye(3)
+ s = np.sin(angle/360.0*np.pi*2)
+ c = np.cos(angle/360.0*np.pi*2)
+ m[0][0] = c
+ m[0][1] = s
+ m[0][2] = translate[0]
+ m[1][0] = -s
+ m[1][1] = c
+ m[1][2] = translate[1]
+ return m
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
+@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
+@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
+def generate_images(
+ network_pkl: str,
+ seeds: List[int],
+ truncation_psi: float,
+ noise_mode: str,
+ outdir: str,
+ translate: Tuple[float,float],
+ rotate: float,
+ class_idx: Optional[int]
+):
+ """Generate images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+
+ \b
+ # Generate uncurated images with truncation using the MetFaces-U dataset
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
+ """
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ os.makedirs(outdir, exist_ok=True)
+
+ # Labels.
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print ('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Generate images.
+ for seed_idx, seed in enumerate(seeds):
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
+
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
+ # operations in the network.
+ if hasattr(G.synthesis, 'input'):
+ m = make_transform(translate, rotate)
+ m = np.linalg.inv(m)
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
+
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
+
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/gen_video.py b/diffusion-projected-gan/gen_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb6f3552bf9c83ae04709801fa0facce84f45d2
--- /dev/null
+++ b/diffusion-projected-gan/gen_video.py
@@ -0,0 +1,192 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate lerp videos using pretrained network pickle."""
+
+import copy
+import os
+import re
+from typing import List, Optional, Tuple, Union
+
+import click
+import dnnlib
+import imageio
+import numpy as np
+import scipy.interpolate
+import torch
+from tqdm import tqdm
+
+import legacy
+
+#----------------------------------------------------------------------------
+
+def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
+ batch_size, channels, img_h, img_w = img.shape
+ if grid_w is None:
+ grid_w = batch_size // grid_h
+ assert batch_size == grid_w * grid_h
+ if float_to_uint8:
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
+ img = img.permute(2, 0, 3, 1, 4)
+ img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
+ if chw_to_hwc:
+ img = img.permute(1, 2, 0)
+ if to_numpy:
+ img = img.cpu().numpy()
+ return img
+
+#----------------------------------------------------------------------------
+
+def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), class_idx=None, **video_kwargs):
+ grid_w = grid_dims[0]
+ grid_h = grid_dims[1]
+
+ if num_keyframes is None:
+ if len(seeds) % (grid_w*grid_h) != 0:
+ raise ValueError('Number of input seeds must be divisible by grid W*H')
+ num_keyframes = len(seeds) // (grid_w*grid_h)
+
+ all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
+ for idx in range(num_keyframes*grid_h*grid_w):
+ all_seeds[idx] = seeds[idx % len(seeds)]
+
+ if shuffle_seed is not None:
+ rng = np.random.RandomState(seed=shuffle_seed)
+ rng.shuffle(all_seeds)
+
+ zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device).float()
+ # Labels.
+ label = torch.zeros([zs.size(0), G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print ('warn: --class=lbl ignored when running on an unconditional network')
+
+ ws = G.mapping(z=zs, c=label, truncation_psi=psi)
+ _ = G.synthesis(ws[:1], c=label) # warm up
+ ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
+
+ # Interpolation.
+ grid = []
+ for yi in range(grid_h):
+ row = []
+ for xi in range(grid_w):
+ x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
+ y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
+ interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
+ row.append(interp)
+ grid.append(row)
+
+ # Render video.
+ video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
+ for frame_idx in tqdm(range(num_keyframes * w_frames)):
+ imgs = []
+ for yi in range(grid_h):
+ for xi in range(grid_w):
+ interp = grid[yi][xi]
+ w = torch.from_numpy(interp(frame_idx / w_frames)).to(device).float()
+ img = G.synthesis(w.unsqueeze(0), c=label, noise_mode='const')[0]
+ imgs.append(img)
+ video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
+ video_out.close()
+
+#----------------------------------------------------------------------------
+
+def parse_range(s: Union[str, List[int]]) -> List[int]:
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ '''
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ m = range_re.match(p)
+ if m:
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+#----------------------------------------------------------------------------
+
+def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
+ '''Parse a 'M,N' or 'MxN' integer tuple.
+
+ Example:
+ '4x2' returns (4,2)
+ '0,1' returns (0,1)
+ '''
+ if isinstance(s, tuple): return s
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
+ if m:
+ return (int(m.group(1)), int(m.group(2)))
+ raise ValueError(f'cannot parse tuple {s}')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
+@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
+@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
+@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
+@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+def generate_images(
+ network_pkl: str,
+ seeds: List[int],
+ shuffle_seed: Optional[int],
+ truncation_psi: float,
+ grid: Tuple[int,int],
+ num_keyframes: Optional[int],
+ w_frames: int,
+ output: str,
+ class_idx: Optional[int],
+):
+ """Render a latent vector interpolation video.
+
+ Examples:
+
+ \b
+ # Render a 4x2 grid of interpolations for seeds 0 through 31.
+ python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+
+ Animation length and seed keyframes:
+
+ The animation length is either determined based on the --seeds value or explicitly
+ specified using the --num-keyframes option.
+
+ When num keyframes is specified with --num-keyframes, the output video length
+ will be 'num_keyframes*w_frames' frames.
+
+ If --num-keyframes is not specified, the number of seeds given with
+ --seeds must be divisible by grid size W*H (--grid). In this case the
+ output video length will be '# seeds/(w*h)*w_frames' frames.
+ """
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, class_idx=class_idx)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/legacy.py b/diffusion-projected-gan/legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2ed5da622c49fca3ee2db6357da65b0868a16c
--- /dev/null
+++ b/diffusion-projected-gan/legacy.py
@@ -0,0 +1,328 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Converting legacy network pickle into the new format."""
+
+import click
+import pickle
+import re
+import copy
+import numpy as np
+import torch
+import io
+import dnnlib
+from torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def load_network_pkl(f, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_G, tf_D, tf_Gs = data
+ G = convert_tf_generator(tf_G)
+ D = convert_tf_discriminator(tf_D)
+ G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=G, D=D, G_ema=G_ema)
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ assert isinstance(data['G'], torch.nn.Module)
+ assert isinstance(data['D'], torch.nn.Module)
+ assert isinstance(data['G_ema'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['G', 'D', 'G_ema']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
+ fp16_kwargs.num_fp16_res = 4
+ fp16_kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ return data
+
+#----------------------------------------------------------------------------
+
+class _TFNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _LegacyUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'dnnlib.tflib.network' and name == 'Network':
+ return _TFNetworkStub
+ if module == 'torch.storage' and name == '_load_from_bytes':
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
+ return super().find_class(module, name)
+
+#----------------------------------------------------------------------------
+
+def _collect_tf_params(tf_net):
+ # pylint: disable=protected-access
+ tf_params = dict()
+ def recurse(prefix, tf_net):
+ for name, value in tf_net.variables:
+ tf_params[prefix + name] = value
+ for name, comp in tf_net.components.items():
+ recurse(prefix + name + '/', comp)
+ recurse('', tf_net)
+ return tf_params
+
+#----------------------------------------------------------------------------
+
+def _populate_module_params(module, *patterns):
+ for name, tensor in misc.named_params_and_buffers(module):
+ found = False
+ value = None
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
+ match = re.fullmatch(pattern, name)
+ if match:
+ found = True
+ if value_fn is not None:
+ value = value_fn(*match.groups())
+ break
+ try:
+ assert found
+ if value is not None:
+ tensor.copy_(torch.from_numpy(np.array(value)))
+ except:
+ print(name, list(tensor.shape))
+ raise
+
+#----------------------------------------------------------------------------
+
+def convert_tf_generator(tf_G):
+ if tf_G.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_G.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None, none=None):
+ known_kwargs.add(tf_name)
+ val = tf_kwargs.get(tf_name, default)
+ return val if val is not None else none
+
+ # Convert kwargs.
+ from pg_modules import networks_stylegan2
+ network_class = networks_stylegan2.Generator
+ kwargs = dnnlib.EasyDict(
+ z_dim = kwarg('latent_size', 512),
+ c_dim = kwarg('label_size', 0),
+ w_dim = kwarg('dlatent_size', 512),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ architecture = kwarg('architecture', 'skip'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ use_noise = kwarg('use_noise', True),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 8),
+ embed_features = kwarg('label_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('truncation_psi')
+ kwarg('truncation_cutoff')
+ kwarg('style_mixing_prob')
+ kwarg('structure')
+ kwarg('conditioning')
+ kwarg('fused_modconv')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_G)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
+ kwargs.synthesis.kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ G = network_class(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(G,
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'.*\.resample_filter', None,
+ r'.*\.act_filter', None,
+ )
+ return G
+
+#----------------------------------------------------------------------------
+
+def convert_tf_discriminator(tf_D):
+ if tf_D.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_D.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None):
+ known_kwargs.add(tf_name)
+ return tf_kwargs.get(tf_name, default)
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ c_dim = kwarg('label_size', 0),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ architecture = kwarg('architecture', 'resnet'),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ cmap_dim = kwarg('mapping_fmaps', None),
+ block_kwargs = dnnlib.EasyDict(
+ activation = kwarg('nonlinearity', 'lrelu'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ freeze_layers = kwarg('freeze_layers', 0),
+ ),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 0),
+ embed_features = kwarg('mapping_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
+ ),
+ epilogue_kwargs = dnnlib.EasyDict(
+ mbstd_group_size = kwarg('mbstd_group_size', None),
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('structure')
+ kwarg('conditioning')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_D)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
+ kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ #from pg_modules import networks_stylegan2
+ from pg_modules.discriminator import ProjectedDiscriminator
+
+ D = ProjectedDiscriminator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(D,
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
+ r'.*\.resample_filter', None,
+ )
+ return D
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--source', help='Input pickle', required=True, metavar='PATH')
+@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
+@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
+def convert_network_pickle(source, dest, force_fp16):
+ """Convert legacy network pickle into the native PyTorch format.
+
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
+
+ Example:
+
+ \b
+ python legacy.py \\
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
+ --dest=stylegan2-cat-config-f.pkl
+ """
+ print(f'Loading "{source}"...')
+ with dnnlib.util.open_url(source) as f:
+ data = load_network_pkl(f, force_fp16=force_fp16)
+ print(f'Saving "{dest}"...')
+ with open(dest, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/metrics/equivariance.py b/diffusion-projected-gan/metrics/equivariance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5559aca37b45e246a2ed279193cf4d59011f0b9
--- /dev/null
+++ b/diffusion-projected-gan/metrics/equivariance.py
@@ -0,0 +1,267 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import copy
+import numpy as np
+import torch
+import torch.fft
+from torch_utils.ops import upfirdn2d
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+# Utilities.
+
+def sinc(x):
+ y = (x * np.pi).abs()
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
+
+def lanczos_window(x, a):
+ x = x.abs() / a
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
+
+def rotation_matrix(angle):
+ angle = torch.as_tensor(angle).to(torch.float32)
+ mat = torch.eye(3, device=angle.device)
+ mat[0, 0] = angle.cos()
+ mat[0, 1] = angle.sin()
+ mat[1, 0] = -angle.sin()
+ mat[1, 1] = angle.cos()
+ return mat
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.1.
+
+def apply_integer_translation(x, tx, ty):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.round().to(torch.int64)
+ iy = ty.round().to(torch.int64)
+
+ z = torch.zeros_like(x)
+ m = torch.zeros_like(x)
+ if abs(ix) < W and abs(iy) < H:
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.2.
+
+def apply_fractional_translation(x, tx, ty, a=3):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.floor().to(torch.int64)
+ iy = ty.floor().to(torch.int64)
+ fx = tx - ix
+ fy = ty - iy
+ b = a - 1
+
+ z = torch.zeros_like(x)
+ zx0 = max(ix - b, 0)
+ zy0 = max(iy - b, 0)
+ zx1 = min(ix + a, 0) + W
+ zy1 = min(iy + a, 0) + H
+ if zx0 < zx1 and zy0 < zy1:
+ taps = torch.arange(a * 2, device=x.device) - b
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
+ y = x
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
+ z[:, :, zy0:zy1, zx0:zx1] = y
+
+ m = torch.zeros_like(x)
+ mx0 = max(ix + a, 0)
+ my0 = max(iy + a, 0)
+ mx1 = min(ix - b, 0) + W
+ my1 = min(iy - b, 0) + H
+ if mx0 < mx1 and my0 < my1:
+ m[:, :, my0:my1, mx0:mx1] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Construct an oriented low-pass filter that applies the appropriate
+# bandlimit with respect to the input and output of the given affine 2D
+# image transformation.
+
+def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
+ assert a <= amax < aflt
+ mat = torch.as_tensor(mat).to(torch.float32)
+
+ # Construct 2D filter taps in input & output coordinate spaces.
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
+ yi, xi = torch.meshgrid(taps, taps)
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
+
+ # Convolution of two oriented 2D sinc filters.
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
+
+ # Convolution of two oriented 2D Lanczos windows.
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
+
+ # Construct windowed FIR filter.
+ f = f * w
+
+ # Finalize.
+ c = (aflt - amax) * up
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
+ return f
+
+#----------------------------------------------------------------------------
+# Apply the given affine transformation to a batch of 2D images.
+
+def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
+ _N, _C, H, W = x.shape
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
+
+ # Construct filter.
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
+ p = f.shape[0] // 2
+
+ # Construct sampling grid.
+ theta = mat.inverse()
+ theta[:2, 2] *= 2
+ theta[0, 2] += 1 / up / W
+ theta[1, 2] += 1 / up / H
+ theta[0, :] *= W / (W + p / up * 2)
+ theta[1, :] *= H / (H + p / up * 2)
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
+
+ # Resample image.
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+ # Form mask.
+ m = torch.zeros_like(y)
+ c = p * 2 + 1
+ m[:, :, c:-c, c:-c] = 1
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply fractional rotation to a batch of 2D images. Corresponds to the
+# operator R_\alpha in Appendix E.3.
+
+def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(angle)
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
+
+#----------------------------------------------------------------------------
+# Modify the frequency content of a batch of 2D images as if they had undergo
+# fractional rotation -- but without actually rotating them. Corresponds to
+# the operator R^*_\alpha in Appendix E.3.
+
+def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(-angle)
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
+ y = upfirdn2d.filter2d(x=x, f=f)
+ m = torch.zeros_like(y)
+ c = f.shape[0] // 2
+ m[:, :, c:-c, c:-c] = 1
+ return y, m
+
+#----------------------------------------------------------------------------
+# Compute the selected equivariance metrics for the given generator.
+
+def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ I = torch.eye(3, device=opts.device)
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
+ if M is None:
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ sums = None
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ s = []
+
+ # Randomize noise buffers, if any.
+ for name, buf in G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Run mapping network.
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
+ c = next(c_iter)
+ ws = G.mapping(z=z, c=c)
+
+ # Generate reference image.
+ M[:] = I
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+
+ # Integer translation (EQ-T).
+ if compute_eqt_int:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ t = (t * G.img_resolution).round() / G.img_resolution
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Fractional translation (EQ-T_frac).
+ if compute_eqt_frac:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Rotation (EQ-R).
+ if compute_eqr:
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
+ M[:] = rotation_matrix(-angle)
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
+ mask = ref_mask * pseudo_mask
+ s += [(ref - pseudo).square() * mask, mask]
+
+ # Accumulate results.
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
+ sums = sums + s if sums is not None else s
+ progress.update(num_samples)
+
+ # Compute PSNRs.
+ if opts.num_gpus > 1:
+ torch.distributed.all_reduce(sums)
+ sums = sums.cpu()
+ mses = sums[0::2] / sums[1::2]
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
+ psnrs = tuple(psnrs.numpy())
+ return psnrs[0] if len(psnrs) == 1 else psnrs
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/metrics/frechet_inception_distance.py b/diffusion-projected-gan/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99c828e8bde37d46b09c64217977fd485781460
--- /dev/null
+++ b/diffusion-projected-gan/metrics/frechet_inception_distance.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen, swav=False, sfid=False):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, swav=swav, sfid=sfid).get_mean_cov()
+
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, swav=swav, sfid=sfid).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/metrics/inception_score.py b/diffusion-projected-gan/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a3a442b4cd7f993773ee77c5394796c28c2ef8
--- /dev/null
+++ b/diffusion-projected-gan/metrics/inception_score.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/metrics/kernel_inception_distance.py b/diffusion-projected-gan/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69325c1ef4e2894817ef6003e9335c4de657199
--- /dev/null
+++ b/diffusion-projected-gan/metrics/kernel_inception_distance.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/metrics/metric_main.py b/diffusion-projected-gan/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..27adc6e7e81ae39ab9c8cc18e9a5a490a744dc80
--- /dev/null
+++ b/diffusion-projected-gan/metrics/metric_main.py
@@ -0,0 +1,151 @@
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Main API for computing and reporting quality metrics."""
+
+import os
+import time
+import json
+import torch
+import dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import precision_recall
+from . import perceptual_path_length
+from . import inception_score
+from . import equivariance
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ results = _metric_dict[metric](opts)
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Recommended metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+@register_metric
+def fid10k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000)
+ return dict(fid10k_full=fid)
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def pr50k3_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
+
+@register_metric
+def ppl2_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
+ return dict(ppl2_wend=ppl)
+
+@register_metric
+def eqt50k_int(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
+ return dict(eqt50k_int=psnr)
+
+@register_metric
+def eqt50k_frac(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
+ return dict(eqt50k_frac=psnr)
+
+@register_metric
+def eqr50k(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
+ return dict(eqr50k=psnr)
+
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+@register_metric
+def pr50k3(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
diff --git a/diffusion-projected-gan/metrics/metric_utils.py b/diffusion-projected-gan/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7e3960b7cee3d49ee3c3b427a67201bb6b9c1d5
--- /dev/null
+++ b/diffusion-projected-gan/metrics/metric_utils.py
@@ -0,0 +1,298 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utilities used internally by the quality metrics."""
+
+import os
+import time
+import hashlib
+import pickle
+import copy
+import uuid
+import numpy as np
+import torch
+import dnnlib
+from tqdm import tqdm
+
+#----------------------------------------------------------------------------
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, run_dir=None, cur_nimg=None, snapshot_pkl=None):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+ self.run_dir = run_dir
+ self.cur_nimg = cur_nimg
+ self.snapshot_pkl = snapshot_pkl
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+ _feature_detector_cache[key] = pickle.load(f).to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+def iterate_random_labels(opts, batch_size):
+ if opts.G.c_dim == 0:
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
+ while True:
+ yield c
+ else:
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ while True:
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ yield c
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items):
+ assert (self.num_items is None) or (cur_items <= self.num_items)
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, swav=False, sfid=False, **stats_kwargs):
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ det_name = get_feature_detector_name(detector_url)
+
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ cache_tag = f'{dataset.name}-{det_name}-{md5.hexdigest()}'
+ cache_file = os.path.join('.', 'dnnlib', 'gan-metrics', cache_tag + '.pkl')
+ # cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return FeatureStats.load(cache_file)
+
+ print('Calculating the stats for this dataset the first time\n')
+ print(f'Saving them to {cache_file}')
+
+ # Initialize.
+ num_items = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+
+ # get detector
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
+ for images, _labels in tqdm(torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs)):
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+
+ with torch.no_grad():
+ features = detector(images.to(opts.device), **detector_kwargs)
+
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, swav=False, sfid=False, **stats_kwargs):
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
+
+ # Initialize.
+ stats = FeatureStats(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+
+ # get detector
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ # img = G(z=z, c=next(c_iter), truncation_psi=0.1, **opts.G_kwargs)
+ img = G(z=z, c=next(c_iter), **opts.G_kwargs)
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ images.append(img)
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+
+ with torch.no_grad():
+ features = detector(images.to(opts.device), **detector_kwargs)
+
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
diff --git a/diffusion-projected-gan/metrics/perceptual_path_length.py b/diffusion-projected-gan/metrics/perceptual_path_length.py
new file mode 100644
index 0000000000000000000000000000000000000000..c68519fea298b076ef317b5ea75e22a77225baaf
--- /dev/null
+++ b/diffusion-projected-gan/metrics/perceptual_path_length.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
+Architecture for Generative Adversarial Networks". Matches the original
+implementation by Karras et al. at
+https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
+
+import copy
+import numpy as np
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+#----------------------------------------------------------------------------
+
+class PPLSampler(torch.nn.Module):
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__()
+ self.G = copy.deepcopy(G)
+ self.G_kwargs = G_kwargs
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.vgg16 = copy.deepcopy(vgg16)
+
+ def forward(self, c):
+ # Generate random latents and interpolation t-values.
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
+ else: # space == 'z'
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
+
+ # Randomize noise buffers.
+ for name, buf in self.G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Generate images.
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
+
+ # Center crop.
+ if self.crop:
+ assert img.shape[2] == img.shape[3]
+ c = img.shape[2] // 8
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample to 256x256.
+ factor = self.G.img_resolution // 256
+ if factor > 1:
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
+
+ # Scale dynamic range from [-1,1] to [0,255].
+ img = (img + 1) * (255 / 2)
+ if self.G.img_channels == 1:
+ img = img.repeat([1, 3, 1, 1])
+
+ # Evaluate differential LPIPS.
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
+ return dist
+
+#----------------------------------------------------------------------------
+
+def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
+
+ # Setup sampler and labels.
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
+ sampler.eval().requires_grad_(False).to(opts.device)
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ dist = []
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ x = sampler(next(c_iter))
+ for src in range(opts.num_gpus):
+ y = x.clone()
+ if opts.num_gpus > 1:
+ torch.distributed.broadcast(y, src=src)
+ dist.append(y)
+ progress.update(num_samples)
+
+ # Compute PPL.
+ if opts.rank != 0:
+ return float('nan')
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
+ lo = np.percentile(dist, 1, interpolation='lower')
+ hi = np.percentile(dist, 99, interpolation='higher')
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
+ return float(ppl)
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/metrics/precision_recall.py b/diffusion-projected-gan/metrics/precision_recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..120ef801ba488ae5288d14cbee49b566492b6695
--- /dev/null
+++ b/diffusion-projected-gan/metrics/precision_recall.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper "Improved Precision and Recall
+Metric for Assessing Generative Models". Matches the original implementation
+by Kynkaanniemi et al. at
+https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
+
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
+ assert 0 <= rank < num_gpus
+ num_cols = col_features.shape[0]
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
+ dist_batches = []
+ for col_batch in col_batches[rank :: num_gpus]:
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
+ for src in range(num_gpus):
+ dist_broadcast = dist_batch.clone()
+ if num_gpus > 1:
+ torch.distributed.broadcast(dist_broadcast, src=src)
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
+
+#----------------------------------------------------------------------------
+
+def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ detector_kwargs = dict(return_features=True)
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
+
+ results = dict()
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
+ kth = []
+ for manifold_batch in manifold.split(row_batch_size):
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
+ kth = torch.cat(kth) if opts.rank == 0 else None
+ pred = []
+ for probes_batch in probes.split(row_batch_size):
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
+ return results['precision'], results['recall']
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/pg_modules/blocks.py b/diffusion-projected-gan/pg_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..78bd113bac1cd6486ede92b1ae8d5adfb678eb81
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/blocks.py
@@ -0,0 +1,325 @@
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import spectral_norm
+
+
+### single layers
+
+
+def conv2d(*args, **kwargs):
+ return spectral_norm(nn.Conv2d(*args, **kwargs))
+
+
+def convTranspose2d(*args, **kwargs):
+ return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
+
+
+def embedding(*args, **kwargs):
+ return spectral_norm(nn.Embedding(*args, **kwargs))
+
+
+def linear(*args, **kwargs):
+ return spectral_norm(nn.Linear(*args, **kwargs))
+
+
+def NormLayer(c, mode='batch'):
+ if mode == 'group':
+ return nn.GroupNorm(c//2, c)
+ elif mode == 'batch':
+ return nn.BatchNorm2d(c)
+
+
+### Activations
+
+
+class GLU(nn.Module):
+ def forward(self, x):
+ nc = x.size(1)
+ assert nc % 2 == 0, 'channels dont divide 2!'
+ nc = int(nc/2)
+ return x[:, :nc] * torch.sigmoid(x[:, nc:])
+
+
+class Swish(nn.Module):
+ def forward(self, feat):
+ return feat * torch.sigmoid(feat)
+
+
+### Upblocks
+
+
+class InitLayer(nn.Module):
+ def __init__(self, nz, channel, sz=4):
+ super().__init__()
+
+ self.init = nn.Sequential(
+ convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
+ NormLayer(channel*2),
+ GLU(),
+ )
+
+ def forward(self, noise):
+ noise = noise.view(noise.shape[0], -1, 1, 1)
+ return self.init(noise)
+
+
+def UpBlockSmall(in_planes, out_planes):
+ block = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
+ NormLayer(out_planes*2), GLU())
+ return block
+
+
+class UpBlockSmallCond(nn.Module):
+ def __init__(self, in_planes, out_planes, z_dim):
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
+ self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
+
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
+ self.bn = which_bn(2*out_planes)
+ self.act = GLU()
+
+ def forward(self, x, c):
+ x = self.up(x)
+ x = self.conv(x)
+ x = self.bn(x, c)
+ x = self.act(x)
+ return x
+
+
+def UpBlockBig(in_planes, out_planes):
+ block = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
+ NoiseInjection(),
+ NormLayer(out_planes*2), GLU(),
+ conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
+ NoiseInjection(),
+ NormLayer(out_planes*2), GLU()
+ )
+ return block
+
+
+class UpBlockBigCond(nn.Module):
+ def __init__(self, in_planes, out_planes, z_dim):
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
+ self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
+ self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
+
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
+ self.bn1 = which_bn(2*out_planes)
+ self.bn2 = which_bn(2*out_planes)
+ self.act = GLU()
+ self.noise = NoiseInjection()
+
+ def forward(self, x, c):
+ # block 1
+ x = self.up(x)
+ x = self.conv1(x)
+ x = self.noise(x)
+ x = self.bn1(x, c)
+ x = self.act(x)
+
+ # block 2
+ x = self.conv2(x)
+ x = self.noise(x)
+ x = self.bn2(x, c)
+ x = self.act(x)
+
+ return x
+
+
+class SEBlock(nn.Module):
+ def __init__(self, ch_in, ch_out):
+ super().__init__()
+ self.main = nn.Sequential(
+ nn.AdaptiveAvgPool2d(4),
+ conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
+ Swish(),
+ conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, feat_small, feat_big):
+ return feat_big * self.main(feat_small)
+
+
+### Downblocks
+
+
+class SeparableConv2d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, bias=False):
+ super(SeparableConv2d, self).__init__()
+ self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
+ groups=in_channels, bias=bias, padding=1)
+ self.pointwise = conv2d(in_channels, out_channels,
+ kernel_size=1, bias=bias)
+
+ def forward(self, x):
+ out = self.depthwise(x)
+ out = self.pointwise(out)
+ return out
+
+
+class DownBlock(nn.Module):
+ def __init__(self, in_planes, out_planes, separable=False):
+ super().__init__()
+ if not separable:
+ self.main = nn.Sequential(
+ conv2d(in_planes, out_planes, 4, 2, 1),
+ NormLayer(out_planes),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+ else:
+ self.main = nn.Sequential(
+ SeparableConv2d(in_planes, out_planes, 3),
+ NormLayer(out_planes),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.AvgPool2d(2, 2),
+ )
+
+ def forward(self, feat):
+ return self.main(feat)
+
+
+class DownBlockPatch(nn.Module):
+ def __init__(self, in_planes, out_planes, separable=False):
+ super().__init__()
+ self.main = nn.Sequential(
+ DownBlock(in_planes, out_planes, separable),
+ conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
+ NormLayer(out_planes),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+
+ def forward(self, feat):
+ return self.main(feat)
+
+
+### CSM
+
+
+class ResidualConvUnit(nn.Module):
+ def __init__(self, cin, activation, bn):
+ super().__init__()
+ self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ return self.skip_add.add(self.conv(x), x)
+
+
+class FeatureFusionBlock(nn.Module):
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
+ super().__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ output = xs[0]
+
+ if len(xs) == 2:
+ output = self.skip_add.add(output, xs[1])
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
+
+### Misc
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
+
+ def forward(self, feat, noise=None):
+ if noise is None:
+ batch, _, height, width = feat.shape
+ noise = torch.randn(batch, 1, height, width).to(feat.device)
+
+ return feat + self.weight * noise
+
+
+class CCBN(nn.Module):
+ ''' conditional batchnorm '''
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
+ super().__init__()
+ self.output_size, self.input_size = output_size, input_size
+
+ # Prepare gain and bias layers
+ self.gain = which_linear(input_size, output_size)
+ self.bias = which_linear(input_size, output_size)
+
+ # epsilon to avoid dividing by 0
+ self.eps = eps
+ # Momentum
+ self.momentum = momentum
+
+ self.register_buffer('stored_mean', torch.zeros(output_size))
+ self.register_buffer('stored_var', torch.ones(output_size))
+
+ def forward(self, x, y):
+ # Calculate class-conditional gains and biases
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
+ self.training, 0.1, self.eps)
+ return out * gain + bias
+
+
+class Interpolate(nn.Module):
+ """Interpolation module."""
+
+ def __init__(self, size, mode='bilinear', align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.size = size
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x,
+ size=self.size,
+ mode=self.mode,
+ align_corners=self.align_corners,
+ )
+
+ return x
diff --git a/diffusion-projected-gan/pg_modules/diffaug.py b/diffusion-projected-gan/pg_modules/diffaug.py
new file mode 100644
index 0000000000000000000000000000000000000000..54020be64733f1a92e454b2959f6297f7dcb9e22
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/diffaug.py
@@ -0,0 +1,76 @@
+# Differentiable Augmentation for Data-Efficient GAN Training
+# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
+# https://arxiv.org/pdf/2006.10738
+
+import torch
+import torch.nn.functional as F
+
+
+def DiffAugment(x, policy='', channels_first=True):
+ if policy:
+ if not channels_first:
+ x = x.permute(0, 3, 1, 2)
+ for p in policy.split(','):
+ for f in AUGMENT_FNS[p]:
+ x = f(x)
+ if not channels_first:
+ x = x.permute(0, 2, 3, 1)
+ x = x.contiguous()
+ return x
+
+
+def rand_brightness(x):
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
+ return x
+
+
+def rand_saturation(x):
+ x_mean = x.mean(dim=1, keepdim=True)
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
+ return x
+
+
+def rand_contrast(x):
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
+ return x
+
+
+def rand_translation(x, ratio=0.125):
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
+ )
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
+ return x
+
+
+def rand_cutout(x, ratio=0.2):
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
+ )
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
+ mask[grid_batch, grid_x, grid_y] = 0
+ x = x * mask.unsqueeze(1)
+ return x
+
+
+AUGMENT_FNS = {
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
+ 'translation': [rand_translation],
+ 'cutout': [rand_cutout],
+}
diff --git a/diffusion-projected-gan/pg_modules/diffusion.py b/diffusion-projected-gan/pg_modules/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dcc879895239fd98129e4f47f19cf4c2c1878e2
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/diffusion.py
@@ -0,0 +1,141 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import scipy.signal
+import torch
+from torch_utils import persistence
+from torch_utils import misc
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import grid_sample_gradfix
+from torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+# Helpers for doing diffusion process.
+
+
+def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps):
+ def sigmoid(x):
+ return 1 / (np.exp(-x) + 1)
+
+ def continuous_t_beta(t, T):
+ b_max = 5.
+ b_min = 0.1
+ alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
+ return 1 - alpha
+
+ if beta_schedule == "continuous_t":
+ betas = continuous_t_beta(np.arange(1, num_diffusion_timesteps+1), num_diffusion_timesteps)
+ elif beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "sigmoid":
+ betas = np.linspace(-6, 6, num_diffusion_timesteps)
+ betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type='gauss', noise_std=1.0):
+ batch_size, num_channels, _, _ = x_0.shape
+ if noise_type == 'gauss':
+ noise = torch.randn_like(x_0, device=x_0.device) * noise_std
+ elif noise_type == 'bernoulli':
+ noise = (torch.bernoulli(torch.ones_like(x_0) * 0.5) * 2 - 1.) * noise_std
+ else:
+ raise NotImplementedError(noise_type)
+ alphas_t_sqrt = alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1)
+ one_minus_alphas_bar_t_sqrt = one_minus_alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1)
+ x_t = alphas_t_sqrt * x_0 + one_minus_alphas_bar_t_sqrt * noise
+ return x_t
+
+
+@persistence.persistent_class
+class Diffusion(torch.nn.Module):
+ def __init__(self,
+ beta_schedule='linear', beta_start=1e-4, beta_end=1e-2,
+ t_min=5, t_max=500, noise_std=0.5,
+ ):
+ super().__init__()
+ self.p = 0.0 # Overall multiplier for augmentation probability.
+ self.noise_type = self.base_noise_type = 'gauss'
+ self.base_schedule = beta_schedule
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.t_min = t_min
+ self.t_max = t_max
+ self.t_add = t_max - t_min
+ self.update_T()
+
+ # Image-space corruptions.
+ self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
+
+ def set_diffusion_process(self, t, beta_schedule):
+
+ betas = get_beta_schedule(
+ beta_schedule=beta_schedule,
+ beta_start=self.beta_start,
+ beta_end=self.beta_end,
+ num_diffusion_timesteps=t,
+ )
+
+ betas = self.betas = torch.from_numpy(betas).float()
+ self.num_timesteps = betas.shape[0]
+
+ alphas = self.alphas = 1.0 - betas
+ alphas_cumprod = torch.cat([torch.tensor([1.]), alphas.cumprod(dim=0)])
+ self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod)
+ self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod)
+
+ def update_T(self):
+ t_adjust = round(self.p * self.t_add)
+ t = np.clip(int(self.t_min + t_adjust), a_min=self.t_min, a_max=self.t_max)
+ self.set_diffusion_process(t, "linear")
+
+ # sampling t
+ self.t_epl = np.zeros(64, dtype=np.int)
+ diffusion_ind = min(round(self.p * 64), 48) # 48
+ prob_t = np.arange(t) / np.arange(t).sum()
+ t_diffusion = np.random.choice(np.arange(1, t+1), size=diffusion_ind, p=prob_t)
+ self.t_epl[:diffusion_ind] = t_diffusion
+
+ def forward(self, x_0, noise_std=1.0):
+ assert isinstance(x_0, torch.Tensor) and x_0.ndim == 4
+ batch_size, num_channels, height, width = x_0.shape
+ device = x_0.device
+
+ alphas_bar_sqrt = self.alphas_bar_sqrt.to(device)
+ one_minus_alphas_bar_sqrt = self.one_minus_alphas_bar_sqrt.to(device)
+
+ t = torch.from_numpy(np.random.choice(self.t_epl, size=batch_size * num_channels, replace=True)).to(device)
+
+ x_t = q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t,
+ noise_type=self.noise_type,
+ noise_std=noise_std)
+ return x_t
+
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/diffusion-projected-gan/pg_modules/discriminator.py b/diffusion-projected-gan/pg_modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb57eb1e84505ed2285656ff7fa2d4fff607a69b
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/discriminator.py
@@ -0,0 +1,186 @@
+from functools import partial
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
+from pg_modules.projector import F_RandomProj
+from pg_modules.diffaug import DiffAugment
+
+
+class SingleDisc(nn.Module):
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
+ super().__init__()
+ channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
+ 256: 32, 512: 16, 1024: 8}
+
+ # interpolate for start sz that are not powers of two
+ if start_sz not in channel_dict.keys():
+ sizes = np.array(list(channel_dict.keys()))
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
+ self.start_sz = start_sz
+
+ # if given ndf, allocate all layers with the same ndf
+ if ndf is None:
+ nfc = channel_dict
+ else:
+ nfc = {k: ndf for k, v in channel_dict.items()}
+
+ # for feature map discriminators with nfc not in channel_dict
+ # this is the case for the pretrained backbone (midas.pretrained)
+ if nc is not None and head is None:
+ nfc[start_sz] = nc
+
+ layers = []
+
+ # Head if the initial input is the full modality
+ if head:
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
+ nn.LeakyReLU(0.2, inplace=True)]
+
+ # Down Blocks
+ DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
+ while start_sz > end_sz:
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
+ start_sz = start_sz // 2
+
+ layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
+ self.main = nn.Sequential(*layers)
+
+ def forward(self, x, c):
+ return self.main(x)
+
+
+class SingleDiscCond(nn.Module):
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128):
+ super().__init__()
+ self.cmap_dim = cmap_dim
+
+ # midas channels
+ channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
+ 256: 32, 512: 16, 1024: 8}
+
+ # interpolate for start sz that are not powers of two
+ if start_sz not in channel_dict.keys():
+ sizes = np.array(list(channel_dict.keys()))
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
+ self.start_sz = start_sz
+
+ # if given ndf, allocate all layers with the same ndf
+ if ndf is None:
+ nfc = channel_dict
+ else:
+ nfc = {k: ndf for k, v in channel_dict.items()}
+
+ # for feature map discriminators with nfc not in channel_dict
+ # this is the case for the pretrained backbone (midas.pretrained)
+ if nc is not None and head is None:
+ nfc[start_sz] = nc
+
+ layers = []
+
+ # Head if the initial input is the full modality
+ if head:
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
+ nn.LeakyReLU(0.2, inplace=True)]
+
+ # Down Blocks
+ DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
+ while start_sz > end_sz:
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
+ start_sz = start_sz // 2
+ self.main = nn.Sequential(*layers)
+
+ # additions for conditioning on class information
+ self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
+ self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
+ self.embed_proj = nn.Sequential(
+ nn.Linear(self.embed.embedding_dim, self.cmap_dim),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+
+ def forward(self, x, c):
+ h = self.main(x)
+ out = self.cls(h)
+
+ # conditioning via projection
+ cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1)
+ out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ return out
+
+
+class MultiScaleD(nn.Module):
+ def __init__(
+ self,
+ channels,
+ resolutions,
+ num_discs=1,
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
+ cond=0,
+ separable=False,
+ patch=False,
+ **kwargs,
+ ):
+ super().__init__()
+
+ assert num_discs in [1, 2, 3, 4]
+
+ # the first disc is on the lowest level of the backbone
+ self.disc_in_channels = channels[:num_discs]
+ self.disc_in_res = resolutions[:num_discs]
+ Disc = SingleDiscCond if cond else SingleDisc
+
+ mini_discs = []
+ for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
+ start_sz = res if not patch else 16
+ mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)],
+ self.mini_discs = nn.ModuleDict(mini_discs)
+
+ def forward(self, features, c):
+ all_logits = []
+ for k, disc in self.mini_discs.items():
+ all_logits.append(disc(features[k], c).view(features[k].size(0), -1))
+
+ all_logits = torch.cat(all_logits, dim=1)
+ return all_logits
+
+
+class ProjectedDiscriminator(torch.nn.Module):
+ def __init__(
+ self,
+ diffaug=True,
+ interp224=True,
+ backbone_kwargs={},
+ **kwargs
+ ):
+ super().__init__()
+ self.diffaug = diffaug
+ self.interp224 = interp224
+ self.feature_network = F_RandomProj(**backbone_kwargs)
+ self.discriminator = MultiScaleD(
+ channels=self.feature_network.CHANNELS,
+ resolutions=self.feature_network.RESOLUTIONS,
+ **backbone_kwargs,
+ )
+
+ def train(self, mode=True):
+ self.feature_network = self.feature_network.train(False)
+ self.discriminator = self.discriminator.train(mode)
+ return self
+
+ def eval(self):
+ return self.train(False)
+
+ def forward(self, x, c):
+ if self.diffaug:
+ x = DiffAugment(x, policy='color,translation,cutout')
+
+ if self.interp224:
+ x = F.interpolate(x, 224, mode='bilinear', align_corners=False)
+
+ features = self.feature_network(x)
+ logits = self.discriminator(features, c)
+
+ return logits
\ No newline at end of file
diff --git a/diffusion-projected-gan/pg_modules/networks_fastgan.py b/diffusion-projected-gan/pg_modules/networks_fastgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..5768edc08881e794989eb5ce0d63cc15dfb8ebfa
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/networks_fastgan.py
@@ -0,0 +1,190 @@
+# original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py
+#
+# modified by Axel Sauer for "Projected GANs Converge Faster"
+#
+import torch.nn as nn
+from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d)
+
+
+def normalize_second_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+
+class DummyMapping(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z, c, **kwargs):
+ return z.unsqueeze(1) # to fit the StyleGAN API
+
+
+class FastganSynthesis(nn.Module):
+ def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False):
+ super().__init__()
+ self.img_resolution = img_resolution
+ self.z_dim = z_dim
+
+ # channel multiplier
+ nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
+ 512:0.25, 1024:0.125}
+ nfc = {}
+ for k, v in nfc_multi.items():
+ nfc[k] = int(v*ngf)
+
+ # layers
+ self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
+
+ UpBlock = UpBlockSmall if lite else UpBlockBig
+
+ self.feat_8 = UpBlock(nfc[4], nfc[8])
+ self.feat_16 = UpBlock(nfc[8], nfc[16])
+ self.feat_32 = UpBlock(nfc[16], nfc[32])
+ self.feat_64 = UpBlock(nfc[32], nfc[64])
+ self.feat_128 = UpBlock(nfc[64], nfc[128])
+ self.feat_256 = UpBlock(nfc[128], nfc[256])
+
+ self.se_64 = SEBlock(nfc[4], nfc[64])
+ self.se_128 = SEBlock(nfc[8], nfc[128])
+ self.se_256 = SEBlock(nfc[16], nfc[256])
+
+ self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
+
+ if img_resolution > 256:
+ self.feat_512 = UpBlock(nfc[256], nfc[512])
+ self.se_512 = SEBlock(nfc[32], nfc[512])
+ if img_resolution > 512:
+ self.feat_1024 = UpBlock(nfc[512], nfc[1024])
+
+ def forward(self, input, c, **kwargs):
+ # map noise to hypersphere as in "Progressive Growing of GANS"
+ input = normalize_second_moment(input[:, 0])
+
+ feat_4 = self.init(input)
+ feat_8 = self.feat_8(feat_4)
+ feat_16 = self.feat_16(feat_8)
+ feat_32 = self.feat_32(feat_16)
+
+ if self.img_resolution == 32:
+ return self.to_big(feat_32)
+
+ feat_64 = self.se_64(feat_4, self.feat_64(feat_32))
+ if self.img_resolution == 64:
+ return self.to_big(feat_64)
+
+ feat_128 = self.se_128(feat_8, self.feat_128(feat_64))
+ if self.img_resolution >= 128:
+ feat_last = feat_128
+
+ if self.img_resolution >= 256:
+ feat_last = self.se_256(feat_16, self.feat_256(feat_last))
+
+ if self.img_resolution >= 512:
+ feat_last = self.se_512(feat_32, self.feat_512(feat_last))
+
+ if self.img_resolution >= 1024:
+ feat_last = self.feat_1024(feat_last)
+
+ return self.to_big(feat_last)
+
+
+class FastganSynthesisCond(nn.Module):
+ def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False):
+ super().__init__()
+
+ self.z_dim = z_dim
+ nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
+ 512:0.25, 1024:0.125, 2048:0.125}
+ nfc = {}
+ for k, v in nfc_multi.items():
+ nfc[k] = int(v*ngf)
+
+ self.img_resolution = img_resolution
+
+ self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
+
+ UpBlock = UpBlockSmallCond if lite else UpBlockBigCond
+
+ self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim)
+ self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim)
+ self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim)
+ self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim)
+ self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim)
+ self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim)
+
+ self.se_64 = SEBlock(nfc[4], nfc[64])
+ self.se_128 = SEBlock(nfc[8], nfc[128])
+ self.se_256 = SEBlock(nfc[16], nfc[256])
+
+ self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
+
+ if img_resolution > 256:
+ self.feat_512 = UpBlock(nfc[256], nfc[512])
+ self.se_512 = SEBlock(nfc[32], nfc[512])
+ if img_resolution > 512:
+ self.feat_1024 = UpBlock(nfc[512], nfc[1024])
+
+ self.embed = nn.Embedding(num_classes, z_dim)
+
+ def forward(self, input, c, update_emas=False):
+ c = self.embed(c.argmax(1))
+
+ # map noise to hypersphere as in "Progressive Growing of GANS"
+ input = normalize_second_moment(input[:, 0])
+
+ feat_4 = self.init(input)
+ feat_8 = self.feat_8(feat_4, c)
+ feat_16 = self.feat_16(feat_8, c)
+ feat_32 = self.feat_32(feat_16, c)
+
+ if self.img_resolution == 32:
+ return self.to_big(feat_32)
+
+ feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c))
+ if self.img_resolution == 64:
+ return self.to_big(feat_64)
+
+ feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c))
+ if self.img_resolution >= 128:
+ feat_last = feat_128
+
+ if self.img_resolution >= 256:
+ feat_last = self.se_256(feat_16, self.feat_256(feat_last, c))
+
+ if self.img_resolution >= 512:
+ feat_last = self.se_512(feat_32, self.feat_512(feat_last, c))
+
+ if self.img_resolution >= 1024:
+ feat_last = self.feat_1024(feat_last, c)
+
+ return self.to_big(feat_last)
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ z_dim=256,
+ c_dim=0,
+ w_dim=0,
+ img_resolution=256,
+ img_channels=3,
+ ngf=128,
+ cond=0,
+ mapping_kwargs={},
+ synthesis_kwargs={}
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+
+ # Mapping and Synthesis Networks
+ self.mapping = DummyMapping() # to fit the StyleGAN API
+ Synthesis = FastganSynthesisCond if cond else FastganSynthesis
+ self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs)
+
+ def forward(self, z, c, **kwargs):
+ w = self.mapping(z, c)
+ img = self.synthesis(w, c)
+ return img
diff --git a/diffusion-projected-gan/pg_modules/networks_stylegan2.py b/diffusion-projected-gan/pg_modules/networks_stylegan2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c554a2fda2ab39e881a8248555e6c93ba939ffa5
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/networks_stylegan2.py
@@ -0,0 +1,537 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+#
+# modified by Axel Sauer for "Projected GANs Converge Faster"
+#
+import numpy as np
+import torch
+from torch_utils import misc
+from torch_utils import persistence
+from torch_utils.ops import conv2d_resample
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import bias_act
+from torch_utils.ops import fma
+
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
+ noise = None, # Optional noise tensor to add to the output activations.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ padding = 0, # Padding with respect to the upsampled image.
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
+ demodulate = True, # Apply weight demodulation?
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
+):
+ batch_size = x.shape[0]
+ out_channels, in_channels, kh, kw = weight.shape
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs to avoid FP16 overflow.
+ if x.dtype == torch.float16 and demodulate:
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
+
+ # Calculate per-sample weights and demodulation coefficients.
+ w = None
+ dcoefs = None
+ if demodulate or fused_modconv:
+ w = weight.unsqueeze(0) # [NOIkk]
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
+ if demodulate and fused_modconv:
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
+
+ # Execute by scaling the activations before and after the convolution.
+ if not fused_modconv:
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
+ if demodulate and noise is not None:
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
+ elif demodulate:
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ elif noise is not None:
+ x = x.add_(noise.to(x.dtype))
+ return x
+
+ # Execute as one fused op using grouped convolution.
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(batch_size)
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ if noise is not None:
+ x = x.add_(noise)
+ return x
+
+
+@persistence.persistent_class
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+
+@persistence.persistent_class
+class Conv2dLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
+ channels_last = False, # Expect the input to have memory_format=channels_last?
+ trainable = True, # Update the weights of this layer during training?
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ def forward(self, x, gain=1):
+ w = self.weight * self.weight_gain
+ b = self.bias.to(x.dtype) if self.bias is not None else None
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
+ f'up={self.up}, down={self.down}'])
+
+
+@persistence.persistent_class
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers = 8, # Number of mapping layers.
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if update_emas and self.w_avg_beta is not None:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+
+@persistence.persistent_class
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this layer.
+ kernel_size = 3, # Convolution kernel size.
+ up = 1, # Integer upsampling factor.
+ use_noise = True, # Enable noise input?
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ channels_last = False, # Use channels_last format for the weights?
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.up = up
+ self.use_noise = use_noise
+ self.activation = activation
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ if use_noise:
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
+ assert noise_mode in ['random', 'const', 'none']
+ in_resolution = self.resolution // self.up
+ misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
+ styles = self.affine(w)
+
+ noise = None
+ if self.use_noise and noise_mode == 'random':
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
+ if self.use_noise and noise_mode == 'const':
+ noise = self.noise_const * self.noise_strength
+
+ flip_weight = (self.up == 1) # slightly faster
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
+ f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
+
+
+@persistence.persistent_class
+class ToRGBLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.conv_clamp = conv_clamp
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+
+ def forward(self, x, w, fused_modconv=True):
+ styles = self.affine(w) * self.weight_gain
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
+ return x
+
+ def extra_repr(self):
+ return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
+
+
+@persistence.persistent_class
+class SynthesisBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+
+ # ToRGB.
+ if img is not None:
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+@persistence.persistent_class
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ **block_kwargs, # Arguments for SynthesisBlock.
+ ):
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
+ super().__init__()
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.num_fp16_res = num_fp16_res
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ self.num_ws = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res // 2] if res > 4 else 0
+ out_channels = channels_dict[res]
+ use_fp16 = (res >= fp16_resolution)
+ is_last = (res == self.img_resolution)
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
+ self.num_ws += block.num_conv
+ if is_last:
+ self.num_ws += block.num_torgb
+ setattr(self, f'b{res}', block)
+
+ def forward(self, ws, c=None, **block_kwargs):
+ block_ws = []
+ with torch.autograd.profiler.record_function('split_ws'):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32)
+ w_idx = 0
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
+ w_idx += block.num_conv
+
+ x = img = None
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, cur_ws, **block_kwargs)
+ return img
+
+ def extra_repr(self):
+ return ' '.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_fp16_res={self.num_fp16_res:d}'])
+
+
+@persistence.persistent_class
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
diff --git a/diffusion-projected-gan/pg_modules/projector.py b/diffusion-projected-gan/pg_modules/projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..2280b01bda9a88a58caa4d564b8d5f5f0e22f527
--- /dev/null
+++ b/diffusion-projected-gan/pg_modules/projector.py
@@ -0,0 +1,190 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import timm
+from pg_modules.blocks import FeatureFusionBlock
+from pg_modules.diffusion import Diffusion
+
+
+def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
+ # shapes
+ out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
+
+ scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
+ scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
+ scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
+ scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
+
+ scratch.CHANNELS = out_channels
+
+ return scratch
+
+
+def _make_scratch_csm(scratch, in_channels, cout, expand):
+ scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
+ scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
+ scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
+ scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
+
+ # last refinenet does not expand to save channels in higher dimensions
+ scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
+
+ return scratch
+
+
+def _make_efficientnet(model):
+ pretrained = nn.Module()
+ pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
+ pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
+ pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
+ pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
+ return pretrained
+
+
+def calc_channels(pretrained, inp_res=224):
+ channels = []
+ tmp = torch.zeros(1, 3, inp_res, inp_res)
+
+ # forward pass
+ tmp = pretrained.layer0(tmp)
+ channels.append(tmp.shape[1])
+ tmp = pretrained.layer1(tmp)
+ channels.append(tmp.shape[1])
+ tmp = pretrained.layer2(tmp)
+ channels.append(tmp.shape[1])
+ tmp = pretrained.layer3(tmp)
+ channels.append(tmp.shape[1])
+
+ return channels
+
+
+def _make_projector(im_res, cout, proj_type, expand=False):
+ assert proj_type in [0, 1, 2], "Invalid projection type"
+
+ ### Build pretrained feature network
+ model = timm.create_model('tf_efficientnet_lite0', pretrained=True)
+ pretrained = _make_efficientnet(model)
+
+ # determine resolution of feature maps, this is later used to calculate the number
+ # of down blocks in the discriminators. Interestingly, the best results are achieved
+ # by fixing this to 256, ie., we use the same number of down blocks per discriminator
+ # independent of the dataset resolution
+ im_res = 256
+ pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
+ pretrained.CHANNELS = calc_channels(pretrained)
+
+ if proj_type == 0: return pretrained, None
+
+ ### Build CCM
+ scratch = nn.Module()
+ scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
+ pretrained.CHANNELS = scratch.CHANNELS
+
+ if proj_type == 1: return pretrained, scratch
+
+ ### build CSM
+ scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
+
+ # CSM upsamples x2 so the feature map resolution doubles
+ pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
+ pretrained.CHANNELS = scratch.CHANNELS
+
+ return pretrained, scratch
+
+
+def rescale(out):
+ out_min, out_max = out.min(), out.max()
+ return (out - out_min) / (out_max - out_min) * 2 - 1
+
+
+class F_RandomProj(nn.Module):
+ def __init__(
+ self,
+ im_res=256,
+ cout=64,
+ expand=True,
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
+ d_pos='first',
+ noise_sd=0.5,
+ **kwargs,
+ ):
+ super().__init__()
+ self.proj_type = proj_type
+ self.cout = cout
+ self.expand = expand
+
+ self.d_pos = d_pos
+ self.noise_sd = noise_sd
+ # self.diffusion = AugmentPipe(t_max=1000)
+ self.diffusion = Diffusion(t_min=5, t_max=500, beta_start=1e-4, beta_end=1e-2)
+ # build pretrained feature network and random decoder (scratch)
+ self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
+ self.CHANNELS = self.pretrained.CHANNELS
+ self.RESOLUTIONS = self.pretrained.RESOLUTIONS
+
+ def forward(self, x):
+ # x = self.diffusion(x, noise_std=0.05)
+ # predict feature maps
+ out0 = self.pretrained.layer0(x)
+ out1 = self.pretrained.layer1(out0)
+ out2 = self.pretrained.layer2(out1)
+ out3 = self.pretrained.layer3(out2)
+
+ # start enumerating at the lowest layer (this is where we put the first discriminator)
+ out = {
+ '0': out0,
+ '1': out1,
+ '2': out2,
+ '3': out3,
+ }
+
+ if self.d_pos == 'first':
+ out['0'] = self.diffusion(out['0'], noise_std=self.noise_sd)
+ out['1'] = self.diffusion(out['1'], noise_std=self.noise_sd)
+ out['2'] = self.diffusion(out['2'], noise_std=self.noise_sd)
+ out['3'] = self.diffusion(out['3'], noise_std=self.noise_sd)
+
+ if self.proj_type == 0: return out
+
+ out0_channel_mixed = self.scratch.layer0_ccm(out['0'])
+ out1_channel_mixed = self.scratch.layer1_ccm(out['1'])
+ out2_channel_mixed = self.scratch.layer2_ccm(out['2'])
+ out3_channel_mixed = self.scratch.layer3_ccm(out['3'])
+
+ out = {
+ '0': out0_channel_mixed,
+ '1': out1_channel_mixed,
+ '2': out2_channel_mixed,
+ '3': out3_channel_mixed,
+ }
+
+ if self.proj_type == 1: return out
+
+ # from bottom to top
+ out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
+ out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
+ out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
+ out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
+
+ out = {
+ '0': out0_scale_mixed,
+ '1': out1_scale_mixed,
+ '2': out2_scale_mixed,
+ '3': out3_scale_mixed,
+ }
+
+ if self.d_pos == 'last':
+ out['0'] = self.diffusion(out['0'], noise_std=self.noise_sd)
+ out['1'] = self.diffusion(out['1'], noise_std=self.noise_sd)
+ out['2'] = self.diffusion(out['2'], noise_std=self.noise_sd)
+ out['3'] = self.diffusion(out['3'], noise_std=self.noise_sd)
+ # CDA
+ # n_sd1, n_sd2 = 0.5, 0.25
+ # n_sd1, n_sd2 = 0.25, 0.1
+ # out['0'], t0 = self.diffusion(out['0'], noise_std=n_sd1)
+ # out['1'], t1 = self.diffusion(out['1'], noise_std=n_sd1)
+ # out['2'], t2 = self.diffusion(out['2'], noise_std=n_sd2)
+ # out['3'], t3 = self.diffusion(out['3'], noise_std=n_sd2)
+ # diffusion_t = {'0': t0, '1': t1, '2': t2, '3': t3}
+
+ return out
diff --git a/diffusion-projected-gan/torch_utils/__init__.py b/diffusion-projected-gan/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e7c6c8f94c4ea1141885c3c3295fe083b06aa
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/diffusion-projected-gan/torch_utils/custom_ops.py b/diffusion-projected-gan/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd7cc046e925f58602154be9bdf678ca9d76f59f
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/custom_ops.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/misc.py b/diffusion-projected-gan/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3173c4824788797af5b3d7a5b8d62b229c957156
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/misc.py
@@ -0,0 +1,272 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
+
+# Added by Katja
+import os
+
+def get_ckpt_path(run_dir):
+ return os.path.join(run_dir, f'network-snapshot.pkl')
diff --git a/diffusion-projected-gan/torch_utils/ops/__init__.py b/diffusion-projected-gan/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e7c6c8f94c4ea1141885c3c3295fe083b06aa
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/diffusion-projected-gan/torch_utils/ops/bias_act.cpp b/diffusion-projected-gan/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3adaeee2ae44e96655d354c2bdfb81de8ebfe6c6
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,99 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/bias_act.cu b/diffusion-projected-gan/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ed1d16f14eadd1344939e074ace1375cfd936cea
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/bias_act.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/bias_act.h b/diffusion-projected-gan/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..60b81c6058d54638a6d74a13046fa388442d767d
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/bias_act.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/bias_act.py b/diffusion-projected-gan/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c485c0027570decab26f0b6602a363a432b851f
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/bias_act.py
@@ -0,0 +1,209 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/conv2d_gradfix.py b/diffusion-projected-gan/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..388778fa971d7bc5c64b5fd6c0e5492863ee1c5f
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,198 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import contextlib
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients(disable=True):
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ if disable:
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ return True
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+_null_tensor = torch.empty([0])
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(
+ input if weight.requires_grad else _null_tensor,
+ weight if input.requires_grad else _null_tensor,
+ )
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ if transpose:
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ input_shape = ctx.input_shape
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input_shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input):
+ ctx.save_for_backward(
+ grad_output if input.requires_grad else _null_tensor,
+ input if grad_output.requires_grad else _null_tensor,
+ )
+ ctx.grad_output_shape = grad_output.shape
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad_output_shape = ctx.grad_output_shape
+ input_shape = ctx.input_shape
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output_shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input_shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/conv2d_resample.py b/diffusion-projected-gan/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb5877d7ffe4af74a2165f1d8d8c39dfac2476b
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight and (kw > 1 or kh > 1):
+ w = w.flip([2, 3])
+
+ # Execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/filtered_lrelu.cpp b/diffusion-projected-gan/torch_utils/ops/filtered_lrelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ff4149b8b46b54d2f400ae10e44d19f20503ba1f
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/filtered_lrelu.cpp
@@ -0,0 +1,300 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "filtered_lrelu.h"
+
+//------------------------------------------------------------------------
+
+static std::tuple filtered_lrelu(
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+ // Figure out how much shared memory is available on the device.
+ int maxSharedBytes = 0;
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
+ int sharedKB = maxSharedBytes >> 10;
+
+ // Populate enough launch parameters to check if a CUDA kernel exists.
+ filtered_lrelu_kernel_params p;
+ p.up = up;
+ p.down = down;
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ if (!test_spec.exec)
+ {
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+ }
+
+ // Input/output element size.
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+ // Input sizes.
+ int64_t xw = (int)x.size(3);
+ int64_t xh = (int)x.size(2);
+ int64_t fut_w = (int)fu.size(-1) - 1;
+ int64_t fut_h = (int)fu.size(0) - 1;
+ int64_t fdt_w = (int)fd.size(-1) - 1;
+ int64_t fdt_h = (int)fd.size(0) - 1;
+
+ // Logical size of upsampled buffer.
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+ // Compute output size and allocate.
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
+
+ // Allocate sign tensor.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ int64_t sw_active = 0; // Active width of sign tensor.
+ if (writeSigns)
+ {
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+ else if (readSigns)
+ sw_active = s.size(3) << 2;
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
+ }
+
+ // Populate rest of CUDA kernel parameters.
+ p.x = x.data_ptr();
+ p.y = y.data_ptr();
+ p.b = b.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.fu = fu.data_ptr();
+ p.fd = fd.data_ptr();
+ p.pad0 = make_int2(px0, py0);
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.flip = (flip_filters) ? 1 : 0;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
+
+ // x, y, b strides are in bytes.
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
+ p.bStride = sz * b.stride(0);
+
+ // fu, fd strides are in elements.
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
+ bool index64b = false;
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
+ if (s.numel() > INT_MAX) index64b = true;
+
+ // Choose CUDA kernel.
+ filtered_lrelu_kernel_spec spec = { 0 };
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
+ {
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
+ {
+ // Choose kernel based on index type, datatype and sign read/write modes.
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ }
+ });
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = spec.numWarps * 32;
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+ int gz = p.yShape.z * p.yShape.w;
+
+ // Repeat multiple horizontal tiles in a CTA?
+ if (spec.xrep)
+ {
+ p.tilesXrep = spec.xrep;
+ p.tilesXdim = gx;
+
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+ std::swap(gx, gy);
+ }
+ else
+ {
+ p.tilesXrep = 0;
+ p.tilesXdim = 0;
+ }
+
+ // Launch filter setup kernel.
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
+
+ // Copy kernels to constant memory.
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+
+ // Set cache and shared memory configurations for main kernel.
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
+
+ // Launch main kernel.
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
+ {
+ p.blockZofs = zofs;
+ int subGz = std::min(maxSubGz, gz - zofs);
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
+ }
+
+ // Done.
+ return std::make_tuple(y, so, 0);
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
+
+ // Output signs if we don't have sign input.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ if (writeSigns)
+ {
+ int64_t sw = x.size(3);
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
+ }
+
+ // Initialize CUDA kernel parameters.
+ filtered_lrelu_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+
+ // Choose CUDA kernel.
+ void* func = 0;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
+ {
+ if (writeSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else if (readSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else
+ func = choose_filtered_lrelu_act_kernel();
+ });
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = 128; // 4 warps per block.
+
+ // Logical size of launch = writeSigns ? p.s : p.x
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
+ gx = (gx - 1) / bx + 1;
+
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
+ const uint32_t gmax = 65535;
+ gy = std::min(gy, gmax);
+ gz = std::min(gz, gmax);
+
+ // Launch.
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
+ return so;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
+}
+
+//------------------------------------------------------------------------
diff --git a/diffusion-projected-gan/torch_utils/ops/filtered_lrelu.cu b/diffusion-projected-gan/torch_utils/ops/filtered_lrelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8e6f47f873d42f7181a0faf64779377e70be3012
--- /dev/null
+++ b/diffusion-projected-gan/torch_utils/ops/filtered_lrelu.cu
@@ -0,0 +1,1284 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "filtered_lrelu.h"
+#include
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum // Filter modes.
+{
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
+};
+
+template struct InternalType;
+template <> struct InternalType
+{
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
+template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
+{
+ if ((N & (N-1)) && N <= 256)
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
+ else
+ y = i/N;
+
+ x = i - y*N;
+}
+
+// Type cast stride before reading it.
+template __device__ __forceinline__ T get_stride(const int64_t& x)
+{
+ return *reinterpret_cast(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
+__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
+{
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
+ {
+ int x, y;
+ fast_div_mod(x, y, idx);
+
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+ if (p.fuShape.y > 0)
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+ else
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+ if (p.fdShape.y > 0)
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+ else
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+ }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer for main kernel.
+template static cudaError_t copy_filters(cudaStream_t stream)
+{
+ void* src = 0;
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
+ if (err) return err;
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor: inX, inY, tileInX, tileInY
+// - Relative to input tile: relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor: outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
+
+template
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
+{
+ // Check that we don't try to support non-existing filter modes.
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
+
+ // Static definitions.
+ typedef typename InternalType::scalar_t scalar_t;
+ typedef typename InternalType::vec2_t vec2_t;
+ typedef typename InternalType::vec4_t vec4_t;
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
+
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
+
+ // Sizes of logical buffers.
+ const int szIn = tileInH_up * tileInW;
+ const int szUpX = tileInH_up * tileUpW;
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+ const int szDownX = tileUpH * tileOutW;
+
+ // Sizes for shared memory arrays.
+ const int s_buf0_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUFD) ? szIn :
+ -1;
+ const int s_buf1_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
+ (filterMode == MODE_FUSD) ? szUpXY :
+ (filterMode == MODE_SUFD) ? szUpX :
+ (filterMode == MODE_FUFD) ? szUpXY :
+ -1;
+
+ // Ensure U128 alignment.
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+ // Check at compile time that we don't use too much shared memory.
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
+
+ // Declare shared memory arrays.
+ scalar_t* s_buf0;
+ scalar_t* s_buf1;
+ if (sharedKB <= 48)
+ {
+ // Allocate shared memory arrays here.
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
+ s_buf0 = s_buf0_st;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+ else
+ {
+ // Use the dynamically allocated shared memory array.
+ s_buf0 = (scalar_t*)s_buf_raw;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+
+ // Pointers to the buffers.
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
+ if (filterMode == MODE_SUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ s_tileDownX = s_buf1;
+ }
+ else if (filterMode == MODE_FUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ s_tileDownX = s_buf0;
+ }
+ else if (filterMode == MODE_SUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ }
+ else if (filterMode == MODE_FUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ }
+
+ // Allow large grids in z direction via per-launch offset.
+ int channelIdx = blockIdx.z + p.blockZofs;
+ int batchIdx = channelIdx / p.yShape.z;
+ channelIdx -= batchIdx * p.yShape.z;
+
+ // Offset to output feature map. In bytes.
+ index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
+
+ // Sign shift amount.
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+ // Inner tile loop.
+ #pragma unroll 1
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
+ {
+ // Locate output tile.
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+ int tileOutX = tileX * tileOutW;
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+ // Locate input tile.
+ int tmpX = tileOutX * down - p.pad0.x;
+ int tmpY = tileOutY * down - p.pad0.y;
+ int tileInX = CEIL_DIV(tmpX, up);
+ int tileInY = CEIL_DIV(tmpY, up);
+ const int phaseInX = tileInX * up - tmpX;
+ const int phaseInY = tileInY * up - tmpY;
+
+ // Extra sync if input and output buffers are the same and we are not on first tile.
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
+ __syncthreads();
+
+ // Load input tile & apply bias. Unrolled.
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
+ index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
+ int idx = threadIdx.x;
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+ #pragma unroll
+ for (int loop = 0; loop < loopCountIN; loop++)
+ {
+ int relInX, relInY;
+ fast_div_mod(relInX, relInY, idx);
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
+
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
+ if (!skip)
+ s_tileIn[idx] = v;
+
+ idx += threadsPerBlock;
+ }
+
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
+ {
+ // Horizontal upsampling.
+ __syncthreads();
+ if (up == 4)
+ {
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ scalar_t a = s_tileIn[src0];
+ if (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInX == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInX == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ s_tileUpX[dst+2] = v.z;
+ s_tileUpX[dst+3] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ bool p0 = (phaseInX == 0);
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ scalar_t a = s_tileIn[src0];
+ if (p0) // (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ }
+ }
+
+ // Vertical upsampling & nonlinearity.
+
+ __syncthreads();
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
+ if (up == 4)
+ {
+ minY -= 3; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec4_t v = InternalType::zero_vec4();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInY == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInY == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+ index_t si2 = si0 + p.sShape.x * 2;
+ index_t si3 = si0 + p.sShape.x * 3;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ int ss = (signX & 3) << 1;
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ minY -= 1; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec2_t v = InternalType::zero_vec2();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType