| try: |
| from localutils.debugger import enable_debug |
| enable_debug() |
| except ImportError: |
| pass |
|
|
| import flax.linen as nn |
| import jax.numpy as jnp |
| from absl import app, flags |
| from functools import partial |
| import numpy as np |
| import tqdm |
| import jax |
| import jax.numpy as jnp |
| import flax |
| import optax |
| import wandb |
| from ml_collections import config_flags |
| import ml_collections |
| import tensorflow_datasets as tfds |
| import tensorflow as tf |
| tf.config.set_visible_devices([], "GPU") |
| tf.config.set_visible_devices([], "TPU") |
| import matplotlib.pyplot as plt |
| from typing import Any |
| import os |
|
|
| from utils.wandb import setup_wandb, default_wandb_config |
| from utils.train_state import TrainState, target_update |
| from utils.checkpoint import Checkpoint |
| from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model |
| from utils.fid import get_fid_network, fid_from_stats |
| from models.vqvae import VQVAE |
| from models.discriminator import Discriminator |
|
|
| FLAGS = flags.FLAGS |
| flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.') |
| flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).') |
| flags.DEFINE_string('load_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint.tmp" , 'Load dir (if not None, load params from here).') |
| flags.DEFINE_integer('seed', 0, 'Random seed.') |
| flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') |
| flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.') |
| flags.DEFINE_integer('save_interval', 1000, 'Save interval.') |
| flags.DEFINE_integer('batch_size', 64, 'Total Batch size.') |
| flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.') |
|
|
| model_config = ml_collections.ConfigDict({ |
| |
| 'lr': 0.0001, |
| 'beta1': 0.0, |
| 'beta2': 0.99, |
| 'lr_warmup_steps': 2000, |
| 'lr_decay_steps': 500_000, |
| 'filters': 128, |
| 'num_res_blocks': 2, |
| 'channel_multipliers': (1, 1, 2, 2, 4, 4), |
| 'embedding_dim': 32, |
| 'norm_type': 'GN', |
| 'weight_decay': 0.05, |
| 'clip_gradient': 1.0, |
| 'l2_loss_weight': 1.0, |
| 'eps_update_rate': 0.9999, |
| |
| 'quantizer_type': 'vq', |
| |
| 'quantizer_loss_ratio': 1, |
| 'codebook_size': 1024, |
| 'entropy_loss_ratio': 0.1, |
| 'entropy_loss_type': 'softmax', |
| 'entropy_temperature': 0.01, |
| 'commitment_cost': 0.25, |
| |
| 'fsq_levels': 5, |
| |
| 'kl_weight': 0.000001, |
| |
| 'g_adversarial_loss_weight': 0.5, |
| 'g_grad_penalty_cost': 10, |
| 'perceptual_loss_weight': 0.5, |
| 'gan_warmup_steps': 25000, |
| }) |
|
|
| wandb_config = default_wandb_config() |
| wandb_config.update({ |
| 'project': 'vqvae', |
| 'name': 'vqvae_{dataset_name}', |
| }) |
|
|
| config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False) |
| config_flags.DEFINE_config_dict('model', model_config, lock_config=False) |
|
|
| |
| |
| |
|
|
| @jax.vmap |
| def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: |
| """https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py |
| """ |
| zeros = jnp.zeros_like(logits, dtype=logits.dtype) |
| condition = (logits >= zeros) |
| relu_logits = jnp.where(condition, logits, zeros) |
| neg_abs_logits = jnp.where(condition, -logits, logits) |
| return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits)) |
|
|
| class VQGANModel(flax.struct.PyTreeNode): |
| rng: Any |
| config: dict = flax.struct.field(pytree_node=False) |
| vqvae: TrainState |
| vqvae_eps: TrainState |
| discriminator: TrainState |
|
|
| |
| @partial(jax.pmap, axis_name='data', in_axes=(0, 0)) |
| def update(self, images, pmap_axis='data'): |
| new_rng, curr_key = jax.random.split(self.rng, 2) |
|
|
| resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy') |
|
|
| is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32) |
|
|
| def loss_fn(params_vqvae, params_disc): |
| |
| reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key}) |
| print("Reconstructed images shape", reconstructed_images.shape) |
| print("Input images shape", images.shape) |
| assert reconstructed_images.shape == images.shape |
|
|
| |
| discriminator_fn = lambda x: self.discriminator(x, params=params_disc) |
| real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False) |
| gradient = vjp_fn(jnp.ones_like(real_logit))[0] |
| gradient = gradient.reshape((images.shape[0], -1)) |
| gradient = jnp.asarray(gradient, jnp.float32) |
| penalty = jnp.sum(jnp.square(gradient), axis=-1) |
| penalty = jnp.mean(penalty) |
| fake_logit = discriminator_fn(reconstructed_images) |
| d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean() |
| d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean() |
| loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost']) |
|
|
| d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean() |
| d_loss_for_vae = d_loss_for_vae * is_gan_training |
|
|
| real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images) |
| fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images) |
| perceptual_loss = jnp.mean((real_pools - fake_pools)**2) |
|
|
| l2_loss = jnp.mean((reconstructed_images - images) ** 2) |
| quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0 |
| if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two": |
| quantizer_loss = quantizer_loss * self.config['kl_weight'] |
| loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \ |
| + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \ |
| + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \ |
| + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) |
| codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0 |
| return (loss_vae, loss_d), { |
| 'loss_vae': loss_vae, |
| 'loss_d': loss_d, |
| 'l2_loss': l2_loss, |
| 'd_loss_for_vae': d_loss_for_vae, |
| 'perceptual_loss': perceptual_loss, |
| 'quantizer_loss': quantizer_loss, |
| 'codebook_usage': codebook_usage, |
| } |
| |
| |
| _, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True) |
| vae_grads, _ = grad_fn((1., 0.)) |
| _, d_grads = grad_fn((0., 1.)) |
|
|
| vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis) |
| d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis) |
| d_grads = jax.tree_map(lambda x: x * is_gan_training, d_grads) |
|
|
| info = jax.lax.pmean(info, axis_name=pmap_axis) |
| if self.config['quantizer_type'] == 'fsq': |
| info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1] |
|
|
| updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params) |
| new_params = optax.apply_updates(self.vqvae.params, updates) |
| new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state) |
|
|
| updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params) |
| new_params = optax.apply_updates(self.discriminator.params, updates) |
| new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state) |
|
|
| info['grad_norm_vae'] = optax.global_norm(vae_grads) |
| info['grad_norm_d'] = optax.global_norm(d_grads) |
| info['update_norm'] = optax.global_norm(updates) |
| info['param_norm'] = optax.global_norm(new_params) |
| info['is_gan_training'] = is_gan_training |
|
|
| new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate']) |
|
|
| new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator) |
| return new_model, info |
| |
| @partial(jax.pmap, axis_name='data', in_axes=(0, 0)) |
| def reconstruction(self, images, pmap_axis='data', sampling = False): |
| if not sampling: |
| reconstructed_images, _ = self.vqvae_eps(images) |
| else: |
| new_rng, curr_key = jax.random.split(self.rng, 2) |
| reconstructed_images, _ = self.vqvae(images, rngs={'noise': curr_key}) |
|
|
| reconstructed_images = jnp.clip(reconstructed_images, 0, 1) |
| return reconstructed_images |
|
|
| |
| |
| |
| def main(_): |
| np.random.seed(FLAGS.seed) |
| print("Using devices", jax.local_devices()) |
| device_count = len(jax.local_devices()) |
| global_device_count = jax.device_count() |
| local_batch_size = FLAGS.batch_size // (global_device_count // device_count) |
| print("Device count", device_count) |
| print("Global device count", global_device_count) |
| print("Global Batch: ", FLAGS.batch_size) |
| print("Node Batch: ", local_batch_size) |
| print("Device Batch:", local_batch_size // device_count) |
|
|
| |
| if jax.process_index() == 0: |
| setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb) |
|
|
| def get_dataset(is_train): |
| if 'imagenet' in FLAGS.dataset_name: |
| def deserialization_fn(data): |
| image = data['image'] |
| min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) |
| image = tf.image.resize_with_crop_or_pad(image, min_side, min_side) |
| if 'imagenet256' in FLAGS.dataset_name: |
| image = tf.image.resize(image, (256, 256)) |
| elif 'imagenet128' in FLAGS.dataset_name: |
| image = tf.image.resize(image, (128, 128)) |
| else: |
| raise ValueError(f"Unknown dataset {FLAGS.dataset_name}") |
| if is_train: |
| image = tf.image.random_flip_left_right(image) |
| image = tf.cast(image, tf.float32) / 255.0 |
| return image |
|
|
| |
| split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True) |
| print(split) |
| dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm") |
| dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE) |
| dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True) |
| dataset = dataset.repeat() |
| dataset = dataset.batch(local_batch_size) |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) |
| dataset = tfds.as_numpy(dataset) |
| dataset = iter(dataset) |
| return dataset |
| else: |
| raise ValueError(f"Unknown dataset {FLAGS.dataset_name}") |
| |
| dataset = get_dataset(is_train=True) |
| dataset_valid = get_dataset(is_train=False) |
| example_obs = next(dataset)[:1] |
|
|
| get_fid_activations = get_fid_network() |
| if not os.path.exists('./data/imagenet256_fidstats_openai.npz'): |
| raise ValueError("Please download the FID stats file! See the README.") |
| |
| truth_fid_stats = np.load("./base_stats.npz") |
|
|
| rng = jax.random.PRNGKey(FLAGS.seed) |
| rng, param_key = jax.random.split(rng) |
| print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB") |
|
|
| |
| |
| |
| FLAGS.model.image_channels = example_obs.shape[-1] |
| FLAGS.model.image_size = example_obs.shape[1] |
| vqvae_def = VQVAE(FLAGS.model, train=True) |
| vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params'] |
| tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2']) |
| vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx) |
| vqvae_def_eps = VQVAE(FLAGS.model, train=False) |
| vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params) |
| print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params))) |
|
|
| discriminator_def = Discriminator(FLAGS.model) |
| discriminator_params = discriminator_def.init(param_key, example_obs)['params'] |
| tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2']) |
| discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx) |
| print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params))) |
|
|
| model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model) |
|
|
| if FLAGS.load_dir is not None: |
| try: |
| cp = Checkpoint(FLAGS.load_dir) |
| model = cp.load_model(model) |
| print("Loaded model with step", model.vqvae.step) |
| except: |
| print("Random init") |
| else: |
| print("Random init") |
|
|
| model = flax.jax_utils.replicate(model, devices=jax.local_devices()) |
| jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias']) |
|
|
| |
| |
| |
| |
| best_fid = 100000 |
|
|
| for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1), |
| smoothing=0.1, |
| dynamic_ncols=True): |
|
|
| batch_images = next(dataset) |
| batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) |
|
|
| model, update_info = model.update(batch_images) |
|
|
| if i % FLAGS.log_interval == 0: |
| update_info = jax.tree_map(lambda x: x.mean(), update_info) |
| train_metrics = {f'training/{k}': v for k, v in update_info.items()} |
| if jax.process_index() == 0: |
| wandb.log(train_metrics, step=i) |
|
|
| if i % FLAGS.eval_interval == 0: |
| |
| reconstructed_images = model.reconstruction(batch_images) |
| valid_images = next(dataset_valid) |
| valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) |
| valid_reconstructed_images = model.reconstruction(valid_images) |
|
|
| if jax.process_index() == 0: |
| wandb.log({'batch_image_mean': batch_images.mean()}, step=i) |
| wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i) |
| wandb.log({'batch_image_std': batch_images.std()}, step=i) |
| wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i) |
|
|
| |
| fig, axs = plt.subplots(2, 8, figsize=(30, 15)) |
| |
| |
| |
| |
| for j in range(4): |
| axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1) |
| axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1) |
| wandb.log({'reconstruction': wandb.Image(fig)}, step=i) |
| plt.close(fig) |
| fig, axs = plt.subplots(2, 8, figsize=(30, 15)) |
| for j in range(4): |
| axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1) |
| axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1) |
| wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i) |
| plt.close(fig) |
|
|
| |
| _, valid_update_info = model.update(valid_images) |
| valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info) |
| valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()} |
| if jax.process_index() == 0: |
| wandb.log(valid_metrics, step=i) |
|
|
| |
| activations = [] |
| activations2 = [] |
| for _ in range(780): |
| valid_images = next(dataset_valid) |
| valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) |
| valid_reconstructed_images = model.reconstruction(valid_images) |
|
|
| valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3), |
| method='bilinear', antialias=False) |
| valid_reconstructed_images = 2 * valid_reconstructed_images - 1 |
| activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]] |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| activations = np.concatenate(activations, axis=0) |
| activations = activations.reshape((-1, activations.shape[-1])) |
|
|
| |
| |
|
|
| print("doing this much FID", activations.shape) |
| mu1 = np.mean(activations, axis=0) |
| sigma1 = np.cov(activations, rowvar=False) |
| fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma']) |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| if jax.process_index() == 0: |
| wandb.log({'validation/fid': fid}, step=i) |
| print("validation FID at step", i, fid) |
| |
| if fid < best_fid: |
| model_single = flax.jax_utils.unreplicate(model) |
| cp = Checkpoint(FLAGS.save_dir + "best.tmp") |
| cp.set_model(model_single) |
| cp.save() |
| best_fid = fid |
|
|
| if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None): |
| if jax.process_index() == 0: |
| model_single = flax.jax_utils.unreplicate(model) |
| cp = Checkpoint(FLAGS.save_dir) |
| cp.set_model(model_single) |
| cp.save() |
|
|
| if __name__ == '__main__': |
| app.run(main) |
|
|