KublaiKhan1 commited on
Commit
0b206f6
·
verified ·
1 Parent(s): 6d91519

Delete GramAESmall

Browse files
GramAESmall/checkpointbest.tmp.tmp DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:90b7c76d670c5ea797ef13f5501ba445a69b628b9401016aa795895dae4c3216
3
- size 1369029948
 
 
 
 
GramAESmall/train.py DELETED
@@ -1,692 +0,0 @@
1
- try: # For debugging
2
- from localutils.debugger import enable_debug
3
- enable_debug()
4
- except ImportError:
5
- pass
6
-
7
- import flax.linen as nn
8
- import jax.numpy as jnp
9
- from absl import app, flags
10
- from functools import partial
11
- import numpy as np
12
- import tqdm
13
- import jax
14
- import jax.numpy as jnp
15
- import flax
16
- import optax
17
- import wandb
18
- from ml_collections import config_flags
19
- import ml_collections
20
- import tensorflow_datasets as tfds
21
- import tensorflow as tf
22
- tf.config.set_visible_devices([], "GPU")
23
- tf.config.set_visible_devices([], "TPU")
24
- import matplotlib.pyplot as plt
25
- from typing import Any
26
- import os
27
-
28
- from utils.wandb import setup_wandb, default_wandb_config
29
- from utils.train_state import TrainState, target_update
30
- from utils.checkpoint import Checkpoint
31
- from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model
32
- from utils.fid import get_fid_network, fid_from_stats
33
- from models.vqvae import VQVAE
34
- from models.discriminator import Discriminator
35
-
36
- FLAGS = flags.FLAGS
37
- flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
38
- flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).')
39
- flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp" , 'Load dir (if not None, load params from here).')
40
- flags.DEFINE_integer('seed', 0, 'Random seed.')
41
- flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42
- flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.')
43
- flags.DEFINE_integer('save_interval', 1000, 'Save interval.')
44
- flags.DEFINE_integer('batch_size', 64, 'Total Batch size.')
45
- flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.')
46
-
47
- model_config = ml_collections.ConfigDict({
48
- # VQVAE
49
- 'lr': 0.0001,
50
- 'beta1': 0.0,#.5
51
- 'beta2': 0.99,#.9
52
- 'lr_warmup_steps': 2000,
53
- 'lr_decay_steps': 500_000,#They use 'lambdalr'
54
- 'filters': 128,
55
- 'num_res_blocks': 2,
56
- 'channel_multipliers': (1, 2, 4, 4),#Seems right
57
- 'embedding_dim': 4, # For FSQ, a good default is 4.
58
- 'norm_type': 'GN',
59
- 'weight_decay': 0.05,#None maybe?
60
- 'clip_gradient': 1.0,
61
- 'l2_loss_weight': 1.0,#They use L1 actually
62
- 'eps_update_rate': 0.9999,
63
- # Quantizer
64
- 'quantizer_type': 'ae', # or 'fsq', 'kl'
65
- # Quantizer (VQ)
66
- 'quantizer_loss_ratio': 1,
67
- 'codebook_size': 1024,
68
- 'entropy_loss_ratio': 0.1,
69
- 'entropy_loss_type': 'softmax',
70
- 'entropy_temperature': 0.01,
71
- 'commitment_cost': 0.25,
72
- # Quantizer (FSQ)
73
- 'fsq_levels': 5, # Bins per dimension.
74
- # Quantizer (KL)
75
- 'kl_weight': 0.00007,#They use 1e-6 on their stuff LUL. .001 is the default
76
- # GAN
77
- 'g_adversarial_loss_weight': 0.5,
78
- 'g_grad_penalty_cost': 10,
79
- 'perceptual_loss_weight': 0.5,
80
- 'gan_warmup_steps': 25000,
81
- "pl_decay": 0.01,
82
- "pl_weight": -1,
83
- 'MMD_weight': 1.0
84
-
85
- })
86
-
87
- wandb_config = default_wandb_config()
88
- wandb_config.update({
89
- 'project': 'vqvae',
90
- 'name': 'vqvae_{dataset_name}',
91
- })
92
-
93
- config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
94
- config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
95
-
96
- ##############################################
97
- ## Model Definitions.
98
- ##############################################
99
-
100
- @jax.vmap
101
- def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
102
- """https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py
103
- """
104
- zeros = jnp.zeros_like(logits, dtype=logits.dtype)
105
- condition = (logits >= zeros)
106
- relu_logits = jnp.where(condition, logits, zeros)
107
- neg_abs_logits = jnp.where(condition, -logits, logits)
108
- return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
109
-
110
- class VQGANModel(flax.struct.PyTreeNode):
111
- rng: Any
112
- config: dict = flax.struct.field(pytree_node=False)
113
- vqvae: TrainState
114
- vqvae_eps: TrainState
115
- discriminator: TrainState
116
-
117
- # Train G and D.
118
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
119
- def update(self, images, pmap_axis='data'):
120
- new_rng, curr_key = jax.random.split(self.rng, 2)
121
-
122
- resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy')
123
-
124
- is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32)
125
-
126
- def loss_fn(params_vqvae, params_disc):
127
-
128
- def path_reg_loss(latents, targets):#let's have pl_mean be in our self.config
129
- #1/2 should be our spatial dimensions.
130
-
131
- latents = latents[0:2, :, :, :]
132
- targets = targets[0:2, :, :, :]
133
- pl_noise = jax.random.normal(new_rng, shape = targets.shape) / jnp.sqrt(targets.shape[1] * targets.shape[2])
134
- def grad_sum(latents, pl_noise):#So we don't have access to the actual decode method
135
- #return jnp.sum(self.vqvae.decode(latents))
136
-
137
- #I am not sure if this makes any sense whatsoever tbh
138
- my_sum = self.vqvae(latents, params=params_vqvae, method="decode", rngs={'noise': curr_key})*pl_noise
139
- print("Decode shape", my_sum.shape)
140
- return jnp.sum(my_sum)
141
-
142
- decode_grad_fn = jax.grad(grad_sum)
143
- pl_grads = decode_grad_fn(latents, pl_noise)
144
- pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis = [2,3]), axis = 1))
145
- #pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=3))
146
-
147
- pl_mean = self.vqvae.pl_mean + self.config.pl_decay * (jnp.mean(pl_lengths) - self.vqvae.pl_mean)
148
- pl_penalty = jnp.square(pl_lengths - pl_mean)
149
- loss = jnp.mean(pl_penalty)
150
- return loss, pl_mean
151
-
152
- if self.config.pl_weight != -1:
153
- smooth_loss, pl_mean = path_reg_loss(result_dict["latents"], reconstructed_images)
154
- # self.vqvae.replace(pl_mean = pl_mean)
155
- #We need to update pl mean in self.vqvae
156
-
157
- # Reconstruct image
158
- reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key})
159
- print("Reconstructed images shape", reconstructed_images.shape)
160
- print("Input images shape", images.shape)
161
- assert reconstructed_images.shape == images.shape
162
-
163
-
164
- #Gram is not normalized, so let's try that first.
165
- reshaped_latents = result_dict["latents"].reshape(result_dict["latents"].shape[0],-1,result_dict["latents"].shape[-1])
166
- #Reshape to batch x patches x embeddings
167
- #Calculate gram matrix
168
- x_transposed = jnp.transpose(reshaped_latents, (0, 2, 1))
169
-
170
- gram_matrix = jnp.matmul(reshaped_latents, x_transposed)
171
- diagonal_elements = jnp.einsum('bii->bi', gram_matrix)
172
- sum_of_diagonals = jnp.sum(diagonal_elements)
173
- total_sum = jnp.sum(gram_matrix)
174
- gram_loss = total_sum - sum_of_diagonals
175
- gram_loss = gram_loss / 992 #divide by 32x32 - 32
176
- gram_loss = gram_loss / 40 #Try this for now
177
-
178
-
179
-
180
- # GAN loss on VQVAE output.
181
- discriminator_fn = lambda x: self.discriminator(x, params=params_disc)
182
- real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False)
183
- gradient = vjp_fn(jnp.ones_like(real_logit))[0] # Gradient of discriminator output wrt. real images.
184
- gradient = gradient.reshape((images.shape[0], -1))
185
- gradient = jnp.asarray(gradient, jnp.float32)
186
- penalty = jnp.sum(jnp.square(gradient), axis=-1)
187
- penalty = jnp.mean(penalty) # Gradient penalty for training D.
188
- fake_logit = discriminator_fn(reconstructed_images)
189
- d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean()
190
- d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean()
191
- loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost'])
192
-
193
- d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean()
194
- d_loss_for_vae = d_loss_for_vae * is_gan_training
195
-
196
- real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images)
197
- fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images)
198
- perceptual_loss = jnp.mean((real_pools - fake_pools)**2)
199
-
200
- l2_loss = jnp.mean((reconstructed_images - images) ** 2)
201
- quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0
202
- if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two":
203
- quantizer_loss = quantizer_loss * self.config['kl_weight']
204
- elif self.config["quantizer_type"] == "MMD":
205
- quantizer_loss = quantizer_loss * self.config['MMD_weight']
206
- loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \
207
- + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
208
- + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
209
- + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
210
- #+ (smooth_loss * FLAGS.model['pl_weight'] )
211
- codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
212
-
213
- return_dict = {
214
- 'loss_vae': loss_vae,
215
- 'loss_d': loss_d,
216
- 'l2_loss': l2_loss,
217
- 'd_loss_for_vae': d_loss_for_vae,
218
- 'perceptual_loss': perceptual_loss,
219
- 'quantizer_loss': quantizer_loss,
220
- 'codebook_usage': codebook_usage,
221
- #'pl_loss': smooth_loss,
222
- }
223
-
224
- if self.config["pl_weight"] != -1:
225
- loss_vae += (smooth_loss * FLAGS.model["pl_weight"])
226
- return_dict["pl_mean"] = pl_mean
227
- return_dict["smooth_loss"] = smooth_loss
228
-
229
-
230
- return (loss_vae, loss_d), return_dict
231
-
232
-
233
- # This is a fancy way to do 'jax.grad' so (loss_vae, params_vqvae) and (loss_d, params_disc) are differentiated.
234
- _, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True)
235
- vae_grads, _ = grad_fn((1., 0.))
236
- _, d_grads = grad_fn((0., 1.))
237
-
238
- vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis)
239
- d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis)
240
- d_grads = jax.tree.map(lambda x: x * is_gan_training, d_grads)
241
-
242
- info = jax.lax.pmean(info, axis_name=pmap_axis)
243
- if self.config['quantizer_type'] == 'fsq':
244
- info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1]
245
-
246
- updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params)
247
- new_params = optax.apply_updates(self.vqvae.params, updates)
248
-
249
- if self.config["pl_weight"] != -1:
250
- new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state, pl_mean=info["pl_mean"])
251
- else:
252
- new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state)
253
-
254
- updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params)
255
- new_params = optax.apply_updates(self.discriminator.params, updates)
256
- new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state)
257
-
258
- info['grad_norm_vae'] = optax.global_norm(vae_grads)
259
- info['grad_norm_d'] = optax.global_norm(d_grads)
260
- info['update_norm'] = optax.global_norm(updates)
261
- info['param_norm'] = optax.global_norm(new_params)
262
- info['is_gan_training'] = is_gan_training
263
-
264
- new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate'])
265
-
266
- new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator)
267
- return new_model, info
268
-
269
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
270
- def reconstruction(self, images, pmap_axis='data', sampling = True):
271
- if not sampling:
272
- reconstructed_images, _ = self.vqvae_eps(images)
273
- else:#Not sure what our theoretical sampling mode does
274
- new_rng, curr_key = jax.random.split(self.rng, 2)
275
- reconstructed_images, _ = self.vqvae_eps(images, rngs={'noise': curr_key})
276
-
277
- reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
278
- return reconstructed_images
279
-
280
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
281
- def reconstruction_sampling(self, images, pmap_axis='data'):
282
-
283
- reconstructed_images_determistic, _ = self.vqvae_eps(images)
284
-
285
-
286
- new_rng, curr_key = jax.random.split(self.rng, 2)
287
- reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
288
-
289
- #We don't need to return the result dict.
290
- reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
291
- reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
292
-
293
- return reconstructed_images_determistic, reconstructed_images_sample
294
-
295
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
296
- def reconstruction_interpolation(self, images, pmap_axis='data'):
297
-
298
- #So we *have* our two images. We are going to linearly interpolate between them in... latent space
299
- #But also in image space?
300
- #Sure, why not
301
- reconstructed_images_determistic, _ = self.vqvae_eps(images)
302
-
303
-
304
- new_rng, curr_key = jax.random.split(self.rng, 2)
305
- reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
306
-
307
- #We don't need to return the result dict.
308
- reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
309
- reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
310
-
311
- return reconstructed_images_determistic, reconstructed_images_sample
312
-
313
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
314
- def get_latent(self, images, pmap_axis='data'):
315
-
316
- #We do *not* add the noise ourselves, just save it.
317
- latents, result_dict = self.vqvae_eps(images, params=self.vqvae_eps.params, method="encode")
318
-
319
- # reconstructed_images, result_dict_two = self.vqvae_eps(images)
320
- # reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
321
- #
322
- #
323
- # decoded = self.vqvae_eps(latents, params=self.vqvae_eps.params, method="decode")
324
- # decoded = jnp.clip(decoded, 0, 1)
325
-
326
- #reconstructed images should be correct
327
- return latents, result_dict#, result_dict_two, reconstructed_images, decoded
328
-
329
-
330
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
331
- def reconstruction_noisy(self, images, pmap_axis='data'):
332
-
333
-
334
- noises = []
335
- numbers = np.arange(0.00, 1.0, 0.01)
336
-
337
- for number in numbers:
338
- noises.append(float(number))
339
-
340
-
341
- #So 3 things to try out.
342
- #One is normalize variance of the latents before adding noise, start there
343
- #The second is plot snr instead.
344
- #snr = var(latent)/var(noise)
345
- #var is std^2
346
-
347
-
348
- #This return the full reconstruction, but *also* the latents.
349
- reconstructed_images, result_dict = self.vqvae_eps(images)
350
- latents = result_dict["latents"]
351
- std = result_dict["std"]
352
- #We need to check the latnes std
353
-
354
- #Get rng for creating noise.
355
- new_rng, curr_key = jax.random.split(self.rng, 2)
356
-
357
- decode = []
358
- latent_std = latents.std(axis = [1,2,3]).reshape(-1,1,1,1)
359
-
360
- for mult in noises:
361
-
362
- noise = jax.random.normal(curr_key, latents.shape)
363
- #Combine noise with latents
364
-
365
-
366
- if True:
367
- latent_var = latent_std ** 2
368
- noise_std = mult*noise.std()#noise std should be around 1
369
- noise_var = mult ** 2
370
- if noise_var == 0:#If noise is zero, then instead denominator is it's variance
371
- snr = 0
372
- else:
373
- snr = latent_var/noise_var
374
-
375
- temp_latents = latents + noise*mult
376
-
377
- #vae_eps is the determinstic one.
378
- decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
379
- decoded = jnp.clip(decoded, 0, 1)
380
- if True:
381
- decode.append((decoded, snr))
382
-
383
- reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
384
- return reconstructed_images, decode, std
385
-
386
-
387
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
388
- def reconstruction_ppl(self, images, pmap_axis='data'):
389
-
390
- epsilon = .0001
391
- reconstructed_images, result_dict = self.vqvae_eps(images)
392
- latents = result_dict["latents"]
393
- std = result_dict["std"]
394
-
395
- new_rng, curr_key = jax.random.split(self.rng, 2)
396
-
397
- noise = jax.random.normal(curr_key, latents.shape)
398
- #Combine noise with latents
399
-
400
- temp_latents = latents + noise * epsilon
401
- # print(temp_latents.shape)#Probably should be like, bs, 32,32,4
402
- # exit()
403
- decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
404
- decoded = jnp.clip(decoded, 0, 1)
405
-
406
- reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
407
- return reconstructed_images, decoded, std, latents
408
-
409
-
410
- #So this method simply will return the gradient/jacobian
411
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
412
- def reconstruction_grad_distance(self, images, pmap_axis='data'):
413
- #We want to try and identify C.
414
- #C means that when we change our latents by a specific and small number X, our outputs change by C*X also.
415
- #We want to capture all of the C, and see what their STD is.
416
- pass
417
-
418
-
419
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
420
- def reconstruction_ppl_two(self, images, pmap_axis='data'):
421
-
422
- epsilon = .0001
423
- reconstructed_images, result_dict = self.vqvae_eps(images)
424
- latents = result_dict["latents"]
425
- std = result_dict["std"]
426
-
427
- new_rng, curr_key = jax.random.split(self.rng, 2)
428
-
429
- noise = jax.random.normal(curr_key, latents.shape)
430
- #Combine noise with latents
431
-
432
- temp_latents = latents + noise/2 * epsilon
433
-
434
- decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
435
- decoded = jnp.clip(decoded, 0, 1)
436
-
437
- temp_latents_2 = latents + -1 * noise/2 * epsilon
438
-
439
- decoded_2 = self.vqvae_eps(temp_latents_2, params=self.vqvae_eps.params, method="decode")
440
- decoded_2 = jnp.clip(decoded_2, 0, 1)
441
-
442
-
443
- reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
444
- return reconstructed_images, decoded, std, latents, decoded_2
445
-
446
- @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
447
- def reconstruction_ppl_image(self, images, pmap_axis='data'):
448
-
449
- epsilon = .0001
450
- new_rng, curr_key = jax.random.split(self.rng, 2)
451
-
452
- reconstructed_images, result_dict = self.vqvae_eps(images)
453
- latents = result_dict["latents"]
454
- std = result_dict["std"]
455
-
456
-
457
- noise = jax.random.normal(curr_key, images.shape)
458
- images = images + noise * epsilon
459
-
460
-
461
- decoded, result_dict_2 = self.vqvae_eps(images)
462
- decoded = jnp.clip(decoded, 0, 1)
463
-
464
- latents_noisy = result_dict_2["latents"]
465
- std_noisy = result_dict_2["std"]
466
-
467
- reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
468
- return reconstructed_images, decoded, std, latents, std_noisy, latents_noisy
469
-
470
- ##############################################
471
- ## Training Code.
472
- ##############################################
473
- def main(_):
474
- np.random.seed(FLAGS.seed)
475
- print("Using devices", jax.local_devices())
476
- device_count = len(jax.local_devices())
477
- global_device_count = jax.device_count()
478
- local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
479
- print("Device count", device_count)
480
- print("Global device count", global_device_count)
481
- print("Global Batch: ", FLAGS.batch_size)
482
- print("Node Batch: ", local_batch_size)
483
- print("Device Batch:", local_batch_size // device_count)
484
-
485
- # Create wandb logger
486
- if jax.process_index() == 0:
487
- setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
488
-
489
- def get_dataset(is_train):
490
- if 'imagenet' in FLAGS.dataset_name:
491
- def deserialization_fn(data):
492
- image = data['image']
493
- min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
494
- image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
495
- if 'imagenet256' in FLAGS.dataset_name:
496
- image = tf.image.resize(image, (256, 256))
497
- elif 'imagenet128' in FLAGS.dataset_name:
498
- image = tf.image.resize(image, (128, 128))
499
- else:
500
- raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
501
- if is_train:
502
- image = tf.image.random_flip_left_right(image)
503
- image = tf.cast(image, tf.float32) / 255.0
504
- return image
505
-
506
-
507
- split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
508
- print(split)
509
- dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm")
510
- dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
511
- dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
512
- dataset = dataset.repeat()
513
- dataset = dataset.batch(local_batch_size)
514
- dataset = dataset.prefetch(tf.data.AUTOTUNE)
515
- dataset = tfds.as_numpy(dataset)
516
- dataset = iter(dataset)
517
- return dataset
518
- else:
519
- raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
520
-
521
- dataset = get_dataset(is_train=True)
522
- dataset_valid = get_dataset(is_train=False)
523
- example_obs = next(dataset)[:1]
524
-
525
- get_fid_activations = get_fid_network()
526
- if not os.path.exists('./data/imagenet256_fidstats_openai.npz'):
527
- raise ValueError("Please download the FID stats file! See the README.")
528
- truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
529
- #truth_fid_stats = np.load("./base_stats.npz")
530
-
531
- rng = jax.random.PRNGKey(FLAGS.seed)
532
- rng, param_key = jax.random.split(rng)
533
- print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
534
-
535
- ###################################
536
- # Creating Model and put on devices.
537
- ###################################
538
- FLAGS.model.image_channels = example_obs.shape[-1]
539
- FLAGS.model.image_size = example_obs.shape[1]
540
- vqvae_def = VQVAE(FLAGS.model, train=True)
541
- vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
542
- tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
543
- vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
544
- vqvae_def_eps = VQVAE(FLAGS.model, train=False)
545
- vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
546
- print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
547
-
548
- discriminator_def = Discriminator(FLAGS.model)
549
- discriminator_params = discriminator_def.init(param_key, example_obs)['params']
550
- tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
551
- discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
552
- print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
553
-
554
- model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
555
-
556
- if FLAGS.load_dir is not None:
557
- try:
558
- cp = Checkpoint(FLAGS.load_dir)
559
- model = cp.load_model(model)
560
- print("Loaded model with step", model.vqvae.step)
561
- except:
562
- print("Random init")
563
- else:
564
- print("Random init")
565
-
566
- model = flax.jax_utils.replicate(model, devices=jax.local_devices())
567
- jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
568
-
569
- ###################################
570
- # Train Loop
571
- ###################################
572
-
573
- best_fid = 100000
574
-
575
- for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
576
- smoothing=0.1,
577
- dynamic_ncols=True):
578
-
579
- batch_images = next(dataset)
580
- batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]
581
-
582
- model, update_info = model.update(batch_images)
583
-
584
- if i % FLAGS.log_interval == 0:
585
- update_info = jax.tree.map(lambda x: x.mean(), update_info)
586
- train_metrics = {f'training/{k}': v for k, v in update_info.items()}
587
- if jax.process_index() == 0:
588
- wandb.log(train_metrics, step=i)
589
-
590
- if i % FLAGS.eval_interval == 0:
591
- # Print some images
592
- reconstructed_images = model.reconstruction(batch_images) # [devices, 8, 256, 256, 3]
593
- valid_images = next(dataset_valid)
594
- valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
595
- valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
596
-
597
- if jax.process_index() == 0:
598
- wandb.log({'batch_image_mean': batch_images.mean()}, step=i)
599
- wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i)
600
- wandb.log({'batch_image_std': batch_images.std()}, step=i)
601
- wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i)
602
-
603
- # plot comparison witah matplotlib. put each reconstruction side by side.
604
- fig, axs = plt.subplots(2, 8, figsize=(30, 15))
605
- #print("batch shape", batch_images.shape)#batch shape (4, 32, 256, 256, 3) #THE FIRST SHAPE IS DEVICES
606
- #print("recon shape", reconstructed_images.shape)#it's all the same lol
607
- #print("valid shape", valid_images.shape)
608
- #it seems to be made for 8 device, aka tpuv3 instead
609
- for j in range(4):#fuck it
610
- axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1)
611
- axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1)
612
- wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
613
- plt.close(fig)
614
- fig, axs = plt.subplots(2, 8, figsize=(30, 15))
615
- for j in range(4):
616
- axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
617
- axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
618
- wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i)
619
- plt.close(fig)
620
-
621
- # Validation Losses
622
- _, valid_update_info = model.update(valid_images)
623
- valid_update_info = jax.tree.map(lambda x: x.mean(), valid_update_info)
624
- valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()}
625
- if jax.process_index() == 0:
626
- wandb.log(valid_metrics, step=i)
627
-
628
- # FID measurement.
629
- activations = []
630
- activations2 = []
631
- for _ in range(780):#This is apprximately 40k
632
- valid_images = next(dataset_valid)
633
- valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
634
- valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
635
-
636
- valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
637
- method='bilinear', antialias=False)
638
- valid_reconstructed_images = 2 * valid_reconstructed_images - 1
639
- activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
640
-
641
-
642
- #Only needed when we save
643
- #valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
644
- #method='bilinear', antialias=False)
645
- #valid_reconstructed_images = 2 * valid_reconstructed_images - 1
646
- #activations2 += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
647
-
648
-
649
- # TODO: use all_gather to get activations from all devices.
650
- #This seems to be FID with only 64 images?
651
- activations = np.concatenate(activations, axis=0)
652
- activations = activations.reshape((-1, activations.shape[-1]))
653
-
654
- # activations2 = np.concatenate(activations2, axis = 0)
655
- # activations2 = activations2.reshape((-1, activations2.shape[-1]))
656
-
657
- print("doing this much FID", activations.shape)#8192, 2048 should be 2048 items then I guess
658
- mu1 = np.mean(activations, axis=0)
659
- sigma1 = np.cov(activations, rowvar=False)
660
- fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
661
-
662
- # mu2 = np.mean(activations2, axis = 0)
663
- # sigma2 = np.cov(activations2, rowvar = False)
664
-
665
- #save mu2 and sigma2
666
- #And then exit for now
667
- # np.savez("base.npz", mu = mu2, sigma = sigma2)
668
- # exit()
669
-
670
- #Used with loading base
671
- #fid = fid_from_stats(mu1, sigma1, mu2, sigma2)
672
-
673
- if jax.process_index() == 0:
674
- wandb.log({'validation/fid': fid}, step=i)
675
- print("validation FID at step", i, fid)
676
- #Then if fid is smaller than previous best FID, save new FID
677
- if fid < best_fid:
678
- model_single = flax.jax_utils.unreplicate(model)
679
- cp = Checkpoint(FLAGS.save_dir + "best.tmp")
680
- cp.set_model(model_single)
681
- cp.save()
682
- best_fid = fid
683
-
684
- if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None):
685
- if jax.process_index() == 0:
686
- model_single = flax.jax_utils.unreplicate(model)
687
- cp = Checkpoint(FLAGS.save_dir)
688
- cp.set_model(model_single)
689
- cp.save()
690
-
691
- if __name__ == '__main__':
692
- app.run(main)