KublaiKhan1 commited on
Commit
0679648
·
verified ·
1 Parent(s): 1ac0dd1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -312,3 +312,5 @@ NormalizedGram10K/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
312
  NormalizedGram10K/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
313
  1e-17/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
314
  1e-17/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
 
 
 
312
  NormalizedGram10K/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
313
  1e-17/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
314
  1e-17/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
315
+ gram_smol_latent/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
316
+ gram_smol_latent/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
gram_smol_latent/checkpoint.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fe93a1683447f9d80fe36da35860f79d297bd34883a7590c095c680cdf863ae
3
+ size 1369029948
gram_smol_latent/checkpointbest.tmp.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3829bad8e723cb16ff5c980859c1440515980c8c94bc470c4776641b65e0d6bf
3
+ size 1369029948
gram_smol_latent/train.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+ import flax.linen as nn
8
+ import jax.numpy as jnp
9
+ from absl import app, flags
10
+ from functools import partial
11
+ import numpy as np
12
+ import tqdm
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import flax
16
+ import optax
17
+ import wandb
18
+ from ml_collections import config_flags
19
+ import ml_collections
20
+ import tensorflow_datasets as tfds
21
+ import tensorflow as tf
22
+ tf.config.set_visible_devices([], "GPU")
23
+ tf.config.set_visible_devices([], "TPU")
24
+ import matplotlib.pyplot as plt
25
+ from typing import Any
26
+ import os
27
+
28
+ from utils.wandb import setup_wandb, default_wandb_config
29
+ from utils.train_state import TrainState, target_update
30
+ from utils.checkpoint import Checkpoint
31
+ from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model
32
+ from utils.fid import get_fid_network, fid_from_stats
33
+ from models.vqvae import VQVAE
34
+ from models.discriminator import Discriminator
35
+
36
+ FLAGS = flags.FLAGS
37
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
38
+ flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).')
39
+ flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp" , 'Load dir (if not None, load params from here).')
40
+ flags.DEFINE_integer('seed', 0, 'Random seed.')
41
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42
+ flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.')
43
+ flags.DEFINE_integer('save_interval', 1000, 'Save interval.')
44
+ flags.DEFINE_integer('batch_size', 64, 'Total Batch size.')
45
+ flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.')
46
+
47
+ model_config = ml_collections.ConfigDict({
48
+ # VQVAE
49
+ 'lr': 0.0001,
50
+ 'beta1': 0.0,#.5
51
+ 'beta2': 0.99,#.9
52
+ 'lr_warmup_steps': 2000,
53
+ 'lr_decay_steps': 500_000,#They use 'lambdalr'
54
+ 'filters': 128,
55
+ 'num_res_blocks': 2,
56
+ 'channel_multipliers': (1, 2, 4, 4),#Seems right
57
+ 'embedding_dim': 4, # For FSQ, a good default is 4.
58
+ 'norm_type': 'GN',
59
+ 'weight_decay': 0.05,#None maybe?
60
+ 'clip_gradient': 1.0,
61
+ 'l2_loss_weight': 1.0,#They use L1 actually
62
+ 'eps_update_rate': 0.9999,
63
+ # Quantizer
64
+ 'quantizer_type': 'ae', # or 'fsq', 'kl'
65
+ # Quantizer (VQ)
66
+ 'quantizer_loss_ratio': 1,
67
+ 'codebook_size': 1024,
68
+ 'entropy_loss_ratio': 0.1,
69
+ 'entropy_loss_type': 'softmax',
70
+ 'entropy_temperature': 0.01,
71
+ 'commitment_cost': 0.25,
72
+ # Quantizer (FSQ)
73
+ 'fsq_levels': 5, # Bins per dimension.
74
+ # Quantizer (KL)
75
+ 'kl_weight': 0.00007,#They use 1e-6 on their stuff LUL. .001 is the default
76
+ # GAN
77
+ 'g_adversarial_loss_weight': 0.5,
78
+ 'g_grad_penalty_cost': 10,
79
+ 'perceptual_loss_weight': 0.5,
80
+ 'gan_warmup_steps': 25000,
81
+ "pl_decay": 0.01,
82
+ "pl_weight": -1,
83
+ 'MMD_weight': 1.0
84
+
85
+ })
86
+
87
+ wandb_config = default_wandb_config()
88
+ wandb_config.update({
89
+ 'project': 'vqvae',
90
+ 'name': 'vqvae_{dataset_name}',
91
+ })
92
+
93
+ config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
94
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
95
+
96
+ ##############################################
97
+ ## Model Definitions.
98
+ ##############################################
99
+
100
+ @jax.vmap
101
+ def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
102
+ """https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py
103
+ """
104
+ zeros = jnp.zeros_like(logits, dtype=logits.dtype)
105
+ condition = (logits >= zeros)
106
+ relu_logits = jnp.where(condition, logits, zeros)
107
+ neg_abs_logits = jnp.where(condition, -logits, logits)
108
+ return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
109
+
110
+ class VQGANModel(flax.struct.PyTreeNode):
111
+ rng: Any
112
+ config: dict = flax.struct.field(pytree_node=False)
113
+ vqvae: TrainState
114
+ vqvae_eps: TrainState
115
+ discriminator: TrainState
116
+
117
+ # Train G and D.
118
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
119
+ def update(self, images, pmap_axis='data'):
120
+ new_rng, curr_key = jax.random.split(self.rng, 2)
121
+
122
+ resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy')
123
+
124
+ is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32)
125
+
126
+ def loss_fn(params_vqvae, params_disc):
127
+
128
+ def path_reg_loss(latents, targets):#let's have pl_mean be in our self.config
129
+ #1/2 should be our spatial dimensions.
130
+
131
+ latents = latents[0:2, :, :, :]
132
+ targets = targets[0:2, :, :, :]
133
+ pl_noise = jax.random.normal(new_rng, shape = targets.shape) / jnp.sqrt(targets.shape[1] * targets.shape[2])
134
+ def grad_sum(latents, pl_noise):#So we don't have access to the actual decode method
135
+ #return jnp.sum(self.vqvae.decode(latents))
136
+
137
+ #I am not sure if this makes any sense whatsoever tbh
138
+ my_sum = self.vqvae(latents, params=params_vqvae, method="decode", rngs={'noise': curr_key})*pl_noise
139
+ print("Decode shape", my_sum.shape)
140
+ return jnp.sum(my_sum)
141
+
142
+ decode_grad_fn = jax.grad(grad_sum)
143
+ pl_grads = decode_grad_fn(latents, pl_noise)
144
+ pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis = [2,3]), axis = 1))
145
+ #pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=3))
146
+
147
+ pl_mean = self.vqvae.pl_mean + self.config.pl_decay * (jnp.mean(pl_lengths) - self.vqvae.pl_mean)
148
+ pl_penalty = jnp.square(pl_lengths - pl_mean)
149
+ loss = jnp.mean(pl_penalty)
150
+ return loss, pl_mean
151
+
152
+ if self.config.pl_weight != -1:
153
+ smooth_loss, pl_mean = path_reg_loss(result_dict["latents"], reconstructed_images)
154
+ # self.vqvae.replace(pl_mean = pl_mean)
155
+ #We need to update pl mean in self.vqvae
156
+
157
+ # Reconstruct image
158
+ reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key})
159
+ print("Reconstructed images shape", reconstructed_images.shape)
160
+ print("Input images shape", images.shape)
161
+ assert reconstructed_images.shape == images.shape
162
+
163
+
164
+ def calculate_covariance_loss_single(image):
165
+ """Calculates the covariance loss for one image."""
166
+ # image.shape is (H, W, C)
167
+ C = image.shape[-1]
168
+
169
+ # Reshape the spatial dimensions into one dimension of "observations"
170
+ # New shape: (H*W, C)
171
+ reshaped_features = image.reshape(-1, C)
172
+
173
+ # Calculate the covariance matrix of the channels.
174
+ # We treat each channel as a variable and spatial locations as observations.
175
+ # The resulting shape will be (C, C).
176
+ cov_matrix = jnp.cov(reshaped_features, rowvar=False)
177
+
178
+ # The target is the identity matrix of size (C, C)
179
+ identity_matrix = jnp.eye(C)
180
+
181
+ # The loss is the sum of squared differences (Frobenius norm squared)
182
+ loss = jnp.sum(jnp.square(cov_matrix - identity_matrix))
183
+
184
+ return loss
185
+
186
+
187
+ B, H, W, C = reconstructed_images.shape
188
+ reshaped_features = reconstructed_images.reshape(B, -1, C)
189
+ batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
190
+ per_image_losses = batched_loss_fn(reconstructed_images)
191
+
192
+ gram_loss = jnp.mean(per_image_losses) * 1
193
+ #Gram loss is very low - let's crank it up until it starts harming thngs?
194
+
195
+
196
+
197
+ # GAN loss on VQVAE output.
198
+ discriminator_fn = lambda x: self.discriminator(x, params=params_disc)
199
+ real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False)
200
+ gradient = vjp_fn(jnp.ones_like(real_logit))[0] # Gradient of discriminator output wrt. real images.
201
+ gradient = gradient.reshape((images.shape[0], -1))
202
+ gradient = jnp.asarray(gradient, jnp.float32)
203
+ penalty = jnp.sum(jnp.square(gradient), axis=-1)
204
+ penalty = jnp.mean(penalty) # Gradient penalty for training D.
205
+ fake_logit = discriminator_fn(reconstructed_images)
206
+ d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean()
207
+ d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean()
208
+ loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost'])
209
+
210
+ d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean()
211
+ d_loss_for_vae = d_loss_for_vae * is_gan_training
212
+
213
+ real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images)
214
+ fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images)
215
+ perceptual_loss = jnp.mean((real_pools - fake_pools)**2)
216
+
217
+ l2_loss = jnp.mean((reconstructed_images - images) ** 2)
218
+ quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0
219
+ if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two":
220
+ quantizer_loss = quantizer_loss * self.config['kl_weight']
221
+ elif self.config["quantizer_type"] == "MMD":
222
+ quantizer_loss = quantizer_loss * self.config['MMD_weight']
223
+ loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \
224
+ + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
225
+ + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
226
+ + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
227
+ + gram_loss
228
+ #+ (smooth_loss * FLAGS.model['pl_weight'] )
229
+ codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
230
+
231
+ return_dict = {
232
+ 'loss_vae': loss_vae,
233
+ 'loss_d': loss_d,
234
+ 'l2_loss': l2_loss,
235
+ 'd_loss_for_vae': d_loss_for_vae,
236
+ 'perceptual_loss': perceptual_loss,
237
+ 'quantizer_loss': quantizer_loss,
238
+ 'codebook_usage': codebook_usage,
239
+ 'cov loss': gram_loss
240
+ #'pl_loss': smooth_loss,
241
+ }
242
+
243
+ if self.config["pl_weight"] != -1:
244
+ loss_vae += (smooth_loss * FLAGS.model["pl_weight"])
245
+ return_dict["pl_mean"] = pl_mean
246
+ return_dict["smooth_loss"] = smooth_loss
247
+
248
+
249
+ return (loss_vae, loss_d), return_dict
250
+
251
+
252
+ # This is a fancy way to do 'jax.grad' so (loss_vae, params_vqvae) and (loss_d, params_disc) are differentiated.
253
+ _, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True)
254
+ vae_grads, _ = grad_fn((1., 0.))
255
+ _, d_grads = grad_fn((0., 1.))
256
+
257
+ vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis)
258
+ d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis)
259
+ d_grads = jax.tree.map(lambda x: x * is_gan_training, d_grads)
260
+
261
+ info = jax.lax.pmean(info, axis_name=pmap_axis)
262
+ if self.config['quantizer_type'] == 'fsq':
263
+ info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1]
264
+
265
+ updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params)
266
+ new_params = optax.apply_updates(self.vqvae.params, updates)
267
+
268
+ if self.config["pl_weight"] != -1:
269
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state, pl_mean=info["pl_mean"])
270
+ else:
271
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state)
272
+
273
+ updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params)
274
+ new_params = optax.apply_updates(self.discriminator.params, updates)
275
+ new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state)
276
+
277
+ info['grad_norm_vae'] = optax.global_norm(vae_grads)
278
+ info['grad_norm_d'] = optax.global_norm(d_grads)
279
+ info['update_norm'] = optax.global_norm(updates)
280
+ info['param_norm'] = optax.global_norm(new_params)
281
+ info['is_gan_training'] = is_gan_training
282
+
283
+ new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate'])
284
+
285
+ new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator)
286
+ return new_model, info
287
+
288
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
289
+ def reconstruction(self, images, pmap_axis='data', sampling = True):
290
+ if not sampling:
291
+ reconstructed_images, _ = self.vqvae_eps(images)
292
+ else:#Not sure what our theoretical sampling mode does
293
+ new_rng, curr_key = jax.random.split(self.rng, 2)
294
+ reconstructed_images, _ = self.vqvae_eps(images, rngs={'noise': curr_key})
295
+
296
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
297
+ return reconstructed_images
298
+
299
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
300
+ def reconstruction_sampling(self, images, pmap_axis='data'):
301
+
302
+ reconstructed_images_determistic, _ = self.vqvae_eps(images)
303
+
304
+
305
+ new_rng, curr_key = jax.random.split(self.rng, 2)
306
+ reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
307
+
308
+ #We don't need to return the result dict.
309
+ reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
310
+ reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
311
+
312
+ return reconstructed_images_determistic, reconstructed_images_sample
313
+
314
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
315
+ def reconstruction_interpolation(self, images, pmap_axis='data'):
316
+
317
+ #So we *have* our two images. We are going to linearly interpolate between them in... latent space
318
+ #But also in image space?
319
+ #Sure, why not
320
+ reconstructed_images_determistic, _ = self.vqvae_eps(images)
321
+
322
+
323
+ new_rng, curr_key = jax.random.split(self.rng, 2)
324
+ reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
325
+
326
+ #We don't need to return the result dict.
327
+ reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
328
+ reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
329
+
330
+ return reconstructed_images_determistic, reconstructed_images_sample
331
+
332
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
333
+ def get_latent(self, images, pmap_axis='data'):
334
+
335
+ #We do *not* add the noise ourselves, just save it.
336
+ latents, result_dict = self.vqvae_eps(images, params=self.vqvae_eps.params, method="encode")
337
+
338
+ # reconstructed_images, result_dict_two = self.vqvae_eps(images)
339
+ # reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
340
+ #
341
+ #
342
+ # decoded = self.vqvae_eps(latents, params=self.vqvae_eps.params, method="decode")
343
+ # decoded = jnp.clip(decoded, 0, 1)
344
+
345
+ #reconstructed images should be correct
346
+ return latents, result_dict#, result_dict_two, reconstructed_images, decoded
347
+
348
+
349
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
350
+ def reconstruction_noisy(self, images, pmap_axis='data'):
351
+
352
+
353
+ noises = []
354
+ numbers = np.arange(0.00, 1.0, 0.01)
355
+
356
+ for number in numbers:
357
+ noises.append(float(number))
358
+
359
+
360
+ #So 3 things to try out.
361
+ #One is normalize variance of the latents before adding noise, start there
362
+ #The second is plot snr instead.
363
+ #snr = var(latent)/var(noise)
364
+ #var is std^2
365
+
366
+
367
+ #This return the full reconstruction, but *also* the latents.
368
+ reconstructed_images, result_dict = self.vqvae_eps(images)
369
+ latents = result_dict["latents"]
370
+ std = result_dict["std"]
371
+ #We need to check the latnes std
372
+
373
+ #Get rng for creating noise.
374
+ new_rng, curr_key = jax.random.split(self.rng, 2)
375
+
376
+ decode = []
377
+ latent_std = latents.std(axis = [1,2,3]).reshape(-1,1,1,1)
378
+
379
+ for mult in noises:
380
+
381
+ noise = jax.random.normal(curr_key, latents.shape)
382
+ #Combine noise with latents
383
+
384
+
385
+ if True:
386
+ latent_var = latent_std ** 2
387
+ noise_std = mult*noise.std()#noise std should be around 1
388
+ noise_var = mult ** 2
389
+ if noise_var == 0:#If noise is zero, then instead denominator is it's variance
390
+ snr = 0
391
+ else:
392
+ snr = latent_var/noise_var
393
+
394
+ temp_latents = latents + noise*mult
395
+
396
+ #vae_eps is the determinstic one.
397
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
398
+ decoded = jnp.clip(decoded, 0, 1)
399
+ if True:
400
+ decode.append((decoded, snr))
401
+
402
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
403
+ return reconstructed_images, decode, std
404
+
405
+
406
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
407
+ def reconstruction_ppl(self, images, pmap_axis='data'):
408
+
409
+ epsilon = .0001
410
+ reconstructed_images, result_dict = self.vqvae_eps(images)
411
+ latents = result_dict["latents"]
412
+ std = result_dict["std"]
413
+
414
+ new_rng, curr_key = jax.random.split(self.rng, 2)
415
+
416
+ noise = jax.random.normal(curr_key, latents.shape)
417
+ #Combine noise with latents
418
+
419
+ temp_latents = latents + noise * epsilon
420
+ # print(temp_latents.shape)#Probably should be like, bs, 32,32,4
421
+ # exit()
422
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
423
+ decoded = jnp.clip(decoded, 0, 1)
424
+
425
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
426
+ return reconstructed_images, decoded, std, latents
427
+
428
+
429
+ #So this method simply will return the gradient/jacobian
430
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
431
+ def reconstruction_grad_distance(self, images, pmap_axis='data'):
432
+ #We want to try and identify C.
433
+ #C means that when we change our latents by a specific and small number X, our outputs change by C*X also.
434
+ #We want to capture all of the C, and see what their STD is.
435
+ pass
436
+
437
+
438
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
439
+ def reconstruction_ppl_two(self, images, pmap_axis='data'):
440
+
441
+ epsilon = .0001
442
+ reconstructed_images, result_dict = self.vqvae_eps(images)
443
+ latents = result_dict["latents"]
444
+ std = result_dict["std"]
445
+
446
+ new_rng, curr_key = jax.random.split(self.rng, 2)
447
+
448
+ noise = jax.random.normal(curr_key, latents.shape)
449
+ #Combine noise with latents
450
+
451
+ temp_latents = latents + noise/2 * epsilon
452
+
453
+ decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
454
+ decoded = jnp.clip(decoded, 0, 1)
455
+
456
+ temp_latents_2 = latents + -1 * noise/2 * epsilon
457
+
458
+ decoded_2 = self.vqvae_eps(temp_latents_2, params=self.vqvae_eps.params, method="decode")
459
+ decoded_2 = jnp.clip(decoded_2, 0, 1)
460
+
461
+
462
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
463
+ return reconstructed_images, decoded, std, latents, decoded_2
464
+
465
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
466
+ def reconstruction_ppl_image(self, images, pmap_axis='data'):
467
+
468
+ epsilon = .0001
469
+ new_rng, curr_key = jax.random.split(self.rng, 2)
470
+
471
+ reconstructed_images, result_dict = self.vqvae_eps(images)
472
+ latents = result_dict["latents"]
473
+ std = result_dict["std"]
474
+
475
+
476
+ noise = jax.random.normal(curr_key, images.shape)
477
+ images = images + noise * epsilon
478
+
479
+
480
+ decoded, result_dict_2 = self.vqvae_eps(images)
481
+ decoded = jnp.clip(decoded, 0, 1)
482
+
483
+ latents_noisy = result_dict_2["latents"]
484
+ std_noisy = result_dict_2["std"]
485
+
486
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
487
+ return reconstructed_images, decoded, std, latents, std_noisy, latents_noisy
488
+
489
+ ##############################################
490
+ ## Training Code.
491
+ ##############################################
492
+ def main(_):
493
+ np.random.seed(FLAGS.seed)
494
+ print("Using devices", jax.local_devices())
495
+ device_count = len(jax.local_devices())
496
+ global_device_count = jax.device_count()
497
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
498
+ print("Device count", device_count)
499
+ print("Global device count", global_device_count)
500
+ print("Global Batch: ", FLAGS.batch_size)
501
+ print("Node Batch: ", local_batch_size)
502
+ print("Device Batch:", local_batch_size // device_count)
503
+
504
+ # Create wandb logger
505
+ if jax.process_index() == 0:
506
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
507
+
508
+ def get_dataset(is_train):
509
+ if 'imagenet' in FLAGS.dataset_name:
510
+ def deserialization_fn(data):
511
+ image = data['image']
512
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
513
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
514
+ if 'imagenet256' in FLAGS.dataset_name:
515
+ image = tf.image.resize(image, (256, 256))
516
+ elif 'imagenet128' in FLAGS.dataset_name:
517
+ image = tf.image.resize(image, (128, 128))
518
+ else:
519
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
520
+ if is_train:
521
+ image = tf.image.random_flip_left_right(image)
522
+ image = tf.cast(image, tf.float32) / 255.0
523
+ return image
524
+
525
+
526
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
527
+ print(split)
528
+ dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm")
529
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
530
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
531
+ dataset = dataset.repeat()
532
+ dataset = dataset.batch(local_batch_size)
533
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
534
+ dataset = tfds.as_numpy(dataset)
535
+ dataset = iter(dataset)
536
+ return dataset
537
+ else:
538
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
539
+
540
+ dataset = get_dataset(is_train=True)
541
+ dataset_valid = get_dataset(is_train=False)
542
+ example_obs = next(dataset)[:1]
543
+
544
+ get_fid_activations = get_fid_network()
545
+ if not os.path.exists('./data/imagenet256_fidstats_openai.npz'):
546
+ raise ValueError("Please download the FID stats file! See the README.")
547
+ truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
548
+ #truth_fid_stats = np.load("./base_stats.npz")
549
+
550
+ rng = jax.random.PRNGKey(FLAGS.seed)
551
+ rng, param_key = jax.random.split(rng)
552
+ print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
553
+
554
+ ###################################
555
+ # Creating Model and put on devices.
556
+ ###################################
557
+ FLAGS.model.image_channels = example_obs.shape[-1]
558
+ FLAGS.model.image_size = example_obs.shape[1]
559
+ vqvae_def = VQVAE(FLAGS.model, train=True)
560
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
561
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
562
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
563
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
564
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
565
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
566
+
567
+ discriminator_def = Discriminator(FLAGS.model)
568
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
569
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
570
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
571
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
572
+
573
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
574
+
575
+ if FLAGS.load_dir is not None:
576
+ try:
577
+ cp = Checkpoint(FLAGS.load_dir)
578
+ model = cp.load_model(model)
579
+ print("Loaded model with step", model.vqvae.step)
580
+ except:
581
+ print("Random init")
582
+ else:
583
+ print("Random init")
584
+
585
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
586
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
587
+
588
+ ###################################
589
+ # Train Loop
590
+ ###################################
591
+
592
+ best_fid = 100000
593
+
594
+ for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
595
+ smoothing=0.1,
596
+ dynamic_ncols=True):
597
+
598
+ batch_images = next(dataset)
599
+ batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]
600
+
601
+ model, update_info = model.update(batch_images)
602
+
603
+ print(update_info)
604
+
605
+ if i % FLAGS.log_interval == 0:
606
+ update_info = jax.tree.map(lambda x: x.mean(), update_info)
607
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
608
+ if jax.process_index() == 0:
609
+ wandb.log(train_metrics, step=i)
610
+
611
+ if i % FLAGS.eval_interval == 0:
612
+ # Print some images
613
+ reconstructed_images = model.reconstruction(batch_images) # [devices, 8, 256, 256, 3]
614
+ valid_images = next(dataset_valid)
615
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
616
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
617
+
618
+ if jax.process_index() == 0:
619
+ wandb.log({'batch_image_mean': batch_images.mean()}, step=i)
620
+ wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i)
621
+ wandb.log({'batch_image_std': batch_images.std()}, step=i)
622
+ wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i)
623
+
624
+ # plot comparison witah matplotlib. put each reconstruction side by side.
625
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
626
+ #print("batch shape", batch_images.shape)#batch shape (4, 32, 256, 256, 3) #THE FIRST SHAPE IS DEVICES
627
+ #print("recon shape", reconstructed_images.shape)#it's all the same lol
628
+ #print("valid shape", valid_images.shape)
629
+ #it seems to be made for 8 device, aka tpuv3 instead
630
+ for j in range(4):#fuck it
631
+ axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1)
632
+ axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1)
633
+ wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
634
+ plt.close(fig)
635
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
636
+ for j in range(4):
637
+ axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
638
+ axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
639
+ wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i)
640
+ plt.close(fig)
641
+
642
+ # Validation Losses
643
+ _, valid_update_info = model.update(valid_images)
644
+ valid_update_info = jax.tree.map(lambda x: x.mean(), valid_update_info)
645
+ valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()}
646
+ if jax.process_index() == 0:
647
+ wandb.log(valid_metrics, step=i)
648
+
649
+ # FID measurement.
650
+ activations = []
651
+ activations2 = []
652
+ for _ in range(780):#This is apprximately 40k
653
+ valid_images = next(dataset_valid)
654
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
655
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
656
+
657
+ valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
658
+ method='bilinear', antialias=False)
659
+ valid_reconstructed_images = 2 * valid_reconstructed_images - 1
660
+ activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
661
+
662
+
663
+ #Only needed when we save
664
+ #valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
665
+ #method='bilinear', antialias=False)
666
+ #valid_reconstructed_images = 2 * valid_reconstructed_images - 1
667
+ #activations2 += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
668
+
669
+
670
+ # TODO: use all_gather to get activations from all devices.
671
+ #This seems to be FID with only 64 images?
672
+ activations = np.concatenate(activations, axis=0)
673
+ activations = activations.reshape((-1, activations.shape[-1]))
674
+
675
+ # activations2 = np.concatenate(activations2, axis = 0)
676
+ # activations2 = activations2.reshape((-1, activations2.shape[-1]))
677
+
678
+ print("doing this much FID", activations.shape)#8192, 2048 should be 2048 items then I guess
679
+ mu1 = np.mean(activations, axis=0)
680
+ sigma1 = np.cov(activations, rowvar=False)
681
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
682
+
683
+ # mu2 = np.mean(activations2, axis = 0)
684
+ # sigma2 = np.cov(activations2, rowvar = False)
685
+
686
+ #save mu2 and sigma2
687
+ #And then exit for now
688
+ # np.savez("base.npz", mu = mu2, sigma = sigma2)
689
+ # exit()
690
+
691
+ #Used with loading base
692
+ #fid = fid_from_stats(mu1, sigma1, mu2, sigma2)
693
+
694
+ if jax.process_index() == 0:
695
+ wandb.log({'validation/fid': fid}, step=i)
696
+ print("validation FID at step", i, fid)
697
+ #Then if fid is smaller than previous best FID, save new FID
698
+ if fid < best_fid:
699
+ model_single = flax.jax_utils.unreplicate(model)
700
+ cp = Checkpoint(FLAGS.save_dir + "best.tmp")
701
+ cp.set_model(model_single)
702
+ cp.save()
703
+ best_fid = fid
704
+
705
+ if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None):
706
+ if jax.process_index() == 0:
707
+ model_single = flax.jax_utils.unreplicate(model)
708
+ cp = Checkpoint(FLAGS.save_dir)
709
+ cp.set_model(model_single)
710
+ cp.save()
711
+
712
+ if __name__ == '__main__':
713
+ app.run(main)