KublaiKhan1 commited on
Commit
561f2b7
·
verified ·
1 Parent(s): a3e9437

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. NormalizedGram/train.py +706 -0
NormalizedGram/train.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+ import flax.linen as nn
8
+ import jax.numpy as jnp
9
+ from absl import app, flags
10
+ from functools import partial
11
+ import numpy as np
12
+ import tqdm
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import flax
16
+ import optax
17
+ import wandb
18
+ from ml_collections import config_flags
19
+ import ml_collections
20
+ import tensorflow_datasets as tfds
21
+ import tensorflow as tf
22
+ tf.config.set_visible_devices([], "GPU")
23
+ tf.config.set_visible_devices([], "TPU")
24
+ import matplotlib.pyplot as plt
25
+ from typing import Any
26
+ import os
27
+
28
+ from utils.wandb import setup_wandb, default_wandb_config
29
+ from utils.train_state import TrainState, target_update
30
+ from utils.checkpoint import Checkpoint
31
+ from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model
32
+ from utils.fid import get_fid_network, fid_from_stats
33
+ from models.vqvae import VQVAE
34
+ from models.discriminator import Discriminator
35
+
36
+ FLAGS = flags.FLAGS
37
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
38
+ flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).')
39
+ flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp" , 'Load dir (if not None, load params from here).')
40
+ flags.DEFINE_integer('seed', 0, 'Random seed.')
41
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42
+ flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.')
43
+ flags.DEFINE_integer('save_interval', 1000, 'Save interval.')
44
+ flags.DEFINE_integer('batch_size', 64, 'Total Batch size.')
45
+ flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.')
46
+
47
+ model_config = ml_collections.ConfigDict({
48
+ # VQVAE
49
+ 'lr': 0.0001,
50
+ 'beta1': 0.0,#.5
51
+ 'beta2': 0.99,#.9
52
+ 'lr_warmup_steps': 2000,
53
+ 'lr_decay_steps': 500_000,#They use 'lambdalr'
54
+ 'filters': 128,
55
+ 'num_res_blocks': 2,
56
+ 'channel_multipliers': (1, 2, 4, 4),#Seems right
57
+ 'embedding_dim': 4, # For FSQ, a good default is 4.
58
+ 'norm_type': 'GN',
59
+ 'weight_decay': 0.05,#None maybe?
60
+ 'clip_gradient': 1.0,
61
+ 'l2_loss_weight': 1.0,#They use L1 actually
62
+ 'eps_update_rate': 0.9999,
63
+ # Quantizer
64
+ 'quantizer_type': 'ae', # or 'fsq', 'kl'
65
+ # Quantizer (VQ)
66
+ 'quantizer_loss_ratio': 1,
67
+ 'codebook_size': 1024,
68
+ 'entropy_loss_ratio': 0.1,
69
+ 'entropy_loss_type': 'softmax',
70
+ 'entropy_temperature': 0.01,
71
+ 'commitment_cost': 0.25,
72
+ # Quantizer (FSQ)
73
+ 'fsq_levels': 5, # Bins per dimension.
74
+ # Quantizer (KL)
75
+ 'kl_weight': 0.00007,#They use 1e-6 on their stuff LUL. .001 is the default
76
+ # GAN
77
+ 'g_adversarial_loss_weight': 0.5,
78
+ 'g_grad_penalty_cost': 10,
79
+ 'perceptual_loss_weight': 0.5,
80
+ 'gan_warmup_steps': 25000,
81
+ "pl_decay": 0.01,
82
+ "pl_weight": -1,
83
+ 'MMD_weight': 1.0
84
+
85
+ })
86
+
87
+ wandb_config = default_wandb_config()
88
+ wandb_config.update({
89
+ 'project': 'vqvae',
90
+ 'name': 'vqvae_{dataset_name}',
91
+ })
92
+
93
+ config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
94
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
95
+
96
+ ##############################################
97
+ ## Model Definitions.
98
+ ##############################################
99
+
100
+ @jax.vmap
101
+ def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
102
+ """https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py
103
+ """
104
+ zeros = jnp.zeros_like(logits, dtype=logits.dtype)
105
+ condition = (logits >= zeros)
106
+ relu_logits = jnp.where(condition, logits, zeros)
107
+ neg_abs_logits = jnp.where(condition, -logits, logits)
108
+ return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
109
+
110
+ class VQGANModel(flax.struct.PyTreeNode):
111
+ rng: Any
112
+ config: dict = flax.struct.field(pytree_node=False)
113
+ vqvae: TrainState
114
+ vqvae_eps: TrainState
115
+ discriminator: TrainState
116
+
117
+ # Train G and D.
118
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
119
+ def update(self, images, pmap_axis='data'):
120
+ new_rng, curr_key = jax.random.split(self.rng, 2)
121
+
122
+ resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy')
123
+
124
+ is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32)
125
+
126
+ def loss_fn(params_vqvae, params_disc):
127
+
128
+ def path_reg_loss(latents, targets):#let's have pl_mean be in our self.config
129
+ #1/2 should be our spatial dimensions.
130
+
131
+ latents = latents[0:2, :, :, :]
132
+ targets = targets[0:2, :, :, :]
133
+ pl_noise = jax.random.normal(new_rng, shape = targets.shape) / jnp.sqrt(targets.shape[1] * targets.shape[2])
134
+ def grad_sum(latents, pl_noise):#So we don't have access to the actual decode method
135
+ #return jnp.sum(self.vqvae.decode(latents))
136
+
137
+ #I am not sure if this makes any sense whatsoever tbh
138
+ my_sum = self.vqvae(latents, params=params_vqvae, method="decode", rngs={'noise': curr_key})*pl_noise
139
+ print("Decode shape", my_sum.shape)
140
+ return jnp.sum(my_sum)
141
+
142
+ decode_grad_fn = jax.grad(grad_sum)
143
+ pl_grads = decode_grad_fn(latents, pl_noise)
144
+ pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis = [2,3]), axis = 1))
145
+ #pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=3))
146
+
147
+ pl_mean = self.vqvae.pl_mean + self.config.pl_decay * (jnp.mean(pl_lengths) - self.vqvae.pl_mean)
148
+ pl_penalty = jnp.square(pl_lengths - pl_mean)
149
+ loss = jnp.mean(pl_penalty)
150
+ return loss, pl_mean
151
+
152
+ if self.config.pl_weight != -1:
153
+ smooth_loss, pl_mean = path_reg_loss(result_dict["latents"], reconstructed_images)
154
+ # self.vqvae.replace(pl_mean = pl_mean)
155
+ #We need to update pl mean in self.vqvae
156
+
157
+ # Reconstruct image
158
+ reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key})
159
+ print("Reconstructed images shape", reconstructed_images.shape)
160
+ print("Input images shape", images.shape)
161
+ assert reconstructed_images.shape == images.shape
162
+
163
+ #Dino loss/gram loss/vicreg go here
164
+ #Let's start with a gram based thing.
165
+ #First, we need the gram matrix of every patch against every other patch.
166
+ #What's our plan
167
+ #What's our shape
168
+ #print(result_dict["latents"].shape)#16x32x32x4
169
+
170
+ #Gram is not normalized, so let's try that first.
171
+ reshaped_latents = result_dict["latents"].reshape(result_dict["latents"].shape[0],-1,result_dict["latents"].shape[-1])
172
+ #Reshape to batch x patches x embeddings
173
+ #Calculate gram matrix
174
+ reshaped_latents = reshaped_latents/jnp.linalg.norm(reshaped_latents, axis = -1, keepdims=True)
175
+ x_transposed = jnp.transpose(reshaped_latents, (0, 2, 1))
176
+
177
+ #Let's try.... normalized gram matrix...
178
+ gram_matrix = jnp.matmul(reshaped_latents, x_transposed)
179
+ diagonal_elements = jnp.einsum('bii->bi', gram_matrix)
180
+ sum_of_diagonals = jnp.sum(diagonal_elements)
181
+ total_sum = jnp.sum(gram_matrix)
182
+ gram_loss = total_sum - sum_of_diagonals
183
+ gram_loss = gram_loss / 992 #divide by 32x32 - 32
184
+
185
+ gram_loss = gram_loss / 1000
186
+
187
+
188
+ # GAN loss on VQVAE output.
189
+ discriminator_fn = lambda x: self.discriminator(x, params=params_disc)
190
+ real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False)
191
+ gradient = vjp_fn(jnp.ones_like(real_logit))[0] # Gradient of discriminator output wrt. real images.
192
+ gradient = gradient.reshape((images.shape[0], -1))
193
+ gradient = jnp.asarray(gradient, jnp.float32)
194
+ penalty = jnp.sum(jnp.square(gradient), axis=-1)
195
+ penalty = jnp.mean(penalty) # Gradient penalty for training D.
196
+ fake_logit = discriminator_fn(reconstructed_images)
197
+ d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean()
198
+ d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean()
199
+ loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost'])
200
+
201
+ d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean()
202
+ d_loss_for_vae = d_loss_for_vae * is_gan_training
203
+
204
+ real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images)
205
+ fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images)
206
+ perceptual_loss = jnp.mean((real_pools - fake_pools)**2)
207
+
208
+ l2_loss = jnp.mean((reconstructed_images - images) ** 2)
209
+ quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0
210
+ if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two":
211
+ quantizer_loss = quantizer_loss * self.config['kl_weight']
212
+ elif self.config["quantizer_type"] == "MMD":
213
+ quantizer_loss = quantizer_loss * self.config['MMD_weight']
214
+ loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \
215
+ + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
216
+ + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
217
+ + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
218
+ + gram_loss * .05 \
219
+ #+ (smooth_loss * FLAGS.model['pl_weight'] )
220
+ codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
221
+
222
+ return_dict = {
223
+ 'loss_vae': loss_vae,
224
+ 'loss_d': loss_d,
225
+ 'l2_loss': l2_loss,
226
+ 'd_loss_for_vae': d_loss_for_vae,
227
+ 'perceptual_loss': perceptual_loss,
228
+ 'quantizer_loss': quantizer_loss,
229
+ 'codebook_usage': codebook_usage,
230
+ 'gram_loss' : gram_loss
231
+ }
232
+
233
+ if self.config["pl_weight"] != -1:
234
+ loss_vae += (smooth_loss * FLAGS.model["pl_weight"])
235
+ return_dict["pl_mean"] = pl_mean
236
+ return_dict["smooth_loss"] = smooth_loss
237
+
238
+
239
+ return (loss_vae, loss_d), return_dict
240
+
241
+
242
+ # This is a fancy way to do 'jax.grad' so (loss_vae, params_vqvae) and (loss_d, params_disc) are differentiated.
243
+ _, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True)
244
+ vae_grads, _ = grad_fn((1., 0.))
245
+ _, d_grads = grad_fn((0., 1.))
246
+
247
+ vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis)
248
+ d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis)
249
+ d_grads = jax.tree.map(lambda x: x * is_gan_training, d_grads)
250
+
251
+ info = jax.lax.pmean(info, axis_name=pmap_axis)
252
+ if self.config['quantizer_type'] == 'fsq':
253
+ info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1]
254
+
255
+ updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params)
256
+ new_params = optax.apply_updates(self.vqvae.params, updates)
257
+
258
+ if self.config["pl_weight"] != -1:
259
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state, pl_mean=info["pl_mean"])
260
+ else:
261
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state)
262
+
263
+ updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params)
264
+ new_params = optax.apply_updates(self.discriminator.params, updates)
265
+ new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state)
266
+
267
+ info['grad_norm_vae'] = optax.global_norm(vae_grads)
268
+ info['grad_norm_d'] = optax.global_norm(d_grads)
269
+ info['update_norm'] = optax.global_norm(updates)
270
+ info['param_norm'] = optax.global_norm(new_params)
271
+ info['is_gan_training'] = is_gan_training
272
+
273
+ new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate'])
274
+
275
+ new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator)
276
+ return new_model, info
277
+
278
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
279
+ def reconstruction(self, images, pmap_axis='data', sampling = True):
280
+ if not sampling:
281
+ reconstructed_images, _ = self.vqvae_eps(images)
282
+ else:#Not sure what our theoretical sampling mode does
283
+ new_rng, curr_key = jax.random.split(self.rng, 2)
284
+ reconstructed_images, _ = self.vqvae_eps(images, rngs={'noise': curr_key})
285
+
286
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
287
+ return reconstructed_images
288
+
289
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
290
+ def reconstruction_sampling(self, images, pmap_axis='data'):
291
+
292
+ reconstructed_images_determistic, _ = self.vqvae_eps(images)
293
+
294
+
295
+ new_rng, curr_key = jax.random.split(self.rng, 2)
296
+ reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
297
+
298
+ #We don't need to return the result dict.
299
+ reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
300
+ reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
301
+
302
+ return reconstructed_images_determistic, reconstructed_images_sample
303
+
304
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
305
+ def reconstruction_interpolation(self, images, pmap_axis='data'):
306
+
307
+ #So we *have* our two images. We are going to linearly interpolate between them in... latent space
308
+ #But also in image space?
309
+ #Sure, why not
310
+ reconstructed_images_determistic, _ = self.vqvae_eps(images)
311
+
312
+
313
+ new_rng, curr_key = jax.random.split(self.rng, 2)
314
+ reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
315
+
316
+ #We don't need to return the result dict.
317
+ reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
318
+ reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
319
+
320
+ return reconstructed_images_determistic, reconstructed_images_sample
321
+
322
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
323
+ def get_latent(self, images, pmap_axis='data'):
324
+
325
+ #We do *not* add the noise ourselves, just save it.
326
+ latents, result_dict = self.vqvae_eps(images, params=self.vqvae_eps.params, method="encode")
327
+
328
+ # reconstructed_images, result_dict_two = self.vqvae_eps(images)
329
+ # reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
330
+ #
331
+ #
332
+ # decoded = self.vqvae_eps(latents, params=self.vqvae_eps.params, method="decode")
333
+ # decoded = jnp.clip(decoded, 0, 1)
334
+
335
+ #reconstructed images should be correct
336
+ return latents, result_dict#, result_dict_two, reconstructed_images, decoded
337
+
338
+
339
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
340
+ def reconstruction_noisy(self, images, pmap_axis='data'):
341
+
342
+
343
+ noises = []
344
+ numbers = np.arange(0.00, 1.0, 0.01)
345
+
346
+ for number in numbers:
347
+ noises.append(float(number))
348
+
349
+
350
+ #So 3 things to try out.
351
+ #One is normalize variance of the latents before adding noise, start there
352
+ #The second is plot snr instead.
353
+ #snr = var(latent)/var(noise)
354
+ #var is std^2
355
+
356
+
357
+ #This return the full reconstruction, but *also* the latents.
358
+ reconstructed_images, result_dict = self.vqvae_eps(images)
359
+ latents = result_dict["latents"]
360
+ std = result_dict["std"]
361
+ #We need to check the latnes std
362
+
363
+ #Get rng for creating noise.
364
+ new_rng, curr_key = jax.random.split(self.rng, 2)
365
+
366
+ decode = []
367
+ latent_std = latents.std(axis = [1,2,3]).reshape(-1,1,1,1)
368
+
369
+ for mult in noises:
370
+
371
+ noise = jax.random.normal(curr_key, latents.shape)
372
+ #Combine noise with latents
373
+
374
+
375
+ if True:
376
+ latent_var = latent_std ** 2
377
+ noise_std = mult*noise.std()#noise std should be around 1
378
+ noise_var = mult ** 2
379
+ if noise_var == 0:#If noise is zero, then instead denominator is it's variance
380
+ snr = 0
381
+ else:
382
+ snr = latent_var/noise_var
383
+
384
+ temp_latents = latents + noise*mult
385
+
386
+ #vae_eps is the determinstic one.
387
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
388
+ decoded = jnp.clip(decoded, 0, 1)
389
+ if True:
390
+ decode.append((decoded, snr))
391
+
392
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
393
+ return reconstructed_images, decode, std
394
+
395
+
396
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
397
+ def reconstruction_ppl(self, images, pmap_axis='data'):
398
+
399
+ epsilon = .0001
400
+ reconstructed_images, result_dict = self.vqvae_eps(images)
401
+ latents = result_dict["latents"]
402
+ std = result_dict["std"]
403
+
404
+ new_rng, curr_key = jax.random.split(self.rng, 2)
405
+
406
+ noise = jax.random.normal(curr_key, latents.shape)
407
+ #Combine noise with latents
408
+
409
+ temp_latents = latents + noise * epsilon
410
+ # print(temp_latents.shape)#Probably should be like, bs, 32,32,4
411
+ # exit()
412
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
413
+ decoded = jnp.clip(decoded, 0, 1)
414
+
415
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
416
+ return reconstructed_images, decoded, std, latents
417
+
418
+
419
+ #So this method simply will return the gradient/jacobian
420
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
421
+ def reconstruction_grad_distance(self, images, pmap_axis='data'):
422
+ #We want to try and identify C.
423
+ #C means that when we change our latents by a specific and small number X, our outputs change by C*X also.
424
+ #We want to capture all of the C, and see what their STD is.
425
+ pass
426
+
427
+
428
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
429
+ def reconstruction_ppl_two(self, images, pmap_axis='data'):
430
+
431
+ epsilon = .0001
432
+ reconstructed_images, result_dict = self.vqvae_eps(images)
433
+ latents = result_dict["latents"]
434
+ std = result_dict["std"]
435
+
436
+ new_rng, curr_key = jax.random.split(self.rng, 2)
437
+
438
+ noise = jax.random.normal(curr_key, latents.shape)
439
+ #Combine noise with latents
440
+
441
+ temp_latents = latents + noise/2 * epsilon
442
+
443
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
444
+ decoded = jnp.clip(decoded, 0, 1)
445
+
446
+ temp_latents_2 = latents + -1 * noise/2 * epsilon
447
+
448
+ decoded_2 = self.vqvae_eps(temp_latents_2, params=self.vqvae_eps.params, method="decode")
449
+ decoded_2 = jnp.clip(decoded_2, 0, 1)
450
+
451
+
452
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
453
+ return reconstructed_images, decoded, std, latents, decoded_2
454
+
455
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
456
+ def reconstruction_ppl_image(self, images, pmap_axis='data'):
457
+
458
+ epsilon = .0001
459
+ new_rng, curr_key = jax.random.split(self.rng, 2)
460
+
461
+ reconstructed_images, result_dict = self.vqvae_eps(images)
462
+ latents = result_dict["latents"]
463
+ std = result_dict["std"]
464
+
465
+
466
+ noise = jax.random.normal(curr_key, images.shape)
467
+ images = images + noise * epsilon
468
+
469
+
470
+ decoded, result_dict_2 = self.vqvae_eps(images)
471
+ decoded = jnp.clip(decoded, 0, 1)
472
+
473
+ latents_noisy = result_dict_2["latents"]
474
+ std_noisy = result_dict_2["std"]
475
+
476
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
477
+ return reconstructed_images, decoded, std, latents, std_noisy, latents_noisy
478
+
479
+ ##############################################
480
+ ## Training Code.
481
+ ##############################################
482
+ def main(_):
483
+ np.random.seed(FLAGS.seed)
484
+ print("Using devices", jax.local_devices())
485
+ device_count = len(jax.local_devices())
486
+ global_device_count = jax.device_count()
487
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
488
+ print("Device count", device_count)
489
+ print("Global device count", global_device_count)
490
+ print("Global Batch: ", FLAGS.batch_size)
491
+ print("Node Batch: ", local_batch_size)
492
+ print("Device Batch:", local_batch_size // device_count)
493
+
494
+ # Create wandb logger
495
+ if jax.process_index() == 0:
496
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
497
+
498
+ def get_dataset(is_train):
499
+ if 'imagenet' in FLAGS.dataset_name:
500
+ def deserialization_fn(data):
501
+ image = data['image']
502
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
503
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
504
+ if 'imagenet256' in FLAGS.dataset_name:
505
+ image = tf.image.resize(image, (256, 256))
506
+ elif 'imagenet128' in FLAGS.dataset_name:
507
+ image = tf.image.resize(image, (128, 128))
508
+ else:
509
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
510
+ if is_train:
511
+ image = tf.image.random_flip_left_right(image)
512
+ image = tf.cast(image, tf.float32) / 255.0
513
+ return image
514
+
515
+
516
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
517
+ print(split)
518
+ dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm")
519
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
520
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
521
+ dataset = dataset.repeat()
522
+ dataset = dataset.batch(local_batch_size)
523
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
524
+ dataset = tfds.as_numpy(dataset)
525
+ dataset = iter(dataset)
526
+ return dataset
527
+ else:
528
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
529
+
530
+ dataset = get_dataset(is_train=True)
531
+ dataset_valid = get_dataset(is_train=False)
532
+ example_obs = next(dataset)[:1]
533
+
534
+ get_fid_activations = get_fid_network()
535
+ if not os.path.exists('./data/imagenet256_fidstats_openai.npz'):
536
+ raise ValueError("Please download the FID stats file! See the README.")
537
+ truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
538
+ #truth_fid_stats = np.load("./base_stats.npz")
539
+
540
+ rng = jax.random.PRNGKey(FLAGS.seed)
541
+ rng, param_key = jax.random.split(rng)
542
+ print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
543
+
544
+ ###################################
545
+ # Creating Model and put on devices.
546
+ ###################################
547
+ FLAGS.model.image_channels = example_obs.shape[-1]
548
+ FLAGS.model.image_size = example_obs.shape[1]
549
+ vqvae_def = VQVAE(FLAGS.model, train=True)
550
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
551
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
552
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
553
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
554
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
555
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
556
+
557
+ discriminator_def = Discriminator(FLAGS.model)
558
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
559
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
560
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
561
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
562
+
563
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
564
+
565
+ if FLAGS.load_dir is not None:
566
+ try:
567
+ cp = Checkpoint(FLAGS.load_dir)
568
+ model = cp.load_model(model)
569
+ print("Loaded model with step", model.vqvae.step)
570
+ except:
571
+ print("Random init")
572
+ else:
573
+ print("Random init")
574
+
575
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
576
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
577
+
578
+ ###################################
579
+ # Train Loop
580
+ ###################################
581
+
582
+ best_fid = 100000
583
+
584
+ for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
585
+ smoothing=0.1,
586
+ dynamic_ncols=True):
587
+
588
+ batch_images = next(dataset)
589
+ batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]
590
+
591
+ model, update_info = model.update(batch_images)
592
+
593
+ print(update_info["gram_loss"])
594
+ print(update_info["loss_vae"])
595
+ print(update_info["l2_loss"])
596
+
597
+
598
+ if i % FLAGS.log_interval == 0:
599
+ update_info = jax.tree.map(lambda x: x.mean(), update_info)
600
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
601
+ if jax.process_index() == 0:
602
+ wandb.log(train_metrics, step=i)
603
+
604
+ if i % FLAGS.eval_interval == 0:
605
+ # Print some images
606
+ reconstructed_images = model.reconstruction(batch_images) # [devices, 8, 256, 256, 3]
607
+ valid_images = next(dataset_valid)
608
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
609
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
610
+
611
+ if jax.process_index() == 0:
612
+ wandb.log({'batch_image_mean': batch_images.mean()}, step=i)
613
+ wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i)
614
+ wandb.log({'batch_image_std': batch_images.std()}, step=i)
615
+ wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i)
616
+
617
+ # plot comparison witah matplotlib. put each reconstruction side by side.
618
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
619
+ #print("batch shape", batch_images.shape)#batch shape (4, 32, 256, 256, 3) #THE FIRST SHAPE IS DEVICES
620
+ #print("recon shape", reconstructed_images.shape)#it's all the same lol
621
+ #print("valid shape", valid_images.shape)
622
+ #it seems to be made for 8 device, aka tpuv3 instead
623
+ for j in range(4):#fuck it
624
+ axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1)
625
+ axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1)
626
+ wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
627
+ plt.close(fig)
628
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
629
+ for j in range(4):
630
+ axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
631
+ axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
632
+ wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i)
633
+ plt.close(fig)
634
+
635
+ # Validation Losses
636
+ _, valid_update_info = model.update(valid_images)
637
+ valid_update_info = jax.tree.map(lambda x: x.mean(), valid_update_info)
638
+ valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()}
639
+ if jax.process_index() == 0:
640
+ wandb.log(valid_metrics, step=i)
641
+
642
+ # FID measurement.
643
+ activations = []
644
+ activations2 = []
645
+ for _ in range(780):#This is apprximately 40k
646
+ valid_images = next(dataset_valid)
647
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
648
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
649
+
650
+ valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
651
+ method='bilinear', antialias=False)
652
+ valid_reconstructed_images = 2 * valid_reconstructed_images - 1
653
+ activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
654
+
655
+
656
+ #Only needed when we save
657
+ #valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
658
+ #method='bilinear', antialias=False)
659
+ #valid_reconstructed_images = 2 * valid_reconstructed_images - 1
660
+ #activations2 += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
661
+
662
+
663
+ # TODO: use all_gather to get activations from all devices.
664
+ #This seems to be FID with only 64 images?
665
+ activations = np.concatenate(activations, axis=0)
666
+ activations = activations.reshape((-1, activations.shape[-1]))
667
+
668
+ # activations2 = np.concatenate(activations2, axis = 0)
669
+ # activations2 = activations2.reshape((-1, activations2.shape[-1]))
670
+
671
+ print("doing this much FID", activations.shape)#8192, 2048 should be 2048 items then I guess
672
+ mu1 = np.mean(activations, axis=0)
673
+ sigma1 = np.cov(activations, rowvar=False)
674
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
675
+
676
+ # mu2 = np.mean(activations2, axis = 0)
677
+ # sigma2 = np.cov(activations2, rowvar = False)
678
+
679
+ #save mu2 and sigma2
680
+ #And then exit for now
681
+ # np.savez("base.npz", mu = mu2, sigma = sigma2)
682
+ # exit()
683
+
684
+ #Used with loading base
685
+ #fid = fid_from_stats(mu1, sigma1, mu2, sigma2)
686
+
687
+ if jax.process_index() == 0:
688
+ wandb.log({'validation/fid': fid}, step=i)
689
+ print("validation FID at step", i, fid)
690
+ #Then if fid is smaller than previous best FID, save new FID
691
+ if fid < best_fid:
692
+ model_single = flax.jax_utils.unreplicate(model)
693
+ cp = Checkpoint(FLAGS.save_dir + "best.tmp")
694
+ cp.set_model(model_single)
695
+ cp.save()
696
+ best_fid = fid
697
+
698
+ if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None):
699
+ if jax.process_index() == 0:
700
+ model_single = flax.jax_utils.unreplicate(model)
701
+ cp = Checkpoint(FLAGS.save_dir)
702
+ cp.set_model(model_single)
703
+ cp.save()
704
+
705
+ if __name__ == '__main__':
706
+ app.run(main)