KublaiKhan1 commited on
Commit
93a9c9c
·
verified ·
1 Parent(s): 8f715bd

Upload folder using huggingface_hub

Browse files
f16c16/all_stats.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ #Apparently we've always been running this code on cpu.
12
+
13
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
14
+
15
+ import jax
16
+ import lpips
17
+
18
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
19
+ loss_fn_alex = loss_fn_alex.cuda()
20
+
21
+
22
+ from dadapy.data import Data
23
+
24
+ import numpy as np
25
+ import flax.linen as nn
26
+ import jax.numpy as jnp
27
+ from absl import app, flags
28
+ from functools import partial
29
+ import numpy as np
30
+ import tqdm
31
+ import flax
32
+ import optax
33
+ import wandb
34
+ from ml_collections import config_flags
35
+ #import elements
36
+ import ml_collections
37
+ import tensorflow_datasets as tfds
38
+ import tensorflow as tf
39
+ tf.config.set_visible_devices([], "GPU")
40
+ tf.config.set_visible_devices([], "TPU")
41
+ import matplotlib.pyplot as plt
42
+ from typing import Any
43
+
44
+ from utils.train_state import TrainState, target_update
45
+ from utils.checkpoint import Checkpoint
46
+ from utils.fid import get_fid_network, fid_from_stats
47
+
48
+ from train import VQGANModel
49
+ from models.vqvae import VQVAE
50
+ from models.discriminator import Discriminator
51
+
52
+ from PIL import Image
53
+ import torch
54
+
55
+ delattr(flags.FLAGS, 'dataset_name')
56
+ delattr(flags.FLAGS, 'load_dir')
57
+ delattr(flags.FLAGS, 'batch_size')
58
+
59
+ FLAGS = flags.FLAGS
60
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
61
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
62
+
63
+
64
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
65
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
66
+
67
+ import gc
68
+
69
+ from scipy.spatial.distance import cdist
70
+ #
71
+ def relative(images, latents):
72
+ #Get the distance matrix for images
73
+ #Get the distance matrix for latents
74
+
75
+ images = images.reshape(images.shape[0], -1)
76
+ latents = latents.reshape(latents.shape[0], -1)
77
+
78
+ image_distances = cdist(images, images, metric='euclidean')
79
+ latent_distances = cdist(latents, latents, metric='euclidean')
80
+
81
+ #Probably want cosine for latents.
82
+ #Now, we need to find the C that best matches....
83
+ #So we just do images/latents, then take stats on that.
84
+ c = image_distances/latent_distances
85
+ print("mean C", np.mean(c))
86
+ print("C std", np.std(c))
87
+
88
+
89
+
90
+ def operations(reconstructed_images, decoded):
91
+
92
+ reconstructed_images = reconstructed_images * 2 - 1
93
+ decoded = decoded * 2 -1
94
+
95
+ #Turn from 1,2,256,256,3
96
+ #To 2,3,256,256
97
+
98
+
99
+ reconstructed_images = jax.dlpack.to_dlpack(reconstructed_images)
100
+ reconstructed_images = torch.utils.dlpack.from_dlpack(reconstructed_images)
101
+
102
+ decoded = jax.dlpack.to_dlpack(decoded)
103
+ decoded = torch.utils.dlpack.from_dlpack(decoded)
104
+
105
+ reconstructed_images = reconstructed_images.squeeze()
106
+ decoded = decoded.squeeze()
107
+
108
+ reconstructed_images = reconstructed_images.permute(0, 3, 1, 2)
109
+ decoded = decoded.permute(0, 3, 1, 2)
110
+
111
+
112
+ lpips_loss = loss_fn_alex(reconstructed_images, decoded)
113
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
114
+ lpips_cpu = lpips_cpu / (.0001 ** 2)
115
+
116
+ return lpips_cpu
117
+
118
+ def main(_):
119
+ device_count = len(jax.local_devices())
120
+ global_device_count = jax.device_count()
121
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
122
+
123
+ def get_dataset(is_train):
124
+ if 'imagenet' in FLAGS.dataset_name:
125
+ def deserialization_fn(data):
126
+ image = data['image']
127
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
128
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
129
+ if 'imagenet256' in FLAGS.dataset_name:
130
+ image = tf.image.resize(image, (256, 256))
131
+ elif 'imagenet128' in FLAGS.dataset_name:
132
+ image = tf.image.resize(image, (128, 128))
133
+ else:
134
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
135
+ if is_train:
136
+ image = tf.image.random_flip_left_right(image)
137
+ image = tf.cast(image, tf.float32) / 255.0
138
+ return image
139
+
140
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
141
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
142
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
143
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
144
+ dataset = dataset.batch(local_batch_size)
145
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
146
+ dataset = tfds.as_numpy(dataset)
147
+ dataset = iter(dataset)
148
+ return dataset
149
+ else:
150
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
151
+
152
+ dataset = get_dataset(is_train=True)
153
+ dataset_valid = get_dataset(is_train=False)
154
+
155
+ example_obs = next(dataset)[:1]
156
+
157
+
158
+ rng = jax.random.PRNGKey(FLAGS.seed)
159
+ rng, param_key = jax.random.split(rng)
160
+ print("Total devices", jax.local_devices()[0])
161
+
162
+
163
+ ###################################
164
+ # Creating Model and put on devices.
165
+ ###################################
166
+ FLAGS.model.image_channels = example_obs.shape[-1]
167
+ FLAGS.model.image_size = example_obs.shape[1]
168
+ vqvae_def = VQVAE(FLAGS.model, train=True)
169
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
170
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
171
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
172
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
173
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
174
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
175
+
176
+ discriminator_def = Discriminator(FLAGS.model)
177
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
178
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
179
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
180
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
181
+
182
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
183
+
184
+ assert FLAGS.load_dir is not None
185
+ cp = Checkpoint(FLAGS.load_dir)
186
+ model = cp.load_model(model)
187
+ print("Loaded model with step", model.vqvae.step)
188
+
189
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
190
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
191
+ #print(model.vqvae)
192
+
193
+
194
+
195
+ ####################################
196
+ # Noise stuff
197
+ ###################################
198
+
199
+ cpus = jax.devices("cpu")
200
+
201
+ i = 0
202
+ lpips_list = []
203
+ lpips_list_ppl_two = []
204
+ means = []
205
+ stds = []
206
+
207
+ noisy_means = []
208
+ noisy_stds = []
209
+
210
+ predicted_stds = []
211
+
212
+ noisy_predicted_stds = []
213
+
214
+ latent_list = []
215
+ #TODO
216
+ #equivariance loss, DCT shit, psnr, ssim
217
+ #Instead of isometry, we want... RELATIVEMTRY
218
+ #Gini coefficient
219
+ #denstity cv
220
+ #normalized entropy
221
+ #"uniformity" - basically related to the covariance loss? How spread out the pionts are
222
+
223
+ #relativemtry basically says:
224
+ #Given the function F, that turn x into x'
225
+ #For all possible x, y within X, |x - y| = C [x' - y'|
226
+ #Is this a desirable property though?
227
+ #Who cares, let's calculate it anyway
228
+
229
+ #
230
+ #Need to try out our own f16c16, which is the same compression as f8c4
231
+ #We will try
232
+ #1,1,2,2,4
233
+ #1,2,2,4,4
234
+ #1,2,4,8,8
235
+ #1,2,4,4,4
236
+
237
+
238
+
239
+ for valid_images in dataset_valid:
240
+
241
+
242
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
243
+ #1, 2, 256, 256, 3
244
+
245
+
246
+ #Regular PPL
247
+ reconstructed_images, decoded, std, latents = model.reconstruction_ppl(valid_images) # [devices, 8, 256, 256, 3]
248
+ #Leaves channel dim out
249
+ mean = jnp.mean(latents, axis = [0,1,2,3])
250
+ std = jnp.std(latents, axis = [0,1,2,3])
251
+
252
+ #TODO maybe need to put this onto CPU
253
+ latent_list.append(latents)
254
+
255
+
256
+
257
+ means.append(mean)
258
+ stds.append(std)
259
+
260
+ predicted_stds.append(std)
261
+
262
+ lpips_list.append(operations(reconstructed_images, decoded))
263
+
264
+
265
+ #PPL two, walk both directions
266
+ reconstructed_images, decoded, std, latents, decoded_2 = model.reconstruction_ppl_two(valid_images) # [devices, 8, 256, 256, 3]
267
+ #For this one we don't care about reconstructed images, only decoded and decoded 2
268
+
269
+ lpips_list_ppl_two.append(operations(decoded, decoded_2))
270
+
271
+
272
+
273
+
274
+ #Ppl but images.
275
+ reconstructed_images, decoded, std, latents, std_noisy, latents_noisy = model.reconstruction_ppl_image(valid_images) # [devices, 8, 256, 256, 3]
276
+ noisy_means.append(latents_noisy.mean(axis = [0,1,2,3]))
277
+ noisy_stds.append(latents_noisy.std(axis = [0,1,2,3]))
278
+ noisy_predicted_stds.append(std_noisy)
279
+
280
+ #TODO WHAT IS THE LOSS FUNCTION FOR THIS ONE
281
+ #it's not quite perplexity, but there's two components
282
+ #one is that we check lpips difference as a function of final image
283
+ #The other is that we look at how far away the latents are, and see if that is consistent.
284
+
285
+ i += 1
286
+ #
287
+ if i == 500:
288
+ break
289
+
290
+ #Should be just 4 here, so... 0?
291
+
292
+ mean_lpips = jnp.mean(jnp.asarray(lpips_list))
293
+
294
+ #So our lpips list or whatever is like. Maybe we want per channel?
295
+ std_lpips = jnp.std(jnp.asarray(lpips_list))
296
+ print("PPL Regular", mean_lpips)
297
+ print("C std", std_lpips)
298
+
299
+ #So here we have 500/50,000 x 4.
300
+ #We can mean, get the mean per channel.
301
+ #We can get the std per channel.
302
+
303
+ print("mean of means", jnp.asarray(means).mean(axis = [0]))
304
+ print("stds of means", jnp.asarray(means).std(axis = [0]))
305
+
306
+
307
+ print("mean of stds", jnp.asarray(stds).mean(axis = [0]))
308
+ print("std of stds", jnp.asarray(stds).std(axis = [0]))
309
+
310
+
311
+
312
+ mean_lpips = jnp.mean(jnp.asarray(lpips_list_ppl_two))
313
+ std_lpips = jnp.std(jnp.asarray(lpips_list_ppl_two))
314
+
315
+ print("PPL Two", mean_lpips)
316
+ print("C std Two", std_lpips)
317
+
318
+ print("noisy mean of means", jnp.asarray(noisy_means).mean(axis = [0]))
319
+ print("noisy stds of means", jnp.asarray(noisy_means).std(axis = [0]))
320
+ print("noisy mean of stds", jnp.asarray(noisy_stds).mean(axis = [0]))
321
+ print("noisy std of stds", jnp.asarray(noisy_stds).std(axis = [0]))
322
+
323
+ print("Average noise added to image", jnp.asarray(predicted_stds).mean(axis = [0]))
324
+ print("Average noise added to image std", jnp.asarray(predicted_stds).std(axis = [0]))
325
+
326
+ print("Average noise added to noisy image", jnp.asarray(noisy_predicted_stds).mean(axis = [0, 1, 2, 3, 4]))
327
+ print("Average noise added to noisy image std", jnp.asarray(noisy_predicted_stds).std(axis = [0, 1, 2, 3, 4]))
328
+
329
+ print("Effective new variance (sqrt it)", jnp.asarray(noisy_predicted_stds).std(axis = [0,1,2,3,4]) ** 2 + jnp.asarray(stds).mean(axis = [0]) ** 2)
330
+
331
+
332
+ #Intrinsic
333
+ latent_list = np.asarray(latent_list).squeeze()
334
+ print(latent_list.shape)#Should be like, 500,2,32,32,4
335
+ latent_list = latent_list.reshape(-1,32,32,4)
336
+ latent_list = latent_list.reshape(latent_list.shape[0], -1)
337
+ latent_list = Data(latent_list)
338
+ latent_list.compute_distances(maxk=100)
339
+
340
+ # compute the intrinsic dimension using 2nn estimator
341
+ id, id_error, id_distance = latent_list.compute_id_2NN()
342
+ print(id, id_error, id_distance)
343
+
344
+ #None of these stats take anything else into account.
345
+ #No normalization, nothing
346
+ """PL 100
347
+ PPL Regular 6.3766294
348
+ C std 0.9229477
349
+ mean of means 0.16227543
350
+ stds of means 0.53616405
351
+ mean of stds 4.4914503
352
+ std of stds 0.6015057
353
+ PPL Two 6.3642726
354
+ C std Two 0.92391133
355
+ """
356
+
357
+
358
+ """1e-4
359
+ PPL Regular 12.521122
360
+ C std 2.3125298
361
+ mean of means 0.0065882676
362
+ stds of means 0.042861093
363
+ mean of stds 0.7608507
364
+ std of stds 0.05846726
365
+ PPL Two 12.581134
366
+ C std Two 2.5102239
367
+ Average noise added to image 0.5992337
368
+ Average noise added to image std 0.25218853
369
+ """
370
+
371
+
372
+ """1e-5
373
+ PPL Regular 13.183324
374
+ C std 2.9292953
375
+ mean of means 0.0065166513
376
+ stds of means 0.06983645
377
+ mean of stds 0.9855982
378
+ std of stds 0.05810356
379
+ PPL Two 13.193566
380
+ C std Two 2.9465785
381
+ Average noise added to image 0.16906397
382
+ Average noise added to image std 0.12756345
383
+ """
384
+
385
+ """1e-6
386
+ PPL Regular 14.146276
387
+ C std 3.6374733
388
+ mean of means -0.018107202
389
+ stds of means 0.11694455
390
+ mean of stds 1.0860059
391
+ std of stds 0.09732369
392
+ PPL Two 14.116948
393
+ C std Two 3.547216
394
+ Average noise added to image 0.039256155
395
+ Average noise added to image std 0.026851926
396
+ """
397
+
398
+ """AE
399
+ PPL Regular 10.103417
400
+ C std 2.2966182
401
+ mean of means 0.35234922
402
+ stds of means 0.4036692
403
+ mean of stds 2.6363409
404
+ std of stds 0.30666474
405
+ PPL Two 10.075436
406
+ C std Two 2.2949345
407
+ No noise added to image
408
+ """
409
+
410
+ """Dino 1e-5
411
+ PPL Regular 2.373527
412
+ C std 0.45295972
413
+ mean of means 2.5987418
414
+ stds of means 3.097953
415
+ mean of stds 49.437305
416
+ std of stds 2.5111952
417
+ PPL Two 2.3797483
418
+ C std Two 0.49930122
419
+ noisy mean of means 2.598704
420
+ noisy stds of means 3.0979395
421
+ noisy mean of stds 49.437298
422
+ noisy std of stds 2.5112264
423
+
424
+ """
425
+
426
+ #58.344119061134336 0.0 57.78905382129868
427
+
428
+ if __name__ == '__main__':
429
+ app.run(main)
f16c16/decode_only.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ import jax
12
+
13
+ import flax.linen as nn
14
+ import jax.numpy as jnp
15
+ from absl import app, flags
16
+ from functools import partial
17
+ import numpy as np
18
+ import tqdm
19
+ import flax
20
+ import optax
21
+ import wandb
22
+ from ml_collections import config_flags
23
+ #import elements
24
+ import ml_collections
25
+ import tensorflow_datasets as tfds
26
+ import tensorflow as tf
27
+ tf.config.set_visible_devices([], "GPU")
28
+ tf.config.set_visible_devices([], "TPU")
29
+ import matplotlib.pyplot as plt
30
+ from typing import Any
31
+
32
+ from utils.train_state import TrainState, target_update
33
+ from utils.checkpoint import Checkpoint
34
+ from utils.fid import get_fid_network, fid_from_stats
35
+
36
+ from train import VQGANModel
37
+ from models.vqvae import VQVAE
38
+ from models.discriminator import Discriminator
39
+
40
+ from PIL import Image
41
+
42
+
43
+ delattr(flags.FLAGS, 'dataset_name')
44
+ delattr(flags.FLAGS, 'load_dir')
45
+ delattr(flags.FLAGS, 'batch_size')
46
+
47
+ FLAGS = flags.FLAGS
48
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
49
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Documents/LiClipse Workspace/VAE/jax-vqvae-vqgan/7e-5_sdlike_sym/checkpoint.tmp", 'Load dir (if not None, load params from here).')
50
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
51
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
52
+
53
+ def main(_):
54
+ device_count = len(jax.local_devices())
55
+ global_device_count = jax.device_count()
56
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
57
+
58
+
59
+ rng = jax.random.PRNGKey(FLAGS.seed)
60
+ rng, param_key = jax.random.split(rng)
61
+ print("Total devices", jax.local_devices()[0])
62
+
63
+
64
+ ###################################
65
+ # Creating Model and put on devices.
66
+ ###################################
67
+
68
+ vqvae_def = VQVAE(FLAGS.model, train=True)
69
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
70
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
71
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
72
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
73
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
74
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
75
+
76
+ discriminator_def = Discriminator(FLAGS.model)
77
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
78
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
79
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
80
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
81
+
82
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
83
+
84
+ assert FLAGS.load_dir is not None
85
+ cp = Checkpoint(FLAGS.load_dir)
86
+ model = cp.load_model(model)
87
+ print("Loaded model with step", model.vqvae.step)
88
+
89
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
90
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
91
+
92
+ return model
93
+
94
+ #Stuff and things.
95
+ # image2 = valid_reconstructed_images[0,0,:,:,:]
96
+ # image2 = (image2 * 255).astype(np.uint8)
97
+ # image2 = np.array(image2)
98
+ # image2 = Image.fromarray(image2)
99
+ # image2.save("recon" + str(i) + ".png")
100
+
101
+
102
+
103
+
104
+
105
+ # images.append((valid_reconstructed_images*255).astype(np.uint8))
106
+
107
+ if __name__ == '__main__':
108
+ app.run(main)
f16c16/encode_latents.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+ #GPU, batch 16, latent:
8
+ """[[[[-9.51360688e-02 -6.00612536e-02 -6.76547512e-02 -3.73330832e-01]
9
+ [-3.10049266e-01 -6.82027787e-02 1.09544434e-01 -1.51526511e-01]
10
+ [-1.63606599e-01 1.52324408e-01 1.03230253e-01 -3.34064662e-01]
11
+ ...
12
+ [-9.08230543e-02 2.53294855e-01 6.09488077e-02 -3.55355501e-01]
13
+ [-2.16098756e-01 -3.44716787e-01 5.68981618e-02 -1.19108176e+00]
14
+ [ 9.24487635e-02 2.20324457e-01 1.84478119e-01 4.46850598e-01]]
15
+
16
+ [[-1.60119295e-01 2.00234763e-02 -1.43943653e-01 -2.22745568e-01]
17
+ [-2.55345762e-01 1.55626327e-01 4.85354941e-03 -1.33636221e-01]
18
+ [-1.64813206e-01 1.63652197e-01 -6.96032941e-02 -3.96138221e-01]
19
+ ...
20
+ [-1.74221992e-01 2.78679162e-01 -1.02342315e-01 -4.71356630e-01]
21
+ [-9.72934887e-02 2.24700689e-01 -1.54692575e-01 -8.07371676e-01]
22
+ [ 1.58384442e-02 9.63119492e-02 4.84653771e-01 8.73409092e-01]]
23
+
24
+ [[-1.16939977e-01 2.56956398e-01 -1.04373530e-01 -1.33346528e-01]
25
+ [-1.52860105e-01 1.76005200e-01 -1.16914781e-02 -1.92210004e-01]
26
+ [-5.50103635e-02 2.04600886e-01 -1.73305750e-01 -4.94984031e-01]
27
+ ...
28
+ [-3.88413459e-01 3.15461606e-01 -1.25539899e-01 -5.62439263e-01]
29
+ [-1.97147772e-01 -2.31708195e-02 -1.44041494e-01 -8.99005592e-01]
30
+ [ 3.42922032e-01 2.24075779e-01 4.25257713e-01 5.85853398e-01]]
31
+ """
32
+
33
+
34
+ #CPU, batch 16, latent
35
+
36
+ """
37
+ [[[[-8.47917721e-02 -8.92071351e-02 -1.05532585e-02 -3.59174877e-01]
38
+ [-1.11725748e-01 -1.22415572e-01 3.33435684e-02 -3.60438257e-01]
39
+ [-1.36060238e-01 -1.37327328e-01 3.79590057e-02 -3.73947173e-01]
40
+ ...
41
+ [ 7.88694695e-02 -5.03079742e-02 6.75498620e-02 -3.39441150e-01]
42
+ [-1.63178548e-01 -3.21848512e-01 1.72039792e-02 -9.50528085e-01]
43
+ [ 2.21429523e-02 1.48582339e-01 1.54685006e-01 6.86266243e-01]]
44
+
45
+ [[-1.69139117e-01 7.81316869e-03 4.33448888e-02 -3.37453634e-01]
46
+ [-1.96011692e-01 -4.98509258e-02 3.32896858e-02 -3.53303224e-01]
47
+ [-9.82111022e-02 -1.94629002e-02 -1.63653865e-02 -3.32124978e-01]
48
+ ...
49
+ [-7.72062615e-02 2.95878220e-02 -7.62912910e-03 -3.61496925e-01]
50
+ [-2.26189673e-01 -5.97889721e-02 -1.16483821e-02 -7.82557964e-01]
51
+ [-6.18810430e-02 7.75512159e-02 2.37205133e-01 8.39313030e-01]]
52
+
53
+ [[-9.37198251e-02 -4.58365604e-02 -2.44572274e-02 -3.00568134e-01]
54
+ [-1.32911175e-01 -9.60890502e-02 -4.78822738e-04 -3.28105956e-01]
55
+ [-7.67295957e-02 -6.57245517e-02 -3.78448963e-02 -3.29079330e-01]
56
+ ...
57
+ [-1.21173687e-01 4.07976359e-02 4.05129045e-02 -3.48512828e-01]
58
+ [-1.64501339e-01 -9.52737629e-02 -1.06653105e-03 -8.39630961e-01]
59
+ [ 2.64041096e-01 2.43525319e-02 3.05205405e-01 4.92310941e-01]]
60
+ """
61
+
62
+ #CPU, 8 vs GPU 8
63
+ """
64
+ [[[[[-3.18646997e-01 -4.77920741e-01 1.07763827e+00 1.70530510e+00]
65
+ [-6.31720126e-01 -2.49106735e-01 1.66874206e+00 -5.45821428e-01]
66
+ [-4.03593808e-01 2.76418477e-01 1.29216135e+00 8.79887521e-01]
67
+ ...
68
+ [-2.03093603e-01 -7.97204554e-01 3.61778885e-01 -3.68656218e-01]
69
+ [-2.61139393e-01 1.64036989e+00 -2.22024798e-01 3.49313989e-02]
70
+ [ 6.32668972e-01 -4.74448204e-01 1.55093277e+00 5.57837903e-01]]
71
+
72
+ [[-7.24952042e-01 4.80744302e-01 3.05105478e-01 1.06132841e+00]
73
+ [ 8.95307362e-02 1.45687327e-01 1.57945228e+00 -1.11452961e+00]
74
+ [-4.61988777e-01 -4.11880344e-01 1.70428991e+00 4.31171536e-01]
75
+ ...
76
+ [-1.17851949e+00 2.03509808e-01 1.84925032e+00 -5.68852723e-01]
77
+ [ 5.74628949e-01 -8.48990500e-01 -2.50778824e-01 1.92248678e+00]
78
+ [-2.69778688e-02 -8.46022546e-01 -7.89667487e-01 9.26319182e-01]]
79
+
80
+ [[-3.10738117e-01 6.01165593e-02 1.57032907e-01 1.53192639e+00]
81
+ [ 6.55903339e-01 7.50707746e-01 6.03949744e-03 1.31769347e+00]
82
+ [ 3.26834202e-01 -2.33611539e-01 1.35725603e-01 -2.39371091e-01]
83
+ ...
84
+ [ 2.19290599e-01 -2.21653271e+00 -2.21055865e+00 1.49363160e+00]
85
+ [-1.45460200e+00 1.18737824e-01 1.56015289e+00 8.23014230e-03]
86
+ [ 3.44308168e-01 1.08958745e+00 -1.23330317e-01 5.41093886e-01]]
87
+
88
+ #GPU
89
+
90
+
91
+ """
92
+
93
+ #import jax
94
+ #jax.config.update('jax_platform_name', 'cpu')
95
+ import os
96
+
97
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
98
+
99
+ import jax
100
+ import lpips
101
+
102
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
103
+ loss_fn_alex = loss_fn_alex.cuda()
104
+
105
+
106
+ import numpy as np
107
+ import flax.linen as nn
108
+ import jax.numpy as jnp
109
+ from absl import app, flags
110
+ from functools import partial
111
+ import numpy as np
112
+ import tqdm
113
+ import flax
114
+ import optax
115
+ import wandb
116
+ from ml_collections import config_flags
117
+ #import elements
118
+ import ml_collections
119
+ import tensorflow_datasets as tfds
120
+ import tensorflow as tf
121
+ tf.config.set_visible_devices([], "GPU")
122
+ tf.config.set_visible_devices([], "TPU")
123
+ import matplotlib.pyplot as plt
124
+ from typing import Any
125
+
126
+ from utils.train_state import TrainState, target_update
127
+ from utils.checkpoint import Checkpoint
128
+ from utils.fid import get_fid_network, fid_from_stats
129
+
130
+ from train import VQGANModel
131
+ from models.vqvae import VQVAE
132
+ from models.discriminator import Discriminator
133
+
134
+ from PIL import Image
135
+ import torch
136
+
137
+ delattr(flags.FLAGS, 'dataset_name')
138
+ delattr(flags.FLAGS, 'load_dir')
139
+ delattr(flags.FLAGS, 'batch_size')
140
+
141
+ FLAGS = flags.FLAGS
142
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
143
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
144
+
145
+ from safetensors.torch import save_file
146
+
147
+ flags.DEFINE_integer('batch_size', 8, 'Total Batch size.')
148
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
149
+
150
+ import gc
151
+
152
+ def main(_):
153
+ device_count = len(jax.local_devices())
154
+ global_device_count = jax.device_count()
155
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
156
+
157
+ def get_dataset(is_train):
158
+ if 'imagenet' in FLAGS.dataset_name:
159
+ def deserialization_fn(data):
160
+ image = data['image']
161
+ label = data["label"]
162
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
163
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
164
+ if 'imagenet256' in FLAGS.dataset_name:
165
+ image = tf.image.resize(image, (256, 256))
166
+ elif 'imagenet128' in FLAGS.dataset_name:
167
+ image = tf.image.resize(image, (128, 128))
168
+ else:
169
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
170
+ if is_train:
171
+ # image = tf.image.random_flip_left_right(image)
172
+ image_flip =tf.image.flip_left_right(image)
173
+ image_flip = tf.cast(image_flip, tf.float32) / 255.0
174
+ image = tf.cast(image, tf.float32) / 255.0
175
+ return image, image_flip, label
176
+
177
+ image = tf.cast(image, tf.float32) / 255.0
178
+ return image, label
179
+
180
+
181
+
182
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
183
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
184
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
185
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
186
+ dataset = dataset.batch(local_batch_size)
187
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
188
+ dataset = tfds.as_numpy(dataset)
189
+ dataset = iter(dataset)
190
+ return dataset
191
+ else:
192
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
193
+
194
+ dataset = get_dataset(is_train=True)
195
+ dataset_valid = get_dataset(is_train=False)
196
+
197
+
198
+ # image = Image.open("osman.png")
199
+ # image = np.array(image) / 255.0
200
+ # print(image)
201
+ # image = jnp.array(image)
202
+ # image = jnp.expand_dims(image, 0)
203
+ # image = jnp.expand_dims(image, 0)
204
+
205
+ example_obs = next(dataset)[:1][0]
206
+
207
+ #Reconstruction loop
208
+ # image = model.reconstruction(image)
209
+ # image = image[0,0,:,:,:]
210
+ # image = (image * 255).astype(np.uint8)
211
+ # image = np.array(image)
212
+ # img = Image.fromarray(image)
213
+ # img.save("osman" + str(i) + ".png")
214
+
215
+
216
+ rng = jax.random.PRNGKey(FLAGS.seed)
217
+ rng, param_key = jax.random.split(rng)
218
+ print("Total devices", jax.local_devices()[0])
219
+
220
+
221
+ ###################################
222
+ # Creating Model and put on devices.
223
+ ###################################
224
+ FLAGS.model.image_channels = example_obs.shape[-1]
225
+ FLAGS.model.image_size = example_obs.shape[1]
226
+ vqvae_def = VQVAE(FLAGS.model, train=True)
227
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
228
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
229
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
230
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
231
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
232
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
233
+
234
+ discriminator_def = Discriminator(FLAGS.model)
235
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
236
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
237
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
238
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
239
+
240
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
241
+
242
+ assert FLAGS.load_dir is not None
243
+ cp = Checkpoint(FLAGS.load_dir)
244
+ model = cp.load_model(model)
245
+ print("Loaded model with step", model.vqvae.step)
246
+
247
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
248
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
249
+ #print(model.vqvae)
250
+
251
+
252
+ latents = []
253
+ latents_flip = []
254
+ labels = []
255
+ saved_files = 0
256
+ for image, image_flip, label in dataset:
257
+ #Also need to hflp the image
258
+
259
+
260
+ image = image.reshape((len(jax.local_devices()), -1, *image.shape[1:])) # [devices, batch//devices, etc..]
261
+ latent, result_dict = model.get_latent(image)
262
+
263
+
264
+ image_flip = image_flip.reshape((len(jax.local_devices()), -1, *image_flip.shape[1:])) # [devices, batch//devices, etc..]
265
+ latent_flip, result_dict_flip = model.get_latent(image_flip)
266
+
267
+
268
+ latents.append(latent.squeeze())
269
+ latents_flip.append(latent_flip.squeeze())
270
+ labels.append(label)
271
+
272
+
273
+ if len(latents) == 5000:#Since we are bs 2, should be 5k
274
+
275
+ latents = jnp.concatenate(latents, axis=0)
276
+ latents_flip = jnp.concatenate(latents_flip, axis=0)
277
+ labels = jnp.concatenate(labels, axis=0)
278
+
279
+
280
+ latents_torch = np.asarray(latents)
281
+ latents_torch = torch.from_numpy(np.copy(latents_torch))
282
+
283
+ latents_flip_torch = np.asarray(latents_flip)
284
+ latents_flip_torch = torch.from_numpy(np.copy(latents_flip_torch))
285
+
286
+ labels_torch = np.asarray(labels)
287
+ labels_torch = torch.from_numpy(np.copy(labels_torch))
288
+
289
+
290
+ save_dict = {
291
+ 'latents': latents_torch,
292
+ 'latents_flip': latents_flip_torch,
293
+ 'labels': labels_torch
294
+ }
295
+
296
+ print(latents_torch.shape)#400,32,32,4
297
+ print(latents_flip_torch.shape)#^
298
+ print(labels_torch.shape)#400
299
+
300
+ #Now we need to calculate the man
301
+ # print("Total mean", latents_torch.mean(axis = [0]))
302
+ # class_means = {}
303
+ # for label, tensor in zip(labels_torch, latents_torch):
304
+ # label = str(label.item())
305
+ # if label in class_means.keys():
306
+ # class_means[label].append(tensor)
307
+ # else:
308
+ # class_means[label] = [tensor]
309
+ #
310
+ #
311
+ # for iclass in class_means.keys():
312
+ # #So now we have a list of tensors
313
+ # stacked_tensors = torch.stack(class_means[iclass])
314
+ # mean = stacked_tensors.mean(axis = [0])
315
+ # print(mean)
316
+ # print(iclass)
317
+ # exit()
318
+
319
+ output_dir = "/data/inet_latents"
320
+ save_filename = os.path.join(output_dir, f'latents_shard{saved_files:03d}.safetensors')
321
+ save_file(
322
+ save_dict,
323
+ save_filename,
324
+ metadata={'total_size': f'{latents_torch.shape[0]}', 'dtype': f'{latents_torch.dtype}', 'device': f'{latents_torch.device}'}
325
+ )
326
+
327
+ latents = []
328
+ latents_flip = []
329
+ labels = []
330
+ saved_files += 1
331
+ #Let's just run the kl2 first and not save the extra
332
+
333
+ # print(latent.shape)
334
+ # print(result_dict)
335
+
336
+
337
+ if __name__ == '__main__':
338
+ app.run(main)
f16c16/eval_fid.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+ import flax.linen as nn
8
+ import jax.numpy as jnp
9
+ from absl import app, flags
10
+ from functools import partial
11
+ import numpy as np
12
+ import tqdm
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import flax
16
+ import optax
17
+ import wandb
18
+ from ml_collections import config_flags
19
+ #import elements
20
+ import ml_collections
21
+ import tensorflow_datasets as tfds
22
+ import tensorflow as tf
23
+ tf.config.set_visible_devices([], "GPU")
24
+ tf.config.set_visible_devices([], "TPU")
25
+ import matplotlib.pyplot as plt
26
+ from typing import Any
27
+
28
+ from utils.train_state import TrainState, target_update
29
+ from utils.checkpoint import Checkpoint
30
+ from utils.fid import get_fid_network, fid_from_stats
31
+
32
+ from train import VQGANModel
33
+ from models.vqvae import VQVAE
34
+ from models.discriminator import Discriminator
35
+
36
+ delattr(flags.FLAGS, 'dataset_name')
37
+ delattr(flags.FLAGS, 'load_dir')
38
+ delattr(flags.FLAGS, 'batch_size')
39
+
40
+ FLAGS = flags.FLAGS
41
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
42
+ flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp", 'Load dir (if not None, load params from here).')
43
+ flags.DEFINE_integer('batch_size', 128, 'Total Batch size.')
44
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
45
+
46
+ def main(_):
47
+ device_count = len(jax.local_devices())
48
+ global_device_count = jax.device_count()
49
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
50
+
51
+ def get_dataset(is_train):
52
+ if 'imagenet' in FLAGS.dataset_name:
53
+ def deserialization_fn(data):
54
+ image = data['image']
55
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
56
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
57
+ if 'imagenet256' in FLAGS.dataset_name:
58
+ image = tf.image.resize(image, (256, 256))
59
+ elif 'imagenet128' in FLAGS.dataset_name:
60
+ image = tf.image.resize(image, (128, 128))
61
+ else:
62
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
63
+ if is_train:
64
+ image = tf.image.random_flip_left_right(image)
65
+ image = tf.cast(image, tf.float32) / 255.0
66
+ return image
67
+
68
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
69
+ dataset = tfds.load('imagenet2012', data_dir="/dev/shm", split=split)
70
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
71
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
72
+ dataset = dataset.batch(local_batch_size)
73
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
74
+ dataset = tfds.as_numpy(dataset)
75
+ dataset = iter(dataset)
76
+ return dataset
77
+ else:
78
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
79
+
80
+ dataset = get_dataset(is_train=False)
81
+ dataset_valid = get_dataset(is_train=False)
82
+ example_obs = next(dataset)[:1]
83
+
84
+
85
+ get_fid_activations = get_fid_network()
86
+ truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
87
+ # truth_fid_stats = np.load('base_stats.npz')
88
+
89
+ rng = jax.random.PRNGKey(FLAGS.seed)
90
+ rng, param_key = jax.random.split(rng)
91
+ print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
92
+
93
+ ###################################
94
+ # Creating Model and put on devices.
95
+ ###################################
96
+ FLAGS.model.image_channels = example_obs.shape[-1]
97
+ FLAGS.model.image_size = example_obs.shape[1]
98
+ vqvae_def = VQVAE(FLAGS.model, train=True)
99
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
100
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
101
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
102
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
103
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
104
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
105
+
106
+ discriminator_def = Discriminator(FLAGS.model)
107
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
108
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
109
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
110
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
111
+
112
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
113
+
114
+ assert FLAGS.load_dir is not None
115
+ cp = Checkpoint(FLAGS.load_dir)
116
+ model = cp.load_model(model)
117
+ print("Loaded model with step", model.vqvae.step)
118
+
119
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
120
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
121
+ #print(model.vqvae)
122
+
123
+
124
+ ###################################
125
+ # FID Evaluation.
126
+ ###################################
127
+
128
+ activations = []
129
+ activations_base = []
130
+
131
+ images = []
132
+ images_original = []
133
+ for valid_images in dataset_valid:
134
+
135
+ images_original.append((valid_images*255).astype(np.uint8))
136
+ if valid_images.shape[0] < local_batch_size:
137
+ zeros_added = local_batch_size - valid_images.shape[0]
138
+ valid_images = np.concatenate([valid_images, np.zeros((local_batch_size - valid_images.shape[0], *valid_images.shape[1:]))], axis=0)
139
+ else:
140
+ zeros_added = 0
141
+
142
+ print(len(jax.local_devices()))
143
+ print(valid_images.shape)
144
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
145
+ print(valid_images.shape)
146
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
147
+ print(valid_reconstructed_images.shape)
148
+
149
+ #Whatever...
150
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
151
+
152
+ for j in range(1):#fuck it
153
+ continue#Turn this off for now
154
+ axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
155
+ axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
156
+ #wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
157
+
158
+
159
+ #We are not sure if we are 0-1 or if we are -1 to 1
160
+ #Let's try both
161
+ add_images = valid_reconstructed_images.reshape(-1,256,256,3)
162
+ if zeros_added > 0:
163
+ add_images = add_images[:-zeros_added, :, :, :]
164
+ images.append((add_images*255).astype(np.uint8))
165
+
166
+ #valid = (valid_reconstructed_images + 1 ) * 127.5
167
+ #images2.append(valid.clamp(0,255).astype(npuint8))
168
+
169
+ valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
170
+ method='bilinear', antialias=True)
171
+ valid_reconstructed_images = 2 * valid_reconstructed_images - 1
172
+ acts = np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]
173
+ if zeros_added > 0:
174
+ acts = acts[:-zeros_added]
175
+ activations.append(acts)
176
+
177
+ #Used to grab baseline truths
178
+ if False:
179
+ valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
180
+ method='bilinear', antialias=True)
181
+ valid_reconstructed_images = 2 * valid_reconstructed_images - 1
182
+ acts = np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]
183
+
184
+ if zeros_added > 0:
185
+ acts = acts[:-zeros_added]
186
+ activations_base.append(acts)
187
+ #This is fine because it's just length
188
+ print(len(activations) * FLAGS.batch_size)
189
+
190
+ images = np.concatenate(images, axis = 0)
191
+ #images_original = np.concatenate(images_original, axis = 0)
192
+ print(images.shape)#1564x32x256x256x3 #Old shape
193
+ #print(images_original.shape)
194
+ #new shape should just be 50k
195
+ #Reshape
196
+ images = images.reshape(-1, 256, 256, 3)
197
+ #images2 = images_original.reshape(-1,256,256,3)
198
+
199
+ activations = np.concatenate(activations, axis=0)
200
+ activations = activations.reshape((-1, activations.shape[-1]))
201
+ mu1 = np.mean(activations, axis=0)
202
+ sigma1 = np.cov(activations, rowvar=False)
203
+ #print(mu1)
204
+ #print(sigma1)
205
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
206
+
207
+ print("FID:", fid)
208
+
209
+ np.savez("./images_recon.npz", arr_0 = images)
210
+ #np.savez("./images_original.npz", arr_0 = images2)
211
+
212
+
213
+ if __name__ == '__main__':
214
+ app.run(main)
f16c16/evaluator.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ import random
5
+ import warnings
6
+ import zipfile
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from functools import partial
10
+ from multiprocessing import cpu_count
11
+ from multiprocessing.pool import ThreadPool
12
+ from typing import Iterable, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import requests
16
+ import tensorflow.compat.v1 as tf
17
+ from scipy import linalg
18
+ from tqdm.auto import tqdm
19
+
20
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22
+
23
+ FID_POOL_NAME = "pool_3:0"
24
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("ref_batch", help="path to reference batch npz file")
30
+ parser.add_argument("sample_batch", help="path to sample batch npz file")
31
+ args = parser.parse_args()
32
+
33
+ config = tf.ConfigProto(
34
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
35
+ )
36
+ config.gpu_options.allow_growth = True
37
+ evaluator = Evaluator(tf.Session(config=config))
38
+
39
+ print("warming up TensorFlow...")
40
+ # This will cause TF to print a bunch of verbose stuff now rather
41
+ # than after the next print(), to help prevent confusion.
42
+ evaluator.warmup()
43
+
44
+ print("computing reference batch activations...")
45
+ ref_acts = evaluator.read_activations(args.ref_batch)
46
+ print("computing/reading reference batch statistics...")
47
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
48
+
49
+ print("computing sample batch activations...")
50
+ sample_acts = evaluator.read_activations(args.sample_batch)
51
+ print("computing/reading sample batch statistics...")
52
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
53
+
54
+ print("Computing evaluations...")
55
+ print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
56
+ print("FID:", sample_stats.frechet_distance(ref_stats))
57
+ print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
58
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
59
+ print("Precision:", prec)
60
+ print("Recall:", recall)
61
+
62
+
63
+ class InvalidFIDException(Exception):
64
+ pass
65
+
66
+
67
+ class FIDStatistics:
68
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
69
+ self.mu = mu
70
+ self.sigma = sigma
71
+
72
+ def frechet_distance(self, other, eps=1e-6):
73
+ """
74
+ Compute the Frechet distance between two sets of statistics.
75
+ """
76
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
77
+ mu1, sigma1 = self.mu, self.sigma
78
+ mu2, sigma2 = other.mu, other.sigma
79
+
80
+ mu1 = np.atleast_1d(mu1)
81
+ mu2 = np.atleast_1d(mu2)
82
+
83
+ sigma1 = np.atleast_2d(sigma1)
84
+ sigma2 = np.atleast_2d(sigma2)
85
+
86
+ assert (
87
+ mu1.shape == mu2.shape
88
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
89
+ assert (
90
+ sigma1.shape == sigma2.shape
91
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
92
+
93
+ diff = mu1 - mu2
94
+
95
+ # product might be almost singular
96
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
97
+ if not np.isfinite(covmean).all():
98
+ msg = (
99
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
100
+ % eps
101
+ )
102
+ warnings.warn(msg)
103
+ offset = np.eye(sigma1.shape[0]) * eps
104
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
105
+
106
+ # numerical error might give slight imaginary component
107
+ if np.iscomplexobj(covmean):
108
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
109
+ m = np.max(np.abs(covmean.imag))
110
+ raise ValueError("Imaginary component {}".format(m))
111
+ covmean = covmean.real
112
+
113
+ tr_covmean = np.trace(covmean)
114
+
115
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
116
+
117
+
118
+ class Evaluator:
119
+ def __init__(
120
+ self,
121
+ session,
122
+ batch_size=64,
123
+ softmax_batch_size=512,
124
+ ):
125
+ self.sess = session
126
+ self.batch_size = batch_size
127
+ self.softmax_batch_size = softmax_batch_size
128
+ self.manifold_estimator = ManifoldEstimator(session)
129
+ with self.sess.graph.as_default():
130
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
131
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
132
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
133
+ self.softmax = _create_softmax_graph(self.softmax_input)
134
+
135
+ def warmup(self):
136
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
137
+
138
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
139
+ with open_npz_array(npz_path, "arr_0") as reader:
140
+ return self.compute_activations(reader.read_batches(self.batch_size))
141
+
142
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
143
+ """
144
+ Compute image features for downstream evals.
145
+
146
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
147
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
148
+ dimension. The tuple is (pool_3, spatial).
149
+ """
150
+ preds = []
151
+ spatial_preds = []
152
+ for batch in tqdm(batches):
153
+ batch = batch.astype(np.float32)
154
+ pred, spatial_pred = self.sess.run(
155
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
156
+ )
157
+ preds.append(pred.reshape([pred.shape[0], -1]))
158
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
159
+ return (
160
+ np.concatenate(preds, axis=0),
161
+ np.concatenate(spatial_preds, axis=0),
162
+ )
163
+
164
+ def read_statistics(
165
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
166
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
167
+ obj = np.load(npz_path)
168
+ if "mu" in list(obj.keys()):
169
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
170
+ obj["mu_s"], obj["sigma_s"]
171
+ )
172
+ return tuple(self.compute_statistics(x) for x in activations)
173
+
174
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
175
+ mu = np.mean(activations, axis=0)
176
+ sigma = np.cov(activations, rowvar=False)
177
+ return FIDStatistics(mu, sigma)
178
+
179
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
180
+ softmax_out = []
181
+ for i in range(0, len(activations), self.softmax_batch_size):
182
+ acts = activations[i : i + self.softmax_batch_size]
183
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
184
+ preds = np.concatenate(softmax_out, axis=0)
185
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
186
+ scores = []
187
+ for i in range(0, len(preds), split_size):
188
+ part = preds[i : i + split_size]
189
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
190
+ kl = np.mean(np.sum(kl, 1))
191
+ scores.append(np.exp(kl))
192
+ return float(np.mean(scores))
193
+
194
+ def compute_prec_recall(
195
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
196
+ ) -> Tuple[float, float]:
197
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
198
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
199
+ pr = self.manifold_estimator.evaluate_pr(
200
+ activations_ref, radii_1, activations_sample, radii_2
201
+ )
202
+ return (float(pr[0][0]), float(pr[1][0]))
203
+
204
+
205
+ class ManifoldEstimator:
206
+ """
207
+ A helper for comparing manifolds of feature vectors.
208
+
209
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ session,
215
+ row_batch_size=10000,
216
+ col_batch_size=10000,
217
+ nhood_sizes=(3,),
218
+ clamp_to_percentile=None,
219
+ eps=1e-5,
220
+ ):
221
+ """
222
+ Estimate the manifold of given feature vectors.
223
+
224
+ :param session: the TensorFlow session.
225
+ :param row_batch_size: row batch size to compute pairwise distances
226
+ (parameter to trade-off between memory usage and performance).
227
+ :param col_batch_size: column batch size to compute pairwise distances.
228
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
229
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
230
+ the given percentile.
231
+ :param eps: small number for numerical stability.
232
+ """
233
+ self.distance_block = DistanceBlock(session)
234
+ self.row_batch_size = row_batch_size
235
+ self.col_batch_size = col_batch_size
236
+ self.nhood_sizes = nhood_sizes
237
+ self.num_nhoods = len(nhood_sizes)
238
+ self.clamp_to_percentile = clamp_to_percentile
239
+ self.eps = eps
240
+
241
+ def warmup(self):
242
+ feats, radii = (
243
+ np.zeros([1, 2048], dtype=np.float32),
244
+ np.zeros([1, 1], dtype=np.float32),
245
+ )
246
+ self.evaluate_pr(feats, radii, feats, radii)
247
+
248
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
249
+ num_images = len(features)
250
+
251
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
252
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
253
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
254
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
255
+
256
+ for begin1 in range(0, num_images, self.row_batch_size):
257
+ end1 = min(begin1 + self.row_batch_size, num_images)
258
+ row_batch = features[begin1:end1]
259
+
260
+ for begin2 in range(0, num_images, self.col_batch_size):
261
+ end2 = min(begin2 + self.col_batch_size, num_images)
262
+ col_batch = features[begin2:end2]
263
+
264
+ # Compute distances between batches.
265
+ distance_batch[
266
+ 0 : end1 - begin1, begin2:end2
267
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
268
+
269
+ # Find the k-nearest neighbor from the current batch.
270
+ radii[begin1:end1, :] = np.concatenate(
271
+ [
272
+ x[:, self.nhood_sizes]
273
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
274
+ ],
275
+ axis=0,
276
+ )
277
+
278
+ if self.clamp_to_percentile is not None:
279
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
280
+ radii[radii > max_distances] = 0
281
+ return radii
282
+
283
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
284
+ """
285
+ Evaluate if new feature vectors are at the manifold.
286
+ """
287
+ num_eval_images = eval_features.shape[0]
288
+ num_ref_images = radii.shape[0]
289
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
290
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
291
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
292
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
293
+
294
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
295
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
296
+ feature_batch = eval_features[begin1:end1]
297
+
298
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
299
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
300
+ ref_batch = features[begin2:end2]
301
+
302
+ distance_batch[
303
+ 0 : end1 - begin1, begin2:end2
304
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
305
+
306
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
307
+ # If a feature vector is inside a hypersphere of some reference sample, then
308
+ # the new sample lies at the estimated manifold.
309
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
310
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
311
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
312
+
313
+ max_realism_score[begin1:end1] = np.max(
314
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
315
+ )
316
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
317
+
318
+ return {
319
+ "fraction": float(np.mean(batch_predictions)),
320
+ "batch_predictions": batch_predictions,
321
+ "max_realisim_score": max_realism_score,
322
+ "nearest_indices": nearest_indices,
323
+ }
324
+
325
+ def evaluate_pr(
326
+ self,
327
+ features_1: np.ndarray,
328
+ radii_1: np.ndarray,
329
+ features_2: np.ndarray,
330
+ radii_2: np.ndarray,
331
+ ) -> Tuple[np.ndarray, np.ndarray]:
332
+ """
333
+ Evaluate precision and recall efficiently.
334
+
335
+ :param features_1: [N1 x D] feature vectors for reference batch.
336
+ :param radii_1: [N1 x K1] radii for reference vectors.
337
+ :param features_2: [N2 x D] feature vectors for the other batch.
338
+ :param radii_2: [N x K2] radii for other vectors.
339
+ :return: a tuple of arrays for (precision, recall):
340
+ - precision: an np.ndarray of length K1
341
+ - recall: an np.ndarray of length K2
342
+ """
343
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
344
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
345
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
346
+ end_1 = begin_1 + self.row_batch_size
347
+ batch_1 = features_1[begin_1:end_1]
348
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
349
+ end_2 = begin_2 + self.col_batch_size
350
+ batch_2 = features_2[begin_2:end_2]
351
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
352
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
353
+ )
354
+ features_1_status[begin_1:end_1] |= batch_1_in
355
+ features_2_status[begin_2:end_2] |= batch_2_in
356
+ return (
357
+ np.mean(features_2_status.astype(np.float64), axis=0),
358
+ np.mean(features_1_status.astype(np.float64), axis=0),
359
+ )
360
+
361
+
362
+ class DistanceBlock:
363
+ """
364
+ Calculate pairwise distances between vectors.
365
+
366
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
367
+ """
368
+
369
+ def __init__(self, session):
370
+ self.session = session
371
+
372
+ # Initialize TF graph to calculate pairwise distances.
373
+ with session.graph.as_default():
374
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
375
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
376
+ distance_block_16 = _batch_pairwise_distances(
377
+ tf.cast(self._features_batch1, tf.float16),
378
+ tf.cast(self._features_batch2, tf.float16),
379
+ )
380
+ self.distance_block = tf.cond(
381
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
382
+ lambda: tf.cast(distance_block_16, tf.float32),
383
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
384
+ )
385
+
386
+ # Extra logic for less thans.
387
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
388
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
389
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
390
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
391
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
392
+
393
+ def pairwise_distances(self, U, V):
394
+ """
395
+ Evaluate pairwise distances between two batches of feature vectors.
396
+ """
397
+ return self.session.run(
398
+ self.distance_block,
399
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
400
+ )
401
+
402
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
403
+ return self.session.run(
404
+ [self._batch_1_in, self._batch_2_in],
405
+ feed_dict={
406
+ self._features_batch1: batch_1,
407
+ self._features_batch2: batch_2,
408
+ self._radii1: radii_1,
409
+ self._radii2: radii_2,
410
+ },
411
+ )
412
+
413
+
414
+ def _batch_pairwise_distances(U, V):
415
+ """
416
+ Compute pairwise distances between two batches of feature vectors.
417
+ """
418
+ with tf.variable_scope("pairwise_dist_block"):
419
+ # Squared norms of each row in U and V.
420
+ norm_u = tf.reduce_sum(tf.square(U), 1)
421
+ norm_v = tf.reduce_sum(tf.square(V), 1)
422
+
423
+ # norm_u as a column and norm_v as a row vectors.
424
+ norm_u = tf.reshape(norm_u, [-1, 1])
425
+ norm_v = tf.reshape(norm_v, [1, -1])
426
+
427
+ # Pairwise squared Euclidean distances.
428
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
429
+
430
+ return D
431
+
432
+
433
+ class NpzArrayReader(ABC):
434
+ @abstractmethod
435
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
436
+ pass
437
+
438
+ @abstractmethod
439
+ def remaining(self) -> int:
440
+ pass
441
+
442
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
443
+ def gen_fn():
444
+ while True:
445
+ batch = self.read_batch(batch_size)
446
+ if batch is None:
447
+ break
448
+ yield batch
449
+
450
+ rem = self.remaining()
451
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
452
+ return BatchIterator(gen_fn, num_batches)
453
+
454
+
455
+ class BatchIterator:
456
+ def __init__(self, gen_fn, length):
457
+ self.gen_fn = gen_fn
458
+ self.length = length
459
+
460
+ def __len__(self):
461
+ return self.length
462
+
463
+ def __iter__(self):
464
+ return self.gen_fn()
465
+
466
+
467
+ class StreamingNpzArrayReader(NpzArrayReader):
468
+ def __init__(self, arr_f, shape, dtype):
469
+ self.arr_f = arr_f
470
+ self.shape = shape
471
+ self.dtype = dtype
472
+ self.idx = 0
473
+
474
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
475
+ if self.idx >= self.shape[0]:
476
+ return None
477
+
478
+ bs = min(batch_size, self.shape[0] - self.idx)
479
+ self.idx += bs
480
+
481
+ if self.dtype.itemsize == 0:
482
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
483
+
484
+ read_count = bs * np.prod(self.shape[1:])
485
+ read_size = int(read_count * self.dtype.itemsize)
486
+ data = _read_bytes(self.arr_f, read_size, "array data")
487
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
488
+
489
+ def remaining(self) -> int:
490
+ return max(0, self.shape[0] - self.idx)
491
+
492
+
493
+ class MemoryNpzArrayReader(NpzArrayReader):
494
+ def __init__(self, arr):
495
+ self.arr = arr
496
+ self.idx = 0
497
+
498
+ @classmethod
499
+ def load(cls, path: str, arr_name: str):
500
+ with open(path, "rb") as f:
501
+ arr = np.load(f)[arr_name]
502
+ return cls(arr)
503
+
504
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
505
+ if self.idx >= self.arr.shape[0]:
506
+ return None
507
+
508
+ res = self.arr[self.idx : self.idx + batch_size]
509
+ self.idx += batch_size
510
+ return res
511
+
512
+ def remaining(self) -> int:
513
+ return max(0, self.arr.shape[0] - self.idx)
514
+
515
+
516
+ @contextmanager
517
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
518
+ with _open_npy_file(path, arr_name) as arr_f:
519
+ version = np.lib.format.read_magic(arr_f)
520
+ if version == (1, 0):
521
+ header = np.lib.format.read_array_header_1_0(arr_f)
522
+ elif version == (2, 0):
523
+ header = np.lib.format.read_array_header_2_0(arr_f)
524
+ else:
525
+ yield MemoryNpzArrayReader.load(path, arr_name)
526
+ return
527
+ shape, fortran, dtype = header
528
+ if fortran or dtype.hasobject:
529
+ yield MemoryNpzArrayReader.load(path, arr_name)
530
+ else:
531
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
532
+
533
+
534
+ def _read_bytes(fp, size, error_template="ran out of data"):
535
+ """
536
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
537
+
538
+ Read from file-like object until size bytes are read.
539
+ Raises ValueError if not EOF is encountered before size bytes are read.
540
+ Non-blocking objects only supported if they derive from io objects.
541
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
542
+ requested.
543
+ """
544
+ data = bytes()
545
+ while True:
546
+ # io files (default in python3) return None or raise on
547
+ # would-block, python2 file will truncate, probably nothing can be
548
+ # done about that. note that regular files can't be non-blocking
549
+ try:
550
+ r = fp.read(size - len(data))
551
+ data += r
552
+ if len(r) == 0 or len(data) == size:
553
+ break
554
+ except io.BlockingIOError:
555
+ pass
556
+ if len(data) != size:
557
+ msg = "EOF: reading %s, expected %d bytes got %d"
558
+ raise ValueError(msg % (error_template, size, len(data)))
559
+ else:
560
+ return data
561
+
562
+
563
+ @contextmanager
564
+ def _open_npy_file(path: str, arr_name: str):
565
+ with open(path, "rb") as f:
566
+ with zipfile.ZipFile(f, "r") as zip_f:
567
+ if f"{arr_name}.npy" not in zip_f.namelist():
568
+ raise ValueError(f"missing {arr_name} in npz file")
569
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
570
+ yield arr_f
571
+
572
+
573
+ def _download_inception_model():
574
+ if os.path.exists(INCEPTION_V3_PATH):
575
+ return
576
+ print("downloading InceptionV3 model...")
577
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
578
+ r.raise_for_status()
579
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
580
+ with open(tmp_path, "wb") as f:
581
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
582
+ f.write(chunk)
583
+ os.rename(tmp_path, INCEPTION_V3_PATH)
584
+
585
+
586
+ def _create_feature_graph(input_batch):
587
+ _download_inception_model()
588
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
589
+ with open(INCEPTION_V3_PATH, "rb") as f:
590
+ graph_def = tf.GraphDef()
591
+ graph_def.ParseFromString(f.read())
592
+ pool3, spatial = tf.import_graph_def(
593
+ graph_def,
594
+ input_map={f"ExpandDims:0": input_batch},
595
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
596
+ name=prefix,
597
+ )
598
+ _update_shapes(pool3)
599
+ spatial = spatial[..., :7]
600
+ return pool3, spatial
601
+
602
+
603
+ def _create_softmax_graph(input_batch):
604
+ _download_inception_model()
605
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
606
+ with open(INCEPTION_V3_PATH, "rb") as f:
607
+ graph_def = tf.GraphDef()
608
+ graph_def.ParseFromString(f.read())
609
+ (matmul,) = tf.import_graph_def(
610
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
611
+ )
612
+ w = matmul.inputs[1]
613
+ logits = tf.matmul(input_batch, w)
614
+ return tf.nn.softmax(logits)
615
+
616
+
617
+ def _update_shapes(pool3):
618
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
619
+ ops = pool3.graph.get_operations()
620
+ for op in ops:
621
+ for o in op.outputs:
622
+ shape = o.get_shape()
623
+ if shape._dims is not None: # pylint: disable=protected-access
624
+ # shape = [s.value for s in shape] TF 1.x
625
+ shape = [s for s in shape] # TF 2.x
626
+ new_shape = []
627
+ for j, s in enumerate(shape):
628
+ if s == 1 and j == 0:
629
+ new_shape.append(None)
630
+ else:
631
+ new_shape.append(s)
632
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
633
+ return pool3
634
+
635
+
636
+ def _numpy_partition(arr, kth, **kwargs):
637
+ num_workers = min(cpu_count(), len(arr))
638
+ chunk_size = len(arr) // num_workers
639
+ extra = len(arr) % num_workers
640
+
641
+ start_idx = 0
642
+ batches = []
643
+ for i in range(num_workers):
644
+ size = chunk_size + (1 if i < extra else 0)
645
+ batches.append(arr[start_idx : start_idx + size])
646
+ start_idx += size
647
+
648
+ with ThreadPool(num_workers) as pool:
649
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
650
+
651
+
652
+ if __name__ == "__main__":
653
+ main()
654
+
f16c16/graph-data.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ noises = []
5
+
6
+ numbers = np.arange(0.00, 1.0, 0.01)
7
+
8
+ for number in numbers:
9
+ noises.append(float(number))
10
+
11
+ # numbers = np.arange(.4, 3, .5)
12
+ # for number in numbers:
13
+ # noises.append(float(number))
14
+
15
+
16
+ mean_l2 = []
17
+ mean_lpips = []
18
+
19
+
20
+ with open("./1e-4.txt", "r") as f:
21
+ print("read")
22
+ for line in f.readlines():
23
+ print(line)
24
+ if "Mean L2" in line:
25
+ mean_l2.append(float(line.split(":")[1].strip()))
26
+ elif "Mean Lpips" in line:
27
+ mean_lpips.append(float(line.split(":")[1].strip()))
28
+
29
+ mean_l2_2 = []
30
+ mean_lpips_2 = []
31
+ with open("./1e-5.txt", "r") as f:
32
+ print("read")
33
+ for line in f.readlines():
34
+ print(line)
35
+ if "Mean L2" in line:
36
+ mean_l2_2.append(float(line.split(":")[1].strip()))
37
+ elif "Mean Lpips" in line:
38
+ mean_lpips_2.append(float(line.split(":")[1].strip()))
39
+
40
+
41
+ mean_l2_3 = []
42
+ mean_lpips_3 = []
43
+ with open("./2e-5.txt", "r") as f:
44
+ print("read")
45
+ for line in f.readlines():
46
+ print(line)
47
+ if "Mean L2" in line:
48
+ mean_l2_3.append(float(line.split(":")[1].strip()))
49
+ elif "Mean Lpips" in line:
50
+ mean_lpips_3.append(float(line.split(":")[1].strip()))
51
+
52
+ mean_l2_4 = []
53
+ mean_lpips_4 = []
54
+ with open("./1e-6.txt", "r") as f:
55
+ print("read")
56
+ for line in f.readlines():
57
+ print(line)
58
+ if "Mean L2" in line:
59
+ mean_l2_4.append(float(line.split(":")[1].strip()))
60
+ elif "Mean Lpips" in line:
61
+ mean_lpips_4.append(float(line.split(":")[1].strip()))
62
+
63
+ mean_l2_5 = []
64
+ mean_lpips_5 = []
65
+ with open("./pl600.txt", "r") as f:
66
+ print("read")
67
+ for line in f.readlines():
68
+ print(line)
69
+ if "Mean L2" in line:
70
+ mean_l2_5.append(float(line.split(":")[1].strip()))
71
+ elif "Mean Lpips" in line:
72
+ mean_lpips_5.append(float(line.split(":")[1].strip()))
73
+
74
+ mean_l2_6 = []
75
+ mean_lpips_6 = []
76
+ with open("./100pl.txt", "r") as f:
77
+ print("read")
78
+ for line in f.readlines():
79
+ print(line)
80
+ if "Mean L2" in line:
81
+ mean_l2_6.append(float(line.split(":")[1].strip()))
82
+ elif "Mean Lpips" in line:
83
+ mean_lpips_6.append(float(line.split(":")[1].strip()))
84
+
85
+ mean_l2_7 = []
86
+ mean_lpips_7 = []
87
+ with open("./300pl.txt", "r") as f:
88
+ print("read")
89
+ for line in f.readlines():
90
+ print(line)
91
+ if "Mean L2" in line:
92
+ mean_l2_7.append(float(line.split(":")[1].strip()))
93
+ elif "Mean Lpips" in line:
94
+ mean_lpips_7.append(float(line.split(":")[1].strip()))
95
+
96
+ mean_l2_8 = []
97
+ mean_lpips_8 = []
98
+ with open("./1e-6_asym.txt", "r") as f:
99
+ print("read")
100
+ for line in f.readlines():
101
+ print(line)
102
+ if "Mean L2" in line:
103
+ mean_l2_8.append(float(line.split(":")[1].strip()))
104
+ elif "Mean Lpips" in line:
105
+ mean_lpips_8.append(float(line.split(":")[1].strip()))
106
+
107
+ # mean_l2_6 = []
108
+ # mean_lpips_6 = []
109
+ # with open("./100pl.txt", "r") as f:
110
+ # print("read")
111
+ # for line in f.readlines():
112
+ # print(line)
113
+ # if "Mean L2" in line:
114
+ # mean_l2_6.append(float(line.split(":")[1].strip()))
115
+ # elif "Mean Lpips" in line:
116
+ # mean_lpips_6.append(float(line.split(":")[1].strip()))
117
+
118
+ plt.figure(figsize=(10, 6))
119
+
120
+ # Plot Mean L2
121
+ # plt.plot(noises, mean_l2, label='Mean L2 1e-4', marker='o', linestyle='-', color='b')
122
+ #
123
+ # plt.plot(noises, mean_l2_3, label='Mean L2 2e-5', marker='o', linestyle='-', color='g')
124
+ #
125
+ # plt.plot(noises, mean_l2_2, label='Mean L2 1e-5', marker='o', linestyle='-', color='r')
126
+
127
+
128
+ do = 100
129
+ mean_lpips = mean_lpips[0:do]
130
+ mean_lpips_2 = mean_lpips_2[0:do]
131
+ mean_lpips_3 = mean_lpips_3[0:do]
132
+ mean_lpips_4 = mean_lpips_4[0:do]
133
+ mean_lpips_5 = mean_lpips_5[0:do]
134
+ mean_lpips_6 = mean_lpips_6[0:do]
135
+ mean_lpips_7 = mean_lpips_7[0:do]
136
+ mean_lpips_8 = mean_lpips_8[0:do]
137
+ noises = noises[0:do]
138
+
139
+ # Plot Mean Lpips
140
+ plt.plot(noises, mean_lpips, label='Mean Lpips 1e-4', marker='s', linestyle='--', color='r')
141
+ plt.plot(noises, mean_lpips_3, label='Mean Lpips 2e-5', marker='s', linestyle='--', color='b')
142
+ plt.plot(noises, mean_lpips_2, label='Mean Lpips 1e-5', marker='s', linestyle='--', color='g')
143
+ plt.plot(noises, mean_lpips_4, label='Mean Lpips 1e-6', marker='s', linestyle='--', color='y')
144
+ plt.plot(noises, mean_lpips_8, label='Mean Lpips 1e-6asym', marker='s', linestyle='--')
145
+ # plt.plot(noises, mean_lpips_5, label='Mean Lpips PL600', marker='s', linestyle='--')
146
+ # plt.plot(noises, mean_lpips_6, label='Mean Lpips Pl100', marker='s', linestyle='--')
147
+ plt.plot(noises, mean_lpips_7, label='Mean Lpips Pl300', marker='s', linestyle='--')
148
+
149
+
150
+ # Labels and title
151
+ plt.xlabel('Noise Level')
152
+ plt.ylabel('Value')
153
+ plt.title('Mean L2 and Mean Lpips vs. Noise Level')
154
+
155
+ # Show grid
156
+ plt.grid(True)
157
+
158
+
159
+ # ax = plt.gca()
160
+ # ax.set_xlim([0,.6])
161
+ # ax.set_ylim([0,.6])
162
+ # ax.set_aspect('equal', adjustable='box')
163
+
164
+ # Add legend
165
+ plt.legend()
166
+
167
+ # Show the plot
168
+ plt.show()
169
+
f16c16/kl_test.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ key = jax.random.PRNGKey(0)
6
+ x = jax.random.normal(key, (2,32,32,4))
7
+ print(x.mean())
8
+ means = jnp.mean(x, axis = [1,2,3])
9
+ #So this gives us the means of each individual one, cool
10
+ print(means)
11
+
12
+ logvars = 0.0
13
+
14
+ print("square of means shit", jnp.square(means))
15
+ print(means)
16
+
17
+ kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(means) - jnp.exp(logvars),axis=tuple(range(1, means.ndim)))
18
+ print(kl_loss)
19
+ kl_loss = jnp.mean(kl_loss)
20
+
21
+ print(kl_loss)
22
+
23
+ print("x mean again", x.mean())
24
+ print(x)
25
+ print(jnp.square(x))
26
+
27
+ kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(x) - jnp.exp(logvars),axis=tuple(range(1, x.ndim)))
28
+ print(kl_loss)
29
+ kl_loss = jnp.mean(kl_loss)
30
+
31
+ print(kl_loss)
f16c16/latent_distances.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
12
+ import jax
13
+ import lpips
14
+
15
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
16
+ loss_fn_alex = loss_fn_alex.cuda()
17
+
18
+
19
+ import numpy as np
20
+ import flax.linen as nn
21
+ import jax.numpy as jnp
22
+ from absl import app, flags
23
+ from functools import partial
24
+ import numpy as np
25
+ import tqdm
26
+ import flax
27
+ import optax
28
+ import wandb
29
+ from ml_collections import config_flags
30
+ #import elements
31
+ import ml_collections
32
+ import tensorflow_datasets as tfds
33
+ import tensorflow as tf
34
+ tf.config.set_visible_devices([], "GPU")
35
+ tf.config.set_visible_devices([], "TPU")
36
+ import matplotlib.pyplot as plt
37
+ from typing import Any
38
+
39
+ from utils.train_state import TrainState, target_update
40
+ from utils.checkpoint import Checkpoint
41
+ from utils.fid import get_fid_network, fid_from_stats
42
+
43
+ from train import VQGANModel
44
+ from models.vqvae import VQVAE
45
+ from models.discriminator import Discriminator
46
+
47
+ from PIL import Image
48
+ import torch
49
+
50
+ delattr(flags.FLAGS, 'dataset_name')
51
+ delattr(flags.FLAGS, 'load_dir')
52
+ delattr(flags.FLAGS, 'batch_size')
53
+
54
+ FLAGS = flags.FLAGS
55
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
56
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
57
+
58
+
59
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
60
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
61
+
62
+ import gc
63
+
64
+ def main(_):
65
+ device_count = len(jax.local_devices())
66
+ global_device_count = jax.device_count()
67
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
68
+
69
+ def get_dataset(is_train):
70
+ if 'imagenet' in FLAGS.dataset_name:
71
+ def deserialization_fn(data):
72
+ image = data['image']
73
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
74
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
75
+ if 'imagenet256' in FLAGS.dataset_name:
76
+ image = tf.image.resize(image, (256, 256))
77
+ elif 'imagenet128' in FLAGS.dataset_name:
78
+ image = tf.image.resize(image, (128, 128))
79
+ else:
80
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
81
+ if is_train:
82
+ image = tf.image.random_flip_left_right(image)
83
+ image = tf.cast(image, tf.float32) / 255.0
84
+ return image
85
+
86
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
87
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
88
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
89
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
90
+ dataset = dataset.batch(local_batch_size)
91
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
92
+ dataset = tfds.as_numpy(dataset)
93
+ dataset = iter(dataset)
94
+ return dataset
95
+ else:
96
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
97
+
98
+ dataset = get_dataset(is_train=True)
99
+ dataset_valid = get_dataset(is_train=False)
100
+
101
+
102
+ # image = Image.open("osman.png")
103
+ # image = np.array(image) / 255.0
104
+ # print(image)
105
+ # image = jnp.array(image)
106
+ # image = jnp.expand_dims(image, 0)
107
+ # image = jnp.expand_dims(image, 0)
108
+
109
+ example_obs = next(dataset)[:1]
110
+
111
+ #Reconstruction loop
112
+ # image = model.reconstruction(image)
113
+ # image = image[0,0,:,:,:]
114
+ # image = (image * 255).astype(np.uint8)
115
+ # image = np.array(image)
116
+ # img = Image.fromarray(image)
117
+ # img.save("osman" + str(i) + ".png")
118
+
119
+
120
+ rng = jax.random.PRNGKey(FLAGS.seed)
121
+ rng, param_key = jax.random.split(rng)
122
+ print("Total devices", jax.local_devices()[0])
123
+
124
+
125
+ ###################################
126
+ # Creating Model and put on devices.
127
+ ###################################
128
+ FLAGS.model.image_channels = example_obs.shape[-1]
129
+ FLAGS.model.image_size = example_obs.shape[1]
130
+ vqvae_def = VQVAE(FLAGS.model, train=True)
131
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
132
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
133
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
134
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
135
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
136
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
137
+
138
+ discriminator_def = Discriminator(FLAGS.model)
139
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
140
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
141
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
142
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
143
+
144
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
145
+
146
+ assert FLAGS.load_dir is not None
147
+ cp = Checkpoint(FLAGS.load_dir)
148
+ model = cp.load_model(model)
149
+ print("Loaded model with step", model.vqvae.step)
150
+
151
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
152
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
153
+ #print(model.vqvae)
154
+
155
+
156
+ ####################################
157
+ # Noise stuff
158
+ ###################################
159
+
160
+ cpus = jax.devices("cpu")
161
+
162
+ #So there are a few ways to calculate PPL here
163
+ #We could take two images in image space
164
+ #Walk between them and check the LPIPS in the output space
165
+ #...actually that's basically it right?
166
+ #We could also do the walk in latent space, which is the same, but with ?? scaling
167
+
168
+ #Let's see if they are any different.
169
+
170
+
171
+ #We could also try taking a latent, going X/2 direction, and -X/2 direction, and seeing that.
172
+ i = 0
173
+ lpips_list = []
174
+ means = []
175
+ stds = []
176
+ for valid_images in dataset_valid:
177
+
178
+
179
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
180
+ #1, 2, 256, 256, 3
181
+ #Given our 2 images, we want to lerp between them...
182
+ #We want to lerp once to point t, and once to point t + eps
183
+ #And then we want to get the LPIPS between those two images
184
+ #And then we calculate LPIPS
185
+ #And then we divide by eps squared, and done.
186
+
187
+
188
+ reconstructed_images, decoded, std, latents = model.latent_distances(valid_images) # [devices, 8, 256, 256, 3]
189
+
190
+
191
+ means.append(latents.mean())
192
+ stds.append(latents.std())
193
+ # print("std", std.mean())
194
+ print("latent mean", latents.mean())
195
+ print("actual latent std", latents.std())
196
+
197
+ #Need to change images back to -1,1
198
+
199
+ reconstructed_images = reconstructed_images * 2 - 1
200
+ decoded = decoded * 2 -1
201
+
202
+ #1,2,256,256,3
203
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
204
+ decoded = jnp.swapaxes(decoded, 0, 4)
205
+
206
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
207
+ decoded = jnp.swapaxes(decoded, 0, 1)
208
+
209
+ reconstructed_images = jnp.squeeze(reconstructed_images)
210
+ decoded = jnp.squeeze(decoded)
211
+
212
+ #So here, we want to put them on CPU and delete the original
213
+
214
+
215
+ image_np = np.asarray(reconstructed_images)
216
+ image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
217
+
218
+ decoded_np = np.asarray(decoded)
219
+ decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
220
+
221
+
222
+
223
+ lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
224
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
225
+ lpips_cpu = lpips_cpu / (.0001 ** 2)
226
+
227
+ print(lpips_cpu)
228
+ lpips_list.append(lpips_cpu)
229
+
230
+
231
+ i += 1
232
+ #
233
+ if i == 500:
234
+ break
235
+
236
+
237
+ mean_lpips = jnp.mean(jnp.asarray(lpips_list))
238
+ print(mean_lpips)
239
+ print("mean of means", jnp.asarray(means).mean())
240
+ print("stds of means", jnp.asarray(means).std())
241
+ print("mean of stds", jnp.asarray(stds).mean())
242
+ print("std of stds", jnp.asarray(stds).std())
243
+
244
+
245
+ #actual ae sym
246
+ # mean of means 0.35234922
247
+ # stds of means 0.4036692
248
+ # mean of stds 2.6363409
249
+ # std of stds 0.30666474
250
+
251
+
252
+ #1e-6:
253
+ #mean of means -0.018107202
254
+ # stds of means 0.11694455
255
+ # mean of stds 1.0860059
256
+ # std of stds 0.09732369
257
+
258
+ #1e-5:
259
+ # mean of means 0.0065166513
260
+ # stds of means 0.06983645
261
+ # mean of stds 0.9855982
262
+ # std of stds 0.05810356
263
+
264
+ #1e-4:
265
+ # mean of means 0.0065882676
266
+ # stds of means 0.042861093
267
+ # mean of stds 0.7608507
268
+ # std of stds 0.05846726
269
+
270
+
271
+ #pl300
272
+ # mean of means 0.090131655
273
+ # stds of means 0.69894844
274
+ # mean of stds 5.5634923
275
+ # std of stds 0.6767279
276
+
277
+
278
+ #pl100
279
+ # mean of means 0.16227543
280
+ # stds of means 0.53616405
281
+ # mean of stds 4.4914503
282
+ # std of stds 0.6015057
283
+
284
+
285
+
286
+ #Maybe we want to do "std multiplied PPL"? smoo
287
+
288
+ #Grab the STD of the Lpips
289
+
290
+
291
+
292
+ if __name__ == '__main__':
293
+ app.run(main)
f16c16/make_samples.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+
12
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
13
+
14
+ import jax
15
+
16
+ import flax.linen as nn
17
+ import jax.numpy as jnp
18
+ from absl import app, flags
19
+ from functools import partial
20
+ import numpy as np
21
+ import tqdm
22
+ import flax
23
+ import optax
24
+ import wandb
25
+ from ml_collections import config_flags
26
+ #import elements
27
+ import ml_collections
28
+ import tensorflow_datasets as tfds
29
+ import tensorflow as tf
30
+ tf.config.set_visible_devices([], "GPU")
31
+ tf.config.set_visible_devices([], "TPU")
32
+ import matplotlib.pyplot as plt
33
+ from typing import Any
34
+
35
+ from utils.train_state import TrainState, target_update
36
+ from utils.checkpoint import Checkpoint
37
+ from utils.fid import get_fid_network, fid_from_stats
38
+
39
+ from train import VQGANModel
40
+ from models.vqvae import VQVAE
41
+ from models.discriminator import Discriminator
42
+
43
+ from PIL import Image
44
+
45
+
46
+ delattr(flags.FLAGS, 'dataset_name')
47
+ delattr(flags.FLAGS, 'load_dir')
48
+ delattr(flags.FLAGS, 'batch_size')
49
+
50
+ FLAGS = flags.FLAGS
51
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
52
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Documents/LiClipse Workspace/VAE/jax-vqvae-vqgan/7e-5_sdlike_sym/checkpoint.tmp", 'Load dir (if not None, load params from here).')
53
+ flags.DEFINE_integer('batch_size', 16, 'Total Batch size.')
54
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
55
+
56
+ def main(_):
57
+ device_count = len(jax.local_devices())
58
+ global_device_count = jax.device_count()
59
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
60
+
61
+ def get_dataset(is_train):
62
+ if 'imagenet' in FLAGS.dataset_name:
63
+ def deserialization_fn(data):
64
+ image = data['image']
65
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
66
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
67
+ if 'imagenet256' in FLAGS.dataset_name:
68
+ image = tf.image.resize(image, (256, 256))
69
+ elif 'imagenet128' in FLAGS.dataset_name:
70
+ image = tf.image.resize(image, (128, 128))
71
+ else:
72
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
73
+ if is_train:
74
+ image_flip = tf.image.flip_left_right(image)
75
+ image_flip = tf.cast(image_flip, tf.float32) / 255.0
76
+ image = tf.cast(image, tf.float32) / 255.0
77
+ return image, image_flip, data["label"]
78
+ image = tf.cast(image, tf.float32) / 255.0
79
+ return image
80
+
81
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
82
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
83
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
84
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
85
+ dataset = dataset.batch(local_batch_size)
86
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
87
+ dataset = tfds.as_numpy(dataset)
88
+ dataset = iter(dataset)
89
+ return dataset
90
+ else:
91
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
92
+
93
+ dataset = get_dataset(is_train=True)
94
+ dataset_valid = get_dataset(is_train=False)
95
+
96
+ example_obs = next(dataset)[0][:1]
97
+
98
+ get_fid_activations = get_fid_network()
99
+ truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
100
+
101
+ rng = jax.random.PRNGKey(FLAGS.seed)
102
+ rng, param_key = jax.random.split(rng)
103
+ print("Total devices", jax.local_devices()[0])
104
+
105
+
106
+ ###################################
107
+ # Creating Model and put on devices.
108
+ ###################################
109
+ FLAGS.model.image_channels = example_obs.shape[-1]
110
+ FLAGS.model.image_size = example_obs.shape[1]
111
+ vqvae_def = VQVAE(FLAGS.model, train=True)
112
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
113
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
114
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
115
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
116
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
117
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
118
+
119
+ discriminator_def = Discriminator(FLAGS.model)
120
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
121
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
122
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
123
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
124
+
125
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
126
+
127
+ assert FLAGS.load_dir is not None
128
+ cp = Checkpoint(FLAGS.load_dir)
129
+ model = cp.load_model(model)
130
+ print("Loaded model with step", model.vqvae.step)
131
+
132
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
133
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
134
+ #print(model.vqvae)
135
+
136
+
137
+ ####################################
138
+ # FID Evaluation.
139
+ ###################################
140
+
141
+ i = 0
142
+ for valid_images, image_flip, label in dataset:#dataset_valid:
143
+
144
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
145
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
146
+
147
+
148
+ #load up custom image
149
+ # image = Image.open("osman.png")
150
+ # image = np.array(image) / 255.0
151
+ # print(image)
152
+ # image = jnp.array(image)
153
+ # image = jnp.expand_dims(image, 0)
154
+ # image = jnp.expand_dims(image, 0)
155
+ #Try saving the image off the bat
156
+ # image_orig =
157
+
158
+
159
+ # image = model.reconstruction(image)
160
+ # image = image[0,0,:,:,:]
161
+ # image = (image * 255).astype(np.uint8)
162
+ # image = np.array(image)
163
+ # img = Image.fromarray(image)
164
+ # img.save("osman" + str(i) + ".png")
165
+ # exit()
166
+
167
+
168
+ #Whatever...
169
+ #top left mine
170
+ #Bottom right SD
171
+
172
+ # fig, axs = plt.subplots(2, 2, figsize=(30, 15))
173
+
174
+ # axs[0, 0].imshow(valid_images[0, 0], vmin=0, vmax=1)
175
+ # axs[1, 0].imshow(valid_reconstructed_images[0, 0], vmin=0, vmax=1)
176
+ # axs[0, 1].imshow
177
+
178
+ # plt.savefig("img.jpg")
179
+
180
+ image = valid_images[0,0,:,:,:]
181
+ image = (image * 255).astype(np.uint8)
182
+ img = Image.fromarray(image)
183
+ img.save("original" + str(i) + ".png")
184
+
185
+ image2 = valid_reconstructed_images[0,0,:,:,:]
186
+ image2 = (image2 * 255).astype(np.uint8)
187
+ image2 = np.array(image2)
188
+ image2 = Image.fromarray(image2)
189
+
190
+ image2.save("recon" + str(i) + ".png")
191
+
192
+
193
+ i += 1
194
+
195
+ if i == 6:
196
+ exit()
197
+
198
+
199
+
200
+
201
+
202
+ # images.append((valid_reconstructed_images*255).astype(np.uint8))
203
+
204
+ if __name__ == '__main__':
205
+ app.run(main)
f16c16/models/__pycache__/discriminator.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
f16c16/models/__pycache__/discriminator.cpython-312.pyc ADDED
Binary file (8.13 kB). View file
 
f16c16/models/__pycache__/vqvae.cpython-310.pyc ADDED
Binary file (14.7 kB). View file
 
f16c16/models/__pycache__/vqvae.cpython-312.pyc ADDED
Binary file (26.9 kB). View file
 
f16c16/models/back_model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import flax.linen as nn
3
+ import jax.numpy as jnp
4
+ import functools
5
+ import ml_collections
6
+ import jax
7
+
8
+ ###########################
9
+ ### Helper Modules
10
+ ### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
11
+ ###########################
12
+
13
+ def get_norm_layer(norm_type):
14
+ """Normalization layer."""
15
+ if norm_type == 'BN':
16
+ raise NotImplementedError
17
+ elif norm_type == 'LN':
18
+ norm_fn = functools.partial(nn.LayerNorm)
19
+ elif norm_type == 'GN':
20
+ norm_fn = functools.partial(nn.GroupNorm)
21
+ else:
22
+ raise NotImplementedError
23
+ return norm_fn
24
+
25
+
26
+ def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
27
+ pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
28
+ (1,) + window_shape + (1,),
29
+ (1,) + strides + (1,), padding)
30
+ pool_denom = jax.lax.reduce_window(
31
+ jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
32
+ (1,) + strides + (1,), padding)
33
+ return pool_sum / pool_denom
34
+
35
+ def upsample(x, factor=2):
36
+ n, h, w, c = x.shape
37
+ x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
38
+ return x
39
+
40
+ def dsample(x):
41
+ return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
42
+
43
+ def squared_euclidean_distance(a: jnp.ndarray,
44
+ b: jnp.ndarray,
45
+ b2: jnp.ndarray = None) -> jnp.ndarray:
46
+ """Computes the pairwise squared Euclidean distance.
47
+
48
+ Args:
49
+ a: float32: (n, d): An array of points.
50
+ b: float32: (m, d): An array of points.
51
+ b2: float32: (d, m): b square transpose.
52
+
53
+ Returns:
54
+ d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
55
+ a[i] and b[j].
56
+ """
57
+ if b2 is None:
58
+ b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
59
+ a2 = jnp.sum(a**2, axis=1, keepdims=True)
60
+ ab = jnp.matmul(a, b.T)
61
+ d = a2 - 2 * ab + b2
62
+ return d
63
+
64
+ def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
65
+ """Calculates the entropy loss. Affinity is the similarity/distance matrix."""
66
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
67
+ flat_affinity /= temperature
68
+ probs = jax.nn.softmax(flat_affinity, axis=-1)
69
+ log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
70
+ if loss_type == "softmax":
71
+ target_probs = probs
72
+ elif loss_type == "argmax":
73
+ codes = jnp.argmax(flat_affinity, axis=-1)
74
+ onehots = jax.nn.one_hot(
75
+ codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
76
+ onehots = probs - jax.lax.stop_gradient(probs - onehots)
77
+ target_probs = onehots
78
+ else:
79
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
80
+ avg_probs = jnp.mean(target_probs, axis=0)
81
+ avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
82
+ sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
83
+ loss = sample_entropy - avg_entropy
84
+ return loss
85
+
86
+ def sg(x):
87
+ return jax.lax.stop_gradient(x)
88
+
89
+
90
+
91
+
92
+ ###########################
93
+ ### Modules
94
+ ###########################
95
+
96
+ class ResBlock(nn.Module):
97
+ """Basic Residual Block."""
98
+ filters: int
99
+ norm_fn: Any
100
+ activation_fn: Any
101
+
102
+ @nn.compact
103
+ def __call__(self, x):
104
+ input_dim = x.shape[-1]
105
+ residual = x
106
+ x = self.norm_fn()(x)
107
+ x = self.activation_fn(x)
108
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
109
+ x = self.norm_fn()(x)
110
+ x = self.activation_fn(x)
111
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
112
+
113
+ if input_dim != self.filters:
114
+ residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
115
+ return x + residual
116
+
117
+ class Encoder(nn.Module):
118
+ """From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
119
+ config: ml_collections.ConfigDict
120
+
121
+ def setup(self):
122
+ self.filters = self.config.filters
123
+ self.num_res_blocks = self.config.num_res_blocks
124
+ self.channel_multipliers = self.config.channel_multipliers
125
+ self.embedding_dim = self.config.embedding_dim
126
+ self.norm_type = self.config.norm_type
127
+ self.activation_fn = nn.swish
128
+
129
+ @nn.compact
130
+ def __call__(self, x):
131
+ print("Initializing encoder.")
132
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
133
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
134
+ print("Incoming encoder shape", x.shape)
135
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
136
+ print('Encoder layer', x.shape)
137
+ num_blocks = len(self.channel_multipliers)
138
+ for i in range(num_blocks):
139
+ filters = self.filters * self.channel_multipliers[i]
140
+ for _ in range(self.num_res_blocks):
141
+ x = ResBlock(filters, **block_args)(x)
142
+ if i < num_blocks - 1:
143
+ x = dsample(x)
144
+ print('Encoder layer', x.shape)
145
+
146
+ for _ in range(self.num_res_blocks):
147
+ x = ResBlock(filters, **block_args)(x)
148
+ print('Encoder layer', x.shape)
149
+ x = norm_fn()(x)
150
+ x = self.activation_fn(x)
151
+ last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
152
+ x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
153
+ print("Before final", x.shape)
154
+ x = nn.Conv(8, kernel_size=(1,1))(x)
155
+ print("Final embeddings are size", x.shape)
156
+ return x
157
+
158
+ class Decoder(nn.Module):
159
+ """From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
160
+
161
+ config: ml_collections.ConfigDict
162
+
163
+ def setup(self):
164
+ self.filters = self.config.filters
165
+ self.num_res_blocks = self.config.num_res_blocks
166
+ self.channel_multipliers = self.config.channel_multipliers
167
+ self.norm_type = self.config.norm_type
168
+ self.image_channels = self.config.image_channels
169
+ self.activation_fn = nn.swish
170
+
171
+ @nn.compact
172
+ def __call__(self, x):
173
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
174
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
175
+ num_blocks = len(self.channel_multipliers)
176
+ filters = self.filters * self.channel_multipliers[-1]
177
+ print("Decoder incoming shape", x.shape)
178
+
179
+ #We don't need to do anything here because it'll put it back to 512
180
+
181
+ x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
182
+ print("Decoder input", x.shape)
183
+
184
+ for _ in range(self.num_res_blocks):
185
+ x = ResBlock(filters, **block_args)(x)
186
+ print('Decoder layer', x.shape)
187
+ for i in reversed(range(num_blocks)):
188
+ filters = self.filters * self.channel_multipliers[i]
189
+ for _ in range(self.num_res_blocks):
190
+ x = ResBlock(filters, **block_args)(x)
191
+ if i > 0:
192
+ x = upsample(x, 2)
193
+ x = nn.Conv(filters, kernel_size=(3, 3))(x)
194
+ print('Decoder layer', x.shape)
195
+ x = norm_fn()(x)
196
+ x = self.activation_fn(x)
197
+ x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
198
+ return x
199
+
200
+ class VectorQuantizer(nn.Module):
201
+ """Basic vector quantizer."""
202
+ config: ml_collections.ConfigDict
203
+ train: bool
204
+
205
+ @nn.compact
206
+ def __call__(self, x):
207
+ codebook_size = self.config.codebook_size
208
+ emb_dim = x.shape[-1]
209
+ codebook = self.param(
210
+ "codebook",
211
+ jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
212
+ (codebook_size, emb_dim))
213
+ codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
214
+ distances = jnp.reshape(
215
+ squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
216
+ x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
217
+ encoding_indices = jnp.argmin(distances, axis=-1)
218
+ encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
219
+ quantized = self.quantize(encoding_onehot)
220
+ result_dict = dict()
221
+ if self.train:
222
+ e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
223
+ q_latent_loss = jnp.mean((quantized - sg(x))**2)
224
+ entropy_loss = 0.0
225
+ if self.config.entropy_loss_ratio != 0:
226
+ entropy_loss = entropy_loss_fn(
227
+ -distances,
228
+ loss_type=self.config.entropy_loss_type,
229
+ temperature=self.config.entropy_temperature
230
+ ) * self.config.entropy_loss_ratio
231
+ e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
232
+ q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
233
+ entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
234
+ loss = e_latent_loss + q_latent_loss + entropy_loss
235
+ result_dict = dict(
236
+ quantizer_loss=loss,
237
+ e_latent_loss=e_latent_loss,
238
+ q_latent_loss=q_latent_loss,
239
+ entropy_loss=entropy_loss)
240
+ quantized = x + jax.lax.stop_gradient(quantized - x)
241
+
242
+ result_dict.update({
243
+ "z_ids": encoding_indices,
244
+ })
245
+ return quantized, result_dict
246
+
247
+ def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
248
+ codebook = jnp.asarray(self.variables["params"]["codebook"])
249
+ return jnp.dot(encoding_onehot, codebook)
250
+
251
+ def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
252
+ codebook = self.variables["params"]["codebook"]
253
+ return jnp.take(codebook, ids, axis=0)
254
+
255
+ class KLQuantizer(nn.Module):
256
+ config: ml_collections.ConfigDict
257
+ train: bool
258
+
259
+ @nn.compact
260
+ def __call__(self, x):
261
+ emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
262
+ means = x[..., :emb_dim]
263
+ logvars = x[..., emb_dim:]
264
+ if not self.train:
265
+ result_dict = dict()
266
+ return means, result_dict
267
+ else:
268
+ noise = jax.random.normal(self.make_rng("noise"), means.shape)
269
+ stds = jnp.exp(0.5 * logvars)
270
+ z = means + stds * noise
271
+ kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
272
+ result_dict = dict(quantizer_loss=kl_loss)
273
+ return z, result_dict
274
+
275
+ class FSQuantizer(nn.Module):
276
+ config: ml_collections.ConfigDict
277
+ train: bool
278
+
279
+ @nn.compact
280
+ def __call__(self, x):
281
+ assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
282
+ z = jnp.tanh(x) # [-1, 1]
283
+ z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
284
+ zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
285
+ quantized = z + jax.lax.stop_gradient(zhat - z)
286
+ quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
287
+ result_dict = dict()
288
+
289
+ # Diagnostics for codebook usage.
290
+ zhat_scaled = zhat + self.config['fsq_levels'] // 2
291
+ basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
292
+ idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
293
+ idx_flat = idx.reshape(-1)
294
+ usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
295
+
296
+ result_dict.update({
297
+ "z_ids": zhat,
298
+ 'usage': usage
299
+ })
300
+ return quantized, result_dict
301
+
302
+ class VQVAE(nn.Module):
303
+ """VQVAE model."""
304
+ config: ml_collections.ConfigDict
305
+ train: bool
306
+
307
+ def setup(self):
308
+ """VQVAE setup."""
309
+ if self.config['quantizer_type'] == 'vq':
310
+ self.quantizer = VectorQuantizer(config=self.config, train=self.train)
311
+ elif self.config['quantizer_type'] == 'kl':
312
+ self.quantizer = KLQuantizer(config=self.config, train=self.train)
313
+ elif self.config['quantizer_type'] == 'fsq':
314
+ self.quantizer = FSQuantizer(config=self.config, train=self.train)
315
+ self.encoder = Encoder(config=self.config)
316
+ self.decoder = Decoder(config=self.config)
317
+
318
+ def encode(self, image):
319
+ encoded_feature = self.encoder(image)
320
+ quantized, result_dict = self.quantizer(encoded_feature)
321
+ print("After quant", quantized.shape)
322
+ return quantized, result_dict
323
+
324
+ def decode(self, z_vectors):
325
+ print("z_vectors shape", z_vectors.shape)
326
+ reconstructed = self.decoder(z_vectors)
327
+ return reconstructed
328
+
329
+ def decode_from_indices(self, z_ids):
330
+ z_vectors = self.quantizer.decode_ids(z_ids)
331
+ reconstructed_image = self.decode(z_vectors)
332
+ return reconstructed_image
333
+
334
+ def encode_to_indices(self, image):
335
+ encoded_feature = self.encoder(image)
336
+ _, result_dict = self.quantizer(encoded_feature)
337
+ ids = result_dict["z_ids"]
338
+ return ids
339
+
340
+ def __call__(self, input_dict):
341
+ quantized, result_dict = self.encode(input_dict)
342
+ outputs = self.decoder(quantized)
343
+ return outputs, result_dict
f16c16/models/discriminator.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Discriminator from StyleGAN. https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py"""
2
+
3
+ import functools
4
+ import math
5
+ from typing import Any, Tuple
6
+ import flax.linen as nn
7
+ from flax.linen.initializers import xavier_uniform
8
+ import jax
9
+ from jax import lax
10
+ import jax.numpy as jnp
11
+ import ml_collections
12
+
13
+ default_kernel_init = xavier_uniform()
14
+
15
+ def _conv_dimension_numbers(input_shape):
16
+ """Computes the dimension numbers based on the input shape."""
17
+ ndim = len(input_shape)
18
+ lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
19
+ rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
20
+ out_spec = lhs_spec
21
+ return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
22
+
23
+
24
+ class BlurPool2D(nn.Module):
25
+ """A layer to do channel-wise blurring + subsampling on 2D inputs.
26
+
27
+ Reference:
28
+ Zhang et al. Making Convolutional Networks Shift-Invariant Again.
29
+ https://arxiv.org/pdf/1904.11486.pdf.
30
+ """
31
+ filter_size: int = 4
32
+ strides: Tuple[int, int] = (2, 2)
33
+ padding: str = 'SAME'
34
+
35
+ def setup(self):
36
+ if self.filter_size == 3:
37
+ self.filter = [1., 2., 1.]
38
+ elif self.filter_size == 4:
39
+ self.filter = [1., 3., 3., 1.]
40
+ elif self.filter_size == 5:
41
+ self.filter = [1., 4., 6., 4., 1.]
42
+ elif self.filter_size == 6:
43
+ self.filter = [1., 5., 10., 10., 5., 1.]
44
+ elif self.filter_size == 7:
45
+ self.filter = [1., 6., 15., 20., 15., 6., 1.]
46
+ else:
47
+ raise ValueError('Only filter_size of 3, 4, 5, 6 or 7 is supported.')
48
+
49
+ self.filter = jnp.array(self.filter, dtype=jnp.float32)
50
+ self.filter = self.filter[:, None] * self.filter[None, :]
51
+ with jax.default_matmul_precision('float32'):
52
+ self.filter /= jnp.sum(self.filter)
53
+ self.filter = jnp.reshape(
54
+ self.filter, [self.filter.shape[0], self.filter.shape[1], 1, 1])
55
+
56
+ @nn.compact
57
+ def __call__(self, inputs):
58
+ channel_num = inputs.shape[-1]
59
+ dimension_numbers = _conv_dimension_numbers(inputs.shape)
60
+ depthwise_filter = jnp.tile(self.filter, [1, 1, 1, channel_num])
61
+ with jax.default_matmul_precision('float32'):
62
+ outputs = lax.conv_general_dilated(inputs, depthwise_filter, self.strides,
63
+ self.padding, feature_group_count=channel_num, dimension_numbers=dimension_numbers)
64
+ return outputs
65
+
66
+ class ResBlock(nn.Module):
67
+ """StyleGAN ResBlock for D.
68
+
69
+ https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py#L618
70
+ """
71
+ filters: int
72
+ activation_fn: Any
73
+
74
+ @nn.compact
75
+ def __call__(self, x):
76
+ input_dim = x.shape[-1]
77
+ residual = x
78
+ x = nn.Conv(input_dim, (3, 3), kernel_init=default_kernel_init)(x)
79
+ x = self.activation_fn(x)
80
+ x = BlurPool2D(filter_size=4)(x)
81
+ residual = BlurPool2D(filter_size=4)(residual)
82
+ residual = nn.Conv(self.filters, (1, 1), use_bias=False, kernel_init=default_kernel_init)(residual)
83
+ x = nn.Conv(self.filters, (3, 3), kernel_init=default_kernel_init)(x)
84
+ x = self.activation_fn(x)
85
+ out = (residual + x) / math.sqrt(2)
86
+ return out
87
+
88
+
89
+ class Discriminator(nn.Module):
90
+ """StyleGAN Discriminator."""
91
+ config: ml_collections.ConfigDict
92
+
93
+ def setup(self):
94
+ self.input_size = self.config.image_size
95
+ self.activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
96
+ self.channel_multiplier = 1
97
+
98
+ @nn.compact
99
+ def __call__(self, x):
100
+ filters = {
101
+ 4: 512,
102
+ 8: 512,
103
+ 16: 512,
104
+ 32: 512,
105
+ 64: 256 * self.channel_multiplier,
106
+ 128: 128 * self.channel_multiplier,
107
+ 256: 64 * self.channel_multiplier,
108
+ 512: 32 * self.channel_multiplier,
109
+ 1024: 16 * self.channel_multiplier,
110
+ }
111
+ x = nn.Conv(filters[self.input_size], (3, 3), kernel_init=default_kernel_init)(x)
112
+ x = self.activation_fn(x)
113
+ log_size = int(math.log2(self.input_size))
114
+ for i in range(log_size, 2, -1):
115
+ x = ResBlock(filters[2**(i - 1)], self.activation_fn)(x)
116
+ print("Disc shape", x.shape)
117
+ x = nn.Conv(filters[4], (3, 3), kernel_init=default_kernel_init)(x)
118
+ x = self.activation_fn(x)
119
+ x = x.reshape((x.shape[0], -1))
120
+ x = nn.Dense(filters[4], kernel_init=default_kernel_init)(x)
121
+ x = self.activation_fn(x)
122
+ x = nn.Dense(1, kernel_init=default_kernel_init)(x)
123
+ return x
f16c16/models/vqvae.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import flax.linen as nn
3
+ import jax.numpy as jnp
4
+ import functools
5
+ import ml_collections
6
+ import jax
7
+
8
+ from flax.linen import initializers
9
+
10
+ ###########################
11
+ ### Helper Modules
12
+ ### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
13
+ ###########################
14
+
15
+ def get_norm_layer(norm_type):
16
+ """Normalization layer."""
17
+ if norm_type == 'BN':
18
+ raise NotImplementedError
19
+ elif norm_type == 'LN':
20
+ norm_fn = functools.partial(nn.LayerNorm)
21
+ elif norm_type == 'GN':
22
+ norm_fn = functools.partial(nn.GroupNorm)
23
+ else:
24
+ raise NotImplementedError
25
+ return norm_fn
26
+
27
+
28
+ def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
29
+ pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
30
+ (1,) + window_shape + (1,),
31
+ (1,) + strides + (1,), padding)
32
+ pool_denom = jax.lax.reduce_window(
33
+ jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
34
+ (1,) + strides + (1,), padding)
35
+ return pool_sum / pool_denom
36
+
37
+ def upsample(x, factor=2):
38
+ n, h, w, c = x.shape
39
+ x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
40
+ return x
41
+
42
+ def dsample(x):
43
+ return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
44
+
45
+ def squared_euclidean_distance(a: jnp.ndarray,
46
+ b: jnp.ndarray,
47
+ b2: jnp.ndarray = None) -> jnp.ndarray:
48
+ """Computes the pairwise squared Euclidean distance.
49
+
50
+ Args:
51
+ a: float32: (n, d): An array of points.
52
+ b: float32: (m, d): An array of points.
53
+ b2: float32: (d, m): b square transpose.
54
+
55
+ Returns:
56
+ d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
57
+ a[i] and b[j].
58
+ """
59
+ if b2 is None:
60
+ b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
61
+ a2 = jnp.sum(a**2, axis=1, keepdims=True)
62
+ ab = jnp.matmul(a, b.T)
63
+ d = a2 - 2 * ab + b2
64
+ return d
65
+
66
+ def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
67
+ """Calculates the entropy loss. Affinity is the similarity/distance matrix."""
68
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
69
+ flat_affinity /= temperature
70
+ probs = jax.nn.softmax(flat_affinity, axis=-1)
71
+ log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
72
+ if loss_type == "softmax":
73
+ target_probs = probs
74
+ elif loss_type == "argmax":
75
+ codes = jnp.argmax(flat_affinity, axis=-1)
76
+ onehots = jax.nn.one_hot(
77
+ codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
78
+ onehots = probs - jax.lax.stop_gradient(probs - onehots)
79
+ target_probs = onehots
80
+ else:
81
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
82
+ avg_probs = jnp.mean(target_probs, axis=0)
83
+ avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
84
+ sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
85
+ loss = sample_entropy - avg_entropy
86
+ return loss
87
+
88
+ def sg(x):
89
+ return jax.lax.stop_gradient(x)
90
+
91
+
92
+
93
+
94
+ ###########################
95
+ ### Modules
96
+ ###########################
97
+
98
+ class ResBlock(nn.Module):
99
+ """Basic Residual Block."""
100
+ filters: int
101
+ norm_fn: Any
102
+ activation_fn: Any
103
+
104
+ @nn.compact
105
+ def __call__(self, x):
106
+ input_dim = x.shape[-1]
107
+ residual = x
108
+ x = self.norm_fn()(x)
109
+ x = self.activation_fn(x)
110
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
111
+ x = self.norm_fn()(x)
112
+ x = self.activation_fn(x)
113
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
114
+
115
+ if input_dim != self.filters:#Basically if input doesn't match output, use a skip
116
+ residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
117
+ return x + residual
118
+
119
+ class Fourier(nn.Module):
120
+
121
+ def setup(self):
122
+
123
+ #Our input comes in as 3... after we convert to 512, maybe instead we convert to 256, and then do this?
124
+ self.weight = jax.random.normal(self.make_rng("noise"), means.shape)
125
+
126
+ @nn.compact
127
+ def __call__(self, f):
128
+ #this is probabl ycahnnels lastz
129
+ f = 2 * math.pi * input @ self.weight.T
130
+ return torch.cat([f.cos(), f.sin()], dim = -1)
131
+
132
+ from einops import rearrange
133
+ class LinearEncoder(nn.Module):
134
+
135
+ config: ml_collections.ConfigDict
136
+
137
+ #So in this setup, we don't carea bout anything
138
+ @nn.compact
139
+ def __call__(self, x):
140
+ print("init encoder")
141
+ print("x shape", x.shape)
142
+ x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=8, b2=8)
143
+ x = nn.Dense(4)(x)#We just put to 4 for now
144
+ print(x.shape)
145
+ return x
146
+ #k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
147
+ #1x1 conv, uplift from 3 to like..... 64
148
+ #That gives us 256x256x64
149
+ #Then pixelshuffle to
150
+
151
+
152
+ class Encoder(nn.Module):
153
+ """From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
154
+ config: ml_collections.ConfigDict
155
+
156
+ def setup(self):
157
+ self.filters = self.config.filters#filters is the original setup
158
+ self.num_res_blocks = self.config.num_res_blocks
159
+ self.channel_multipliers = self.config.channel_multipliers
160
+ self.embedding_dim = self.config.embedding_dim
161
+ self.norm_type = self.config.norm_type
162
+ self.activation_fn = nn.swish
163
+ self.kernel_init = initializers.he_normal()
164
+
165
+ @nn.compact
166
+ def __call__(self, x):
167
+ print("Initializing encoder.")
168
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
169
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
170
+ print("Incoming encoder shape", x.shape)
171
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
172
+ print('Encoder layer', x.shape)
173
+ num_blocks = len(self.channel_multipliers)
174
+
175
+ #The way SD works, is it does 2x resnet, not changing anything, then downsample
176
+ #It does this 3 times, leading to 8x downsample
177
+ #Then it has an extra resnet block, and THEN from 512 to 8 / 4
178
+
179
+ for i in range(num_blocks):
180
+ filters = self.filters * self.channel_multipliers[i]
181
+ for _ in range(self.num_res_blocks):
182
+ x = ResBlock(filters, **block_args)(x)
183
+ if i < num_blocks - 1:#For each block *except end* do downsample
184
+ print("doing downsample")
185
+ x = dsample(x)
186
+ print('Encoder layer', x.shape)
187
+
188
+ #After we are done downsampling, we do the 2 resnet, and down below here, we have the 2 midblock?
189
+
190
+ for _ in range(self.num_res_blocks):
191
+ x = ResBlock(filters, **block_args)(x)
192
+ print('Encoder layer final', x.shape)
193
+
194
+ x = norm_fn()(x)
195
+ x = self.activation_fn(x)
196
+ last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
197
+ x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
198
+ print("Final embeddings are size", x.shape)
199
+ return x
200
+
201
+ class Decoder(nn.Module):
202
+ """From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
203
+
204
+ config: ml_collections.ConfigDict
205
+
206
+ def setup(self):
207
+ self.filters = self.config.filters
208
+ self.num_res_blocks = self.config.num_res_blocks
209
+ self.channel_multipliers = self.config.channel_multipliers
210
+ self.norm_type = self.config.norm_type
211
+ self.image_channels = self.config.image_channels
212
+ self.activation_fn = nn.swish
213
+ self.kernel_init = initializers.he_normal()
214
+
215
+ @nn.compact
216
+ def __call__(self, x):
217
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
218
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
219
+ num_blocks = len(self.channel_multipliers)
220
+ filters = self.filters * self.channel_multipliers[-1]
221
+ print("Decoder incoming shape", x.shape)
222
+
223
+ #We don't need to do anything here because it'll put it back to 512
224
+
225
+ x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
226
+ print("Decoder input", x.shape)
227
+
228
+
229
+ #This is the mid block
230
+ for _ in range(self.num_res_blocks):
231
+ x = ResBlock(filters, **block_args)(x)
232
+ print('Mid Block Decoder layer', x.shape)
233
+
234
+ #First two SET of blocks is just 3 resnet, no channel changes, we are already at 4x = 512
235
+
236
+ for i in reversed(range(num_blocks)):
237
+ filters = self.filters * self.channel_multipliers[i]
238
+ for _ in range(self.num_res_blocks):#sym
239
+ x = ResBlock(filters, **block_args)(x)
240
+ if i > 0:
241
+ x = upsample(x, 2)
242
+ x = nn.Conv(filters, kernel_size=(3, 3))(x)
243
+ print('Decoder layer', x.shape)
244
+ x = norm_fn()(x)
245
+ x = self.activation_fn(x)
246
+ x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
247
+ return x
248
+
249
+ class VectorQuantizer(nn.Module):
250
+ """Basic vector quantizer."""
251
+ config: ml_collections.ConfigDict
252
+ train: bool
253
+
254
+ @nn.compact
255
+ def __call__(self, x):
256
+ codebook_size = self.config.codebook_size
257
+ emb_dim = x.shape[-1]
258
+ codebook = self.param(
259
+ "codebook",
260
+ jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
261
+ (codebook_size, emb_dim))
262
+ codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
263
+ distances = jnp.reshape(
264
+ squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
265
+ x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
266
+ encoding_indices = jnp.argmin(distances, axis=-1)
267
+ encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
268
+ quantized = self.quantize(encoding_onehot)
269
+ result_dict = dict()
270
+ if self.train:
271
+ e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
272
+ q_latent_loss = jnp.mean((quantized - sg(x))**2)
273
+ entropy_loss = 0.0
274
+ if self.config.entropy_loss_ratio != 0:
275
+ entropy_loss = entropy_loss_fn(
276
+ -distances,
277
+ loss_type=self.config.entropy_loss_type,
278
+ temperature=self.config.entropy_temperature
279
+ ) * self.config.entropy_loss_ratio
280
+ e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
281
+ q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
282
+ entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
283
+ loss = e_latent_loss + q_latent_loss + entropy_loss
284
+ result_dict = dict(
285
+ quantizer_loss=loss,
286
+ e_latent_loss=e_latent_loss,
287
+ q_latent_loss=q_latent_loss,
288
+ entropy_loss=entropy_loss)
289
+ quantized = x + jax.lax.stop_gradient(quantized - x)
290
+
291
+ result_dict.update({
292
+ "z_ids": encoding_indices,
293
+ })
294
+ return quantized, result_dict
295
+
296
+ def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
297
+ codebook = jnp.asarray(self.variables["params"]["codebook"])
298
+ return jnp.dot(encoding_onehot, codebook)
299
+
300
+ def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
301
+ codebook = self.variables["params"]["codebook"]
302
+ return jnp.take(codebook, ids, axis=0)
303
+
304
+ class KLQuantizer(nn.Module):
305
+ config: ml_collections.ConfigDict
306
+ train: bool
307
+
308
+ @nn.compact
309
+ def __call__(self, x):
310
+ emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
311
+ means = x[..., :emb_dim]
312
+ logvars = x[..., emb_dim:]
313
+ if not self.train:
314
+ result_dict = dict()
315
+ result_dict["std"] = jnp.exp(0.5 * logvars)
316
+ return means, result_dict
317
+ else:
318
+ noise = jax.random.normal(self.make_rng("noise"), means.shape)
319
+ stds = jnp.exp(0.5 * logvars)
320
+ z = means + stds * noise
321
+ #kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
322
+
323
+ #New kl
324
+ kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(means) - jnp.exp(logvars),axis=tuple(range(1, means.ndim)))
325
+ kl_loss = jnp.mean(kl_loss)
326
+
327
+ result_dict = dict(quantizer_loss=kl_loss)
328
+ result_dict["std"] = jnp.exp(0.5 * logvars)
329
+ return z, result_dict
330
+
331
+ class AEQuantizer(nn.Module): #cooking
332
+ config: ml_collections.ConfigDict
333
+ train: bool
334
+
335
+ @nn.compact
336
+ def __call__(self, x):
337
+ result_dict = dict()
338
+ result_dict["std"] = 0.0
339
+ return x, result_dict
340
+
341
+ import jax
342
+ import jax.numpy as jnp
343
+ from jax import random
344
+
345
+ def imq_kernel(X: jnp.ndarray, Y: jnp.ndarray, h_dim: int):
346
+ batch_size = X.shape[0]
347
+
348
+ norms_x = jnp.sum(X**2, axis=1, keepdims=True) # batch_size x 1
349
+ prods_x = jnp.dot(X, X.T) # batch_size x batch_size
350
+ dists_x = norms_x + norms_x.T - 2 * prods_x
351
+
352
+ norms_y = jnp.sum(Y**2, axis=1, keepdims=True) # batch_size x 1
353
+ prods_y = jnp.dot(Y, Y.T) # batch_size x batch_size
354
+ dists_y = norms_y + norms_y.T - 2 * prods_y
355
+
356
+ dot_prd = jnp.dot(X, Y.T)
357
+ dists_c = norms_x + norms_y.T - 2 * dot_prd
358
+
359
+ stats = 0
360
+ for scale in [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]:
361
+ C = 2 * h_dim * 1.0 * scale
362
+ res1 = C / (C + dists_x)
363
+ res1 += C / (C + dists_y)
364
+
365
+ res1 = (1 - jnp.eye(batch_size)) * res1
366
+ res1 = jnp.sum(res1) / (batch_size - 1)
367
+
368
+ res2 = C / (C + dists_c)
369
+ res2 = jnp.sum(res2) * 2.0 / batch_size
370
+ stats += res1 - res2
371
+
372
+ return stats
373
+
374
+ class MMDQuantizer(nn.Module): #cooking
375
+ config: ml_collections.ConfigDict
376
+ train: bool
377
+
378
+ @nn.compact
379
+ def __call__(self, x):
380
+ if not self.train:
381
+ result_dict = dict()
382
+ return x, result_dict
383
+ else:
384
+ print("mmd quantizer")
385
+ batch_size, height, width, latent_channels = x.shape
386
+ z_flat = x.reshape(batch_size, -1)
387
+ print(z_flat.shape)
388
+ z_fake_flat = jax.random.normal(self.make_rng("noise"), z_flat.shape) * self.config["MMD_weight"]
389
+ print(z_fake_flat.shape)
390
+ mmd_loss = imq_kernel(z_flat, z_fake_flat, z_flat.shape[1])
391
+ print(mmd_loss.shape)
392
+ print(mmd_loss)
393
+ result_dict = dict(quantizer_loss=mmd_loss)
394
+ return x, result_dict
395
+
396
+
397
+
398
+ class KLQuantizerTwo(nn.Module):
399
+ config: ml_collections.ConfigDict
400
+ train: bool
401
+
402
+ @nn.compact
403
+ def __call__(self, x):
404
+ #emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
405
+ #means = x[..., :emb_dim]
406
+ #logvars = x[..., emb_dim:]
407
+
408
+ #Wwe actually wanna do mean and STD on the batch axis?
409
+
410
+
411
+ #we start as b hw 8, go to b hw 4, with mean and std over those.
412
+
413
+ if not self.train:
414
+ result_dict = dict()
415
+ result_dict["std"] = 1.0
416
+ return x, result_dict
417
+ else:
418
+ stds = jnp.std(x, axis = [1,2,3])
419
+
420
+ noise = jax.random.normal(self.make_rng("noise"), x.shape)
421
+
422
+ logvars = .5 * jnp.log(stds)
423
+ logvars = logvars.reshape(-1,1,1,1)
424
+ if True:#This is true for special KL where we set sigma to 1 manually
425
+ logvars = 0.0
426
+
427
+
428
+ if False:#dinossl
429
+ x_2 = x.reshape(x.shape[0], -1, x.shape[-1])#Linear with channel size
430
+ x_2 = jnp.swapaxes(x_2,0,1)
431
+ #then/ get the covariance
432
+ cov = jnp.swapaxes(x_2,1,2) @ x_2 / x.shape[0]
433
+ #Not sure about this, we also have regular cov
434
+ I_d = jnp.identity(x.shape[-1])
435
+ R_eps = jnp.log(jnp.linalg.det(jnp.expand_dims(I_d, axis = 0) + x.shape[-1]/ (.0001 ** 2) * cov))
436
+
437
+ #So something here *does* depend on the -1 shape, but I need to math it out.
438
+ kl_loss = R_eps.mean()
439
+
440
+
441
+ #This is the denoising version
442
+ kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(x) - jnp.exp(logvars),axis=tuple(range(1, x.ndim)))
443
+ kl_loss = jnp.mean(kl_loss)
444
+
445
+ result_dict = dict(quantizer_loss=kl_loss)
446
+ result_dict["std"] = 1.0
447
+
448
+ #For proper kl two, we need to return noise + mean.
449
+ return x + noise, result_dict
450
+
451
+
452
+ class FSQuantizer(nn.Module):
453
+ config: ml_collections.ConfigDict
454
+ train: bool
455
+
456
+ @nn.compact
457
+ def __call__(self, x):
458
+ assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
459
+ z = jnp.tanh(x) # [-1, 1]
460
+ z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
461
+ zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
462
+ quantized = z + jax.lax.stop_gradient(zhat - z)
463
+ quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
464
+ result_dict = dict()
465
+
466
+ # Diagnostics for codebook usage.
467
+ zhat_scaled = zhat + self.config['fsq_levels'] // 2
468
+ basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
469
+ idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
470
+ idx_flat = idx.reshape(-1)
471
+ usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
472
+
473
+ result_dict.update({
474
+ "z_ids": zhat,
475
+ 'usage': usage
476
+ })
477
+ return quantized, result_dict
478
+
479
+ class VQVAE(nn.Module):
480
+ """VQVAE model."""
481
+ config: ml_collections.ConfigDict
482
+ train: bool
483
+
484
+ def setup(self):
485
+ """VQVAE setup."""
486
+ if self.config['quantizer_type'] == 'vq':
487
+ self.quantizer = VectorQuantizer(config=self.config, train=self.train)
488
+ elif self.config['quantizer_type'] == 'kl':
489
+ self.quantizer = KLQuantizer(config=self.config, train=self.train)
490
+ elif self.config['quantizer_type'] == 'fsq':
491
+ self.quantizer = FSQuantizer(config=self.config, train=self.train)
492
+ elif self.config['quantizer_type'] == 'ae':
493
+ self.quantizer = AEQuantizer(config=self.config, train=self.train)
494
+ elif self.config["quantizer_type"] == "kl_two":
495
+ self.quantizer = KLQuantizerTwo(config=self.config, train=self.train)
496
+ self.encoder = Encoder(config=self.config)
497
+ self.decoder = Decoder(config=self.config)
498
+
499
+ def encode(self, image):
500
+ encoded_feature = self.encoder(image)
501
+ quantized, result_dict = self.quantizer(encoded_feature)
502
+ print("After quant", quantized.shape)
503
+ return quantized, result_dict
504
+
505
+ def decode(self, z_vectors):
506
+ print("z_vectors shape", z_vectors.shape)
507
+ reconstructed = self.decoder(z_vectors)
508
+ return reconstructed
509
+
510
+ def decode_from_indices(self, z_ids):
511
+ z_vectors = self.quantizer.decode_ids(z_ids)
512
+ reconstructed_image = self.decode(z_vectors)
513
+ return reconstructed_image
514
+
515
+ def encode_to_indices(self, image):
516
+ encoded_feature = self.encoder(image)
517
+ _, result_dict = self.quantizer(encoded_feature)
518
+ ids = result_dict["z_ids"]
519
+ return ids
520
+
521
+ def __call__(self, input_dict):
522
+ quantized, result_dict = self.encode(input_dict)
523
+ #Freezing encoder now
524
+ print("encode finished")
525
+ result_dict["latents"] = quantized
526
+ outputs = self.decoder(quantized)
527
+ return outputs, result_dict
f16c16/ppl_images.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
12
+ import jax
13
+ import lpips
14
+
15
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
16
+ loss_fn_alex = loss_fn_alex.cuda()
17
+
18
+
19
+ import numpy as np
20
+ import flax.linen as nn
21
+ import jax.numpy as jnp
22
+ from absl import app, flags
23
+ from functools import partial
24
+ import numpy as np
25
+ import tqdm
26
+ import flax
27
+ import optax
28
+ import wandb
29
+ from ml_collections import config_flags
30
+ #import elements
31
+ import ml_collections
32
+ import tensorflow_datasets as tfds
33
+ import tensorflow as tf
34
+ tf.config.set_visible_devices([], "GPU")
35
+ tf.config.set_visible_devices([], "TPU")
36
+ import matplotlib.pyplot as plt
37
+ from typing import Any
38
+
39
+ from utils.train_state import TrainState, target_update
40
+ from utils.checkpoint import Checkpoint
41
+ from utils.fid import get_fid_network, fid_from_stats
42
+
43
+ from train import VQGANModel
44
+ from models.vqvae import VQVAE
45
+ from models.discriminator import Discriminator
46
+
47
+ from PIL import Image
48
+ import torch
49
+
50
+ delattr(flags.FLAGS, 'dataset_name')
51
+ delattr(flags.FLAGS, 'load_dir')
52
+ delattr(flags.FLAGS, 'batch_size')
53
+
54
+ FLAGS = flags.FLAGS
55
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
56
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
57
+
58
+
59
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
60
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
61
+
62
+ import gc
63
+
64
+ def main(_):
65
+ device_count = len(jax.local_devices())
66
+ global_device_count = jax.device_count()
67
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
68
+
69
+ def get_dataset(is_train):
70
+ if 'imagenet' in FLAGS.dataset_name:
71
+ def deserialization_fn(data):
72
+ image = data['image']
73
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
74
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
75
+ if 'imagenet256' in FLAGS.dataset_name:
76
+ image = tf.image.resize(image, (256, 256))
77
+ elif 'imagenet128' in FLAGS.dataset_name:
78
+ image = tf.image.resize(image, (128, 128))
79
+ else:
80
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
81
+ if is_train:
82
+ image = tf.image.random_flip_left_right(image)
83
+ image = tf.cast(image, tf.float32) / 255.0
84
+ return image
85
+
86
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
87
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
88
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
89
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
90
+ dataset = dataset.batch(local_batch_size)
91
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
92
+ dataset = tfds.as_numpy(dataset)
93
+ dataset = iter(dataset)
94
+ return dataset
95
+ else:
96
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
97
+
98
+ dataset = get_dataset(is_train=True)
99
+ dataset_valid = get_dataset(is_train=False)
100
+
101
+
102
+ # image = Image.open("osman.png")
103
+ # image = np.array(image) / 255.0
104
+ # print(image)
105
+ # image = jnp.array(image)
106
+ # image = jnp.expand_dims(image, 0)
107
+ # image = jnp.expand_dims(image, 0)
108
+
109
+ example_obs = next(dataset)[:1]
110
+
111
+ #Reconstruction loop
112
+ # image = model.reconstruction(image)
113
+ # image = image[0,0,:,:,:]
114
+ # image = (image * 255).astype(np.uint8)
115
+ # image = np.array(image)
116
+ # img = Image.fromarray(image)
117
+ # img.save("osman" + str(i) + ".png")
118
+
119
+
120
+ rng = jax.random.PRNGKey(FLAGS.seed)
121
+ rng, param_key = jax.random.split(rng)
122
+ print("Total devices", jax.local_devices()[0])
123
+
124
+
125
+ ###################################
126
+ # Creating Model and put on devices.
127
+ ###################################
128
+ FLAGS.model.image_channels = example_obs.shape[-1]
129
+ FLAGS.model.image_size = example_obs.shape[1]
130
+ vqvae_def = VQVAE(FLAGS.model, train=True)
131
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
132
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
133
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
134
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
135
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
136
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
137
+
138
+ discriminator_def = Discriminator(FLAGS.model)
139
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
140
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
141
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
142
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
143
+
144
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
145
+
146
+ assert FLAGS.load_dir is not None
147
+ cp = Checkpoint(FLAGS.load_dir)
148
+ model = cp.load_model(model)
149
+ print("Loaded model with step", model.vqvae.step)
150
+
151
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
152
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
153
+ #print(model.vqvae)
154
+
155
+
156
+ ####################################
157
+ # Noise stuff
158
+ ###################################
159
+
160
+ cpus = jax.devices("cpu")
161
+
162
+ #So there are a few ways to calculate PPL here
163
+ #We could take two images in image space
164
+ #Walk between them and check the LPIPS in the output space
165
+ #...actually that's basically it right?
166
+ #We could also do the walk in latent space, which is the same, but with ?? scaling
167
+
168
+ #Let's see if they are any different.
169
+ i = 0
170
+ lpips_list = []
171
+ means = []
172
+ stds = []
173
+ for valid_images in dataset_valid:
174
+
175
+
176
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
177
+ #1, 2, 256, 256, 3
178
+ #Given our 2 images, we want to lerp between them...
179
+ #We want to lerp once to point t, and once to point t + eps
180
+ #And then we want to get the LPIPS between those two images
181
+ #And then we calculate LPIPS
182
+ #And then we divide by eps squared, and done.
183
+
184
+
185
+
186
+ reconstructed_images, decoded, std, latents, std_noisy, latents_noisy = model.reconstruction_ppl_image(valid_images) # [devices, 8, 256, 256, 3]
187
+
188
+
189
+ means.append(latents.mean())
190
+ stds.append(latents.std())
191
+
192
+ # print("std", std.mean())
193
+ print("latent mean", latents.mean())
194
+ print("actual latent std", latents.std())
195
+
196
+ print("latent mean noisy", latents_noisy.mean())
197
+ print("actual latent std noisy", latents_noisy.std())
198
+
199
+ #Need to change images back to -1,1
200
+
201
+ reconstructed_images = reconstructed_images * 2 - 1
202
+ decoded = decoded * 2 -1
203
+
204
+ #1,2,256,256,3
205
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
206
+ decoded = jnp.swapaxes(decoded, 0, 4)
207
+
208
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
209
+ decoded = jnp.swapaxes(decoded, 0, 1)
210
+
211
+ reconstructed_images = jnp.squeeze(reconstructed_images)
212
+ decoded = jnp.squeeze(decoded)
213
+
214
+ #So here, we want to put them on CPU and delete the original
215
+
216
+
217
+ image_np = np.asarray(reconstructed_images)
218
+ image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
219
+
220
+ decoded_np = np.asarray(decoded)
221
+ decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
222
+
223
+
224
+
225
+ lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
226
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
227
+ lpips_cpu = lpips_cpu / (.0001 ** 2)
228
+
229
+ print(lpips_cpu)
230
+ lpips_list.append(lpips_cpu)
231
+
232
+
233
+ i += 1
234
+ #
235
+ if i == 500:
236
+ break
237
+
238
+ #1e-4 is 54...
239
+ #1e-5 is 106
240
+ #1e-6 is 126
241
+
242
+ #kl2 is 150?
243
+
244
+
245
+
246
+ mean_lpips = jnp.mean(jnp.asarray(lpips_list))
247
+ print(mean_lpips)
248
+ print("mean of means", jnp.asarray(means).mean())
249
+ print("stds of means", jnp.asarray(means).std())
250
+ print("mean of stds", jnp.asarray(stds).mean())
251
+ print("std of stds", jnp.asarray(stds).std())
252
+
253
+
254
+ if __name__ == '__main__':
255
+ app.run(main)
f16c16/ppl_latents.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
12
+ import jax
13
+ import lpips
14
+
15
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
16
+ loss_fn_alex = loss_fn_alex.cuda()
17
+
18
+
19
+ import numpy as np
20
+ import flax.linen as nn
21
+ import jax.numpy as jnp
22
+ from absl import app, flags
23
+ from functools import partial
24
+ import numpy as np
25
+ import tqdm
26
+ import flax
27
+ import optax
28
+ import wandb
29
+ from ml_collections import config_flags
30
+ #import elements
31
+ import ml_collections
32
+ import tensorflow_datasets as tfds
33
+ import tensorflow as tf
34
+ tf.config.set_visible_devices([], "GPU")
35
+ tf.config.set_visible_devices([], "TPU")
36
+ import matplotlib.pyplot as plt
37
+ from typing import Any
38
+
39
+ from utils.train_state import TrainState, target_update
40
+ from utils.checkpoint import Checkpoint
41
+ from utils.fid import get_fid_network, fid_from_stats
42
+
43
+ from train import VQGANModel
44
+ from models.vqvae import VQVAE
45
+ from models.discriminator import Discriminator
46
+
47
+ from PIL import Image
48
+ import torch
49
+
50
+ delattr(flags.FLAGS, 'dataset_name')
51
+ delattr(flags.FLAGS, 'load_dir')
52
+ delattr(flags.FLAGS, 'batch_size')
53
+
54
+ FLAGS = flags.FLAGS
55
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
56
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
57
+
58
+
59
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
60
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
61
+
62
+ import gc
63
+
64
+ def main(_):
65
+ device_count = len(jax.local_devices())
66
+ global_device_count = jax.device_count()
67
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
68
+
69
+ def get_dataset(is_train):
70
+ if 'imagenet' in FLAGS.dataset_name:
71
+ def deserialization_fn(data):
72
+ image = data['image']
73
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
74
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
75
+ if 'imagenet256' in FLAGS.dataset_name:
76
+ image = tf.image.resize(image, (256, 256))
77
+ elif 'imagenet128' in FLAGS.dataset_name:
78
+ image = tf.image.resize(image, (128, 128))
79
+ else:
80
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
81
+ if is_train:
82
+ image = tf.image.random_flip_left_right(image)
83
+ image = tf.cast(image, tf.float32) / 255.0
84
+ return image
85
+
86
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
87
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
88
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
89
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
90
+ dataset = dataset.batch(local_batch_size)
91
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
92
+ dataset = tfds.as_numpy(dataset)
93
+ dataset = iter(dataset)
94
+ return dataset
95
+ else:
96
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
97
+
98
+ dataset = get_dataset(is_train=True)
99
+ dataset_valid = get_dataset(is_train=False)
100
+
101
+
102
+ # image = Image.open("osman.png")
103
+ # image = np.array(image) / 255.0
104
+ # print(image)
105
+ # image = jnp.array(image)
106
+ # image = jnp.expand_dims(image, 0)
107
+ # image = jnp.expand_dims(image, 0)
108
+
109
+ example_obs = next(dataset)[:1]
110
+
111
+ #Reconstruction loop
112
+ # image = model.reconstruction(image)
113
+ # image = image[0,0,:,:,:]
114
+ # image = (image * 255).astype(np.uint8)
115
+ # image = np.array(image)
116
+ # img = Image.fromarray(image)
117
+ # img.save("osman" + str(i) + ".png")
118
+
119
+
120
+ rng = jax.random.PRNGKey(FLAGS.seed)
121
+ rng, param_key = jax.random.split(rng)
122
+ print("Total devices", jax.local_devices()[0])
123
+
124
+
125
+ ###################################
126
+ # Creating Model and put on devices.
127
+ ###################################
128
+ FLAGS.model.image_channels = example_obs.shape[-1]
129
+ FLAGS.model.image_size = example_obs.shape[1]
130
+ vqvae_def = VQVAE(FLAGS.model, train=True)
131
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
132
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
133
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
134
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
135
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
136
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
137
+
138
+ discriminator_def = Discriminator(FLAGS.model)
139
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
140
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
141
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
142
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
143
+
144
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
145
+
146
+ assert FLAGS.load_dir is not None
147
+ cp = Checkpoint(FLAGS.load_dir)
148
+ model = cp.load_model(model)
149
+ print("Loaded model with step", model.vqvae.step)
150
+
151
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
152
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
153
+ #print(model.vqvae)
154
+
155
+
156
+ ####################################
157
+ # Noise stuff
158
+ ###################################
159
+
160
+ cpus = jax.devices("cpu")
161
+
162
+ #So there are a few ways to calculate PPL here
163
+ #We could take two images in image space
164
+ #Walk between them and check the LPIPS in the output space
165
+ #...actually that's basically it right?
166
+ #We could also do the walk in latent space, which is the same, but with ?? scaling
167
+
168
+ #Let's see if they are any different.
169
+
170
+
171
+ #We could also try taking a latent, going X/2 direction, and -X/2 direction, and seeing that.
172
+ i = 0
173
+ lpips_list = []
174
+ means = []
175
+ stds = []
176
+ for valid_images in dataset_valid:
177
+
178
+
179
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
180
+ #1, 2, 256, 256, 3
181
+ #Given our 2 images, we want to lerp between them...
182
+ #We want to lerp once to point t, and once to point t + eps
183
+ #And then we want to get the LPIPS between those two images
184
+ #And then we calculate LPIPS
185
+ #And then we divide by eps squared, and done.
186
+
187
+
188
+ reconstructed_images, decoded, std, latents = model.reconstruction_ppl(valid_images) # [devices, 8, 256, 256, 3]
189
+
190
+
191
+ means.append(latents.mean())
192
+ stds.append(latents.std())
193
+ print("noise added", std.mean())
194
+ print("latent mean", latents.mean())
195
+ print("actual latent std", latents.std())
196
+
197
+ #Need to change images back to -1,1
198
+
199
+ reconstructed_images = reconstructed_images * 2 - 1
200
+ decoded = decoded * 2 -1
201
+
202
+ #1,2,256,256,3
203
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
204
+ decoded = jnp.swapaxes(decoded, 0, 4)
205
+
206
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
207
+ decoded = jnp.swapaxes(decoded, 0, 1)
208
+
209
+ reconstructed_images = jnp.squeeze(reconstructed_images)
210
+ decoded = jnp.squeeze(decoded)
211
+
212
+ #So here, we want to put them on CPU and delete the original
213
+
214
+
215
+ image_np = np.asarray(reconstructed_images)
216
+ image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
217
+
218
+ decoded_np = np.asarray(decoded)
219
+ decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
220
+
221
+
222
+
223
+ lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
224
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
225
+ lpips_cpu = lpips_cpu / (.0001 ** 2)
226
+
227
+ print(lpips_cpu)
228
+ lpips_list.append(lpips_cpu)
229
+
230
+
231
+ i += 1
232
+ #
233
+ if i == 500:
234
+ break
235
+
236
+
237
+ mean_lpips = jnp.mean(jnp.asarray(lpips_list))
238
+ std_lpips = jnp.std(jnp.asarray(lpips_list))
239
+ print("PPL", mean_lpips)
240
+ print("C std", std_lpips)
241
+
242
+ print("mean of means", jnp.asarray(means).mean())
243
+ print("stds of means", jnp.asarray(means).std())
244
+ print("mean of stds", jnp.asarray(stds).mean())
245
+ print("std of stds", jnp.asarray(stds).std())
246
+
247
+
248
+ #ae sym
249
+ # mean of means 0.35234922
250
+ # stds of means 0.4036692
251
+ # mean of stds 2.6363409
252
+ # std of stds 0.30666474
253
+
254
+
255
+ #1e-6:
256
+ #mean of means -0.018107202
257
+ # stds of means 0.11694455
258
+ # mean of stds 1.0860059
259
+ # std of stds 0.09732369
260
+ #average noise added around .03
261
+
262
+ #1e-5:
263
+ # mean of means 0.0065166513
264
+ # stds of means 0.06983645
265
+ # mean of stds 0.9855982
266
+ # std of stds 0.05810356
267
+
268
+ #1e-4:
269
+ # PPL 8.167942
270
+ # C std 1.7576017
271
+ # mean of means 0.0065882676
272
+ # stds of means 0.042861093
273
+ # mean of stds 0.7608507
274
+ # std of stds 0.05846726
275
+ #Average noise added???
276
+
277
+
278
+
279
+ #pl300
280
+ #PPL 3.5399284
281
+ #C std 0.45380986
282
+ # mean of means 0.090131655
283
+ # stds of means 0.69894844
284
+ # mean of stds 5.5634923
285
+ # std of stds 0.6767279
286
+
287
+
288
+ #pl100
289
+ # PPL 3.6192155
290
+ # C std 0.47185272
291
+ # mean of means 0.16227543
292
+ # stds of means 0.53616405
293
+ # mean of stds 4.4914503
294
+ # std of stds 0.6015057
295
+
296
+ #kl2 noise thing
297
+ # PPL 1.2598925
298
+ # C std 0.26455516
299
+ # mean of means -0.013443217
300
+ # stds of means 1.5238239
301
+ # mean of stds 40.043938
302
+ # std of stds 1.7931403
303
+
304
+
305
+
306
+ if __name__ == '__main__':
307
+ app.run(main)
f16c16/ppl_latents2.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
12
+ import jax
13
+ import lpips
14
+
15
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
16
+ loss_fn_alex = loss_fn_alex.cuda()
17
+
18
+
19
+ import numpy as np
20
+ import flax.linen as nn
21
+ import jax.numpy as jnp
22
+ from absl import app, flags
23
+ from functools import partial
24
+ import numpy as np
25
+ import tqdm
26
+ import flax
27
+ import optax
28
+ import wandb
29
+ from ml_collections import config_flags
30
+ #import elements
31
+ import ml_collections
32
+ import tensorflow_datasets as tfds
33
+ import tensorflow as tf
34
+ tf.config.set_visible_devices([], "GPU")
35
+ tf.config.set_visible_devices([], "TPU")
36
+ import matplotlib.pyplot as plt
37
+ from typing import Any
38
+
39
+ from utils.train_state import TrainState, target_update
40
+ from utils.checkpoint import Checkpoint
41
+ from utils.fid import get_fid_network, fid_from_stats
42
+
43
+ from train import VQGANModel
44
+ from models.vqvae import VQVAE
45
+ from models.discriminator import Discriminator
46
+
47
+ from PIL import Image
48
+ import torch
49
+
50
+ delattr(flags.FLAGS, 'dataset_name')
51
+ delattr(flags.FLAGS, 'load_dir')
52
+ delattr(flags.FLAGS, 'batch_size')
53
+
54
+ FLAGS = flags.FLAGS
55
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
56
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
57
+
58
+
59
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
60
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
61
+
62
+ import gc
63
+
64
+ def main(_):
65
+ device_count = len(jax.local_devices())
66
+ global_device_count = jax.device_count()
67
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
68
+
69
+ def get_dataset(is_train):
70
+ if 'imagenet' in FLAGS.dataset_name:
71
+ def deserialization_fn(data):
72
+ image = data['image']
73
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
74
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
75
+ if 'imagenet256' in FLAGS.dataset_name:
76
+ image = tf.image.resize(image, (256, 256))
77
+ elif 'imagenet128' in FLAGS.dataset_name:
78
+ image = tf.image.resize(image, (128, 128))
79
+ else:
80
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
81
+ if is_train:
82
+ image = tf.image.random_flip_left_right(image)
83
+ image = tf.cast(image, tf.float32) / 255.0
84
+ return image
85
+
86
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
87
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
88
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
89
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
90
+ dataset = dataset.batch(local_batch_size)
91
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
92
+ dataset = tfds.as_numpy(dataset)
93
+ dataset = iter(dataset)
94
+ return dataset
95
+ else:
96
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
97
+
98
+ dataset = get_dataset(is_train=True)
99
+ dataset_valid = get_dataset(is_train=False)
100
+
101
+
102
+ # image = Image.open("osman.png")
103
+ # image = np.array(image) / 255.0
104
+ # print(image)
105
+ # image = jnp.array(image)
106
+ # image = jnp.expand_dims(image, 0)
107
+ # image = jnp.expand_dims(image, 0)
108
+
109
+ example_obs = next(dataset)[:1]
110
+
111
+ #Reconstruction loop
112
+ # image = model.reconstruction(image)
113
+ # image = image[0,0,:,:,:]
114
+ # image = (image * 255).astype(np.uint8)
115
+ # image = np.array(image)
116
+ # img = Image.fromarray(image)
117
+ # img.save("osman" + str(i) + ".png")
118
+
119
+
120
+ rng = jax.random.PRNGKey(FLAGS.seed)
121
+ rng, param_key = jax.random.split(rng)
122
+ print("Total devices", jax.local_devices()[0])
123
+
124
+
125
+ ###################################
126
+ # Creating Model and put on devices.
127
+ ###################################
128
+ FLAGS.model.image_channels = example_obs.shape[-1]
129
+ FLAGS.model.image_size = example_obs.shape[1]
130
+ vqvae_def = VQVAE(FLAGS.model, train=True)
131
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
132
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
133
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
134
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
135
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
136
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
137
+
138
+ discriminator_def = Discriminator(FLAGS.model)
139
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
140
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
141
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
142
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
143
+
144
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
145
+
146
+ assert FLAGS.load_dir is not None
147
+ cp = Checkpoint(FLAGS.load_dir)
148
+ model = cp.load_model(model)
149
+ print("Loaded model with step", model.vqvae.step)
150
+
151
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
152
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
153
+ #print(model.vqvae)
154
+
155
+
156
+ ####################################
157
+ # Noise stuff
158
+ ###################################
159
+
160
+ cpus = jax.devices("cpu")
161
+
162
+ #So there are a few ways to calculate PPL here
163
+ #We could take two images in image space
164
+ #Walk between them and check the LPIPS in the output space
165
+ #...actually that's basically it right?
166
+ #We could also do the walk in latent space, which is the same, but with ?? scaling
167
+
168
+ #Let's see if they are any different.
169
+
170
+
171
+ #We could also try taking a latent, going X/2 direction, and -X/2 direction, and seeing that.
172
+ i = 0
173
+ lpips_list = []
174
+ means = []
175
+ stds = []
176
+ for valid_images in dataset_valid:
177
+
178
+
179
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
180
+ #1, 2, 256, 256, 3
181
+ #Given our 2 images, we want to lerp between them...
182
+ #We want to lerp once to point t, and once to point t + eps
183
+ #And then we want to get the LPIPS between those two images
184
+ #And then we calculate LPIPS
185
+ #And then we divide by eps squared, and done.
186
+
187
+
188
+ reconstructed_images, decoded, std, latents, decoded_2 = model.reconstruction_ppl_two(valid_images) # [devices, 8, 256, 256, 3]
189
+
190
+
191
+ means.append(latents.mean())
192
+ stds.append(latents.std())
193
+ # print("std", std.mean())
194
+ print("latent mean", latents.mean())
195
+ print("actual latent std", latents.std())
196
+
197
+ #Need to change images back to -1,1
198
+ #Why are the images so similar? It's different noises...
199
+
200
+ reconstructed_images = decoded_2 * 2 - 1
201
+ decoded = decoded * 2 -1
202
+
203
+ #1,2,256,256,3
204
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
205
+ decoded = jnp.swapaxes(decoded, 0, 4)
206
+
207
+ reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
208
+ decoded = jnp.swapaxes(decoded, 0, 1)
209
+
210
+ reconstructed_images = jnp.squeeze(reconstructed_images)
211
+ decoded = jnp.squeeze(decoded)
212
+
213
+ #So here, we want to put them on CPU and delete the original
214
+
215
+
216
+ image_np = np.asarray(reconstructed_images)
217
+ image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
218
+
219
+ decoded_np = np.asarray(decoded)
220
+ decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
221
+
222
+
223
+
224
+ lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
225
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
226
+ lpips_cpu = lpips_cpu / (.0001 ** 2)
227
+
228
+ print(lpips_cpu)
229
+ lpips_list.append(lpips_cpu)
230
+
231
+
232
+ i += 1
233
+ #
234
+ if i == 500:
235
+ break
236
+
237
+
238
+
239
+
240
+ mean_lpips = jnp.mean(jnp.asarray(lpips_list))
241
+ print(mean_lpips)
242
+ print("mean of means", jnp.asarray(means).mean())
243
+ print("stds of means", jnp.asarray(means).std())
244
+ print("mean of stds", jnp.asarray(stds).mean())
245
+ print("std of stds", jnp.asarray(stds).std())
246
+
247
+ #1e-4? 8.1371
248
+ #1e-5 9.0486
249
+ #1e-6 9.7
250
+
251
+
252
+ #ae is a 5.85.....
253
+
254
+
255
+ #1e-4 kl2 1.26
256
+
257
+
258
+
259
+
260
+
261
+ #1e-6 is 9.8
262
+ #1e-5 is 9.09
263
+ #2e-5 is ..... between these. hopefully. 8.83
264
+ #1e-4 is 8.16
265
+ #ae (sym) is 5.87 right now, somehow.
266
+
267
+
268
+ #basicallly ae 5.56, then 4.95?
269
+
270
+
271
+ #PL100 is 3.6
272
+ #Pl300 is 3.53
273
+ #Pl600 is... 3.97
274
+
275
+
276
+ #So the kl level barely matters it seems.
277
+ #We might want to try MMD + noise, but it also barely matters I think
278
+ #1e-4 was 1.25
279
+ #5e-5 was 1.225
280
+ #kl2 was like super duper low, forgot to save it lol. 1.17 maybe?
281
+
282
+ if __name__ == '__main__':
283
+ app.run(main)
f16c16/stats.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+
8
+ #import jax
9
+ #jax.config.update('jax_platform_name', 'cpu')
10
+ import os
11
+ # os.environ["JAX_PLATFORMS"] = 'cpu'
12
+ import jax
13
+ import lpips
14
+
15
+ loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
16
+ loss_fn_alex = loss_fn_alex.cuda()
17
+
18
+
19
+ import numpy as np
20
+ import flax.linen as nn
21
+ import jax.numpy as jnp
22
+ from absl import app, flags
23
+ from functools import partial
24
+ import numpy as np
25
+ import tqdm
26
+ import flax
27
+ import optax
28
+ import wandb
29
+ from ml_collections import config_flags
30
+ #import elements
31
+ import ml_collections
32
+ import tensorflow_datasets as tfds
33
+ import tensorflow as tf
34
+ tf.config.set_visible_devices([], "GPU")
35
+ tf.config.set_visible_devices([], "TPU")
36
+ import matplotlib.pyplot as plt
37
+ from typing import Any
38
+
39
+ from utils.train_state import TrainState, target_update
40
+ from utils.checkpoint import Checkpoint
41
+ from utils.fid import get_fid_network, fid_from_stats
42
+
43
+ from train import VQGANModel
44
+ from models.vqvae import VQVAE
45
+ from models.discriminator import Discriminator
46
+
47
+ from PIL import Image
48
+ import torch
49
+
50
+ delattr(flags.FLAGS, 'dataset_name')
51
+ delattr(flags.FLAGS, 'load_dir')
52
+ delattr(flags.FLAGS, 'batch_size')
53
+
54
+ FLAGS = flags.FLAGS
55
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
56
+ flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
57
+
58
+
59
+ flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
60
+ # Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
61
+
62
+ import gc
63
+
64
+ def main(_):
65
+
66
+ device_count = len(jax.local_devices())
67
+ global_device_count = jax.device_count()
68
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
69
+
70
+ def get_dataset(is_train):
71
+ if 'imagenet' in FLAGS.dataset_name:
72
+ def deserialization_fn(data):
73
+ image = data['image']
74
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
75
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
76
+ if 'imagenet256' in FLAGS.dataset_name:
77
+ image = tf.image.resize(image, (256, 256))
78
+ elif 'imagenet128' in FLAGS.dataset_name:
79
+ image = tf.image.resize(image, (128, 128))
80
+ else:
81
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
82
+ if is_train:
83
+ image = tf.image.random_flip_left_right(image)
84
+ image = tf.cast(image, tf.float32) / 255.0
85
+ return image
86
+
87
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
88
+ dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
89
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
90
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
91
+ dataset = dataset.batch(local_batch_size)
92
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
93
+ dataset = tfds.as_numpy(dataset)
94
+ dataset = iter(dataset)
95
+ return dataset
96
+ else:
97
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
98
+
99
+ dataset = get_dataset(is_train=True)
100
+ dataset_valid = get_dataset(is_train=False)
101
+
102
+
103
+ # image = Image.open("osman.png")
104
+ # image = np.array(image) / 255.0
105
+ # print(image)
106
+ # image = jnp.array(image)
107
+ # image = jnp.expand_dims(image, 0)
108
+ # image = jnp.expand_dims(image, 0)
109
+
110
+ example_obs = next(dataset)[:1]
111
+
112
+ #Reconstruction loop
113
+ # image = model.reconstruction(image)
114
+ # image = image[0,0,:,:,:]
115
+ # image = (image * 255).astype(np.uint8)
116
+ # image = np.array(image)
117
+ # img = Image.fromarray(image)
118
+ # img.save("osman" + str(i) + ".png")
119
+
120
+
121
+ rng = jax.random.PRNGKey(FLAGS.seed)
122
+ rng, param_key = jax.random.split(rng)
123
+ print("Total devices", jax.local_devices()[0])
124
+
125
+
126
+ ###################################
127
+ # Creating Model and put on devices.
128
+ ###################################
129
+ FLAGS.model.image_channels = example_obs.shape[-1]
130
+ FLAGS.model.image_size = example_obs.shape[1]
131
+ vqvae_def = VQVAE(FLAGS.model, train=True)
132
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
133
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
134
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
135
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
136
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
137
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
138
+
139
+ discriminator_def = Discriminator(FLAGS.model)
140
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
141
+ # tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
142
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
143
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
144
+
145
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
146
+
147
+ assert FLAGS.load_dir is not None
148
+ cp = Checkpoint(FLAGS.load_dir)
149
+ model = cp.load_model(model)
150
+ print("Loaded model with step", model.vqvae.step)
151
+
152
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
153
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
154
+ #print(model.vqvae)
155
+
156
+
157
+ ####################################
158
+ # Noise stuff
159
+ ###################################
160
+
161
+ #on the other end also.
162
+ noises = []
163
+
164
+ numbers = np.arange(0.00, 1.0, 0.01)
165
+
166
+ for number in numbers:
167
+ noises.append(float(number))
168
+
169
+ # numbers = np.arange(.4, 3, .5)
170
+ # for number in numbers:
171
+ # noises.append(float(number))
172
+
173
+ i = 0
174
+ l2_dict = {noise: [] for noise in noises}
175
+ lpips_dict = {noise: [] for noise in noises}
176
+ snr_dict = {noise: [] for noise in noises}
177
+
178
+ cpus = jax.devices("cpu")
179
+ print(noises)
180
+
181
+
182
+ for valid_images in dataset_valid:
183
+ print(i)
184
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
185
+
186
+ # valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
187
+
188
+ valid_reconstructed_images, noisy_reconstructed_images, std = model.reconstruction_noisy(valid_images)
189
+ print(std.mean())
190
+
191
+
192
+ # valid_reconstructed_images, noisy_reconstructed_images = model.reconstruction_sampling(valid_images) # [devices, 8, 256, 256, 3]
193
+
194
+
195
+
196
+ # print(latents)
197
+ #Calculate MSE between valid and noisy.
198
+ if True:
199
+ for noise, decoded in zip(noises, noisy_reconstructed_images):
200
+ image, snr = decoded
201
+ snr = snr.mean()#So this gives us the snr for a given noise level. need to mean it..
202
+ snr_dict[noise].append(snr)
203
+ #So we put it into the noise list.
204
+
205
+ # print("snr", snr)
206
+ l2 = jnp.mean((valid_reconstructed_images - image) ** 2)
207
+ l2_cpu = jax.device_put(l2, cpus[0])
208
+ l2_dict[noise].append(l2_cpu)
209
+
210
+ #Need to change images back to -1,1
211
+
212
+ image = image * 2 - 1
213
+ valid_rescaled = valid_reconstructed_images * 2 -1
214
+
215
+ #1,2,256,256,3
216
+ image = jnp.swapaxes(image, 0, 4)
217
+ valid_rescaled = jnp.swapaxes(valid_rescaled, 0, 4)
218
+
219
+ image = jnp.swapaxes(image, 0, 1)
220
+ valid_rescaled = jnp.swapaxes(valid_rescaled, 0, 1)
221
+
222
+ image = jnp.squeeze(image)
223
+ valid_rescaled = jnp.squeeze(valid_rescaled)
224
+
225
+ #So here, we want to put them on CPU and delete the original
226
+
227
+
228
+ image_np = np.asarray(image)
229
+ image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
230
+
231
+
232
+ #Can be run only once if needd
233
+ valid_rescaled_np = np.asarray(valid_rescaled)
234
+ valid_rescaled_np_2 = torch.from_numpy(np.copy(valid_rescaled_np)).cuda()
235
+
236
+
237
+
238
+ lpips_loss = loss_fn_alex(valid_rescaled_np_2, image_np_2)
239
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
240
+ lpips_dict[noise].append(lpips_cpu)
241
+ elif False:#Check l2 and lpips on our 2 images..
242
+
243
+ l2 = jnp.mean((valid_reconstructed_images - noisy_reconstructed_images) ** 2)
244
+ l2_cpu = jax.device_put(l2, cpus[0])
245
+ print("L2", l2_cpu)
246
+
247
+ #Need to change images back to -1,1
248
+ valid_reconstructed_images = valid_reconstructed_images * 2 - 1
249
+ noisy_reconstructed_images = noisy_reconstructed_images * 2 -1
250
+
251
+ #1,2,256,256,3
252
+ valid_reconstructed_images = jnp.swapaxes(valid_reconstructed_images, 0, 4)
253
+ noisy_reconstructed_images = jnp.swapaxes(noisy_reconstructed_images, 0, 4)
254
+
255
+ valid_reconstructed_images = jnp.swapaxes(valid_reconstructed_images, 0, 1)
256
+ noisy_reconstructed_images = jnp.swapaxes(noisy_reconstructed_images, 0, 1)
257
+
258
+ valid_reconstructed_images = jnp.squeeze(valid_reconstructed_images)
259
+ noisy_reconstructed_images = jnp.squeeze(noisy_reconstructed_images)
260
+
261
+ #So here, we want to put them on CPU and delete the original
262
+
263
+
264
+ image_np = np.asarray(valid_reconstructed_images)
265
+ image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
266
+
267
+ valid_rescaled_np = np.asarray(noisy_reconstructed_images)
268
+ valid_rescaled_np_2 = torch.from_numpy(np.copy(valid_rescaled_np)).cuda()
269
+
270
+
271
+
272
+ lpips_loss = loss_fn_alex(valid_rescaled_np_2, image_np_2)
273
+ lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
274
+ print("Lpips", lpips_cpu)
275
+
276
+
277
+ if False:
278
+ image = valid_images[0,0,:,:,:]
279
+ image = (image * 255).astype(np.uint8)
280
+ img = Image.fromarray(image)
281
+ img.save("original" + str(i) + ".png")
282
+
283
+ image2 = valid_reconstructed_images[0,0,:,:,:]
284
+ image2 = (image2 * 255).astype(np.uint8)
285
+ image2 = np.array(image2)
286
+ image2 = Image.fromarray(image2)
287
+ image2.save("recon" + str(i) + ".png")
288
+
289
+
290
+ #Needs [0] if list
291
+ # image3 = noisy_reconstructed_images[0][0,0,:,:,:]
292
+
293
+ image3 = noisy_reconstructed_images[2][0,0,:,:,:]
294
+ image3 = (image3 * 255).astype(np.uint8)
295
+ image3 = np.array(image3)
296
+ image3 = Image.fromarray(image3)
297
+ image3.save("noisy_recon_0_" + str(i) + ".png")
298
+
299
+ image4 = noisy_reconstructed_images[-1][0,0,:,:,:]
300
+ image4 = (image4 * 255).astype(np.uint8)
301
+ image4 = np.array(image4)
302
+ image4 = Image.fromarray(image4)
303
+ image4.save("noisy_recon_last_" + str(i) + ".png")
304
+
305
+ # del valid_images
306
+ # del valid_reconstructed_images
307
+ # del noisy_reconstructed_images
308
+
309
+ # gc.collect()
310
+ # torch.cuda.empty_cache()
311
+ i += 1
312
+ #
313
+ if i == 50:
314
+ break
315
+ #Now we have our l2 set.
316
+
317
+ mean_l2_dict = {noise: jnp.mean(jnp.asarray(l2_values)) for noise, l2_values in l2_dict.items()}
318
+ std_l2_dict = {noise: jnp.std(jnp.asarray(l2_values)) for noise, l2_values in l2_dict.items()}
319
+ for noise, mean_l2 in mean_l2_dict.items():
320
+ print(f"Mean L2 for noise {noise}: {mean_l2}")
321
+
322
+ mean_lpips_dict = {noise: torch.mean(torch.tensor(lpips_values)) for noise, lpips_values in lpips_dict.items()}
323
+ std_lpips_dict = {noise: torch.std(torch.tensor(lpips_values)) for noise, lpips_values in lpips_dict.items()}
324
+ for noise, mean_lpips in mean_lpips_dict.items():
325
+ print(f"Mean Lpips for noise {noise}: {mean_lpips}")
326
+
327
+ mean_snr_dict = {noise: jnp.mean(jnp.asarray(snr_values)) for noise, snr_values in snr_dict.items()}
328
+ std_snr_dict = {noise: jnp.std(jnp.asarray(snr_values)) for noise, snr_values in snr_dict.items()}
329
+ for noise, mean_snr in mean_snr_dict.items():
330
+ print(f"Mean SNR for noise {noise}: {mean_snr}")
331
+
332
+ array = []
333
+ for noise, std in std_lpips_dict.items():
334
+ array.append(np.asarray(std).tolist())
335
+
336
+ print(array)
337
+ print(std_lpips_dict)
338
+ print(std_snr_dict)#This tells us the range of SNR for a given image/noise level, which... should be lower...?
339
+
340
+ #pl300
341
+ #it's noise to std of the lpips at that noise, but we need....
342
+ #So our points are mean of the lpips at a noise level
343
+ #Mean of the
344
+ ''' PL300
345
+ {0.0: tensor(0.), 0.01: tensor(1.2151e-05), 0.02: tensor(4.2352e-05), 0.03: tensor(8.5722e-05), 0.04: tensor(0.0001), 0.05: tensor(0.0002), 0.06: tensor(0.0003), 0.07: tensor(0.0003), 0.08: tensor(0.0004), 0.09: tensor(0.0005), 0.1: tensor(0.0006), 0.11: tensor(0.0007), 0.12: tensor(0.0008), 0.13: tensor(0.0009), 0.14: tensor(0.0011), 0.15: tensor(0.0012), 0.16: tensor(0.0013), 0.17: tensor(0.0015), 0.18: tensor(0.0016), 0.19: tensor(0.0017), 0.2: tensor(0.0019), 0.21: tensor(0.0020), 0.22: tensor(0.0022), 0.23: tensor(0.0023), 0.24: tensor(0.0025), 0.25: tensor(0.0027), 0.26: tensor(0.0028), 0.27: tensor(0.0030), 0.28: tensor(0.0032), 0.29: tensor(0.0034), 0.3: tensor(0.0036), 0.31: tensor(0.0037), 0.32: tensor(0.0039), 0.33: tensor(0.0041), 0.34: tensor(0.0043), 0.35000000000000003: tensor(0.0045), 0.36: tensor(0.0047), 0.37: tensor(0.0050), 0.38: tensor(0.0052), 0.39: tensor(0.0054), 0.4: tensor(0.0056), 0.41000000000000003: tensor(0.0059), 0.42: tensor(0.0061), 0.43: tensor(0.0063), 0.44: tensor(0.0066), 0.45: tensor(0.0068), 0.46: tensor(0.0070), 0.47000000000000003: tensor(0.0073), 0.48: tensor(0.0075), 0.49: tensor(0.0078), 0.5: tensor(0.0080), 0.51: tensor(0.0083), 0.52: tensor(0.0086), 0.53: tensor(0.0088), 0.54: tensor(0.0091), 0.55: tensor(0.0094), 0.56: tensor(0.0097), 0.5700000000000001: tensor(0.0100), 0.58: tensor(0.0102), 0.59: tensor(0.0105), 0.6: tensor(0.0108), 0.61: tensor(0.0111), 0.62: tensor(0.0114), 0.63: tensor(0.0118), 0.64: tensor(0.0121), 0.65: tensor(0.0124), 0.66: tensor(0.0127), 0.67: tensor(0.0130), 0.68: tensor(0.0133), 0.6900000000000001: tensor(0.0136), 0.7000000000000001: tensor(0.0140), 0.71: tensor(0.0143), 0.72: tensor(0.0146), 0.73: tensor(0.0149), 0.74: tensor(0.0152), 0.75: tensor(0.0156), 0.76: tensor(0.0159), 0.77: tensor(0.0162), 0.78: tensor(0.0166), 0.79: tensor(0.0169), 0.8: tensor(0.0172), 0.81: tensor(0.0176), 0.8200000000000001: tensor(0.0179), 0.8300000000000001: tensor(0.0183), 0.84: tensor(0.0186), 0.85: tensor(0.0190), 0.86: tensor(0.0193), 0.87: tensor(0.0197), 0.88: tensor(0.0200), 0.89: tensor(0.0204), 0.9: tensor(0.0208), 0.91: tensor(0.0211), 0.92: tensor(0.0215), 0.93: tensor(0.0218), 0.9400000000000001: tensor(0.0222), 0.9500000000000001: tensor(0.0226), 0.96: tensor(0.0229), 0.97: tensor(0.0233), 0.98: tensor(0.0236), 0.99: tensor(0.0240)}
346
+ 1e-4
347
+ {0.0: tensor(0.), 0.01: tensor(7.1912e-05), 0.02: tensor(0.0003), 0.03: tensor(0.0006), 0.04: tensor(0.0009), 0.05: tensor(0.0014), 0.06: tensor(0.0018), 0.07: tensor(0.0023), 0.08: tensor(0.0029), 0.09: tensor(0.0034), 0.1: tensor(0.0039), 0.11: tensor(0.0044), 0.12: tensor(0.0049), 0.13: tensor(0.0054), 0.14: tensor(0.0059), 0.15: tensor(0.0064), 0.16: tensor(0.0070), 0.17: tensor(0.0075), 0.18: tensor(0.0080), 0.19: tensor(0.0085), 0.2: tensor(0.0090), 0.21: tensor(0.0096), 0.22: tensor(0.0101), 0.23: tensor(0.0107), 0.24: tensor(0.0112), 0.25: tensor(0.0118), 0.26: tensor(0.0123), 0.27: tensor(0.0129), 0.28: tensor(0.0135), 0.29: tensor(0.0141), 0.3: tensor(0.0147), 0.31: tensor(0.0153), 0.32: tensor(0.0159), 0.33: tensor(0.0166), 0.34: tensor(0.0173), 0.35000000000000003: tensor(0.0180), 0.36: tensor(0.0187), 0.37: tensor(0.0194), 0.38: tensor(0.0201), 0.39: tensor(0.0207), 0.4: tensor(0.0214), 0.41000000000000003: tensor(0.0221), 0.42: tensor(0.0228), 0.43: tensor(0.0236), 0.44: tensor(0.0243), 0.45: tensor(0.0250), 0.46: tensor(0.0258), 0.47000000000000003: tensor(0.0266), 0.48: tensor(0.0274), 0.49: tensor(0.0282), 0.5: tensor(0.0290), 0.51: tensor(0.0298), 0.52: tensor(0.0305), 0.53: tensor(0.0313), 0.54: tensor(0.0321), 0.55: tensor(0.0328), 0.56: tensor(0.0336), 0.5700000000000001: tensor(0.0344), 0.58: tensor(0.0353), 0.59: tensor(0.0361), 0.6: tensor(0.0370), 0.61: tensor(0.0378), 0.62: tensor(0.0386), 0.63: tensor(0.0395), 0.64: tensor(0.0403), 0.65: tensor(0.0410), 0.66: tensor(0.0417), 0.67: tensor(0.0424), 0.68: tensor(0.0430), 0.6900000000000001: tensor(0.0436), 0.7000000000000001: tensor(0.0442), 0.71: tensor(0.0448), 0.72: tensor(0.0454), 0.73: tensor(0.0459), 0.74: tensor(0.0464), 0.75: tensor(0.0468), 0.76: tensor(0.0472), 0.77: tensor(0.0477), 0.78: tensor(0.0480), 0.79: tensor(0.0484), 0.8: tensor(0.0488), 0.81: tensor(0.0493), 0.8200000000000001: tensor(0.0497), 0.8300000000000001: tensor(0.0501), 0.84: tensor(0.0506), 0.85: tensor(0.0510), 0.86: tensor(0.0513), 0.87: tensor(0.0516), 0.88: tensor(0.0519), 0.89: tensor(0.0521), 0.9: tensor(0.0522), 0.91: tensor(0.0524), 0.92: tensor(0.0525), 0.93: tensor(0.0526), 0.9400000000000001: tensor(0.0526), 0.9500000000000001: tensor(0.0526), 0.96: tensor(0.0526), 0.97: tensor(0.0526), 0.98: tensor(0.0525), 0.99: tensor(0.0525)}
348
+
349
+ '''
350
+
351
+
352
+ # for (noise, lpips), (noise_2, snr) in zip(mean_lpips_dict.items(), mean_snr_dict.items()):
353
+ # print(noise, snr)
354
+
355
+
356
+ #So here we want to print out our x, which is the mean_snr, and our y, which is the mean noise
357
+
358
+
359
+ # images.append((valid_reconstructed_images*255).astype(np.uint8))
360
+
361
+ if __name__ == '__main__':
362
+ app.run(main)
f16c16/train.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+ import flax.linen as nn
8
+ import jax.numpy as jnp
9
+ from absl import app, flags
10
+ from functools import partial
11
+ import numpy as np
12
+ import tqdm
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import flax
16
+ import optax
17
+ import wandb
18
+ from ml_collections import config_flags
19
+ import ml_collections
20
+ import tensorflow_datasets as tfds
21
+ import tensorflow as tf
22
+ tf.config.set_visible_devices([], "GPU")
23
+ tf.config.set_visible_devices([], "TPU")
24
+ import matplotlib.pyplot as plt
25
+ from typing import Any
26
+ import os
27
+
28
+ from utils.wandb import setup_wandb, default_wandb_config
29
+ from utils.train_state import TrainState, target_update
30
+ from utils.checkpoint import Checkpoint
31
+ from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model
32
+ from utils.fid import get_fid_network, fid_from_stats
33
+ from models.vqvae import VQVAE
34
+ from models.discriminator import Discriminator
35
+
36
+ FLAGS = flags.FLAGS
37
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
38
+ flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).')
39
+ flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp" , 'Load dir (if not None, load params from here).')
40
+ flags.DEFINE_integer('seed', 0, 'Random seed.')
41
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42
+ flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.')
43
+ flags.DEFINE_integer('save_interval', 1000, 'Save interval.')
44
+ flags.DEFINE_integer('batch_size', 64, 'Total Batch size.')
45
+ flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.')
46
+
47
+ model_config = ml_collections.ConfigDict({
48
+ # VQVAE
49
+ 'lr': 0.0001,
50
+ 'beta1': 0.0,#.5
51
+ 'beta2': 0.99,#.9
52
+ 'lr_warmup_steps': 4000,
53
+ 'lr_decay_steps': 1_000_000, #They use 'lambdalr'
54
+ 'filters': 128,
55
+ 'num_res_blocks': 2,
56
+ 'channel_multipliers': (1, 1, 2, 2, 4),
57
+ 'embedding_dim': 16,
58
+ 'norm_type': 'GN',
59
+ 'weight_decay': 0.05,#None maybe?
60
+ 'clip_gradient': 1.0,
61
+ 'l2_loss_weight': 1.0,#They use L1 actually
62
+ 'eps_update_rate': 0.9999,
63
+ # Quantizer
64
+ 'quantizer_type': 'ae', # or 'fsq', 'kl'
65
+ # Quantizer (VQ)
66
+ 'quantizer_loss_ratio': 1,
67
+ 'codebook_size': 1024,
68
+ 'entropy_loss_ratio': 0.1,
69
+ 'entropy_loss_type': 'softmax',
70
+ 'entropy_temperature': 0.01,
71
+ 'commitment_cost': 0.25,
72
+ # Quantizer (FSQ)
73
+ 'fsq_levels': 5, # Bins per dimension.
74
+ # Quantizer (KL)
75
+ 'kl_weight': 0.000001,#They use 1e-6 on their stuff LUL. .001 is the default
76
+ # GAN
77
+ 'g_adversarial_loss_weight': 0.5,
78
+ 'g_grad_penalty_cost': 10,
79
+ 'perceptual_loss_weight': 0.5,
80
+ 'gan_warmup_steps': 100000,#50000, #Temporary extra time
81
+ "pl_decay": 0.01,
82
+ "pl_weight": -1,
83
+ 'MMD_weight': 1.0
84
+
85
+ })
86
+
87
+ wandb_config = default_wandb_config()
88
+ wandb_config.update({
89
+ 'project': 'vqvae',
90
+ 'name': 'vqvae_{dataset_name}',
91
+ })
92
+
93
+ config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
94
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
95
+
96
+ ##############################################
97
+ ## Model Definitions.
98
+ ##############################################
99
+
100
+ @jax.vmap
101
+ def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
102
+ """https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py
103
+ """
104
+ zeros = jnp.zeros_like(logits, dtype=logits.dtype)
105
+ condition = (logits >= zeros)
106
+ relu_logits = jnp.where(condition, logits, zeros)
107
+ neg_abs_logits = jnp.where(condition, -logits, logits)
108
+ return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
109
+
110
+ class VQGANModel(flax.struct.PyTreeNode):
111
+ rng: Any
112
+ config: dict = flax.struct.field(pytree_node=False)
113
+ vqvae: TrainState
114
+ vqvae_eps: TrainState
115
+ discriminator: TrainState
116
+
117
+ # Train G and D.
118
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
119
+ def update(self, images, pmap_axis='data'):
120
+ new_rng, curr_key = jax.random.split(self.rng, 2)
121
+
122
+ resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy')
123
+
124
+ is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32)
125
+ #Maybe only start GAN way later on?
126
+
127
+ def loss_fn(params_vqvae, params_disc):
128
+
129
+ def path_reg_loss(latents, targets):#let's have pl_mean be in our self.config
130
+ #1/2 should be our spatial dimensions.
131
+
132
+ latents = latents[0:2, :, :, :]
133
+ targets = targets[0:2, :, :, :]
134
+ pl_noise = jax.random.normal(new_rng, shape = targets.shape) / jnp.sqrt(targets.shape[1] * targets.shape[2])
135
+ def grad_sum(latents, pl_noise):#So we don't have access to the actual decode method
136
+ #return jnp.sum(self.vqvae.decode(latents))
137
+
138
+ #I am not sure if this makes any sense whatsoever tbh
139
+ my_sum = self.vqvae(latents, params=params_vqvae, method="decode", rngs={'noise': curr_key})*pl_noise
140
+ print("Decode shape", my_sum.shape)
141
+ return jnp.sum(my_sum)
142
+
143
+ decode_grad_fn = jax.grad(grad_sum)
144
+ pl_grads = decode_grad_fn(latents, pl_noise)
145
+ pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis = [2,3]), axis = 1))
146
+ #pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=3))
147
+
148
+ pl_mean = self.vqvae.pl_mean + self.config.pl_decay * (jnp.mean(pl_lengths) - self.vqvae.pl_mean)
149
+ pl_penalty = jnp.square(pl_lengths - pl_mean)
150
+ loss = jnp.mean(pl_penalty)
151
+ return loss, pl_mean
152
+
153
+ if self.config.pl_weight != -1:
154
+ smooth_loss, pl_mean = path_reg_loss(result_dict["latents"], reconstructed_images)
155
+ # self.vqvae.replace(pl_mean = pl_mean)
156
+ #We need to update pl mean in self.vqvae
157
+
158
+ # Reconstruct image
159
+ reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key})
160
+ print("Reconstructed images shape", reconstructed_images.shape)
161
+ print("Input images shape", images.shape)
162
+ assert reconstructed_images.shape == images.shape
163
+
164
+ # GAN loss on VQVAE output.
165
+ discriminator_fn = lambda x: self.discriminator(x, params=params_disc)
166
+ real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False)
167
+ gradient = vjp_fn(jnp.ones_like(real_logit))[0] # Gradient of discriminator output wrt. real images.
168
+ gradient = gradient.reshape((images.shape[0], -1))
169
+ gradient = jnp.asarray(gradient, jnp.float32)
170
+ penalty = jnp.sum(jnp.square(gradient), axis=-1)
171
+ penalty = jnp.mean(penalty) # Gradient penalty for training D.
172
+ fake_logit = discriminator_fn(reconstructed_images)
173
+ d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean()
174
+ d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean()
175
+ loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost'])
176
+
177
+ d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean()
178
+ d_loss_for_vae = d_loss_for_vae * is_gan_training
179
+
180
+ real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images)
181
+ fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images)
182
+ perceptual_loss = jnp.mean((real_pools - fake_pools)**2)
183
+
184
+ l2_loss = jnp.mean((reconstructed_images - images) ** 2)
185
+ quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0
186
+ if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two":
187
+ quantizer_loss = quantizer_loss * self.config['kl_weight']
188
+ elif self.config["quantizer_type"] == "MMD":
189
+ quantizer_loss = quantizer_loss * self.config['MMD_weight']
190
+ loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \
191
+ + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
192
+ + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
193
+ + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
194
+ #+ (smooth_loss * FLAGS.model['pl_weight'] )
195
+ codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
196
+
197
+ return_dict = {
198
+ 'loss_vae': loss_vae,
199
+ 'loss_d': loss_d,
200
+ 'l2_loss': l2_loss,
201
+ 'd_loss_for_vae': d_loss_for_vae,
202
+ 'perceptual_loss': perceptual_loss,
203
+ 'quantizer_loss': quantizer_loss,
204
+ 'codebook_usage': codebook_usage,
205
+ #'pl_loss': smooth_loss,
206
+ }
207
+
208
+ if self.config["pl_weight"] != -1:
209
+ loss_vae += (smooth_loss * FLAGS.model["pl_weight"])
210
+ return_dict["pl_mean"] = pl_mean
211
+ return_dict["smooth_loss"] = smooth_loss
212
+
213
+
214
+ return (loss_vae, loss_d), return_dict
215
+
216
+
217
+ # This is a fancy way to do 'jax.grad' so (loss_vae, params_vqvae) and (loss_d, params_disc) are differentiated.
218
+ _, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True)
219
+ vae_grads, _ = grad_fn((1., 0.))
220
+ _, d_grads = grad_fn((0., 1.))
221
+
222
+ vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis)
223
+ d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis)
224
+ d_grads = jax.tree.map(lambda x: x * is_gan_training, d_grads)
225
+
226
+ info = jax.lax.pmean(info, axis_name=pmap_axis)
227
+ if self.config['quantizer_type'] == 'fsq':
228
+ info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1]
229
+
230
+ updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params)
231
+ new_params = optax.apply_updates(self.vqvae.params, updates)
232
+
233
+ if self.config["pl_weight"] != -1:
234
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state, pl_mean=info["pl_mean"])
235
+ else:
236
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state)
237
+
238
+ updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params)
239
+ new_params = optax.apply_updates(self.discriminator.params, updates)
240
+ new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state)
241
+
242
+ info['grad_norm_vae'] = optax.global_norm(vae_grads)
243
+ info['grad_norm_d'] = optax.global_norm(d_grads)
244
+ info['update_norm'] = optax.global_norm(updates)
245
+ info['param_norm'] = optax.global_norm(new_params)
246
+ info['is_gan_training'] = is_gan_training
247
+
248
+ new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate'])
249
+
250
+ new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator)
251
+ return new_model, info
252
+
253
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
254
+ def reconstruction(self, images, pmap_axis='data', sampling = True):
255
+ if not sampling:
256
+ reconstructed_images, _ = self.vqvae_eps(images)
257
+ else:#Not sure what our theoretical sampling mode does
258
+ new_rng, curr_key = jax.random.split(self.rng, 2)
259
+ reconstructed_images, _ = self.vqvae_eps(images, rngs={'noise': curr_key})
260
+
261
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
262
+ return reconstructed_images
263
+
264
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
265
+ def reconstruction_sampling(self, images, pmap_axis='data'):
266
+
267
+ reconstructed_images_determistic, _ = self.vqvae_eps(images)
268
+
269
+
270
+ new_rng, curr_key = jax.random.split(self.rng, 2)
271
+ reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
272
+
273
+ #We don't need to return the result dict.
274
+ reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
275
+ reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
276
+
277
+ return reconstructed_images_determistic, reconstructed_images_sample
278
+
279
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
280
+ def reconstruction_interpolation(self, images, pmap_axis='data'):
281
+
282
+ #So we *have* our two images. We are going to linearly interpolate between them in... latent space
283
+ #But also in image space?
284
+ #Sure, why not
285
+ reconstructed_images_determistic, _ = self.vqvae_eps(images)
286
+
287
+
288
+ new_rng, curr_key = jax.random.split(self.rng, 2)
289
+ reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
290
+
291
+ #We don't need to return the result dict.
292
+ reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
293
+ reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
294
+
295
+ return reconstructed_images_determistic, reconstructed_images_sample
296
+
297
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
298
+ def get_latent(self, images, pmap_axis='data'):
299
+
300
+ #We do *not* add the noise ourselves, just save it.
301
+ latents, result_dict = self.vqvae_eps(images, params=self.vqvae_eps.params, method="encode")
302
+
303
+ # reconstructed_images, result_dict_two = self.vqvae_eps(images)
304
+ # reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
305
+ #
306
+ #
307
+ # decoded = self.vqvae_eps(latents, params=self.vqvae_eps.params, method="decode")
308
+ # decoded = jnp.clip(decoded, 0, 1)
309
+
310
+ #reconstructed images should be correct
311
+ return latents, result_dict#, result_dict_two, reconstructed_images, decoded
312
+
313
+
314
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
315
+ def reconstruction_noisy(self, images, pmap_axis='data'):
316
+
317
+
318
+ noises = []
319
+ numbers = np.arange(0.00, 1.0, 0.01)
320
+
321
+ for number in numbers:
322
+ noises.append(float(number))
323
+
324
+
325
+ #So 3 things to try out.
326
+ #One is normalize variance of the latents before adding noise, start there
327
+ #The second is plot snr instead.
328
+ #snr = var(latent)/var(noise)
329
+ #var is std^2
330
+
331
+
332
+ #This return the full reconstruction, but *also* the latents.
333
+ reconstructed_images, result_dict = self.vqvae_eps(images)
334
+ latents = result_dict["latents"]
335
+ std = result_dict["std"]
336
+ #We need to check the latnes std
337
+
338
+ #Get rng for creating noise.
339
+ new_rng, curr_key = jax.random.split(self.rng, 2)
340
+
341
+ decode = []
342
+ latent_std = latents.std(axis = [1,2,3]).reshape(-1,1,1,1)
343
+
344
+ for mult in noises:
345
+
346
+ noise = jax.random.normal(curr_key, latents.shape)
347
+ #Combine noise with latents
348
+
349
+
350
+ if True:
351
+ latent_var = latent_std ** 2
352
+ noise_std = mult*noise.std()#noise std should be around 1
353
+ noise_var = mult ** 2
354
+ if noise_var == 0:#If noise is zero, then instead denominator is it's variance
355
+ snr = 0
356
+ else:
357
+ snr = latent_var/noise_var
358
+
359
+ temp_latents = latents + noise*mult
360
+
361
+ #vae_eps is the determinstic one.
362
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
363
+ decoded = jnp.clip(decoded, 0, 1)
364
+ if True:
365
+ decode.append((decoded, snr))
366
+
367
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
368
+ return reconstructed_images, decode, std
369
+
370
+
371
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
372
+ def reconstruction_ppl(self, images, pmap_axis='data'):
373
+
374
+ epsilon = .0001
375
+ reconstructed_images, result_dict = self.vqvae_eps(images)
376
+ latents = result_dict["latents"]
377
+ std = result_dict["std"]
378
+
379
+ new_rng, curr_key = jax.random.split(self.rng, 2)
380
+
381
+ noise = jax.random.normal(curr_key, latents.shape)
382
+ #Combine noise with latents
383
+
384
+ temp_latents = latents + noise * epsilon
385
+ # print(temp_latents.shape)#Probably should be like, bs, 32,32,4
386
+ # exit()
387
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
388
+ decoded = jnp.clip(decoded, 0, 1)
389
+
390
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
391
+ return reconstructed_images, decoded, std, latents
392
+
393
+
394
+ #So this method simply will return the gradient/jacobian
395
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
396
+ def reconstruction_grad_distance(self, images, pmap_axis='data'):
397
+ #We want to try and identify C.
398
+ #C means that when we change our latents by a specific and small number X, our outputs change by C*X also.
399
+ #We want to capture all of the C, and see what their STD is.
400
+ pass
401
+
402
+
403
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
404
+ def reconstruction_ppl_two(self, images, pmap_axis='data'):
405
+
406
+ epsilon = .0001
407
+ reconstructed_images, result_dict = self.vqvae_eps(images)
408
+ latents = result_dict["latents"]
409
+ std = result_dict["std"]
410
+
411
+ new_rng, curr_key = jax.random.split(self.rng, 2)
412
+
413
+ noise = jax.random.normal(curr_key, latents.shape)
414
+ #Combine noise with latents
415
+
416
+ temp_latents = latents + noise/2 * epsilon
417
+
418
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
419
+ decoded = jnp.clip(decoded, 0, 1)
420
+
421
+ temp_latents_2 = latents + -1 * noise/2 * epsilon
422
+
423
+ decoded_2 = self.vqvae_eps(temp_latents_2, params=self.vqvae_eps.params, method="decode")
424
+ decoded_2 = jnp.clip(decoded_2, 0, 1)
425
+
426
+
427
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
428
+ return reconstructed_images, decoded, std, latents, decoded_2
429
+
430
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
431
+ def reconstruction_ppl_image(self, images, pmap_axis='data'):
432
+
433
+ epsilon = .0001
434
+ new_rng, curr_key = jax.random.split(self.rng, 2)
435
+
436
+ reconstructed_images, result_dict = self.vqvae_eps(images)
437
+ latents = result_dict["latents"]
438
+ std = result_dict["std"]
439
+
440
+
441
+ noise = jax.random.normal(curr_key, images.shape)
442
+ images = images + noise * epsilon
443
+
444
+
445
+ decoded, result_dict_2 = self.vqvae_eps(images)
446
+ decoded = jnp.clip(decoded, 0, 1)
447
+
448
+ latents_noisy = result_dict_2["latents"]
449
+ std_noisy = result_dict_2["std"]
450
+
451
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
452
+ return reconstructed_images, decoded, std, latents, std_noisy, latents_noisy
453
+
454
+ ##############################################
455
+ ## Training Code.
456
+ ##############################################
457
+ def main(_):
458
+ np.random.seed(FLAGS.seed)
459
+ print("Using devices", jax.local_devices())
460
+ device_count = len(jax.local_devices())
461
+ global_device_count = jax.device_count()
462
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
463
+ print("Device count", device_count)
464
+ print("Global device count", global_device_count)
465
+ print("Global Batch: ", FLAGS.batch_size)
466
+ print("Node Batch: ", local_batch_size)
467
+ print("Device Batch:", local_batch_size // device_count)
468
+
469
+ # Create wandb logger
470
+ if jax.process_index() == 0:
471
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
472
+
473
+ def get_dataset(is_train):
474
+ if 'imagenet' in FLAGS.dataset_name:
475
+ def deserialization_fn(data):
476
+ image = data['image']
477
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
478
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
479
+ if 'imagenet256' in FLAGS.dataset_name:
480
+ image = tf.image.resize(image, (256, 256))
481
+ elif 'imagenet128' in FLAGS.dataset_name:
482
+ image = tf.image.resize(image, (128, 128))
483
+ else:
484
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
485
+ if is_train:
486
+ image = tf.image.random_flip_left_right(image)
487
+ image = tf.cast(image, tf.float32) / 255.0
488
+ return image
489
+
490
+
491
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
492
+ print(split)
493
+ dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm")
494
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
495
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
496
+ dataset = dataset.repeat()
497
+ dataset = dataset.batch(local_batch_size)
498
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
499
+ dataset = tfds.as_numpy(dataset)
500
+ dataset = iter(dataset)
501
+ return dataset
502
+ else:
503
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
504
+
505
+ dataset = get_dataset(is_train=True)
506
+ dataset_valid = get_dataset(is_train=False)
507
+ example_obs = next(dataset)[:1]
508
+
509
+ get_fid_activations = get_fid_network()
510
+ if not os.path.exists('./data/imagenet256_fidstats_openai.npz'):
511
+ raise ValueError("Please download the FID stats file! See the README.")
512
+ truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
513
+ #truth_fid_stats = np.load("./base_stats.npz")
514
+
515
+ rng = jax.random.PRNGKey(FLAGS.seed)
516
+ rng, param_key = jax.random.split(rng)
517
+ print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
518
+
519
+ ###################################
520
+ # Creating Model and put on devices.
521
+ ###################################
522
+ FLAGS.model.image_channels = example_obs.shape[-1]
523
+ FLAGS.model.image_size = example_obs.shape[1]
524
+ vqvae_def = VQVAE(FLAGS.model, train=True)
525
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
526
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
527
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
528
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
529
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
530
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
531
+
532
+ discriminator_def = Discriminator(FLAGS.model)
533
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
534
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
535
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
536
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
537
+
538
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
539
+
540
+ if FLAGS.load_dir is not None:
541
+ try:
542
+ cp = Checkpoint(FLAGS.load_dir)
543
+ model = cp.load_model(model)
544
+ print("Loaded model with step", model.vqvae.step)
545
+ except:
546
+ print("Random init")
547
+ else:
548
+ print("Random init")
549
+
550
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
551
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
552
+
553
+ ###################################
554
+ # Train Loop
555
+ ###################################
556
+
557
+ best_fid = 100000
558
+
559
+ for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
560
+ smoothing=0.1,
561
+ dynamic_ncols=True):
562
+
563
+ batch_images = next(dataset)
564
+ batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]
565
+
566
+ model, update_info = model.update(batch_images)
567
+
568
+ if i % FLAGS.log_interval == 0:
569
+ update_info = jax.tree.map(lambda x: x.mean(), update_info)
570
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
571
+ if jax.process_index() == 0:
572
+ wandb.log(train_metrics, step=i)
573
+
574
+ if i % FLAGS.eval_interval == 0:
575
+ # Print some images
576
+ reconstructed_images = model.reconstruction(batch_images) # [devices, 8, 256, 256, 3]
577
+ valid_images = next(dataset_valid)
578
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
579
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
580
+
581
+ if jax.process_index() == 0:
582
+ wandb.log({'batch_image_mean': batch_images.mean()}, step=i)
583
+ wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i)
584
+ wandb.log({'batch_image_std': batch_images.std()}, step=i)
585
+ wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i)
586
+
587
+ # plot comparison witah matplotlib. put each reconstruction side by side.
588
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
589
+ #print("batch shape", batch_images.shape)#batch shape (4, 32, 256, 256, 3) #THE FIRST SHAPE IS DEVICES
590
+ #print("recon shape", reconstructed_images.shape)#it's all the same lol
591
+ #print("valid shape", valid_images.shape)
592
+ #it seems to be made for 8 device, aka tpuv3 instead
593
+ for j in range(4):#fuck it
594
+ axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1)
595
+ axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1)
596
+ wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
597
+ plt.close(fig)
598
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
599
+ for j in range(4):
600
+ axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
601
+ axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
602
+ wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i)
603
+ plt.close(fig)
604
+
605
+ # Validation Losses
606
+ _, valid_update_info = model.update(valid_images)
607
+ valid_update_info = jax.tree.map(lambda x: x.mean(), valid_update_info)
608
+ valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()}
609
+ if jax.process_index() == 0:
610
+ wandb.log(valid_metrics, step=i)
611
+
612
+ # FID measurement.
613
+ activations = []
614
+ activations2 = []
615
+ for _ in range(780):#This is apprximately 40k
616
+ valid_images = next(dataset_valid)
617
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
618
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
619
+
620
+ valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
621
+ method='bilinear', antialias=False)
622
+ valid_reconstructed_images = 2 * valid_reconstructed_images - 1
623
+ activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
624
+
625
+
626
+ #Only needed when we save
627
+ #valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
628
+ #method='bilinear', antialias=False)
629
+ #valid_reconstructed_images = 2 * valid_reconstructed_images - 1
630
+ #activations2 += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
631
+
632
+
633
+ # TODO: use all_gather to get activations from all devices.
634
+ #This seems to be FID with only 64 images?
635
+ activations = np.concatenate(activations, axis=0)
636
+ activations = activations.reshape((-1, activations.shape[-1]))
637
+
638
+ # activations2 = np.concatenate(activations2, axis = 0)
639
+ # activations2 = activations2.reshape((-1, activations2.shape[-1]))
640
+
641
+ print("doing this much FID", activations.shape)#8192, 2048 should be 2048 items then I guess
642
+ mu1 = np.mean(activations, axis=0)
643
+ sigma1 = np.cov(activations, rowvar=False)
644
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
645
+
646
+ # mu2 = np.mean(activations2, axis = 0)
647
+ # sigma2 = np.cov(activations2, rowvar = False)
648
+
649
+ #save mu2 and sigma2
650
+ #And then exit for now
651
+ # np.savez("base.npz", mu = mu2, sigma = sigma2)
652
+ # exit()
653
+
654
+ #Used with loading base
655
+ #fid = fid_from_stats(mu1, sigma1, mu2, sigma2)
656
+
657
+ if jax.process_index() == 0:
658
+ wandb.log({'validation/fid': fid}, step=i)
659
+ print("validation FID at step", i, fid)
660
+ #Then if fid is smaller than previous best FID, save new FID
661
+ if fid < best_fid:
662
+ model_single = flax.jax_utils.unreplicate(model)
663
+ cp = Checkpoint(FLAGS.save_dir + "best.tmp")
664
+ cp.set_model(model_single)
665
+ cp.save()
666
+ best_fid = fid
667
+
668
+ if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None):
669
+ if jax.process_index() == 0:
670
+ model_single = flax.jax_utils.unreplicate(model)
671
+ cp = Checkpoint(FLAGS.save_dir)
672
+ cp.set_model(model_single)
673
+ cp.save()
674
+
675
+ if __name__ == '__main__':
676
+ app.run(main)