File size: 32,282 Bytes
9ce984a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
"""
Title: Data-efficient GANs with Adaptive Discriminator Augmentation
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
Date created: 2021/10/28
Last modified: 2025/01/23
Description: Generating images from limited data using the Caltech Birds dataset.
Accelerator: GPU
"""

"""
## Introduction

### GANs

[Generative Adversarial Networks (GANs)](https://arxiv.org/abs/1406.2661) are a popular
class of generative deep learning models, commonly used for image generation. They
consist of a pair of dueling neural networks, called the discriminator and the generator.
The discriminator's task is to distinguish real images from generated (fake) ones, while
the generator network tries to fool the discriminator by generating more and more
realistic images. If the generator is however too easy or too hard to fool, it might fail
to provide useful learning signal for the generator, therefore training GANs is usually
considered a difficult task.

### Data augmentation for GANS

Data augmentation, a popular technique in deep learning, is the process of randomly
applying semantics-preserving transformations to the input data to generate multiple
realistic versions of it, thereby effectively multiplying the amount of training data
available. The simplest example is left-right flipping an image, which preserves its
contents while generating a second unique training sample. Data augmentation is commonly
used in supervised learning to prevent overfitting and enhance generalization.

The authors of [StyleGAN2-ADA](https://arxiv.org/abs/2006.06676) show that discriminator
overfitting can be an issue in GANs, especially when only low amounts of training data is
available. They propose Adaptive Discriminator Augmentation to mitigate this issue.

Applying data augmentation to GANs however is not straightforward. Since the generator is
updated using the discriminator's gradients, if the generated images are augmented, the
augmentation pipeline has to be differentiable and also has to be GPU-compatible for
computational efficiency. Luckily, the
[Keras image augmentation layers](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
fulfill both these requirements, and are therefore very well suited for this task.

### Invertible data augmentation

A possible difficulty when using data augmentation in generative models is the issue of
["leaky augmentations" (section 2.2)](https://arxiv.org/abs/2006.06676), namely when the
model generates images that are already augmented. This would mean that it was not able
to separate the augmentation from the underlying data distribution, which can be caused
by using non-invertible data transformations. For example, if either 0, 90, 180 or 270
degree rotations are performed with equal probability, the original orientation of the
images is impossible to infer, and this information is destroyed.

A simple trick to make data augmentations invertible is to only apply them with some
probability. That way the original version of the images will be more common, and the
data distribution can be inferred. By properly choosing this probability, one can
effectively regularize the discriminator without making the augmentations leaky.

"""

"""
## Setup
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

import keras
from keras import ops
from keras import layers

"""
## Hyperparameterers
"""

# data
num_epochs = 10  # train for 400 epochs for good results
image_size = 64
# resolution of Kernel Inception Distance measurement, see related section
kid_image_size = 75
padding = 0.25
dataset_name = "caltech_birds2011"

# adaptive discriminator augmentation
max_translation = 0.125
max_rotation = 0.125
max_zoom = 0.25
target_accuracy = 0.85
integration_steps = 1000

# architecture
noise_size = 64
depth = 4
width = 128
leaky_relu_slope = 0.2
dropout_rate = 0.4

# optimization
batch_size = 128
learning_rate = 2e-4
beta_1 = 0.5  # not using the default value of 0.9 is important
ema = 0.99

"""
## Data pipeline

In this example, we will use the
[Caltech Birds (2011)](https://www.tensorflow.org/datasets/catalog/caltech_birds2011) dataset for
generating images of birds, which is a diverse natural dataset containing less then 6000
images for training. When working with such low amounts of data, one has to take extra
care to retain as high data quality as possible. In this example, we use the provided
bounding boxes of the birds to cut them out with square crops while preserving their
aspect ratios when possible.
"""


def round_to_int(float_value):
    return ops.cast(ops.round(float_value), "int32")


def preprocess_image(data):
    # unnormalize bounding box coordinates
    height = ops.cast(ops.shape(data["image"])[0], "float32")
    width = ops.cast(ops.shape(data["image"])[1], "float32")
    bounding_box = data["bbox"] * ops.stack([height, width, height, width])

    # calculate center and length of longer side, add padding
    target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
    target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
    target_size = ops.maximum(
        (1.0 + padding) * (bounding_box[2] - bounding_box[0]),
        (1.0 + padding) * (bounding_box[3] - bounding_box[1]),
    )

    # modify crop size to fit into image
    target_height = ops.min(
        [target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
    )
    target_width = ops.min(
        [target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
    )

    # crop image, `ops.image.crop_images` only works with non-tensor croppings
    image = ops.slice(
        data["image"],
        start_indices=(
            round_to_int(target_center_y - 0.5 * target_height),
            round_to_int(target_center_x - 0.5 * target_width),
            0,
        ),
        shape=(round_to_int(target_height), round_to_int(target_width), 3),
    )

    # resize and clip
    image = ops.cast(image, "float32")
    image = ops.image.resize(image, [image_size, image_size])

    return ops.clip(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
    # the validation dataset is shuffled as well, because data order matters
    # for the KID calculation
    return (
        tfds.load(dataset_name, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )


train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")

"""
After preprocessing the training images look like the following:
![birds dataset](https://i.imgur.com/Ru5HgBM.png)
"""

"""
## Kernel inception distance

[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) was proposed as a
replacement for the popular
[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500)
metric for measuring image generation quality.
Both metrics measure the difference in the generated and training distributions in the
representation space of an [InceptionV3](https://keras.io/api/applications/inceptionv3/)
network pretrained on
[ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012).

According to the paper, KID was proposed because FID has no unbiased estimator, its
expected value is higher when it is measured on fewer images. KID is more suitable for
small datasets because its expected value does not depend on the number of samples it is
measured on. In my experience it is also computationally lighter, numerically more
stable, and simpler to implement because it can be estimated in a per-batch manner.

In this example, the images are evaluated at the minimal possible resolution of the
Inception network (75x75 instead of 299x299), and the metric is only measured on the
validation set for computational efficiency.


"""


class KID(keras.metrics.Metric):
    def __init__(self, name="kid", **kwargs):
        super().__init__(name=name, **kwargs)

        # KID is estimated per batch and is averaged across batches
        self.kid_tracker = keras.metrics.Mean()

        # a pretrained InceptionV3 is used without its classification layer
        # transform the pixel values to the 0-255 range, then use the same
        # preprocessing as during pretraining
        self.encoder = keras.Sequential(
            [
                layers.InputLayer(input_shape=(image_size, image_size, 3)),
                layers.Rescaling(255.0),
                layers.Resizing(height=kid_image_size, width=kid_image_size),
                layers.Lambda(keras.applications.inception_v3.preprocess_input),
                keras.applications.InceptionV3(
                    include_top=False,
                    input_shape=(kid_image_size, kid_image_size, 3),
                    weights="imagenet",
                ),
                layers.GlobalAveragePooling2D(),
            ],
            name="inception_encoder",
        )

    def polynomial_kernel(self, features_1, features_2):
        feature_dimensions = ops.cast(ops.shape(features_1)[1], "float32")
        return (
            features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
        ) ** 3.0

    def update_state(self, real_images, generated_images, sample_weight=None):
        real_features = self.encoder(real_images, training=False)
        generated_features = self.encoder(generated_images, training=False)

        # compute polynomial kernels using the two sets of features
        kernel_real = self.polynomial_kernel(real_features, real_features)
        kernel_generated = self.polynomial_kernel(
            generated_features, generated_features
        )
        kernel_cross = self.polynomial_kernel(real_features, generated_features)

        # estimate the squared maximum mean discrepancy using the average kernel values
        batch_size = ops.shape(real_features)[0]
        batch_size_f = ops.cast(batch_size, "float32")
        mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_generated = ops.sum(
            kernel_generated * (1.0 - ops.eye(batch_size))
        ) / (batch_size_f * (batch_size_f - 1.0))
        mean_kernel_cross = ops.mean(kernel_cross)
        kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

        # update the average KID estimate
        self.kid_tracker.update_state(kid)

    def result(self):
        return self.kid_tracker.result()

    def reset_state(self):
        self.kid_tracker.reset_state()


"""

## Adaptive discriminator augmentation

The authors of [StyleGAN2-ADA](https://arxiv.org/abs/2006.06676) propose to change the
augmentation probability adaptively during training. Though it is explained differently
in the paper, they use [integral control](https://en.wikipedia.org/wiki/PID_controller#Integral) on the augmentation
probability to keep the discriminator's accuracy on real images close to a target value.
Note, that their controlled variable is actually the average sign of the discriminator
logits (r_t in the paper), which corresponds to 2 * accuracy - 1.

This method requires two hyperparameters:

1. `target_accuracy`: the target value for the discriminator's accuracy on real images. I
recommend selecting its value from the 80-90% range.
2. [`integration_steps`](https://en.wikipedia.org/wiki/PID_controller#Mathematical_form):
the number of update steps required for an accuracy error of 100% to transform into an
augmentation probability increase of 100%. To give an intuition, this defines how slowly
the augmentation probability is changed. I recommend setting this to a relatively high
value (1000 in this case) so that the augmentation strength is only adjusted slowly.

The main motivation for this procedure is that the optimal value of the target accuracy
is similar across different dataset sizes (see [figure 4 and 5 in the paper](https://arxiv.org/abs/2006.06676)),
so it does not have to be re-tuned, because the
process automatically applies stronger data augmentation when it is needed.

"""


# "hard sigmoid", useful for binary accuracy calculation from logits
def step(values):
    # negative values -> 0.0, positive values -> 1.0
    return 0.5 * (1.0 + ops.sign(values))


# augments images with a probability that is dynamically updated during training
class AdaptiveAugmenter(keras.Model):
    def __init__(self):
        super().__init__()

        # stores the current probability of an image being augmented
        self.probability = keras.Variable(0.0)
        self.seed_generator = keras.random.SeedGenerator(42)

        # the corresponding augmentation names from the paper are shown above each layer
        # the authors show (see figure 4), that the blitting and geometric augmentations
        # are the most helpful in the low-data regime
        self.augmenter = keras.Sequential(
            [
                layers.InputLayer(input_shape=(image_size, image_size, 3)),
                # blitting/x-flip:
                layers.RandomFlip("horizontal"),
                # blitting/integer translation:
                layers.RandomTranslation(
                    height_factor=max_translation,
                    width_factor=max_translation,
                    interpolation="nearest",
                ),
                # geometric/rotation:
                layers.RandomRotation(factor=max_rotation),
                # geometric/isotropic and anisotropic scaling:
                layers.RandomZoom(
                    height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
                ),
            ],
            name="adaptive_augmenter",
        )

    def call(self, images, training):
        if training:
            augmented_images = self.augmenter(images, training=training)

            # during training either the original or the augmented images are selected
            # based on self.probability
            augmentation_values = keras.random.uniform(
                shape=(batch_size, 1, 1, 1), seed=self.seed_generator
            )
            augmentation_bools = ops.less(augmentation_values, self.probability)

            images = ops.where(augmentation_bools, augmented_images, images)
        return images

    def update(self, real_logits):
        current_accuracy = ops.mean(step(real_logits))

        # the augmentation probability is updated based on the discriminator's
        # accuracy on real images
        accuracy_error = current_accuracy - target_accuracy
        self.probability.assign(
            ops.clip(self.probability + accuracy_error / integration_steps, 0.0, 1.0)
        )


"""
## Network architecture

Here we specify the architecture of the two networks:

* generator: maps a random vector to an image, which should be as realistic as possible
* discriminator: maps an image to a scalar score, which should be high for real and low
for generated images

GANs tend to be sensitive to the network architecture, I implemented a DCGAN architecture
in this example, because it is relatively stable during training while being simple to
implement. We use a constant number of filters throughout the network, use a sigmoid
instead of tanh in the last layer of the generator, and use default initialization
instead of random normal as further simplifications.

As a good practice, we disable the learnable scale parameter in the batch normalization
layers, because on one hand the following relu + convolutional layers make it redundant
(as noted in the
[documentation](https://keras.io/api/layers/normalization_layers/batch_normalization/)).
But also because it should be disabled based on theory when using [spectral normalization
(section 4.1)](https://arxiv.org/abs/1802.05957), which is not used here, but is common
in GANs. We also disable the bias in the fully connected and convolutional layers, because
the following batch normalization makes it redundant.
"""


# DCGAN generator
def get_generator():
    noise_input = keras.Input(shape=(noise_size,))
    x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
    x = layers.BatchNormalization(scale=False)(x)
    x = layers.ReLU()(x)
    x = layers.Reshape(target_shape=(4, 4, width))(x)
    for _ in range(depth - 1):
        x = layers.Conv2DTranspose(
            width,
            kernel_size=4,
            strides=2,
            padding="same",
            use_bias=False,
        )(x)
        x = layers.BatchNormalization(scale=False)(x)
        x = layers.ReLU()(x)
    image_output = layers.Conv2DTranspose(
        3,
        kernel_size=4,
        strides=2,
        padding="same",
        activation="sigmoid",
    )(x)

    return keras.Model(noise_input, image_output, name="generator")


# DCGAN discriminator
def get_discriminator():
    image_input = keras.Input(shape=(image_size, image_size, 3))
    x = image_input
    for _ in range(depth):
        x = layers.Conv2D(
            width,
            kernel_size=4,
            strides=2,
            padding="same",
            use_bias=False,
        )(x)
        x = layers.BatchNormalization(scale=False)(x)
        x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(dropout_rate)(x)
    output_score = layers.Dense(1)(x)

    return keras.Model(image_input, output_score, name="discriminator")


"""
## GAN model
"""


class GAN_ADA(keras.Model):
    def __init__(self):
        super().__init__()

        self.seed_generator = keras.random.SeedGenerator(seed=42)
        self.augmenter = AdaptiveAugmenter()
        self.generator = get_generator()
        self.ema_generator = keras.models.clone_model(self.generator)
        self.discriminator = get_discriminator()

        self.generator.summary()
        self.discriminator.summary()
        # we have created all layers at this point, so we can mark the model
        # as having been built
        self.built = True

    def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
        super().compile(**kwargs)

        # separate optimizers for the two networks
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer

        self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
        self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
        self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
        self.kid = KID()

    @property
    def metrics(self):
        return [
            self.generator_loss_tracker,
            self.discriminator_loss_tracker,
            self.real_accuracy,
            self.generated_accuracy,
            self.augmentation_probability_tracker,
            self.kid,
        ]

    def generate(self, batch_size, training):
        latent_samples = keras.random.normal(
            shape=(batch_size, noise_size), seed=self.seed_generator
        )
        # use ema_generator during inference
        if training:
            generated_images = self.generator(latent_samples, training=training)
        else:
            generated_images = self.ema_generator(latent_samples, training=training)
        return generated_images

    def adversarial_loss(self, real_logits, generated_logits):
        # this is usually called the non-saturating GAN loss

        real_labels = ops.ones(shape=(batch_size, 1))
        generated_labels = ops.zeros(shape=(batch_size, 1))

        # the generator tries to produce images that the discriminator considers as real
        generator_loss = keras.losses.binary_crossentropy(
            real_labels, generated_logits, from_logits=True
        )
        # the discriminator tries to determine if images are real or generated
        discriminator_loss = keras.losses.binary_crossentropy(
            ops.concatenate([real_labels, generated_labels], axis=0),
            ops.concatenate([real_logits, generated_logits], axis=0),
            from_logits=True,
        )

        return ops.mean(generator_loss), ops.mean(discriminator_loss)

    def train_step(self, real_images):
        real_images = self.augmenter(real_images, training=True)

        # use persistent gradient tape because gradients will be calculated twice
        with tf.GradientTape(persistent=True) as tape:
            generated_images = self.generate(batch_size, training=True)
            # gradient is calculated through the image augmentation
            generated_images = self.augmenter(generated_images, training=True)

            # separate forward passes for the real and generated images, meaning
            # that batch normalization is applied separately
            real_logits = self.discriminator(real_images, training=True)
            generated_logits = self.discriminator(generated_images, training=True)

            generator_loss, discriminator_loss = self.adversarial_loss(
                real_logits, generated_logits
            )

        # calculate gradients and update weights
        generator_gradients = tape.gradient(
            generator_loss, self.generator.trainable_weights
        )
        discriminator_gradients = tape.gradient(
            discriminator_loss, self.discriminator.trainable_weights
        )
        self.generator_optimizer.apply_gradients(
            zip(generator_gradients, self.generator.trainable_weights)
        )
        self.discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, self.discriminator.trainable_weights)
        )

        # update the augmentation probability based on the discriminator's performance
        self.augmenter.update(real_logits)

        self.generator_loss_tracker.update_state(generator_loss)
        self.discriminator_loss_tracker.update_state(discriminator_loss)
        self.real_accuracy.update_state(1.0, step(real_logits))
        self.generated_accuracy.update_state(0.0, step(generated_logits))
        self.augmentation_probability_tracker.update_state(self.augmenter.probability)

        # track the exponential moving average of the generator's weights to decrease
        # variance in the generation quality
        for weight, ema_weight in zip(
            self.generator.weights, self.ema_generator.weights
        ):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID is not measured during the training phase for computational efficiency
        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, real_images):
        generated_images = self.generate(batch_size, training=False)

        self.kid.update_state(real_images, generated_images)

        # only KID is measured during the evaluation phase for computational efficiency
        return {self.kid.name: self.kid.result()}

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
        # plot random generated images for visual evaluation of generation quality
        if epoch is None or (epoch + 1) % interval == 0:
            num_images = num_rows * num_cols
            generated_images = self.generate(num_images, training=False)

            plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
            for row in range(num_rows):
                for col in range(num_cols):
                    index = row * num_cols + col
                    plt.subplot(num_rows, num_cols, index + 1)
                    plt.imshow(generated_images[index])
                    plt.axis("off")
            plt.tight_layout()
            plt.show()
            plt.close()


"""
## Training

One can should see from the metrics during training, that if the real accuracy
(discriminator's accuracy on real images) is below the target accuracy, the augmentation
probability is increased, and vice versa. In my experience, during a healthy GAN
training, the discriminator accuracy should stay in the 80-95% range. Below that, the
discriminator is too weak, above that it is too strong.

Note that we track the exponential moving average of the generator's weights, and use that
for image generation and KID evaluation.
"""

# create and compile the model
model = GAN_ADA()
model.compile(
    generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
    discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
)

# save the best model based on the validation KID metric
checkpoint_path = "gan_model.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="val_kid",
    mode="min",
    save_best_only=True,
)

# run training and plot generated images periodically
model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)

"""
## Inference
"""

# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()

"""
## Results

By running the training for 400 epochs (which takes 2-3 hours in a Colab notebook), one
can get high quality image generations using this code example.

The evolution of a random batch of images over a 400 epoch training (ema=0.999 for
animation smoothness):
![birds evolution gif](https://i.imgur.com/ecGuCcz.gif)

Latent-space interpolation between a batch of selected images:
![birds interpolation gif](https://i.imgur.com/nGvzlsC.gif)

I also recommend trying out training on other datasets, such as
[CelebA](https://www.tensorflow.org/datasets/catalog/celeb_a) for example. In my
experience good results can be achieved without changing any hyperparameters (though
discriminator augmentation might not be necessary).
"""

"""
## GAN tips and tricks

My goal with this example was to find a good tradeoff between ease of implementation and
generation quality for GANs. During preparation, I have run numerous ablations using
[this repository](https://github.com/beresandras/gan-flavours-keras).

In this section I list the lessons learned and my recommendations in my subjective order
of importance.

I recommend checking out the [DCGAN paper](https://arxiv.org/abs/1511.06434), this
[NeurIPS talk](https://www.youtube.com/watch?v=myGAju4L7O8), and this
[large scale GAN study](https://arxiv.org/abs/1711.10337) for others' takes on this subject.

### Architectural tips

* **resolution**: Training GANs at higher resolutions tends to get more difficult, I
recommend experimenting at 32x32 or 64x64 resolutions initially.
* **initialization**: If you see strong colorful patterns early on in the training, the
initialization might be the issue. Set the kernel_initializer parameters of layers to
[random normal](https://keras.io/api/layers/initializers/#randomnormal-class), and
decrease the standard deviation (recommended value: 0.02, following DCGAN) until the
issue disappears.
* **upsampling**: There are two main methods for upsampling in the generator.
[Transposed convolution](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/)
is faster, but can lead to
[checkerboard artifacts](https://distill.pub/2016/deconv-checkerboard/), which can be reduced by using
a kernel size that is divisible with the stride (recommended kernel size is 4 for a stride of 2).
[Upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/) +
[standard convolution](https://keras.io/api/layers/convolution_layers/convolution2d/) can have slightly
lower quality, but checkerboard artifacts are not an issue. I recommend using nearest-neighbor
interpolation over bilinear for it.
* **batch normalization in discriminator**: Sometimes has a high impact, I recommend
trying out both ways.
* **[spectral normalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/SpectralNormalization)**:
A popular technique for training GANs, can help with stability. I recommend
disabling batch normalization's learnable scale parameters along with it.
* **[residual connections](https://keras.io/guides/functional_api/#a-toy-resnet-model)**:
While residual discriminators behave similarly, residual generators are more difficult to
train in my experience. They are however necessary for training large and deep
architectures. I recommend starting with non-residual architectures.
* **dropout**: Using dropout before the last layer of the discriminator improves
generation quality in my experience. Recommended dropout rate is below 0.5.
* **[leaky ReLU](https://keras.io/api/layers/activation_layers/leaky_relu/)**: Use leaky
ReLU activations in the discriminator to make its gradients less sparse. Recommended
slope/alpha is 0.2 following DCGAN.

### Algorithmic tips

* **loss functions**: Numerous losses have been proposed over the years for training
GANs, promising improved performance and stability. I have implemented 5 of them in
[this repository](https://github.com/beresandras/gan-flavours-keras), and my experience is in
line with [this GAN study](https://arxiv.org/abs/1711.10337): no loss seems to
consistently outperform the default non-saturating GAN loss. I recommend using that as a
default.
* **Adam's beta_1 parameter**: The beta_1 parameter in Adam can be interpreted as the
momentum of mean gradient estimation. Using 0.5 or even 0.0 instead of the default 0.9
value was proposed in DCGAN and is important. This example would not work using its
default value.
* **separate batch normalization for generated and real images**: The forward pass of the
discriminator should be separate for the generated and real images. Doing otherwise can
lead to artifacts (45 degree stripes in my case) and decreased performance.
* **exponential moving average of generator's weights**: This helps to reduce the
variance of the KID measurement, and helps in averaging out the rapid color palette
changes during training.
* **[different learning rate for generator and discriminator](https://arxiv.org/abs/1706.08500)**:
If one has the resources, it can help
to tune the learning rates of the two networks separately. A similar idea is to update
either network's (usually the discriminator's) weights multiple times for each of the
other network's updates. I recommend using the same learning rate of 2e-4 (Adam),
following DCGAN for both networks, and only updating both of them once as a default.
* **label noise**: [One-sided label smoothing](https://arxiv.org/abs/1606.03498) (using
less than 1.0 for real labels), or adding noise to the labels can regularize the
discriminator not to get overconfident, however in my case they did not improve
performance.
* **adaptive data augmentation**: Since it adds another dynamic component to the training
process, disable it as a default, and only enable it when the other components already
work well.
"""

"""
## Related works

Other GAN-related Keras code examples:

* [DCGAN + CelebA](https://keras.io/examples/generative/dcgan_overriding_train_step/)
* [WGAN + FashionMNIST](https://keras.io/examples/generative/wgan_gp/)
* [WGAN + Molecules](https://keras.io/examples/generative/wgan-graphs/)
* [ConditionalGAN + MNIST](https://keras.io/examples/generative/conditional_gan/)
* [CycleGAN + Horse2Zebra](https://keras.io/examples/generative/cyclegan/)
* [StyleGAN](https://keras.io/examples/generative/stylegan/)

Modern GAN architecture-lines:

* [SAGAN](https://arxiv.org/abs/1805.08318), [BigGAN](https://arxiv.org/abs/1809.11096)
* [ProgressiveGAN](https://arxiv.org/abs/1710.10196),
[StyleGAN](https://arxiv.org/abs/1812.04948),
[StyleGAN2](https://arxiv.org/abs/1912.04958),
[StyleGAN2-ADA](https://arxiv.org/abs/2006.06676),
[AliasFreeGAN](https://arxiv.org/abs/2106.12423)

Concurrent papers on discriminator data augmentation:
[1](https://arxiv.org/abs/2006.02595), [2](https://arxiv.org/abs/2006.05338), [3](https://arxiv.org/abs/2006.10738)

Recent literature overview on GANs: [talk](https://www.youtube.com/watch?v=3ktD752xq5k)
"""