Upload folder using huggingface_hub
Browse files- f16c16/all_stats.py +429 -0
- f16c16/decode_only.py +108 -0
- f16c16/encode_latents.py +338 -0
- f16c16/eval_fid.py +214 -0
- f16c16/evaluator.py +654 -0
- f16c16/graph-data.py +169 -0
- f16c16/kl_test.py +31 -0
- f16c16/latent_distances.py +293 -0
- f16c16/make_samples.py +205 -0
- f16c16/models/__pycache__/discriminator.cpython-310.pyc +0 -0
- f16c16/models/__pycache__/discriminator.cpython-312.pyc +0 -0
- f16c16/models/__pycache__/vqvae.cpython-310.pyc +0 -0
- f16c16/models/__pycache__/vqvae.cpython-312.pyc +0 -0
- f16c16/models/back_model.py +343 -0
- f16c16/models/discriminator.py +123 -0
- f16c16/models/vqvae.py +527 -0
- f16c16/ppl_images.py +255 -0
- f16c16/ppl_latents.py +307 -0
- f16c16/ppl_latents2.py +283 -0
- f16c16/stats.py +362 -0
- f16c16/train.py +676 -0
f16c16/all_stats.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
#Apparently we've always been running this code on cpu.
|
| 12 |
+
|
| 13 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 14 |
+
|
| 15 |
+
import jax
|
| 16 |
+
import lpips
|
| 17 |
+
|
| 18 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 19 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from dadapy.data import Data
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import flax.linen as nn
|
| 26 |
+
import jax.numpy as jnp
|
| 27 |
+
from absl import app, flags
|
| 28 |
+
from functools import partial
|
| 29 |
+
import numpy as np
|
| 30 |
+
import tqdm
|
| 31 |
+
import flax
|
| 32 |
+
import optax
|
| 33 |
+
import wandb
|
| 34 |
+
from ml_collections import config_flags
|
| 35 |
+
#import elements
|
| 36 |
+
import ml_collections
|
| 37 |
+
import tensorflow_datasets as tfds
|
| 38 |
+
import tensorflow as tf
|
| 39 |
+
tf.config.set_visible_devices([], "GPU")
|
| 40 |
+
tf.config.set_visible_devices([], "TPU")
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
from typing import Any
|
| 43 |
+
|
| 44 |
+
from utils.train_state import TrainState, target_update
|
| 45 |
+
from utils.checkpoint import Checkpoint
|
| 46 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 47 |
+
|
| 48 |
+
from train import VQGANModel
|
| 49 |
+
from models.vqvae import VQVAE
|
| 50 |
+
from models.discriminator import Discriminator
|
| 51 |
+
|
| 52 |
+
from PIL import Image
|
| 53 |
+
import torch
|
| 54 |
+
|
| 55 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 56 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 57 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 58 |
+
|
| 59 |
+
FLAGS = flags.FLAGS
|
| 60 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 61 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 65 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 66 |
+
|
| 67 |
+
import gc
|
| 68 |
+
|
| 69 |
+
from scipy.spatial.distance import cdist
|
| 70 |
+
#
|
| 71 |
+
def relative(images, latents):
|
| 72 |
+
#Get the distance matrix for images
|
| 73 |
+
#Get the distance matrix for latents
|
| 74 |
+
|
| 75 |
+
images = images.reshape(images.shape[0], -1)
|
| 76 |
+
latents = latents.reshape(latents.shape[0], -1)
|
| 77 |
+
|
| 78 |
+
image_distances = cdist(images, images, metric='euclidean')
|
| 79 |
+
latent_distances = cdist(latents, latents, metric='euclidean')
|
| 80 |
+
|
| 81 |
+
#Probably want cosine for latents.
|
| 82 |
+
#Now, we need to find the C that best matches....
|
| 83 |
+
#So we just do images/latents, then take stats on that.
|
| 84 |
+
c = image_distances/latent_distances
|
| 85 |
+
print("mean C", np.mean(c))
|
| 86 |
+
print("C std", np.std(c))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def operations(reconstructed_images, decoded):
|
| 91 |
+
|
| 92 |
+
reconstructed_images = reconstructed_images * 2 - 1
|
| 93 |
+
decoded = decoded * 2 -1
|
| 94 |
+
|
| 95 |
+
#Turn from 1,2,256,256,3
|
| 96 |
+
#To 2,3,256,256
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
reconstructed_images = jax.dlpack.to_dlpack(reconstructed_images)
|
| 100 |
+
reconstructed_images = torch.utils.dlpack.from_dlpack(reconstructed_images)
|
| 101 |
+
|
| 102 |
+
decoded = jax.dlpack.to_dlpack(decoded)
|
| 103 |
+
decoded = torch.utils.dlpack.from_dlpack(decoded)
|
| 104 |
+
|
| 105 |
+
reconstructed_images = reconstructed_images.squeeze()
|
| 106 |
+
decoded = decoded.squeeze()
|
| 107 |
+
|
| 108 |
+
reconstructed_images = reconstructed_images.permute(0, 3, 1, 2)
|
| 109 |
+
decoded = decoded.permute(0, 3, 1, 2)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
lpips_loss = loss_fn_alex(reconstructed_images, decoded)
|
| 113 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 114 |
+
lpips_cpu = lpips_cpu / (.0001 ** 2)
|
| 115 |
+
|
| 116 |
+
return lpips_cpu
|
| 117 |
+
|
| 118 |
+
def main(_):
|
| 119 |
+
device_count = len(jax.local_devices())
|
| 120 |
+
global_device_count = jax.device_count()
|
| 121 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 122 |
+
|
| 123 |
+
def get_dataset(is_train):
|
| 124 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 125 |
+
def deserialization_fn(data):
|
| 126 |
+
image = data['image']
|
| 127 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 128 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 129 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 130 |
+
image = tf.image.resize(image, (256, 256))
|
| 131 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 132 |
+
image = tf.image.resize(image, (128, 128))
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 135 |
+
if is_train:
|
| 136 |
+
image = tf.image.random_flip_left_right(image)
|
| 137 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 138 |
+
return image
|
| 139 |
+
|
| 140 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 141 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 142 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 143 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 144 |
+
dataset = dataset.batch(local_batch_size)
|
| 145 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 146 |
+
dataset = tfds.as_numpy(dataset)
|
| 147 |
+
dataset = iter(dataset)
|
| 148 |
+
return dataset
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 151 |
+
|
| 152 |
+
dataset = get_dataset(is_train=True)
|
| 153 |
+
dataset_valid = get_dataset(is_train=False)
|
| 154 |
+
|
| 155 |
+
example_obs = next(dataset)[:1]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 159 |
+
rng, param_key = jax.random.split(rng)
|
| 160 |
+
print("Total devices", jax.local_devices()[0])
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
###################################
|
| 164 |
+
# Creating Model and put on devices.
|
| 165 |
+
###################################
|
| 166 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 167 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 168 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 169 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 170 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 171 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 172 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 173 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 174 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 175 |
+
|
| 176 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 177 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 178 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 179 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 180 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 181 |
+
|
| 182 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 183 |
+
|
| 184 |
+
assert FLAGS.load_dir is not None
|
| 185 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 186 |
+
model = cp.load_model(model)
|
| 187 |
+
print("Loaded model with step", model.vqvae.step)
|
| 188 |
+
|
| 189 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 190 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 191 |
+
#print(model.vqvae)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
####################################
|
| 196 |
+
# Noise stuff
|
| 197 |
+
###################################
|
| 198 |
+
|
| 199 |
+
cpus = jax.devices("cpu")
|
| 200 |
+
|
| 201 |
+
i = 0
|
| 202 |
+
lpips_list = []
|
| 203 |
+
lpips_list_ppl_two = []
|
| 204 |
+
means = []
|
| 205 |
+
stds = []
|
| 206 |
+
|
| 207 |
+
noisy_means = []
|
| 208 |
+
noisy_stds = []
|
| 209 |
+
|
| 210 |
+
predicted_stds = []
|
| 211 |
+
|
| 212 |
+
noisy_predicted_stds = []
|
| 213 |
+
|
| 214 |
+
latent_list = []
|
| 215 |
+
#TODO
|
| 216 |
+
#equivariance loss, DCT shit, psnr, ssim
|
| 217 |
+
#Instead of isometry, we want... RELATIVEMTRY
|
| 218 |
+
#Gini coefficient
|
| 219 |
+
#denstity cv
|
| 220 |
+
#normalized entropy
|
| 221 |
+
#"uniformity" - basically related to the covariance loss? How spread out the pionts are
|
| 222 |
+
|
| 223 |
+
#relativemtry basically says:
|
| 224 |
+
#Given the function F, that turn x into x'
|
| 225 |
+
#For all possible x, y within X, |x - y| = C [x' - y'|
|
| 226 |
+
#Is this a desirable property though?
|
| 227 |
+
#Who cares, let's calculate it anyway
|
| 228 |
+
|
| 229 |
+
#
|
| 230 |
+
#Need to try out our own f16c16, which is the same compression as f8c4
|
| 231 |
+
#We will try
|
| 232 |
+
#1,1,2,2,4
|
| 233 |
+
#1,2,2,4,4
|
| 234 |
+
#1,2,4,8,8
|
| 235 |
+
#1,2,4,4,4
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
for valid_images in dataset_valid:
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 243 |
+
#1, 2, 256, 256, 3
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
#Regular PPL
|
| 247 |
+
reconstructed_images, decoded, std, latents = model.reconstruction_ppl(valid_images) # [devices, 8, 256, 256, 3]
|
| 248 |
+
#Leaves channel dim out
|
| 249 |
+
mean = jnp.mean(latents, axis = [0,1,2,3])
|
| 250 |
+
std = jnp.std(latents, axis = [0,1,2,3])
|
| 251 |
+
|
| 252 |
+
#TODO maybe need to put this onto CPU
|
| 253 |
+
latent_list.append(latents)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
means.append(mean)
|
| 258 |
+
stds.append(std)
|
| 259 |
+
|
| 260 |
+
predicted_stds.append(std)
|
| 261 |
+
|
| 262 |
+
lpips_list.append(operations(reconstructed_images, decoded))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
#PPL two, walk both directions
|
| 266 |
+
reconstructed_images, decoded, std, latents, decoded_2 = model.reconstruction_ppl_two(valid_images) # [devices, 8, 256, 256, 3]
|
| 267 |
+
#For this one we don't care about reconstructed images, only decoded and decoded 2
|
| 268 |
+
|
| 269 |
+
lpips_list_ppl_two.append(operations(decoded, decoded_2))
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
#Ppl but images.
|
| 275 |
+
reconstructed_images, decoded, std, latents, std_noisy, latents_noisy = model.reconstruction_ppl_image(valid_images) # [devices, 8, 256, 256, 3]
|
| 276 |
+
noisy_means.append(latents_noisy.mean(axis = [0,1,2,3]))
|
| 277 |
+
noisy_stds.append(latents_noisy.std(axis = [0,1,2,3]))
|
| 278 |
+
noisy_predicted_stds.append(std_noisy)
|
| 279 |
+
|
| 280 |
+
#TODO WHAT IS THE LOSS FUNCTION FOR THIS ONE
|
| 281 |
+
#it's not quite perplexity, but there's two components
|
| 282 |
+
#one is that we check lpips difference as a function of final image
|
| 283 |
+
#The other is that we look at how far away the latents are, and see if that is consistent.
|
| 284 |
+
|
| 285 |
+
i += 1
|
| 286 |
+
#
|
| 287 |
+
if i == 500:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
#Should be just 4 here, so... 0?
|
| 291 |
+
|
| 292 |
+
mean_lpips = jnp.mean(jnp.asarray(lpips_list))
|
| 293 |
+
|
| 294 |
+
#So our lpips list or whatever is like. Maybe we want per channel?
|
| 295 |
+
std_lpips = jnp.std(jnp.asarray(lpips_list))
|
| 296 |
+
print("PPL Regular", mean_lpips)
|
| 297 |
+
print("C std", std_lpips)
|
| 298 |
+
|
| 299 |
+
#So here we have 500/50,000 x 4.
|
| 300 |
+
#We can mean, get the mean per channel.
|
| 301 |
+
#We can get the std per channel.
|
| 302 |
+
|
| 303 |
+
print("mean of means", jnp.asarray(means).mean(axis = [0]))
|
| 304 |
+
print("stds of means", jnp.asarray(means).std(axis = [0]))
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
print("mean of stds", jnp.asarray(stds).mean(axis = [0]))
|
| 308 |
+
print("std of stds", jnp.asarray(stds).std(axis = [0]))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
mean_lpips = jnp.mean(jnp.asarray(lpips_list_ppl_two))
|
| 313 |
+
std_lpips = jnp.std(jnp.asarray(lpips_list_ppl_two))
|
| 314 |
+
|
| 315 |
+
print("PPL Two", mean_lpips)
|
| 316 |
+
print("C std Two", std_lpips)
|
| 317 |
+
|
| 318 |
+
print("noisy mean of means", jnp.asarray(noisy_means).mean(axis = [0]))
|
| 319 |
+
print("noisy stds of means", jnp.asarray(noisy_means).std(axis = [0]))
|
| 320 |
+
print("noisy mean of stds", jnp.asarray(noisy_stds).mean(axis = [0]))
|
| 321 |
+
print("noisy std of stds", jnp.asarray(noisy_stds).std(axis = [0]))
|
| 322 |
+
|
| 323 |
+
print("Average noise added to image", jnp.asarray(predicted_stds).mean(axis = [0]))
|
| 324 |
+
print("Average noise added to image std", jnp.asarray(predicted_stds).std(axis = [0]))
|
| 325 |
+
|
| 326 |
+
print("Average noise added to noisy image", jnp.asarray(noisy_predicted_stds).mean(axis = [0, 1, 2, 3, 4]))
|
| 327 |
+
print("Average noise added to noisy image std", jnp.asarray(noisy_predicted_stds).std(axis = [0, 1, 2, 3, 4]))
|
| 328 |
+
|
| 329 |
+
print("Effective new variance (sqrt it)", jnp.asarray(noisy_predicted_stds).std(axis = [0,1,2,3,4]) ** 2 + jnp.asarray(stds).mean(axis = [0]) ** 2)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
#Intrinsic
|
| 333 |
+
latent_list = np.asarray(latent_list).squeeze()
|
| 334 |
+
print(latent_list.shape)#Should be like, 500,2,32,32,4
|
| 335 |
+
latent_list = latent_list.reshape(-1,32,32,4)
|
| 336 |
+
latent_list = latent_list.reshape(latent_list.shape[0], -1)
|
| 337 |
+
latent_list = Data(latent_list)
|
| 338 |
+
latent_list.compute_distances(maxk=100)
|
| 339 |
+
|
| 340 |
+
# compute the intrinsic dimension using 2nn estimator
|
| 341 |
+
id, id_error, id_distance = latent_list.compute_id_2NN()
|
| 342 |
+
print(id, id_error, id_distance)
|
| 343 |
+
|
| 344 |
+
#None of these stats take anything else into account.
|
| 345 |
+
#No normalization, nothing
|
| 346 |
+
"""PL 100
|
| 347 |
+
PPL Regular 6.3766294
|
| 348 |
+
C std 0.9229477
|
| 349 |
+
mean of means 0.16227543
|
| 350 |
+
stds of means 0.53616405
|
| 351 |
+
mean of stds 4.4914503
|
| 352 |
+
std of stds 0.6015057
|
| 353 |
+
PPL Two 6.3642726
|
| 354 |
+
C std Two 0.92391133
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
"""1e-4
|
| 359 |
+
PPL Regular 12.521122
|
| 360 |
+
C std 2.3125298
|
| 361 |
+
mean of means 0.0065882676
|
| 362 |
+
stds of means 0.042861093
|
| 363 |
+
mean of stds 0.7608507
|
| 364 |
+
std of stds 0.05846726
|
| 365 |
+
PPL Two 12.581134
|
| 366 |
+
C std Two 2.5102239
|
| 367 |
+
Average noise added to image 0.5992337
|
| 368 |
+
Average noise added to image std 0.25218853
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
"""1e-5
|
| 373 |
+
PPL Regular 13.183324
|
| 374 |
+
C std 2.9292953
|
| 375 |
+
mean of means 0.0065166513
|
| 376 |
+
stds of means 0.06983645
|
| 377 |
+
mean of stds 0.9855982
|
| 378 |
+
std of stds 0.05810356
|
| 379 |
+
PPL Two 13.193566
|
| 380 |
+
C std Two 2.9465785
|
| 381 |
+
Average noise added to image 0.16906397
|
| 382 |
+
Average noise added to image std 0.12756345
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
"""1e-6
|
| 386 |
+
PPL Regular 14.146276
|
| 387 |
+
C std 3.6374733
|
| 388 |
+
mean of means -0.018107202
|
| 389 |
+
stds of means 0.11694455
|
| 390 |
+
mean of stds 1.0860059
|
| 391 |
+
std of stds 0.09732369
|
| 392 |
+
PPL Two 14.116948
|
| 393 |
+
C std Two 3.547216
|
| 394 |
+
Average noise added to image 0.039256155
|
| 395 |
+
Average noise added to image std 0.026851926
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
"""AE
|
| 399 |
+
PPL Regular 10.103417
|
| 400 |
+
C std 2.2966182
|
| 401 |
+
mean of means 0.35234922
|
| 402 |
+
stds of means 0.4036692
|
| 403 |
+
mean of stds 2.6363409
|
| 404 |
+
std of stds 0.30666474
|
| 405 |
+
PPL Two 10.075436
|
| 406 |
+
C std Two 2.2949345
|
| 407 |
+
No noise added to image
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
"""Dino 1e-5
|
| 411 |
+
PPL Regular 2.373527
|
| 412 |
+
C std 0.45295972
|
| 413 |
+
mean of means 2.5987418
|
| 414 |
+
stds of means 3.097953
|
| 415 |
+
mean of stds 49.437305
|
| 416 |
+
std of stds 2.5111952
|
| 417 |
+
PPL Two 2.3797483
|
| 418 |
+
C std Two 0.49930122
|
| 419 |
+
noisy mean of means 2.598704
|
| 420 |
+
noisy stds of means 3.0979395
|
| 421 |
+
noisy mean of stds 49.437298
|
| 422 |
+
noisy std of stds 2.5112264
|
| 423 |
+
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
#58.344119061134336 0.0 57.78905382129868
|
| 427 |
+
|
| 428 |
+
if __name__ == '__main__':
|
| 429 |
+
app.run(main)
|
f16c16/decode_only.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
import jax
|
| 12 |
+
|
| 13 |
+
import flax.linen as nn
|
| 14 |
+
import jax.numpy as jnp
|
| 15 |
+
from absl import app, flags
|
| 16 |
+
from functools import partial
|
| 17 |
+
import numpy as np
|
| 18 |
+
import tqdm
|
| 19 |
+
import flax
|
| 20 |
+
import optax
|
| 21 |
+
import wandb
|
| 22 |
+
from ml_collections import config_flags
|
| 23 |
+
#import elements
|
| 24 |
+
import ml_collections
|
| 25 |
+
import tensorflow_datasets as tfds
|
| 26 |
+
import tensorflow as tf
|
| 27 |
+
tf.config.set_visible_devices([], "GPU")
|
| 28 |
+
tf.config.set_visible_devices([], "TPU")
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
from typing import Any
|
| 31 |
+
|
| 32 |
+
from utils.train_state import TrainState, target_update
|
| 33 |
+
from utils.checkpoint import Checkpoint
|
| 34 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 35 |
+
|
| 36 |
+
from train import VQGANModel
|
| 37 |
+
from models.vqvae import VQVAE
|
| 38 |
+
from models.discriminator import Discriminator
|
| 39 |
+
|
| 40 |
+
from PIL import Image
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 44 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 45 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 46 |
+
|
| 47 |
+
FLAGS = flags.FLAGS
|
| 48 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 49 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Documents/LiClipse Workspace/VAE/jax-vqvae-vqgan/7e-5_sdlike_sym/checkpoint.tmp", 'Load dir (if not None, load params from here).')
|
| 50 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 51 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 52 |
+
|
| 53 |
+
def main(_):
|
| 54 |
+
device_count = len(jax.local_devices())
|
| 55 |
+
global_device_count = jax.device_count()
|
| 56 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 60 |
+
rng, param_key = jax.random.split(rng)
|
| 61 |
+
print("Total devices", jax.local_devices()[0])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
###################################
|
| 65 |
+
# Creating Model and put on devices.
|
| 66 |
+
###################################
|
| 67 |
+
|
| 68 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 69 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 70 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 71 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
|
| 72 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 73 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 74 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 75 |
+
|
| 76 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 77 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 78 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 79 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
|
| 80 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 81 |
+
|
| 82 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 83 |
+
|
| 84 |
+
assert FLAGS.load_dir is not None
|
| 85 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 86 |
+
model = cp.load_model(model)
|
| 87 |
+
print("Loaded model with step", model.vqvae.step)
|
| 88 |
+
|
| 89 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 90 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 91 |
+
|
| 92 |
+
return model
|
| 93 |
+
|
| 94 |
+
#Stuff and things.
|
| 95 |
+
# image2 = valid_reconstructed_images[0,0,:,:,:]
|
| 96 |
+
# image2 = (image2 * 255).astype(np.uint8)
|
| 97 |
+
# image2 = np.array(image2)
|
| 98 |
+
# image2 = Image.fromarray(image2)
|
| 99 |
+
# image2.save("recon" + str(i) + ".png")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# images.append((valid_reconstructed_images*255).astype(np.uint8))
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
app.run(main)
|
f16c16/encode_latents.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
#GPU, batch 16, latent:
|
| 8 |
+
"""[[[[-9.51360688e-02 -6.00612536e-02 -6.76547512e-02 -3.73330832e-01]
|
| 9 |
+
[-3.10049266e-01 -6.82027787e-02 1.09544434e-01 -1.51526511e-01]
|
| 10 |
+
[-1.63606599e-01 1.52324408e-01 1.03230253e-01 -3.34064662e-01]
|
| 11 |
+
...
|
| 12 |
+
[-9.08230543e-02 2.53294855e-01 6.09488077e-02 -3.55355501e-01]
|
| 13 |
+
[-2.16098756e-01 -3.44716787e-01 5.68981618e-02 -1.19108176e+00]
|
| 14 |
+
[ 9.24487635e-02 2.20324457e-01 1.84478119e-01 4.46850598e-01]]
|
| 15 |
+
|
| 16 |
+
[[-1.60119295e-01 2.00234763e-02 -1.43943653e-01 -2.22745568e-01]
|
| 17 |
+
[-2.55345762e-01 1.55626327e-01 4.85354941e-03 -1.33636221e-01]
|
| 18 |
+
[-1.64813206e-01 1.63652197e-01 -6.96032941e-02 -3.96138221e-01]
|
| 19 |
+
...
|
| 20 |
+
[-1.74221992e-01 2.78679162e-01 -1.02342315e-01 -4.71356630e-01]
|
| 21 |
+
[-9.72934887e-02 2.24700689e-01 -1.54692575e-01 -8.07371676e-01]
|
| 22 |
+
[ 1.58384442e-02 9.63119492e-02 4.84653771e-01 8.73409092e-01]]
|
| 23 |
+
|
| 24 |
+
[[-1.16939977e-01 2.56956398e-01 -1.04373530e-01 -1.33346528e-01]
|
| 25 |
+
[-1.52860105e-01 1.76005200e-01 -1.16914781e-02 -1.92210004e-01]
|
| 26 |
+
[-5.50103635e-02 2.04600886e-01 -1.73305750e-01 -4.94984031e-01]
|
| 27 |
+
...
|
| 28 |
+
[-3.88413459e-01 3.15461606e-01 -1.25539899e-01 -5.62439263e-01]
|
| 29 |
+
[-1.97147772e-01 -2.31708195e-02 -1.44041494e-01 -8.99005592e-01]
|
| 30 |
+
[ 3.42922032e-01 2.24075779e-01 4.25257713e-01 5.85853398e-01]]
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
#CPU, batch 16, latent
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
[[[[-8.47917721e-02 -8.92071351e-02 -1.05532585e-02 -3.59174877e-01]
|
| 38 |
+
[-1.11725748e-01 -1.22415572e-01 3.33435684e-02 -3.60438257e-01]
|
| 39 |
+
[-1.36060238e-01 -1.37327328e-01 3.79590057e-02 -3.73947173e-01]
|
| 40 |
+
...
|
| 41 |
+
[ 7.88694695e-02 -5.03079742e-02 6.75498620e-02 -3.39441150e-01]
|
| 42 |
+
[-1.63178548e-01 -3.21848512e-01 1.72039792e-02 -9.50528085e-01]
|
| 43 |
+
[ 2.21429523e-02 1.48582339e-01 1.54685006e-01 6.86266243e-01]]
|
| 44 |
+
|
| 45 |
+
[[-1.69139117e-01 7.81316869e-03 4.33448888e-02 -3.37453634e-01]
|
| 46 |
+
[-1.96011692e-01 -4.98509258e-02 3.32896858e-02 -3.53303224e-01]
|
| 47 |
+
[-9.82111022e-02 -1.94629002e-02 -1.63653865e-02 -3.32124978e-01]
|
| 48 |
+
...
|
| 49 |
+
[-7.72062615e-02 2.95878220e-02 -7.62912910e-03 -3.61496925e-01]
|
| 50 |
+
[-2.26189673e-01 -5.97889721e-02 -1.16483821e-02 -7.82557964e-01]
|
| 51 |
+
[-6.18810430e-02 7.75512159e-02 2.37205133e-01 8.39313030e-01]]
|
| 52 |
+
|
| 53 |
+
[[-9.37198251e-02 -4.58365604e-02 -2.44572274e-02 -3.00568134e-01]
|
| 54 |
+
[-1.32911175e-01 -9.60890502e-02 -4.78822738e-04 -3.28105956e-01]
|
| 55 |
+
[-7.67295957e-02 -6.57245517e-02 -3.78448963e-02 -3.29079330e-01]
|
| 56 |
+
...
|
| 57 |
+
[-1.21173687e-01 4.07976359e-02 4.05129045e-02 -3.48512828e-01]
|
| 58 |
+
[-1.64501339e-01 -9.52737629e-02 -1.06653105e-03 -8.39630961e-01]
|
| 59 |
+
[ 2.64041096e-01 2.43525319e-02 3.05205405e-01 4.92310941e-01]]
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
#CPU, 8 vs GPU 8
|
| 63 |
+
"""
|
| 64 |
+
[[[[[-3.18646997e-01 -4.77920741e-01 1.07763827e+00 1.70530510e+00]
|
| 65 |
+
[-6.31720126e-01 -2.49106735e-01 1.66874206e+00 -5.45821428e-01]
|
| 66 |
+
[-4.03593808e-01 2.76418477e-01 1.29216135e+00 8.79887521e-01]
|
| 67 |
+
...
|
| 68 |
+
[-2.03093603e-01 -7.97204554e-01 3.61778885e-01 -3.68656218e-01]
|
| 69 |
+
[-2.61139393e-01 1.64036989e+00 -2.22024798e-01 3.49313989e-02]
|
| 70 |
+
[ 6.32668972e-01 -4.74448204e-01 1.55093277e+00 5.57837903e-01]]
|
| 71 |
+
|
| 72 |
+
[[-7.24952042e-01 4.80744302e-01 3.05105478e-01 1.06132841e+00]
|
| 73 |
+
[ 8.95307362e-02 1.45687327e-01 1.57945228e+00 -1.11452961e+00]
|
| 74 |
+
[-4.61988777e-01 -4.11880344e-01 1.70428991e+00 4.31171536e-01]
|
| 75 |
+
...
|
| 76 |
+
[-1.17851949e+00 2.03509808e-01 1.84925032e+00 -5.68852723e-01]
|
| 77 |
+
[ 5.74628949e-01 -8.48990500e-01 -2.50778824e-01 1.92248678e+00]
|
| 78 |
+
[-2.69778688e-02 -8.46022546e-01 -7.89667487e-01 9.26319182e-01]]
|
| 79 |
+
|
| 80 |
+
[[-3.10738117e-01 6.01165593e-02 1.57032907e-01 1.53192639e+00]
|
| 81 |
+
[ 6.55903339e-01 7.50707746e-01 6.03949744e-03 1.31769347e+00]
|
| 82 |
+
[ 3.26834202e-01 -2.33611539e-01 1.35725603e-01 -2.39371091e-01]
|
| 83 |
+
...
|
| 84 |
+
[ 2.19290599e-01 -2.21653271e+00 -2.21055865e+00 1.49363160e+00]
|
| 85 |
+
[-1.45460200e+00 1.18737824e-01 1.56015289e+00 8.23014230e-03]
|
| 86 |
+
[ 3.44308168e-01 1.08958745e+00 -1.23330317e-01 5.41093886e-01]]
|
| 87 |
+
|
| 88 |
+
#GPU
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
#import jax
|
| 94 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 95 |
+
import os
|
| 96 |
+
|
| 97 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 98 |
+
|
| 99 |
+
import jax
|
| 100 |
+
import lpips
|
| 101 |
+
|
| 102 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 103 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
import numpy as np
|
| 107 |
+
import flax.linen as nn
|
| 108 |
+
import jax.numpy as jnp
|
| 109 |
+
from absl import app, flags
|
| 110 |
+
from functools import partial
|
| 111 |
+
import numpy as np
|
| 112 |
+
import tqdm
|
| 113 |
+
import flax
|
| 114 |
+
import optax
|
| 115 |
+
import wandb
|
| 116 |
+
from ml_collections import config_flags
|
| 117 |
+
#import elements
|
| 118 |
+
import ml_collections
|
| 119 |
+
import tensorflow_datasets as tfds
|
| 120 |
+
import tensorflow as tf
|
| 121 |
+
tf.config.set_visible_devices([], "GPU")
|
| 122 |
+
tf.config.set_visible_devices([], "TPU")
|
| 123 |
+
import matplotlib.pyplot as plt
|
| 124 |
+
from typing import Any
|
| 125 |
+
|
| 126 |
+
from utils.train_state import TrainState, target_update
|
| 127 |
+
from utils.checkpoint import Checkpoint
|
| 128 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 129 |
+
|
| 130 |
+
from train import VQGANModel
|
| 131 |
+
from models.vqvae import VQVAE
|
| 132 |
+
from models.discriminator import Discriminator
|
| 133 |
+
|
| 134 |
+
from PIL import Image
|
| 135 |
+
import torch
|
| 136 |
+
|
| 137 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 138 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 139 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 140 |
+
|
| 141 |
+
FLAGS = flags.FLAGS
|
| 142 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 143 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 144 |
+
|
| 145 |
+
from safetensors.torch import save_file
|
| 146 |
+
|
| 147 |
+
flags.DEFINE_integer('batch_size', 8, 'Total Batch size.')
|
| 148 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 149 |
+
|
| 150 |
+
import gc
|
| 151 |
+
|
| 152 |
+
def main(_):
|
| 153 |
+
device_count = len(jax.local_devices())
|
| 154 |
+
global_device_count = jax.device_count()
|
| 155 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 156 |
+
|
| 157 |
+
def get_dataset(is_train):
|
| 158 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 159 |
+
def deserialization_fn(data):
|
| 160 |
+
image = data['image']
|
| 161 |
+
label = data["label"]
|
| 162 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 163 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 164 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 165 |
+
image = tf.image.resize(image, (256, 256))
|
| 166 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 167 |
+
image = tf.image.resize(image, (128, 128))
|
| 168 |
+
else:
|
| 169 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 170 |
+
if is_train:
|
| 171 |
+
# image = tf.image.random_flip_left_right(image)
|
| 172 |
+
image_flip =tf.image.flip_left_right(image)
|
| 173 |
+
image_flip = tf.cast(image_flip, tf.float32) / 255.0
|
| 174 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 175 |
+
return image, image_flip, label
|
| 176 |
+
|
| 177 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 178 |
+
return image, label
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 183 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 184 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 185 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 186 |
+
dataset = dataset.batch(local_batch_size)
|
| 187 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 188 |
+
dataset = tfds.as_numpy(dataset)
|
| 189 |
+
dataset = iter(dataset)
|
| 190 |
+
return dataset
|
| 191 |
+
else:
|
| 192 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 193 |
+
|
| 194 |
+
dataset = get_dataset(is_train=True)
|
| 195 |
+
dataset_valid = get_dataset(is_train=False)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# image = Image.open("osman.png")
|
| 199 |
+
# image = np.array(image) / 255.0
|
| 200 |
+
# print(image)
|
| 201 |
+
# image = jnp.array(image)
|
| 202 |
+
# image = jnp.expand_dims(image, 0)
|
| 203 |
+
# image = jnp.expand_dims(image, 0)
|
| 204 |
+
|
| 205 |
+
example_obs = next(dataset)[:1][0]
|
| 206 |
+
|
| 207 |
+
#Reconstruction loop
|
| 208 |
+
# image = model.reconstruction(image)
|
| 209 |
+
# image = image[0,0,:,:,:]
|
| 210 |
+
# image = (image * 255).astype(np.uint8)
|
| 211 |
+
# image = np.array(image)
|
| 212 |
+
# img = Image.fromarray(image)
|
| 213 |
+
# img.save("osman" + str(i) + ".png")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 217 |
+
rng, param_key = jax.random.split(rng)
|
| 218 |
+
print("Total devices", jax.local_devices()[0])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
###################################
|
| 222 |
+
# Creating Model and put on devices.
|
| 223 |
+
###################################
|
| 224 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 225 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 226 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 227 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 228 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 229 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 230 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 231 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 232 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 233 |
+
|
| 234 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 235 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 236 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 237 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 238 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 239 |
+
|
| 240 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 241 |
+
|
| 242 |
+
assert FLAGS.load_dir is not None
|
| 243 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 244 |
+
model = cp.load_model(model)
|
| 245 |
+
print("Loaded model with step", model.vqvae.step)
|
| 246 |
+
|
| 247 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 248 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 249 |
+
#print(model.vqvae)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
latents = []
|
| 253 |
+
latents_flip = []
|
| 254 |
+
labels = []
|
| 255 |
+
saved_files = 0
|
| 256 |
+
for image, image_flip, label in dataset:
|
| 257 |
+
#Also need to hflp the image
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
image = image.reshape((len(jax.local_devices()), -1, *image.shape[1:])) # [devices, batch//devices, etc..]
|
| 261 |
+
latent, result_dict = model.get_latent(image)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
image_flip = image_flip.reshape((len(jax.local_devices()), -1, *image_flip.shape[1:])) # [devices, batch//devices, etc..]
|
| 265 |
+
latent_flip, result_dict_flip = model.get_latent(image_flip)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
latents.append(latent.squeeze())
|
| 269 |
+
latents_flip.append(latent_flip.squeeze())
|
| 270 |
+
labels.append(label)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if len(latents) == 5000:#Since we are bs 2, should be 5k
|
| 274 |
+
|
| 275 |
+
latents = jnp.concatenate(latents, axis=0)
|
| 276 |
+
latents_flip = jnp.concatenate(latents_flip, axis=0)
|
| 277 |
+
labels = jnp.concatenate(labels, axis=0)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
latents_torch = np.asarray(latents)
|
| 281 |
+
latents_torch = torch.from_numpy(np.copy(latents_torch))
|
| 282 |
+
|
| 283 |
+
latents_flip_torch = np.asarray(latents_flip)
|
| 284 |
+
latents_flip_torch = torch.from_numpy(np.copy(latents_flip_torch))
|
| 285 |
+
|
| 286 |
+
labels_torch = np.asarray(labels)
|
| 287 |
+
labels_torch = torch.from_numpy(np.copy(labels_torch))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
save_dict = {
|
| 291 |
+
'latents': latents_torch,
|
| 292 |
+
'latents_flip': latents_flip_torch,
|
| 293 |
+
'labels': labels_torch
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
print(latents_torch.shape)#400,32,32,4
|
| 297 |
+
print(latents_flip_torch.shape)#^
|
| 298 |
+
print(labels_torch.shape)#400
|
| 299 |
+
|
| 300 |
+
#Now we need to calculate the man
|
| 301 |
+
# print("Total mean", latents_torch.mean(axis = [0]))
|
| 302 |
+
# class_means = {}
|
| 303 |
+
# for label, tensor in zip(labels_torch, latents_torch):
|
| 304 |
+
# label = str(label.item())
|
| 305 |
+
# if label in class_means.keys():
|
| 306 |
+
# class_means[label].append(tensor)
|
| 307 |
+
# else:
|
| 308 |
+
# class_means[label] = [tensor]
|
| 309 |
+
#
|
| 310 |
+
#
|
| 311 |
+
# for iclass in class_means.keys():
|
| 312 |
+
# #So now we have a list of tensors
|
| 313 |
+
# stacked_tensors = torch.stack(class_means[iclass])
|
| 314 |
+
# mean = stacked_tensors.mean(axis = [0])
|
| 315 |
+
# print(mean)
|
| 316 |
+
# print(iclass)
|
| 317 |
+
# exit()
|
| 318 |
+
|
| 319 |
+
output_dir = "/data/inet_latents"
|
| 320 |
+
save_filename = os.path.join(output_dir, f'latents_shard{saved_files:03d}.safetensors')
|
| 321 |
+
save_file(
|
| 322 |
+
save_dict,
|
| 323 |
+
save_filename,
|
| 324 |
+
metadata={'total_size': f'{latents_torch.shape[0]}', 'dtype': f'{latents_torch.dtype}', 'device': f'{latents_torch.device}'}
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
latents = []
|
| 328 |
+
latents_flip = []
|
| 329 |
+
labels = []
|
| 330 |
+
saved_files += 1
|
| 331 |
+
#Let's just run the kl2 first and not save the extra
|
| 332 |
+
|
| 333 |
+
# print(latent.shape)
|
| 334 |
+
# print(result_dict)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
if __name__ == '__main__':
|
| 338 |
+
app.run(main)
|
f16c16/eval_fid.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
import flax.linen as nn
|
| 8 |
+
import jax.numpy as jnp
|
| 9 |
+
from absl import app, flags
|
| 10 |
+
from functools import partial
|
| 11 |
+
import numpy as np
|
| 12 |
+
import tqdm
|
| 13 |
+
import jax
|
| 14 |
+
import jax.numpy as jnp
|
| 15 |
+
import flax
|
| 16 |
+
import optax
|
| 17 |
+
import wandb
|
| 18 |
+
from ml_collections import config_flags
|
| 19 |
+
#import elements
|
| 20 |
+
import ml_collections
|
| 21 |
+
import tensorflow_datasets as tfds
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
tf.config.set_visible_devices([], "GPU")
|
| 24 |
+
tf.config.set_visible_devices([], "TPU")
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
from utils.train_state import TrainState, target_update
|
| 29 |
+
from utils.checkpoint import Checkpoint
|
| 30 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 31 |
+
|
| 32 |
+
from train import VQGANModel
|
| 33 |
+
from models.vqvae import VQVAE
|
| 34 |
+
from models.discriminator import Discriminator
|
| 35 |
+
|
| 36 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 37 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 38 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 39 |
+
|
| 40 |
+
FLAGS = flags.FLAGS
|
| 41 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 42 |
+
flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp", 'Load dir (if not None, load params from here).')
|
| 43 |
+
flags.DEFINE_integer('batch_size', 128, 'Total Batch size.')
|
| 44 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 45 |
+
|
| 46 |
+
def main(_):
|
| 47 |
+
device_count = len(jax.local_devices())
|
| 48 |
+
global_device_count = jax.device_count()
|
| 49 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 50 |
+
|
| 51 |
+
def get_dataset(is_train):
|
| 52 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 53 |
+
def deserialization_fn(data):
|
| 54 |
+
image = data['image']
|
| 55 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 56 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 57 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 58 |
+
image = tf.image.resize(image, (256, 256))
|
| 59 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 60 |
+
image = tf.image.resize(image, (128, 128))
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 63 |
+
if is_train:
|
| 64 |
+
image = tf.image.random_flip_left_right(image)
|
| 65 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 66 |
+
return image
|
| 67 |
+
|
| 68 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 69 |
+
dataset = tfds.load('imagenet2012', data_dir="/dev/shm", split=split)
|
| 70 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 71 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 72 |
+
dataset = dataset.batch(local_batch_size)
|
| 73 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 74 |
+
dataset = tfds.as_numpy(dataset)
|
| 75 |
+
dataset = iter(dataset)
|
| 76 |
+
return dataset
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 79 |
+
|
| 80 |
+
dataset = get_dataset(is_train=False)
|
| 81 |
+
dataset_valid = get_dataset(is_train=False)
|
| 82 |
+
example_obs = next(dataset)[:1]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
get_fid_activations = get_fid_network()
|
| 86 |
+
truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
|
| 87 |
+
# truth_fid_stats = np.load('base_stats.npz')
|
| 88 |
+
|
| 89 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 90 |
+
rng, param_key = jax.random.split(rng)
|
| 91 |
+
print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
|
| 92 |
+
|
| 93 |
+
###################################
|
| 94 |
+
# Creating Model and put on devices.
|
| 95 |
+
###################################
|
| 96 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 97 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 98 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 99 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 100 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 101 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
|
| 102 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 103 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 104 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 105 |
+
|
| 106 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 107 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 108 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 109 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
|
| 110 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 111 |
+
|
| 112 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 113 |
+
|
| 114 |
+
assert FLAGS.load_dir is not None
|
| 115 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 116 |
+
model = cp.load_model(model)
|
| 117 |
+
print("Loaded model with step", model.vqvae.step)
|
| 118 |
+
|
| 119 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 120 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 121 |
+
#print(model.vqvae)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
###################################
|
| 125 |
+
# FID Evaluation.
|
| 126 |
+
###################################
|
| 127 |
+
|
| 128 |
+
activations = []
|
| 129 |
+
activations_base = []
|
| 130 |
+
|
| 131 |
+
images = []
|
| 132 |
+
images_original = []
|
| 133 |
+
for valid_images in dataset_valid:
|
| 134 |
+
|
| 135 |
+
images_original.append((valid_images*255).astype(np.uint8))
|
| 136 |
+
if valid_images.shape[0] < local_batch_size:
|
| 137 |
+
zeros_added = local_batch_size - valid_images.shape[0]
|
| 138 |
+
valid_images = np.concatenate([valid_images, np.zeros((local_batch_size - valid_images.shape[0], *valid_images.shape[1:]))], axis=0)
|
| 139 |
+
else:
|
| 140 |
+
zeros_added = 0
|
| 141 |
+
|
| 142 |
+
print(len(jax.local_devices()))
|
| 143 |
+
print(valid_images.shape)
|
| 144 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 145 |
+
print(valid_images.shape)
|
| 146 |
+
valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
|
| 147 |
+
print(valid_reconstructed_images.shape)
|
| 148 |
+
|
| 149 |
+
#Whatever...
|
| 150 |
+
fig, axs = plt.subplots(2, 8, figsize=(30, 15))
|
| 151 |
+
|
| 152 |
+
for j in range(1):#fuck it
|
| 153 |
+
continue#Turn this off for now
|
| 154 |
+
axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
|
| 155 |
+
axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
|
| 156 |
+
#wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
#We are not sure if we are 0-1 or if we are -1 to 1
|
| 160 |
+
#Let's try both
|
| 161 |
+
add_images = valid_reconstructed_images.reshape(-1,256,256,3)
|
| 162 |
+
if zeros_added > 0:
|
| 163 |
+
add_images = add_images[:-zeros_added, :, :, :]
|
| 164 |
+
images.append((add_images*255).astype(np.uint8))
|
| 165 |
+
|
| 166 |
+
#valid = (valid_reconstructed_images + 1 ) * 127.5
|
| 167 |
+
#images2.append(valid.clamp(0,255).astype(npuint8))
|
| 168 |
+
|
| 169 |
+
valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
|
| 170 |
+
method='bilinear', antialias=True)
|
| 171 |
+
valid_reconstructed_images = 2 * valid_reconstructed_images - 1
|
| 172 |
+
acts = np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]
|
| 173 |
+
if zeros_added > 0:
|
| 174 |
+
acts = acts[:-zeros_added]
|
| 175 |
+
activations.append(acts)
|
| 176 |
+
|
| 177 |
+
#Used to grab baseline truths
|
| 178 |
+
if False:
|
| 179 |
+
valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
|
| 180 |
+
method='bilinear', antialias=True)
|
| 181 |
+
valid_reconstructed_images = 2 * valid_reconstructed_images - 1
|
| 182 |
+
acts = np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]
|
| 183 |
+
|
| 184 |
+
if zeros_added > 0:
|
| 185 |
+
acts = acts[:-zeros_added]
|
| 186 |
+
activations_base.append(acts)
|
| 187 |
+
#This is fine because it's just length
|
| 188 |
+
print(len(activations) * FLAGS.batch_size)
|
| 189 |
+
|
| 190 |
+
images = np.concatenate(images, axis = 0)
|
| 191 |
+
#images_original = np.concatenate(images_original, axis = 0)
|
| 192 |
+
print(images.shape)#1564x32x256x256x3 #Old shape
|
| 193 |
+
#print(images_original.shape)
|
| 194 |
+
#new shape should just be 50k
|
| 195 |
+
#Reshape
|
| 196 |
+
images = images.reshape(-1, 256, 256, 3)
|
| 197 |
+
#images2 = images_original.reshape(-1,256,256,3)
|
| 198 |
+
|
| 199 |
+
activations = np.concatenate(activations, axis=0)
|
| 200 |
+
activations = activations.reshape((-1, activations.shape[-1]))
|
| 201 |
+
mu1 = np.mean(activations, axis=0)
|
| 202 |
+
sigma1 = np.cov(activations, rowvar=False)
|
| 203 |
+
#print(mu1)
|
| 204 |
+
#print(sigma1)
|
| 205 |
+
fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
|
| 206 |
+
|
| 207 |
+
print("FID:", fid)
|
| 208 |
+
|
| 209 |
+
np.savez("./images_recon.npz", arr_0 = images)
|
| 210 |
+
#np.savez("./images_original.npz", arr_0 = images2)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == '__main__':
|
| 214 |
+
app.run(main)
|
f16c16/evaluator.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import zipfile
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import partial
|
| 10 |
+
from multiprocessing import cpu_count
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from typing import Iterable, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import tensorflow.compat.v1 as tf
|
| 17 |
+
from scipy import linalg
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 21 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 22 |
+
|
| 23 |
+
FID_POOL_NAME = "pool_3:0"
|
| 24 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument("ref_batch", help="path to reference batch npz file")
|
| 30 |
+
parser.add_argument("sample_batch", help="path to sample batch npz file")
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
config = tf.ConfigProto(
|
| 34 |
+
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
| 35 |
+
)
|
| 36 |
+
config.gpu_options.allow_growth = True
|
| 37 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 38 |
+
|
| 39 |
+
print("warming up TensorFlow...")
|
| 40 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 41 |
+
# than after the next print(), to help prevent confusion.
|
| 42 |
+
evaluator.warmup()
|
| 43 |
+
|
| 44 |
+
print("computing reference batch activations...")
|
| 45 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 46 |
+
print("computing/reading reference batch statistics...")
|
| 47 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 48 |
+
|
| 49 |
+
print("computing sample batch activations...")
|
| 50 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 51 |
+
print("computing/reading sample batch statistics...")
|
| 52 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 53 |
+
|
| 54 |
+
print("Computing evaluations...")
|
| 55 |
+
print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
|
| 56 |
+
print("FID:", sample_stats.frechet_distance(ref_stats))
|
| 57 |
+
print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
|
| 58 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 59 |
+
print("Precision:", prec)
|
| 60 |
+
print("Recall:", recall)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class InvalidFIDException(Exception):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class FIDStatistics:
|
| 68 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 69 |
+
self.mu = mu
|
| 70 |
+
self.sigma = sigma
|
| 71 |
+
|
| 72 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 73 |
+
"""
|
| 74 |
+
Compute the Frechet distance between two sets of statistics.
|
| 75 |
+
"""
|
| 76 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 77 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 78 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 79 |
+
|
| 80 |
+
mu1 = np.atleast_1d(mu1)
|
| 81 |
+
mu2 = np.atleast_1d(mu2)
|
| 82 |
+
|
| 83 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 84 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 85 |
+
|
| 86 |
+
assert (
|
| 87 |
+
mu1.shape == mu2.shape
|
| 88 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 89 |
+
assert (
|
| 90 |
+
sigma1.shape == sigma2.shape
|
| 91 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 92 |
+
|
| 93 |
+
diff = mu1 - mu2
|
| 94 |
+
|
| 95 |
+
# product might be almost singular
|
| 96 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 97 |
+
if not np.isfinite(covmean).all():
|
| 98 |
+
msg = (
|
| 99 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 100 |
+
% eps
|
| 101 |
+
)
|
| 102 |
+
warnings.warn(msg)
|
| 103 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 104 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 105 |
+
|
| 106 |
+
# numerical error might give slight imaginary component
|
| 107 |
+
if np.iscomplexobj(covmean):
|
| 108 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 109 |
+
m = np.max(np.abs(covmean.imag))
|
| 110 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 111 |
+
covmean = covmean.real
|
| 112 |
+
|
| 113 |
+
tr_covmean = np.trace(covmean)
|
| 114 |
+
|
| 115 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Evaluator:
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
session,
|
| 122 |
+
batch_size=64,
|
| 123 |
+
softmax_batch_size=512,
|
| 124 |
+
):
|
| 125 |
+
self.sess = session
|
| 126 |
+
self.batch_size = batch_size
|
| 127 |
+
self.softmax_batch_size = softmax_batch_size
|
| 128 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 129 |
+
with self.sess.graph.as_default():
|
| 130 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 131 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 132 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 133 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 134 |
+
|
| 135 |
+
def warmup(self):
|
| 136 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 137 |
+
|
| 138 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 139 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 140 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 141 |
+
|
| 142 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 143 |
+
"""
|
| 144 |
+
Compute image features for downstream evals.
|
| 145 |
+
|
| 146 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 147 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 148 |
+
dimension. The tuple is (pool_3, spatial).
|
| 149 |
+
"""
|
| 150 |
+
preds = []
|
| 151 |
+
spatial_preds = []
|
| 152 |
+
for batch in tqdm(batches):
|
| 153 |
+
batch = batch.astype(np.float32)
|
| 154 |
+
pred, spatial_pred = self.sess.run(
|
| 155 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 156 |
+
)
|
| 157 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 158 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 159 |
+
return (
|
| 160 |
+
np.concatenate(preds, axis=0),
|
| 161 |
+
np.concatenate(spatial_preds, axis=0),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def read_statistics(
|
| 165 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 166 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 167 |
+
obj = np.load(npz_path)
|
| 168 |
+
if "mu" in list(obj.keys()):
|
| 169 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 170 |
+
obj["mu_s"], obj["sigma_s"]
|
| 171 |
+
)
|
| 172 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 173 |
+
|
| 174 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 175 |
+
mu = np.mean(activations, axis=0)
|
| 176 |
+
sigma = np.cov(activations, rowvar=False)
|
| 177 |
+
return FIDStatistics(mu, sigma)
|
| 178 |
+
|
| 179 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 180 |
+
softmax_out = []
|
| 181 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 182 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 183 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 184 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 185 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 186 |
+
scores = []
|
| 187 |
+
for i in range(0, len(preds), split_size):
|
| 188 |
+
part = preds[i : i + split_size]
|
| 189 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 190 |
+
kl = np.mean(np.sum(kl, 1))
|
| 191 |
+
scores.append(np.exp(kl))
|
| 192 |
+
return float(np.mean(scores))
|
| 193 |
+
|
| 194 |
+
def compute_prec_recall(
|
| 195 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 196 |
+
) -> Tuple[float, float]:
|
| 197 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 198 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 199 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 200 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 201 |
+
)
|
| 202 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class ManifoldEstimator:
|
| 206 |
+
"""
|
| 207 |
+
A helper for comparing manifolds of feature vectors.
|
| 208 |
+
|
| 209 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
session,
|
| 215 |
+
row_batch_size=10000,
|
| 216 |
+
col_batch_size=10000,
|
| 217 |
+
nhood_sizes=(3,),
|
| 218 |
+
clamp_to_percentile=None,
|
| 219 |
+
eps=1e-5,
|
| 220 |
+
):
|
| 221 |
+
"""
|
| 222 |
+
Estimate the manifold of given feature vectors.
|
| 223 |
+
|
| 224 |
+
:param session: the TensorFlow session.
|
| 225 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 226 |
+
(parameter to trade-off between memory usage and performance).
|
| 227 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 228 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 229 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 230 |
+
the given percentile.
|
| 231 |
+
:param eps: small number for numerical stability.
|
| 232 |
+
"""
|
| 233 |
+
self.distance_block = DistanceBlock(session)
|
| 234 |
+
self.row_batch_size = row_batch_size
|
| 235 |
+
self.col_batch_size = col_batch_size
|
| 236 |
+
self.nhood_sizes = nhood_sizes
|
| 237 |
+
self.num_nhoods = len(nhood_sizes)
|
| 238 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 239 |
+
self.eps = eps
|
| 240 |
+
|
| 241 |
+
def warmup(self):
|
| 242 |
+
feats, radii = (
|
| 243 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 244 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 245 |
+
)
|
| 246 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 247 |
+
|
| 248 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 249 |
+
num_images = len(features)
|
| 250 |
+
|
| 251 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 252 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 253 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 254 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 255 |
+
|
| 256 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 257 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 258 |
+
row_batch = features[begin1:end1]
|
| 259 |
+
|
| 260 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 261 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 262 |
+
col_batch = features[begin2:end2]
|
| 263 |
+
|
| 264 |
+
# Compute distances between batches.
|
| 265 |
+
distance_batch[
|
| 266 |
+
0 : end1 - begin1, begin2:end2
|
| 267 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 268 |
+
|
| 269 |
+
# Find the k-nearest neighbor from the current batch.
|
| 270 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 271 |
+
[
|
| 272 |
+
x[:, self.nhood_sizes]
|
| 273 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 274 |
+
],
|
| 275 |
+
axis=0,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if self.clamp_to_percentile is not None:
|
| 279 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 280 |
+
radii[radii > max_distances] = 0
|
| 281 |
+
return radii
|
| 282 |
+
|
| 283 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 284 |
+
"""
|
| 285 |
+
Evaluate if new feature vectors are at the manifold.
|
| 286 |
+
"""
|
| 287 |
+
num_eval_images = eval_features.shape[0]
|
| 288 |
+
num_ref_images = radii.shape[0]
|
| 289 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 290 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 291 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 292 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 293 |
+
|
| 294 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 295 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 296 |
+
feature_batch = eval_features[begin1:end1]
|
| 297 |
+
|
| 298 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 299 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 300 |
+
ref_batch = features[begin2:end2]
|
| 301 |
+
|
| 302 |
+
distance_batch[
|
| 303 |
+
0 : end1 - begin1, begin2:end2
|
| 304 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 305 |
+
|
| 306 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 307 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 308 |
+
# the new sample lies at the estimated manifold.
|
| 309 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 310 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 311 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 312 |
+
|
| 313 |
+
max_realism_score[begin1:end1] = np.max(
|
| 314 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 315 |
+
)
|
| 316 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 317 |
+
|
| 318 |
+
return {
|
| 319 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 320 |
+
"batch_predictions": batch_predictions,
|
| 321 |
+
"max_realisim_score": max_realism_score,
|
| 322 |
+
"nearest_indices": nearest_indices,
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def evaluate_pr(
|
| 326 |
+
self,
|
| 327 |
+
features_1: np.ndarray,
|
| 328 |
+
radii_1: np.ndarray,
|
| 329 |
+
features_2: np.ndarray,
|
| 330 |
+
radii_2: np.ndarray,
|
| 331 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 332 |
+
"""
|
| 333 |
+
Evaluate precision and recall efficiently.
|
| 334 |
+
|
| 335 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 336 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 337 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 338 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 339 |
+
:return: a tuple of arrays for (precision, recall):
|
| 340 |
+
- precision: an np.ndarray of length K1
|
| 341 |
+
- recall: an np.ndarray of length K2
|
| 342 |
+
"""
|
| 343 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
|
| 344 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
|
| 345 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 346 |
+
end_1 = begin_1 + self.row_batch_size
|
| 347 |
+
batch_1 = features_1[begin_1:end_1]
|
| 348 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 349 |
+
end_2 = begin_2 + self.col_batch_size
|
| 350 |
+
batch_2 = features_2[begin_2:end_2]
|
| 351 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 352 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 353 |
+
)
|
| 354 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 355 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 356 |
+
return (
|
| 357 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 358 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class DistanceBlock:
|
| 363 |
+
"""
|
| 364 |
+
Calculate pairwise distances between vectors.
|
| 365 |
+
|
| 366 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
def __init__(self, session):
|
| 370 |
+
self.session = session
|
| 371 |
+
|
| 372 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 373 |
+
with session.graph.as_default():
|
| 374 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 375 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 376 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 377 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 378 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 379 |
+
)
|
| 380 |
+
self.distance_block = tf.cond(
|
| 381 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 382 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 383 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Extra logic for less thans.
|
| 387 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 388 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 389 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 390 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 391 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 392 |
+
|
| 393 |
+
def pairwise_distances(self, U, V):
|
| 394 |
+
"""
|
| 395 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 396 |
+
"""
|
| 397 |
+
return self.session.run(
|
| 398 |
+
self.distance_block,
|
| 399 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 403 |
+
return self.session.run(
|
| 404 |
+
[self._batch_1_in, self._batch_2_in],
|
| 405 |
+
feed_dict={
|
| 406 |
+
self._features_batch1: batch_1,
|
| 407 |
+
self._features_batch2: batch_2,
|
| 408 |
+
self._radii1: radii_1,
|
| 409 |
+
self._radii2: radii_2,
|
| 410 |
+
},
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _batch_pairwise_distances(U, V):
|
| 415 |
+
"""
|
| 416 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 417 |
+
"""
|
| 418 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 419 |
+
# Squared norms of each row in U and V.
|
| 420 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 421 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 422 |
+
|
| 423 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 424 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 425 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 426 |
+
|
| 427 |
+
# Pairwise squared Euclidean distances.
|
| 428 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 429 |
+
|
| 430 |
+
return D
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class NpzArrayReader(ABC):
|
| 434 |
+
@abstractmethod
|
| 435 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 436 |
+
pass
|
| 437 |
+
|
| 438 |
+
@abstractmethod
|
| 439 |
+
def remaining(self) -> int:
|
| 440 |
+
pass
|
| 441 |
+
|
| 442 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 443 |
+
def gen_fn():
|
| 444 |
+
while True:
|
| 445 |
+
batch = self.read_batch(batch_size)
|
| 446 |
+
if batch is None:
|
| 447 |
+
break
|
| 448 |
+
yield batch
|
| 449 |
+
|
| 450 |
+
rem = self.remaining()
|
| 451 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 452 |
+
return BatchIterator(gen_fn, num_batches)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class BatchIterator:
|
| 456 |
+
def __init__(self, gen_fn, length):
|
| 457 |
+
self.gen_fn = gen_fn
|
| 458 |
+
self.length = length
|
| 459 |
+
|
| 460 |
+
def __len__(self):
|
| 461 |
+
return self.length
|
| 462 |
+
|
| 463 |
+
def __iter__(self):
|
| 464 |
+
return self.gen_fn()
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 468 |
+
def __init__(self, arr_f, shape, dtype):
|
| 469 |
+
self.arr_f = arr_f
|
| 470 |
+
self.shape = shape
|
| 471 |
+
self.dtype = dtype
|
| 472 |
+
self.idx = 0
|
| 473 |
+
|
| 474 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 475 |
+
if self.idx >= self.shape[0]:
|
| 476 |
+
return None
|
| 477 |
+
|
| 478 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 479 |
+
self.idx += bs
|
| 480 |
+
|
| 481 |
+
if self.dtype.itemsize == 0:
|
| 482 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 483 |
+
|
| 484 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 485 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 486 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 487 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 488 |
+
|
| 489 |
+
def remaining(self) -> int:
|
| 490 |
+
return max(0, self.shape[0] - self.idx)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 494 |
+
def __init__(self, arr):
|
| 495 |
+
self.arr = arr
|
| 496 |
+
self.idx = 0
|
| 497 |
+
|
| 498 |
+
@classmethod
|
| 499 |
+
def load(cls, path: str, arr_name: str):
|
| 500 |
+
with open(path, "rb") as f:
|
| 501 |
+
arr = np.load(f)[arr_name]
|
| 502 |
+
return cls(arr)
|
| 503 |
+
|
| 504 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 505 |
+
if self.idx >= self.arr.shape[0]:
|
| 506 |
+
return None
|
| 507 |
+
|
| 508 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 509 |
+
self.idx += batch_size
|
| 510 |
+
return res
|
| 511 |
+
|
| 512 |
+
def remaining(self) -> int:
|
| 513 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@contextmanager
|
| 517 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 518 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 519 |
+
version = np.lib.format.read_magic(arr_f)
|
| 520 |
+
if version == (1, 0):
|
| 521 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 522 |
+
elif version == (2, 0):
|
| 523 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 524 |
+
else:
|
| 525 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 526 |
+
return
|
| 527 |
+
shape, fortran, dtype = header
|
| 528 |
+
if fortran or dtype.hasobject:
|
| 529 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 530 |
+
else:
|
| 531 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 535 |
+
"""
|
| 536 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 537 |
+
|
| 538 |
+
Read from file-like object until size bytes are read.
|
| 539 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 540 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 541 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 542 |
+
requested.
|
| 543 |
+
"""
|
| 544 |
+
data = bytes()
|
| 545 |
+
while True:
|
| 546 |
+
# io files (default in python3) return None or raise on
|
| 547 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 548 |
+
# done about that. note that regular files can't be non-blocking
|
| 549 |
+
try:
|
| 550 |
+
r = fp.read(size - len(data))
|
| 551 |
+
data += r
|
| 552 |
+
if len(r) == 0 or len(data) == size:
|
| 553 |
+
break
|
| 554 |
+
except io.BlockingIOError:
|
| 555 |
+
pass
|
| 556 |
+
if len(data) != size:
|
| 557 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 558 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 559 |
+
else:
|
| 560 |
+
return data
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
@contextmanager
|
| 564 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 565 |
+
with open(path, "rb") as f:
|
| 566 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 567 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 568 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 569 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 570 |
+
yield arr_f
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def _download_inception_model():
|
| 574 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 575 |
+
return
|
| 576 |
+
print("downloading InceptionV3 model...")
|
| 577 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 578 |
+
r.raise_for_status()
|
| 579 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 580 |
+
with open(tmp_path, "wb") as f:
|
| 581 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 582 |
+
f.write(chunk)
|
| 583 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def _create_feature_graph(input_batch):
|
| 587 |
+
_download_inception_model()
|
| 588 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 589 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 590 |
+
graph_def = tf.GraphDef()
|
| 591 |
+
graph_def.ParseFromString(f.read())
|
| 592 |
+
pool3, spatial = tf.import_graph_def(
|
| 593 |
+
graph_def,
|
| 594 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 595 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 596 |
+
name=prefix,
|
| 597 |
+
)
|
| 598 |
+
_update_shapes(pool3)
|
| 599 |
+
spatial = spatial[..., :7]
|
| 600 |
+
return pool3, spatial
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def _create_softmax_graph(input_batch):
|
| 604 |
+
_download_inception_model()
|
| 605 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 606 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 607 |
+
graph_def = tf.GraphDef()
|
| 608 |
+
graph_def.ParseFromString(f.read())
|
| 609 |
+
(matmul,) = tf.import_graph_def(
|
| 610 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 611 |
+
)
|
| 612 |
+
w = matmul.inputs[1]
|
| 613 |
+
logits = tf.matmul(input_batch, w)
|
| 614 |
+
return tf.nn.softmax(logits)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _update_shapes(pool3):
|
| 618 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 619 |
+
ops = pool3.graph.get_operations()
|
| 620 |
+
for op in ops:
|
| 621 |
+
for o in op.outputs:
|
| 622 |
+
shape = o.get_shape()
|
| 623 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 624 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 625 |
+
shape = [s for s in shape] # TF 2.x
|
| 626 |
+
new_shape = []
|
| 627 |
+
for j, s in enumerate(shape):
|
| 628 |
+
if s == 1 and j == 0:
|
| 629 |
+
new_shape.append(None)
|
| 630 |
+
else:
|
| 631 |
+
new_shape.append(s)
|
| 632 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 633 |
+
return pool3
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 637 |
+
num_workers = min(cpu_count(), len(arr))
|
| 638 |
+
chunk_size = len(arr) // num_workers
|
| 639 |
+
extra = len(arr) % num_workers
|
| 640 |
+
|
| 641 |
+
start_idx = 0
|
| 642 |
+
batches = []
|
| 643 |
+
for i in range(num_workers):
|
| 644 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 645 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 646 |
+
start_idx += size
|
| 647 |
+
|
| 648 |
+
with ThreadPool(num_workers) as pool:
|
| 649 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
if __name__ == "__main__":
|
| 653 |
+
main()
|
| 654 |
+
|
f16c16/graph-data.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
noises = []
|
| 5 |
+
|
| 6 |
+
numbers = np.arange(0.00, 1.0, 0.01)
|
| 7 |
+
|
| 8 |
+
for number in numbers:
|
| 9 |
+
noises.append(float(number))
|
| 10 |
+
|
| 11 |
+
# numbers = np.arange(.4, 3, .5)
|
| 12 |
+
# for number in numbers:
|
| 13 |
+
# noises.append(float(number))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
mean_l2 = []
|
| 17 |
+
mean_lpips = []
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
with open("./1e-4.txt", "r") as f:
|
| 21 |
+
print("read")
|
| 22 |
+
for line in f.readlines():
|
| 23 |
+
print(line)
|
| 24 |
+
if "Mean L2" in line:
|
| 25 |
+
mean_l2.append(float(line.split(":")[1].strip()))
|
| 26 |
+
elif "Mean Lpips" in line:
|
| 27 |
+
mean_lpips.append(float(line.split(":")[1].strip()))
|
| 28 |
+
|
| 29 |
+
mean_l2_2 = []
|
| 30 |
+
mean_lpips_2 = []
|
| 31 |
+
with open("./1e-5.txt", "r") as f:
|
| 32 |
+
print("read")
|
| 33 |
+
for line in f.readlines():
|
| 34 |
+
print(line)
|
| 35 |
+
if "Mean L2" in line:
|
| 36 |
+
mean_l2_2.append(float(line.split(":")[1].strip()))
|
| 37 |
+
elif "Mean Lpips" in line:
|
| 38 |
+
mean_lpips_2.append(float(line.split(":")[1].strip()))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
mean_l2_3 = []
|
| 42 |
+
mean_lpips_3 = []
|
| 43 |
+
with open("./2e-5.txt", "r") as f:
|
| 44 |
+
print("read")
|
| 45 |
+
for line in f.readlines():
|
| 46 |
+
print(line)
|
| 47 |
+
if "Mean L2" in line:
|
| 48 |
+
mean_l2_3.append(float(line.split(":")[1].strip()))
|
| 49 |
+
elif "Mean Lpips" in line:
|
| 50 |
+
mean_lpips_3.append(float(line.split(":")[1].strip()))
|
| 51 |
+
|
| 52 |
+
mean_l2_4 = []
|
| 53 |
+
mean_lpips_4 = []
|
| 54 |
+
with open("./1e-6.txt", "r") as f:
|
| 55 |
+
print("read")
|
| 56 |
+
for line in f.readlines():
|
| 57 |
+
print(line)
|
| 58 |
+
if "Mean L2" in line:
|
| 59 |
+
mean_l2_4.append(float(line.split(":")[1].strip()))
|
| 60 |
+
elif "Mean Lpips" in line:
|
| 61 |
+
mean_lpips_4.append(float(line.split(":")[1].strip()))
|
| 62 |
+
|
| 63 |
+
mean_l2_5 = []
|
| 64 |
+
mean_lpips_5 = []
|
| 65 |
+
with open("./pl600.txt", "r") as f:
|
| 66 |
+
print("read")
|
| 67 |
+
for line in f.readlines():
|
| 68 |
+
print(line)
|
| 69 |
+
if "Mean L2" in line:
|
| 70 |
+
mean_l2_5.append(float(line.split(":")[1].strip()))
|
| 71 |
+
elif "Mean Lpips" in line:
|
| 72 |
+
mean_lpips_5.append(float(line.split(":")[1].strip()))
|
| 73 |
+
|
| 74 |
+
mean_l2_6 = []
|
| 75 |
+
mean_lpips_6 = []
|
| 76 |
+
with open("./100pl.txt", "r") as f:
|
| 77 |
+
print("read")
|
| 78 |
+
for line in f.readlines():
|
| 79 |
+
print(line)
|
| 80 |
+
if "Mean L2" in line:
|
| 81 |
+
mean_l2_6.append(float(line.split(":")[1].strip()))
|
| 82 |
+
elif "Mean Lpips" in line:
|
| 83 |
+
mean_lpips_6.append(float(line.split(":")[1].strip()))
|
| 84 |
+
|
| 85 |
+
mean_l2_7 = []
|
| 86 |
+
mean_lpips_7 = []
|
| 87 |
+
with open("./300pl.txt", "r") as f:
|
| 88 |
+
print("read")
|
| 89 |
+
for line in f.readlines():
|
| 90 |
+
print(line)
|
| 91 |
+
if "Mean L2" in line:
|
| 92 |
+
mean_l2_7.append(float(line.split(":")[1].strip()))
|
| 93 |
+
elif "Mean Lpips" in line:
|
| 94 |
+
mean_lpips_7.append(float(line.split(":")[1].strip()))
|
| 95 |
+
|
| 96 |
+
mean_l2_8 = []
|
| 97 |
+
mean_lpips_8 = []
|
| 98 |
+
with open("./1e-6_asym.txt", "r") as f:
|
| 99 |
+
print("read")
|
| 100 |
+
for line in f.readlines():
|
| 101 |
+
print(line)
|
| 102 |
+
if "Mean L2" in line:
|
| 103 |
+
mean_l2_8.append(float(line.split(":")[1].strip()))
|
| 104 |
+
elif "Mean Lpips" in line:
|
| 105 |
+
mean_lpips_8.append(float(line.split(":")[1].strip()))
|
| 106 |
+
|
| 107 |
+
# mean_l2_6 = []
|
| 108 |
+
# mean_lpips_6 = []
|
| 109 |
+
# with open("./100pl.txt", "r") as f:
|
| 110 |
+
# print("read")
|
| 111 |
+
# for line in f.readlines():
|
| 112 |
+
# print(line)
|
| 113 |
+
# if "Mean L2" in line:
|
| 114 |
+
# mean_l2_6.append(float(line.split(":")[1].strip()))
|
| 115 |
+
# elif "Mean Lpips" in line:
|
| 116 |
+
# mean_lpips_6.append(float(line.split(":")[1].strip()))
|
| 117 |
+
|
| 118 |
+
plt.figure(figsize=(10, 6))
|
| 119 |
+
|
| 120 |
+
# Plot Mean L2
|
| 121 |
+
# plt.plot(noises, mean_l2, label='Mean L2 1e-4', marker='o', linestyle='-', color='b')
|
| 122 |
+
#
|
| 123 |
+
# plt.plot(noises, mean_l2_3, label='Mean L2 2e-5', marker='o', linestyle='-', color='g')
|
| 124 |
+
#
|
| 125 |
+
# plt.plot(noises, mean_l2_2, label='Mean L2 1e-5', marker='o', linestyle='-', color='r')
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
do = 100
|
| 129 |
+
mean_lpips = mean_lpips[0:do]
|
| 130 |
+
mean_lpips_2 = mean_lpips_2[0:do]
|
| 131 |
+
mean_lpips_3 = mean_lpips_3[0:do]
|
| 132 |
+
mean_lpips_4 = mean_lpips_4[0:do]
|
| 133 |
+
mean_lpips_5 = mean_lpips_5[0:do]
|
| 134 |
+
mean_lpips_6 = mean_lpips_6[0:do]
|
| 135 |
+
mean_lpips_7 = mean_lpips_7[0:do]
|
| 136 |
+
mean_lpips_8 = mean_lpips_8[0:do]
|
| 137 |
+
noises = noises[0:do]
|
| 138 |
+
|
| 139 |
+
# Plot Mean Lpips
|
| 140 |
+
plt.plot(noises, mean_lpips, label='Mean Lpips 1e-4', marker='s', linestyle='--', color='r')
|
| 141 |
+
plt.plot(noises, mean_lpips_3, label='Mean Lpips 2e-5', marker='s', linestyle='--', color='b')
|
| 142 |
+
plt.plot(noises, mean_lpips_2, label='Mean Lpips 1e-5', marker='s', linestyle='--', color='g')
|
| 143 |
+
plt.plot(noises, mean_lpips_4, label='Mean Lpips 1e-6', marker='s', linestyle='--', color='y')
|
| 144 |
+
plt.plot(noises, mean_lpips_8, label='Mean Lpips 1e-6asym', marker='s', linestyle='--')
|
| 145 |
+
# plt.plot(noises, mean_lpips_5, label='Mean Lpips PL600', marker='s', linestyle='--')
|
| 146 |
+
# plt.plot(noises, mean_lpips_6, label='Mean Lpips Pl100', marker='s', linestyle='--')
|
| 147 |
+
plt.plot(noises, mean_lpips_7, label='Mean Lpips Pl300', marker='s', linestyle='--')
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Labels and title
|
| 151 |
+
plt.xlabel('Noise Level')
|
| 152 |
+
plt.ylabel('Value')
|
| 153 |
+
plt.title('Mean L2 and Mean Lpips vs. Noise Level')
|
| 154 |
+
|
| 155 |
+
# Show grid
|
| 156 |
+
plt.grid(True)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ax = plt.gca()
|
| 160 |
+
# ax.set_xlim([0,.6])
|
| 161 |
+
# ax.set_ylim([0,.6])
|
| 162 |
+
# ax.set_aspect('equal', adjustable='box')
|
| 163 |
+
|
| 164 |
+
# Add legend
|
| 165 |
+
plt.legend()
|
| 166 |
+
|
| 167 |
+
# Show the plot
|
| 168 |
+
plt.show()
|
| 169 |
+
|
f16c16/kl_test.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
key = jax.random.PRNGKey(0)
|
| 6 |
+
x = jax.random.normal(key, (2,32,32,4))
|
| 7 |
+
print(x.mean())
|
| 8 |
+
means = jnp.mean(x, axis = [1,2,3])
|
| 9 |
+
#So this gives us the means of each individual one, cool
|
| 10 |
+
print(means)
|
| 11 |
+
|
| 12 |
+
logvars = 0.0
|
| 13 |
+
|
| 14 |
+
print("square of means shit", jnp.square(means))
|
| 15 |
+
print(means)
|
| 16 |
+
|
| 17 |
+
kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(means) - jnp.exp(logvars),axis=tuple(range(1, means.ndim)))
|
| 18 |
+
print(kl_loss)
|
| 19 |
+
kl_loss = jnp.mean(kl_loss)
|
| 20 |
+
|
| 21 |
+
print(kl_loss)
|
| 22 |
+
|
| 23 |
+
print("x mean again", x.mean())
|
| 24 |
+
print(x)
|
| 25 |
+
print(jnp.square(x))
|
| 26 |
+
|
| 27 |
+
kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(x) - jnp.exp(logvars),axis=tuple(range(1, x.ndim)))
|
| 28 |
+
print(kl_loss)
|
| 29 |
+
kl_loss = jnp.mean(kl_loss)
|
| 30 |
+
|
| 31 |
+
print(kl_loss)
|
f16c16/latent_distances.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 12 |
+
import jax
|
| 13 |
+
import lpips
|
| 14 |
+
|
| 15 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 16 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import flax.linen as nn
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from absl import app, flags
|
| 23 |
+
from functools import partial
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tqdm
|
| 26 |
+
import flax
|
| 27 |
+
import optax
|
| 28 |
+
import wandb
|
| 29 |
+
from ml_collections import config_flags
|
| 30 |
+
#import elements
|
| 31 |
+
import ml_collections
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
tf.config.set_visible_devices([], "GPU")
|
| 35 |
+
tf.config.set_visible_devices([], "TPU")
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
from utils.train_state import TrainState, target_update
|
| 40 |
+
from utils.checkpoint import Checkpoint
|
| 41 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 42 |
+
|
| 43 |
+
from train import VQGANModel
|
| 44 |
+
from models.vqvae import VQVAE
|
| 45 |
+
from models.discriminator import Discriminator
|
| 46 |
+
|
| 47 |
+
from PIL import Image
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 51 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 52 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 53 |
+
|
| 54 |
+
FLAGS = flags.FLAGS
|
| 55 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 56 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 60 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 61 |
+
|
| 62 |
+
import gc
|
| 63 |
+
|
| 64 |
+
def main(_):
|
| 65 |
+
device_count = len(jax.local_devices())
|
| 66 |
+
global_device_count = jax.device_count()
|
| 67 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 68 |
+
|
| 69 |
+
def get_dataset(is_train):
|
| 70 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 71 |
+
def deserialization_fn(data):
|
| 72 |
+
image = data['image']
|
| 73 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 74 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 75 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 76 |
+
image = tf.image.resize(image, (256, 256))
|
| 77 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 78 |
+
image = tf.image.resize(image, (128, 128))
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 81 |
+
if is_train:
|
| 82 |
+
image = tf.image.random_flip_left_right(image)
|
| 83 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 84 |
+
return image
|
| 85 |
+
|
| 86 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 87 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 88 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 89 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 90 |
+
dataset = dataset.batch(local_batch_size)
|
| 91 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 92 |
+
dataset = tfds.as_numpy(dataset)
|
| 93 |
+
dataset = iter(dataset)
|
| 94 |
+
return dataset
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 97 |
+
|
| 98 |
+
dataset = get_dataset(is_train=True)
|
| 99 |
+
dataset_valid = get_dataset(is_train=False)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# image = Image.open("osman.png")
|
| 103 |
+
# image = np.array(image) / 255.0
|
| 104 |
+
# print(image)
|
| 105 |
+
# image = jnp.array(image)
|
| 106 |
+
# image = jnp.expand_dims(image, 0)
|
| 107 |
+
# image = jnp.expand_dims(image, 0)
|
| 108 |
+
|
| 109 |
+
example_obs = next(dataset)[:1]
|
| 110 |
+
|
| 111 |
+
#Reconstruction loop
|
| 112 |
+
# image = model.reconstruction(image)
|
| 113 |
+
# image = image[0,0,:,:,:]
|
| 114 |
+
# image = (image * 255).astype(np.uint8)
|
| 115 |
+
# image = np.array(image)
|
| 116 |
+
# img = Image.fromarray(image)
|
| 117 |
+
# img.save("osman" + str(i) + ".png")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 121 |
+
rng, param_key = jax.random.split(rng)
|
| 122 |
+
print("Total devices", jax.local_devices()[0])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
###################################
|
| 126 |
+
# Creating Model and put on devices.
|
| 127 |
+
###################################
|
| 128 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 129 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 130 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 131 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 132 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 133 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 134 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 135 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 136 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 137 |
+
|
| 138 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 139 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 140 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 141 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 142 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 143 |
+
|
| 144 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 145 |
+
|
| 146 |
+
assert FLAGS.load_dir is not None
|
| 147 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 148 |
+
model = cp.load_model(model)
|
| 149 |
+
print("Loaded model with step", model.vqvae.step)
|
| 150 |
+
|
| 151 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 152 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 153 |
+
#print(model.vqvae)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
####################################
|
| 157 |
+
# Noise stuff
|
| 158 |
+
###################################
|
| 159 |
+
|
| 160 |
+
cpus = jax.devices("cpu")
|
| 161 |
+
|
| 162 |
+
#So there are a few ways to calculate PPL here
|
| 163 |
+
#We could take two images in image space
|
| 164 |
+
#Walk between them and check the LPIPS in the output space
|
| 165 |
+
#...actually that's basically it right?
|
| 166 |
+
#We could also do the walk in latent space, which is the same, but with ?? scaling
|
| 167 |
+
|
| 168 |
+
#Let's see if they are any different.
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
#We could also try taking a latent, going X/2 direction, and -X/2 direction, and seeing that.
|
| 172 |
+
i = 0
|
| 173 |
+
lpips_list = []
|
| 174 |
+
means = []
|
| 175 |
+
stds = []
|
| 176 |
+
for valid_images in dataset_valid:
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 180 |
+
#1, 2, 256, 256, 3
|
| 181 |
+
#Given our 2 images, we want to lerp between them...
|
| 182 |
+
#We want to lerp once to point t, and once to point t + eps
|
| 183 |
+
#And then we want to get the LPIPS between those two images
|
| 184 |
+
#And then we calculate LPIPS
|
| 185 |
+
#And then we divide by eps squared, and done.
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
reconstructed_images, decoded, std, latents = model.latent_distances(valid_images) # [devices, 8, 256, 256, 3]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
means.append(latents.mean())
|
| 192 |
+
stds.append(latents.std())
|
| 193 |
+
# print("std", std.mean())
|
| 194 |
+
print("latent mean", latents.mean())
|
| 195 |
+
print("actual latent std", latents.std())
|
| 196 |
+
|
| 197 |
+
#Need to change images back to -1,1
|
| 198 |
+
|
| 199 |
+
reconstructed_images = reconstructed_images * 2 - 1
|
| 200 |
+
decoded = decoded * 2 -1
|
| 201 |
+
|
| 202 |
+
#1,2,256,256,3
|
| 203 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
|
| 204 |
+
decoded = jnp.swapaxes(decoded, 0, 4)
|
| 205 |
+
|
| 206 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
|
| 207 |
+
decoded = jnp.swapaxes(decoded, 0, 1)
|
| 208 |
+
|
| 209 |
+
reconstructed_images = jnp.squeeze(reconstructed_images)
|
| 210 |
+
decoded = jnp.squeeze(decoded)
|
| 211 |
+
|
| 212 |
+
#So here, we want to put them on CPU and delete the original
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
image_np = np.asarray(reconstructed_images)
|
| 216 |
+
image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
|
| 217 |
+
|
| 218 |
+
decoded_np = np.asarray(decoded)
|
| 219 |
+
decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
|
| 224 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 225 |
+
lpips_cpu = lpips_cpu / (.0001 ** 2)
|
| 226 |
+
|
| 227 |
+
print(lpips_cpu)
|
| 228 |
+
lpips_list.append(lpips_cpu)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
i += 1
|
| 232 |
+
#
|
| 233 |
+
if i == 500:
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
mean_lpips = jnp.mean(jnp.asarray(lpips_list))
|
| 238 |
+
print(mean_lpips)
|
| 239 |
+
print("mean of means", jnp.asarray(means).mean())
|
| 240 |
+
print("stds of means", jnp.asarray(means).std())
|
| 241 |
+
print("mean of stds", jnp.asarray(stds).mean())
|
| 242 |
+
print("std of stds", jnp.asarray(stds).std())
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
#actual ae sym
|
| 246 |
+
# mean of means 0.35234922
|
| 247 |
+
# stds of means 0.4036692
|
| 248 |
+
# mean of stds 2.6363409
|
| 249 |
+
# std of stds 0.30666474
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
#1e-6:
|
| 253 |
+
#mean of means -0.018107202
|
| 254 |
+
# stds of means 0.11694455
|
| 255 |
+
# mean of stds 1.0860059
|
| 256 |
+
# std of stds 0.09732369
|
| 257 |
+
|
| 258 |
+
#1e-5:
|
| 259 |
+
# mean of means 0.0065166513
|
| 260 |
+
# stds of means 0.06983645
|
| 261 |
+
# mean of stds 0.9855982
|
| 262 |
+
# std of stds 0.05810356
|
| 263 |
+
|
| 264 |
+
#1e-4:
|
| 265 |
+
# mean of means 0.0065882676
|
| 266 |
+
# stds of means 0.042861093
|
| 267 |
+
# mean of stds 0.7608507
|
| 268 |
+
# std of stds 0.05846726
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
#pl300
|
| 272 |
+
# mean of means 0.090131655
|
| 273 |
+
# stds of means 0.69894844
|
| 274 |
+
# mean of stds 5.5634923
|
| 275 |
+
# std of stds 0.6767279
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
#pl100
|
| 279 |
+
# mean of means 0.16227543
|
| 280 |
+
# stds of means 0.53616405
|
| 281 |
+
# mean of stds 4.4914503
|
| 282 |
+
# std of stds 0.6015057
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
#Maybe we want to do "std multiplied PPL"? smoo
|
| 287 |
+
|
| 288 |
+
#Grab the STD of the Lpips
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == '__main__':
|
| 293 |
+
app.run(main)
|
f16c16/make_samples.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 13 |
+
|
| 14 |
+
import jax
|
| 15 |
+
|
| 16 |
+
import flax.linen as nn
|
| 17 |
+
import jax.numpy as jnp
|
| 18 |
+
from absl import app, flags
|
| 19 |
+
from functools import partial
|
| 20 |
+
import numpy as np
|
| 21 |
+
import tqdm
|
| 22 |
+
import flax
|
| 23 |
+
import optax
|
| 24 |
+
import wandb
|
| 25 |
+
from ml_collections import config_flags
|
| 26 |
+
#import elements
|
| 27 |
+
import ml_collections
|
| 28 |
+
import tensorflow_datasets as tfds
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
tf.config.set_visible_devices([], "GPU")
|
| 31 |
+
tf.config.set_visible_devices([], "TPU")
|
| 32 |
+
import matplotlib.pyplot as plt
|
| 33 |
+
from typing import Any
|
| 34 |
+
|
| 35 |
+
from utils.train_state import TrainState, target_update
|
| 36 |
+
from utils.checkpoint import Checkpoint
|
| 37 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 38 |
+
|
| 39 |
+
from train import VQGANModel
|
| 40 |
+
from models.vqvae import VQVAE
|
| 41 |
+
from models.discriminator import Discriminator
|
| 42 |
+
|
| 43 |
+
from PIL import Image
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 47 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 48 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 49 |
+
|
| 50 |
+
FLAGS = flags.FLAGS
|
| 51 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 52 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Documents/LiClipse Workspace/VAE/jax-vqvae-vqgan/7e-5_sdlike_sym/checkpoint.tmp", 'Load dir (if not None, load params from here).')
|
| 53 |
+
flags.DEFINE_integer('batch_size', 16, 'Total Batch size.')
|
| 54 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 55 |
+
|
| 56 |
+
def main(_):
|
| 57 |
+
device_count = len(jax.local_devices())
|
| 58 |
+
global_device_count = jax.device_count()
|
| 59 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 60 |
+
|
| 61 |
+
def get_dataset(is_train):
|
| 62 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 63 |
+
def deserialization_fn(data):
|
| 64 |
+
image = data['image']
|
| 65 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 66 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 67 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 68 |
+
image = tf.image.resize(image, (256, 256))
|
| 69 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 70 |
+
image = tf.image.resize(image, (128, 128))
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 73 |
+
if is_train:
|
| 74 |
+
image_flip = tf.image.flip_left_right(image)
|
| 75 |
+
image_flip = tf.cast(image_flip, tf.float32) / 255.0
|
| 76 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 77 |
+
return image, image_flip, data["label"]
|
| 78 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 79 |
+
return image
|
| 80 |
+
|
| 81 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 82 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 83 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 84 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 85 |
+
dataset = dataset.batch(local_batch_size)
|
| 86 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 87 |
+
dataset = tfds.as_numpy(dataset)
|
| 88 |
+
dataset = iter(dataset)
|
| 89 |
+
return dataset
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 92 |
+
|
| 93 |
+
dataset = get_dataset(is_train=True)
|
| 94 |
+
dataset_valid = get_dataset(is_train=False)
|
| 95 |
+
|
| 96 |
+
example_obs = next(dataset)[0][:1]
|
| 97 |
+
|
| 98 |
+
get_fid_activations = get_fid_network()
|
| 99 |
+
truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
|
| 100 |
+
|
| 101 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 102 |
+
rng, param_key = jax.random.split(rng)
|
| 103 |
+
print("Total devices", jax.local_devices()[0])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
###################################
|
| 107 |
+
# Creating Model and put on devices.
|
| 108 |
+
###################################
|
| 109 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 110 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 111 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 112 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 113 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 114 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
|
| 115 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 116 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 117 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 118 |
+
|
| 119 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 120 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 121 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 122 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
|
| 123 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 124 |
+
|
| 125 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 126 |
+
|
| 127 |
+
assert FLAGS.load_dir is not None
|
| 128 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 129 |
+
model = cp.load_model(model)
|
| 130 |
+
print("Loaded model with step", model.vqvae.step)
|
| 131 |
+
|
| 132 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 133 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 134 |
+
#print(model.vqvae)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
####################################
|
| 138 |
+
# FID Evaluation.
|
| 139 |
+
###################################
|
| 140 |
+
|
| 141 |
+
i = 0
|
| 142 |
+
for valid_images, image_flip, label in dataset:#dataset_valid:
|
| 143 |
+
|
| 144 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 145 |
+
valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
#load up custom image
|
| 149 |
+
# image = Image.open("osman.png")
|
| 150 |
+
# image = np.array(image) / 255.0
|
| 151 |
+
# print(image)
|
| 152 |
+
# image = jnp.array(image)
|
| 153 |
+
# image = jnp.expand_dims(image, 0)
|
| 154 |
+
# image = jnp.expand_dims(image, 0)
|
| 155 |
+
#Try saving the image off the bat
|
| 156 |
+
# image_orig =
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# image = model.reconstruction(image)
|
| 160 |
+
# image = image[0,0,:,:,:]
|
| 161 |
+
# image = (image * 255).astype(np.uint8)
|
| 162 |
+
# image = np.array(image)
|
| 163 |
+
# img = Image.fromarray(image)
|
| 164 |
+
# img.save("osman" + str(i) + ".png")
|
| 165 |
+
# exit()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
#Whatever...
|
| 169 |
+
#top left mine
|
| 170 |
+
#Bottom right SD
|
| 171 |
+
|
| 172 |
+
# fig, axs = plt.subplots(2, 2, figsize=(30, 15))
|
| 173 |
+
|
| 174 |
+
# axs[0, 0].imshow(valid_images[0, 0], vmin=0, vmax=1)
|
| 175 |
+
# axs[1, 0].imshow(valid_reconstructed_images[0, 0], vmin=0, vmax=1)
|
| 176 |
+
# axs[0, 1].imshow
|
| 177 |
+
|
| 178 |
+
# plt.savefig("img.jpg")
|
| 179 |
+
|
| 180 |
+
image = valid_images[0,0,:,:,:]
|
| 181 |
+
image = (image * 255).astype(np.uint8)
|
| 182 |
+
img = Image.fromarray(image)
|
| 183 |
+
img.save("original" + str(i) + ".png")
|
| 184 |
+
|
| 185 |
+
image2 = valid_reconstructed_images[0,0,:,:,:]
|
| 186 |
+
image2 = (image2 * 255).astype(np.uint8)
|
| 187 |
+
image2 = np.array(image2)
|
| 188 |
+
image2 = Image.fromarray(image2)
|
| 189 |
+
|
| 190 |
+
image2.save("recon" + str(i) + ".png")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
i += 1
|
| 194 |
+
|
| 195 |
+
if i == 6:
|
| 196 |
+
exit()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# images.append((valid_reconstructed_images*255).astype(np.uint8))
|
| 203 |
+
|
| 204 |
+
if __name__ == '__main__':
|
| 205 |
+
app.run(main)
|
f16c16/models/__pycache__/discriminator.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
f16c16/models/__pycache__/discriminator.cpython-312.pyc
ADDED
|
Binary file (8.13 kB). View file
|
|
|
f16c16/models/__pycache__/vqvae.cpython-310.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
f16c16/models/__pycache__/vqvae.cpython-312.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
f16c16/models/back_model.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
import flax.linen as nn
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
import functools
|
| 5 |
+
import ml_collections
|
| 6 |
+
import jax
|
| 7 |
+
|
| 8 |
+
###########################
|
| 9 |
+
### Helper Modules
|
| 10 |
+
### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
|
| 11 |
+
###########################
|
| 12 |
+
|
| 13 |
+
def get_norm_layer(norm_type):
|
| 14 |
+
"""Normalization layer."""
|
| 15 |
+
if norm_type == 'BN':
|
| 16 |
+
raise NotImplementedError
|
| 17 |
+
elif norm_type == 'LN':
|
| 18 |
+
norm_fn = functools.partial(nn.LayerNorm)
|
| 19 |
+
elif norm_type == 'GN':
|
| 20 |
+
norm_fn = functools.partial(nn.GroupNorm)
|
| 21 |
+
else:
|
| 22 |
+
raise NotImplementedError
|
| 23 |
+
return norm_fn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
|
| 27 |
+
pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
|
| 28 |
+
(1,) + window_shape + (1,),
|
| 29 |
+
(1,) + strides + (1,), padding)
|
| 30 |
+
pool_denom = jax.lax.reduce_window(
|
| 31 |
+
jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
|
| 32 |
+
(1,) + strides + (1,), padding)
|
| 33 |
+
return pool_sum / pool_denom
|
| 34 |
+
|
| 35 |
+
def upsample(x, factor=2):
|
| 36 |
+
n, h, w, c = x.shape
|
| 37 |
+
x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
def dsample(x):
|
| 41 |
+
return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
|
| 42 |
+
|
| 43 |
+
def squared_euclidean_distance(a: jnp.ndarray,
|
| 44 |
+
b: jnp.ndarray,
|
| 45 |
+
b2: jnp.ndarray = None) -> jnp.ndarray:
|
| 46 |
+
"""Computes the pairwise squared Euclidean distance.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
a: float32: (n, d): An array of points.
|
| 50 |
+
b: float32: (m, d): An array of points.
|
| 51 |
+
b2: float32: (d, m): b square transpose.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
|
| 55 |
+
a[i] and b[j].
|
| 56 |
+
"""
|
| 57 |
+
if b2 is None:
|
| 58 |
+
b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
|
| 59 |
+
a2 = jnp.sum(a**2, axis=1, keepdims=True)
|
| 60 |
+
ab = jnp.matmul(a, b.T)
|
| 61 |
+
d = a2 - 2 * ab + b2
|
| 62 |
+
return d
|
| 63 |
+
|
| 64 |
+
def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
|
| 65 |
+
"""Calculates the entropy loss. Affinity is the similarity/distance matrix."""
|
| 66 |
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
| 67 |
+
flat_affinity /= temperature
|
| 68 |
+
probs = jax.nn.softmax(flat_affinity, axis=-1)
|
| 69 |
+
log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
|
| 70 |
+
if loss_type == "softmax":
|
| 71 |
+
target_probs = probs
|
| 72 |
+
elif loss_type == "argmax":
|
| 73 |
+
codes = jnp.argmax(flat_affinity, axis=-1)
|
| 74 |
+
onehots = jax.nn.one_hot(
|
| 75 |
+
codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
|
| 76 |
+
onehots = probs - jax.lax.stop_gradient(probs - onehots)
|
| 77 |
+
target_probs = onehots
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
| 80 |
+
avg_probs = jnp.mean(target_probs, axis=0)
|
| 81 |
+
avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
|
| 82 |
+
sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
|
| 83 |
+
loss = sample_entropy - avg_entropy
|
| 84 |
+
return loss
|
| 85 |
+
|
| 86 |
+
def sg(x):
|
| 87 |
+
return jax.lax.stop_gradient(x)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
###########################
|
| 93 |
+
### Modules
|
| 94 |
+
###########################
|
| 95 |
+
|
| 96 |
+
class ResBlock(nn.Module):
|
| 97 |
+
"""Basic Residual Block."""
|
| 98 |
+
filters: int
|
| 99 |
+
norm_fn: Any
|
| 100 |
+
activation_fn: Any
|
| 101 |
+
|
| 102 |
+
@nn.compact
|
| 103 |
+
def __call__(self, x):
|
| 104 |
+
input_dim = x.shape[-1]
|
| 105 |
+
residual = x
|
| 106 |
+
x = self.norm_fn()(x)
|
| 107 |
+
x = self.activation_fn(x)
|
| 108 |
+
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
|
| 109 |
+
x = self.norm_fn()(x)
|
| 110 |
+
x = self.activation_fn(x)
|
| 111 |
+
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
|
| 112 |
+
|
| 113 |
+
if input_dim != self.filters:
|
| 114 |
+
residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
|
| 115 |
+
return x + residual
|
| 116 |
+
|
| 117 |
+
class Encoder(nn.Module):
|
| 118 |
+
"""From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
|
| 119 |
+
config: ml_collections.ConfigDict
|
| 120 |
+
|
| 121 |
+
def setup(self):
|
| 122 |
+
self.filters = self.config.filters
|
| 123 |
+
self.num_res_blocks = self.config.num_res_blocks
|
| 124 |
+
self.channel_multipliers = self.config.channel_multipliers
|
| 125 |
+
self.embedding_dim = self.config.embedding_dim
|
| 126 |
+
self.norm_type = self.config.norm_type
|
| 127 |
+
self.activation_fn = nn.swish
|
| 128 |
+
|
| 129 |
+
@nn.compact
|
| 130 |
+
def __call__(self, x):
|
| 131 |
+
print("Initializing encoder.")
|
| 132 |
+
norm_fn = get_norm_layer(norm_type=self.norm_type)
|
| 133 |
+
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
|
| 134 |
+
print("Incoming encoder shape", x.shape)
|
| 135 |
+
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
|
| 136 |
+
print('Encoder layer', x.shape)
|
| 137 |
+
num_blocks = len(self.channel_multipliers)
|
| 138 |
+
for i in range(num_blocks):
|
| 139 |
+
filters = self.filters * self.channel_multipliers[i]
|
| 140 |
+
for _ in range(self.num_res_blocks):
|
| 141 |
+
x = ResBlock(filters, **block_args)(x)
|
| 142 |
+
if i < num_blocks - 1:
|
| 143 |
+
x = dsample(x)
|
| 144 |
+
print('Encoder layer', x.shape)
|
| 145 |
+
|
| 146 |
+
for _ in range(self.num_res_blocks):
|
| 147 |
+
x = ResBlock(filters, **block_args)(x)
|
| 148 |
+
print('Encoder layer', x.shape)
|
| 149 |
+
x = norm_fn()(x)
|
| 150 |
+
x = self.activation_fn(x)
|
| 151 |
+
last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
|
| 152 |
+
x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
|
| 153 |
+
print("Before final", x.shape)
|
| 154 |
+
x = nn.Conv(8, kernel_size=(1,1))(x)
|
| 155 |
+
print("Final embeddings are size", x.shape)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
class Decoder(nn.Module):
|
| 159 |
+
"""From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
|
| 160 |
+
|
| 161 |
+
config: ml_collections.ConfigDict
|
| 162 |
+
|
| 163 |
+
def setup(self):
|
| 164 |
+
self.filters = self.config.filters
|
| 165 |
+
self.num_res_blocks = self.config.num_res_blocks
|
| 166 |
+
self.channel_multipliers = self.config.channel_multipliers
|
| 167 |
+
self.norm_type = self.config.norm_type
|
| 168 |
+
self.image_channels = self.config.image_channels
|
| 169 |
+
self.activation_fn = nn.swish
|
| 170 |
+
|
| 171 |
+
@nn.compact
|
| 172 |
+
def __call__(self, x):
|
| 173 |
+
norm_fn = get_norm_layer(norm_type=self.norm_type)
|
| 174 |
+
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
|
| 175 |
+
num_blocks = len(self.channel_multipliers)
|
| 176 |
+
filters = self.filters * self.channel_multipliers[-1]
|
| 177 |
+
print("Decoder incoming shape", x.shape)
|
| 178 |
+
|
| 179 |
+
#We don't need to do anything here because it'll put it back to 512
|
| 180 |
+
|
| 181 |
+
x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
|
| 182 |
+
print("Decoder input", x.shape)
|
| 183 |
+
|
| 184 |
+
for _ in range(self.num_res_blocks):
|
| 185 |
+
x = ResBlock(filters, **block_args)(x)
|
| 186 |
+
print('Decoder layer', x.shape)
|
| 187 |
+
for i in reversed(range(num_blocks)):
|
| 188 |
+
filters = self.filters * self.channel_multipliers[i]
|
| 189 |
+
for _ in range(self.num_res_blocks):
|
| 190 |
+
x = ResBlock(filters, **block_args)(x)
|
| 191 |
+
if i > 0:
|
| 192 |
+
x = upsample(x, 2)
|
| 193 |
+
x = nn.Conv(filters, kernel_size=(3, 3))(x)
|
| 194 |
+
print('Decoder layer', x.shape)
|
| 195 |
+
x = norm_fn()(x)
|
| 196 |
+
x = self.activation_fn(x)
|
| 197 |
+
x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
class VectorQuantizer(nn.Module):
|
| 201 |
+
"""Basic vector quantizer."""
|
| 202 |
+
config: ml_collections.ConfigDict
|
| 203 |
+
train: bool
|
| 204 |
+
|
| 205 |
+
@nn.compact
|
| 206 |
+
def __call__(self, x):
|
| 207 |
+
codebook_size = self.config.codebook_size
|
| 208 |
+
emb_dim = x.shape[-1]
|
| 209 |
+
codebook = self.param(
|
| 210 |
+
"codebook",
|
| 211 |
+
jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
|
| 212 |
+
(codebook_size, emb_dim))
|
| 213 |
+
codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
|
| 214 |
+
distances = jnp.reshape(
|
| 215 |
+
squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
|
| 216 |
+
x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
|
| 217 |
+
encoding_indices = jnp.argmin(distances, axis=-1)
|
| 218 |
+
encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
|
| 219 |
+
quantized = self.quantize(encoding_onehot)
|
| 220 |
+
result_dict = dict()
|
| 221 |
+
if self.train:
|
| 222 |
+
e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
|
| 223 |
+
q_latent_loss = jnp.mean((quantized - sg(x))**2)
|
| 224 |
+
entropy_loss = 0.0
|
| 225 |
+
if self.config.entropy_loss_ratio != 0:
|
| 226 |
+
entropy_loss = entropy_loss_fn(
|
| 227 |
+
-distances,
|
| 228 |
+
loss_type=self.config.entropy_loss_type,
|
| 229 |
+
temperature=self.config.entropy_temperature
|
| 230 |
+
) * self.config.entropy_loss_ratio
|
| 231 |
+
e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
|
| 232 |
+
q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
|
| 233 |
+
entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
|
| 234 |
+
loss = e_latent_loss + q_latent_loss + entropy_loss
|
| 235 |
+
result_dict = dict(
|
| 236 |
+
quantizer_loss=loss,
|
| 237 |
+
e_latent_loss=e_latent_loss,
|
| 238 |
+
q_latent_loss=q_latent_loss,
|
| 239 |
+
entropy_loss=entropy_loss)
|
| 240 |
+
quantized = x + jax.lax.stop_gradient(quantized - x)
|
| 241 |
+
|
| 242 |
+
result_dict.update({
|
| 243 |
+
"z_ids": encoding_indices,
|
| 244 |
+
})
|
| 245 |
+
return quantized, result_dict
|
| 246 |
+
|
| 247 |
+
def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
|
| 248 |
+
codebook = jnp.asarray(self.variables["params"]["codebook"])
|
| 249 |
+
return jnp.dot(encoding_onehot, codebook)
|
| 250 |
+
|
| 251 |
+
def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
|
| 252 |
+
codebook = self.variables["params"]["codebook"]
|
| 253 |
+
return jnp.take(codebook, ids, axis=0)
|
| 254 |
+
|
| 255 |
+
class KLQuantizer(nn.Module):
|
| 256 |
+
config: ml_collections.ConfigDict
|
| 257 |
+
train: bool
|
| 258 |
+
|
| 259 |
+
@nn.compact
|
| 260 |
+
def __call__(self, x):
|
| 261 |
+
emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
|
| 262 |
+
means = x[..., :emb_dim]
|
| 263 |
+
logvars = x[..., emb_dim:]
|
| 264 |
+
if not self.train:
|
| 265 |
+
result_dict = dict()
|
| 266 |
+
return means, result_dict
|
| 267 |
+
else:
|
| 268 |
+
noise = jax.random.normal(self.make_rng("noise"), means.shape)
|
| 269 |
+
stds = jnp.exp(0.5 * logvars)
|
| 270 |
+
z = means + stds * noise
|
| 271 |
+
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
|
| 272 |
+
result_dict = dict(quantizer_loss=kl_loss)
|
| 273 |
+
return z, result_dict
|
| 274 |
+
|
| 275 |
+
class FSQuantizer(nn.Module):
|
| 276 |
+
config: ml_collections.ConfigDict
|
| 277 |
+
train: bool
|
| 278 |
+
|
| 279 |
+
@nn.compact
|
| 280 |
+
def __call__(self, x):
|
| 281 |
+
assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
|
| 282 |
+
z = jnp.tanh(x) # [-1, 1]
|
| 283 |
+
z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
|
| 284 |
+
zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
|
| 285 |
+
quantized = z + jax.lax.stop_gradient(zhat - z)
|
| 286 |
+
quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
|
| 287 |
+
result_dict = dict()
|
| 288 |
+
|
| 289 |
+
# Diagnostics for codebook usage.
|
| 290 |
+
zhat_scaled = zhat + self.config['fsq_levels'] // 2
|
| 291 |
+
basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
|
| 292 |
+
idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
|
| 293 |
+
idx_flat = idx.reshape(-1)
|
| 294 |
+
usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
|
| 295 |
+
|
| 296 |
+
result_dict.update({
|
| 297 |
+
"z_ids": zhat,
|
| 298 |
+
'usage': usage
|
| 299 |
+
})
|
| 300 |
+
return quantized, result_dict
|
| 301 |
+
|
| 302 |
+
class VQVAE(nn.Module):
|
| 303 |
+
"""VQVAE model."""
|
| 304 |
+
config: ml_collections.ConfigDict
|
| 305 |
+
train: bool
|
| 306 |
+
|
| 307 |
+
def setup(self):
|
| 308 |
+
"""VQVAE setup."""
|
| 309 |
+
if self.config['quantizer_type'] == 'vq':
|
| 310 |
+
self.quantizer = VectorQuantizer(config=self.config, train=self.train)
|
| 311 |
+
elif self.config['quantizer_type'] == 'kl':
|
| 312 |
+
self.quantizer = KLQuantizer(config=self.config, train=self.train)
|
| 313 |
+
elif self.config['quantizer_type'] == 'fsq':
|
| 314 |
+
self.quantizer = FSQuantizer(config=self.config, train=self.train)
|
| 315 |
+
self.encoder = Encoder(config=self.config)
|
| 316 |
+
self.decoder = Decoder(config=self.config)
|
| 317 |
+
|
| 318 |
+
def encode(self, image):
|
| 319 |
+
encoded_feature = self.encoder(image)
|
| 320 |
+
quantized, result_dict = self.quantizer(encoded_feature)
|
| 321 |
+
print("After quant", quantized.shape)
|
| 322 |
+
return quantized, result_dict
|
| 323 |
+
|
| 324 |
+
def decode(self, z_vectors):
|
| 325 |
+
print("z_vectors shape", z_vectors.shape)
|
| 326 |
+
reconstructed = self.decoder(z_vectors)
|
| 327 |
+
return reconstructed
|
| 328 |
+
|
| 329 |
+
def decode_from_indices(self, z_ids):
|
| 330 |
+
z_vectors = self.quantizer.decode_ids(z_ids)
|
| 331 |
+
reconstructed_image = self.decode(z_vectors)
|
| 332 |
+
return reconstructed_image
|
| 333 |
+
|
| 334 |
+
def encode_to_indices(self, image):
|
| 335 |
+
encoded_feature = self.encoder(image)
|
| 336 |
+
_, result_dict = self.quantizer(encoded_feature)
|
| 337 |
+
ids = result_dict["z_ids"]
|
| 338 |
+
return ids
|
| 339 |
+
|
| 340 |
+
def __call__(self, input_dict):
|
| 341 |
+
quantized, result_dict = self.encode(input_dict)
|
| 342 |
+
outputs = self.decoder(quantized)
|
| 343 |
+
return outputs, result_dict
|
f16c16/models/discriminator.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Discriminator from StyleGAN. https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py"""
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import math
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
import flax.linen as nn
|
| 7 |
+
from flax.linen.initializers import xavier_uniform
|
| 8 |
+
import jax
|
| 9 |
+
from jax import lax
|
| 10 |
+
import jax.numpy as jnp
|
| 11 |
+
import ml_collections
|
| 12 |
+
|
| 13 |
+
default_kernel_init = xavier_uniform()
|
| 14 |
+
|
| 15 |
+
def _conv_dimension_numbers(input_shape):
|
| 16 |
+
"""Computes the dimension numbers based on the input shape."""
|
| 17 |
+
ndim = len(input_shape)
|
| 18 |
+
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
| 19 |
+
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
| 20 |
+
out_spec = lhs_spec
|
| 21 |
+
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BlurPool2D(nn.Module):
|
| 25 |
+
"""A layer to do channel-wise blurring + subsampling on 2D inputs.
|
| 26 |
+
|
| 27 |
+
Reference:
|
| 28 |
+
Zhang et al. Making Convolutional Networks Shift-Invariant Again.
|
| 29 |
+
https://arxiv.org/pdf/1904.11486.pdf.
|
| 30 |
+
"""
|
| 31 |
+
filter_size: int = 4
|
| 32 |
+
strides: Tuple[int, int] = (2, 2)
|
| 33 |
+
padding: str = 'SAME'
|
| 34 |
+
|
| 35 |
+
def setup(self):
|
| 36 |
+
if self.filter_size == 3:
|
| 37 |
+
self.filter = [1., 2., 1.]
|
| 38 |
+
elif self.filter_size == 4:
|
| 39 |
+
self.filter = [1., 3., 3., 1.]
|
| 40 |
+
elif self.filter_size == 5:
|
| 41 |
+
self.filter = [1., 4., 6., 4., 1.]
|
| 42 |
+
elif self.filter_size == 6:
|
| 43 |
+
self.filter = [1., 5., 10., 10., 5., 1.]
|
| 44 |
+
elif self.filter_size == 7:
|
| 45 |
+
self.filter = [1., 6., 15., 20., 15., 6., 1.]
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError('Only filter_size of 3, 4, 5, 6 or 7 is supported.')
|
| 48 |
+
|
| 49 |
+
self.filter = jnp.array(self.filter, dtype=jnp.float32)
|
| 50 |
+
self.filter = self.filter[:, None] * self.filter[None, :]
|
| 51 |
+
with jax.default_matmul_precision('float32'):
|
| 52 |
+
self.filter /= jnp.sum(self.filter)
|
| 53 |
+
self.filter = jnp.reshape(
|
| 54 |
+
self.filter, [self.filter.shape[0], self.filter.shape[1], 1, 1])
|
| 55 |
+
|
| 56 |
+
@nn.compact
|
| 57 |
+
def __call__(self, inputs):
|
| 58 |
+
channel_num = inputs.shape[-1]
|
| 59 |
+
dimension_numbers = _conv_dimension_numbers(inputs.shape)
|
| 60 |
+
depthwise_filter = jnp.tile(self.filter, [1, 1, 1, channel_num])
|
| 61 |
+
with jax.default_matmul_precision('float32'):
|
| 62 |
+
outputs = lax.conv_general_dilated(inputs, depthwise_filter, self.strides,
|
| 63 |
+
self.padding, feature_group_count=channel_num, dimension_numbers=dimension_numbers)
|
| 64 |
+
return outputs
|
| 65 |
+
|
| 66 |
+
class ResBlock(nn.Module):
|
| 67 |
+
"""StyleGAN ResBlock for D.
|
| 68 |
+
|
| 69 |
+
https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py#L618
|
| 70 |
+
"""
|
| 71 |
+
filters: int
|
| 72 |
+
activation_fn: Any
|
| 73 |
+
|
| 74 |
+
@nn.compact
|
| 75 |
+
def __call__(self, x):
|
| 76 |
+
input_dim = x.shape[-1]
|
| 77 |
+
residual = x
|
| 78 |
+
x = nn.Conv(input_dim, (3, 3), kernel_init=default_kernel_init)(x)
|
| 79 |
+
x = self.activation_fn(x)
|
| 80 |
+
x = BlurPool2D(filter_size=4)(x)
|
| 81 |
+
residual = BlurPool2D(filter_size=4)(residual)
|
| 82 |
+
residual = nn.Conv(self.filters, (1, 1), use_bias=False, kernel_init=default_kernel_init)(residual)
|
| 83 |
+
x = nn.Conv(self.filters, (3, 3), kernel_init=default_kernel_init)(x)
|
| 84 |
+
x = self.activation_fn(x)
|
| 85 |
+
out = (residual + x) / math.sqrt(2)
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Discriminator(nn.Module):
|
| 90 |
+
"""StyleGAN Discriminator."""
|
| 91 |
+
config: ml_collections.ConfigDict
|
| 92 |
+
|
| 93 |
+
def setup(self):
|
| 94 |
+
self.input_size = self.config.image_size
|
| 95 |
+
self.activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
|
| 96 |
+
self.channel_multiplier = 1
|
| 97 |
+
|
| 98 |
+
@nn.compact
|
| 99 |
+
def __call__(self, x):
|
| 100 |
+
filters = {
|
| 101 |
+
4: 512,
|
| 102 |
+
8: 512,
|
| 103 |
+
16: 512,
|
| 104 |
+
32: 512,
|
| 105 |
+
64: 256 * self.channel_multiplier,
|
| 106 |
+
128: 128 * self.channel_multiplier,
|
| 107 |
+
256: 64 * self.channel_multiplier,
|
| 108 |
+
512: 32 * self.channel_multiplier,
|
| 109 |
+
1024: 16 * self.channel_multiplier,
|
| 110 |
+
}
|
| 111 |
+
x = nn.Conv(filters[self.input_size], (3, 3), kernel_init=default_kernel_init)(x)
|
| 112 |
+
x = self.activation_fn(x)
|
| 113 |
+
log_size = int(math.log2(self.input_size))
|
| 114 |
+
for i in range(log_size, 2, -1):
|
| 115 |
+
x = ResBlock(filters[2**(i - 1)], self.activation_fn)(x)
|
| 116 |
+
print("Disc shape", x.shape)
|
| 117 |
+
x = nn.Conv(filters[4], (3, 3), kernel_init=default_kernel_init)(x)
|
| 118 |
+
x = self.activation_fn(x)
|
| 119 |
+
x = x.reshape((x.shape[0], -1))
|
| 120 |
+
x = nn.Dense(filters[4], kernel_init=default_kernel_init)(x)
|
| 121 |
+
x = self.activation_fn(x)
|
| 122 |
+
x = nn.Dense(1, kernel_init=default_kernel_init)(x)
|
| 123 |
+
return x
|
f16c16/models/vqvae.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
import flax.linen as nn
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
import functools
|
| 5 |
+
import ml_collections
|
| 6 |
+
import jax
|
| 7 |
+
|
| 8 |
+
from flax.linen import initializers
|
| 9 |
+
|
| 10 |
+
###########################
|
| 11 |
+
### Helper Modules
|
| 12 |
+
### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
|
| 13 |
+
###########################
|
| 14 |
+
|
| 15 |
+
def get_norm_layer(norm_type):
|
| 16 |
+
"""Normalization layer."""
|
| 17 |
+
if norm_type == 'BN':
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
elif norm_type == 'LN':
|
| 20 |
+
norm_fn = functools.partial(nn.LayerNorm)
|
| 21 |
+
elif norm_type == 'GN':
|
| 22 |
+
norm_fn = functools.partial(nn.GroupNorm)
|
| 23 |
+
else:
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
return norm_fn
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
|
| 29 |
+
pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
|
| 30 |
+
(1,) + window_shape + (1,),
|
| 31 |
+
(1,) + strides + (1,), padding)
|
| 32 |
+
pool_denom = jax.lax.reduce_window(
|
| 33 |
+
jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
|
| 34 |
+
(1,) + strides + (1,), padding)
|
| 35 |
+
return pool_sum / pool_denom
|
| 36 |
+
|
| 37 |
+
def upsample(x, factor=2):
|
| 38 |
+
n, h, w, c = x.shape
|
| 39 |
+
x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def dsample(x):
|
| 43 |
+
return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
|
| 44 |
+
|
| 45 |
+
def squared_euclidean_distance(a: jnp.ndarray,
|
| 46 |
+
b: jnp.ndarray,
|
| 47 |
+
b2: jnp.ndarray = None) -> jnp.ndarray:
|
| 48 |
+
"""Computes the pairwise squared Euclidean distance.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
a: float32: (n, d): An array of points.
|
| 52 |
+
b: float32: (m, d): An array of points.
|
| 53 |
+
b2: float32: (d, m): b square transpose.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
|
| 57 |
+
a[i] and b[j].
|
| 58 |
+
"""
|
| 59 |
+
if b2 is None:
|
| 60 |
+
b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
|
| 61 |
+
a2 = jnp.sum(a**2, axis=1, keepdims=True)
|
| 62 |
+
ab = jnp.matmul(a, b.T)
|
| 63 |
+
d = a2 - 2 * ab + b2
|
| 64 |
+
return d
|
| 65 |
+
|
| 66 |
+
def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
|
| 67 |
+
"""Calculates the entropy loss. Affinity is the similarity/distance matrix."""
|
| 68 |
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
| 69 |
+
flat_affinity /= temperature
|
| 70 |
+
probs = jax.nn.softmax(flat_affinity, axis=-1)
|
| 71 |
+
log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
|
| 72 |
+
if loss_type == "softmax":
|
| 73 |
+
target_probs = probs
|
| 74 |
+
elif loss_type == "argmax":
|
| 75 |
+
codes = jnp.argmax(flat_affinity, axis=-1)
|
| 76 |
+
onehots = jax.nn.one_hot(
|
| 77 |
+
codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
|
| 78 |
+
onehots = probs - jax.lax.stop_gradient(probs - onehots)
|
| 79 |
+
target_probs = onehots
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
| 82 |
+
avg_probs = jnp.mean(target_probs, axis=0)
|
| 83 |
+
avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
|
| 84 |
+
sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
|
| 85 |
+
loss = sample_entropy - avg_entropy
|
| 86 |
+
return loss
|
| 87 |
+
|
| 88 |
+
def sg(x):
|
| 89 |
+
return jax.lax.stop_gradient(x)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
###########################
|
| 95 |
+
### Modules
|
| 96 |
+
###########################
|
| 97 |
+
|
| 98 |
+
class ResBlock(nn.Module):
|
| 99 |
+
"""Basic Residual Block."""
|
| 100 |
+
filters: int
|
| 101 |
+
norm_fn: Any
|
| 102 |
+
activation_fn: Any
|
| 103 |
+
|
| 104 |
+
@nn.compact
|
| 105 |
+
def __call__(self, x):
|
| 106 |
+
input_dim = x.shape[-1]
|
| 107 |
+
residual = x
|
| 108 |
+
x = self.norm_fn()(x)
|
| 109 |
+
x = self.activation_fn(x)
|
| 110 |
+
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
|
| 111 |
+
x = self.norm_fn()(x)
|
| 112 |
+
x = self.activation_fn(x)
|
| 113 |
+
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
|
| 114 |
+
|
| 115 |
+
if input_dim != self.filters:#Basically if input doesn't match output, use a skip
|
| 116 |
+
residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
|
| 117 |
+
return x + residual
|
| 118 |
+
|
| 119 |
+
class Fourier(nn.Module):
|
| 120 |
+
|
| 121 |
+
def setup(self):
|
| 122 |
+
|
| 123 |
+
#Our input comes in as 3... after we convert to 512, maybe instead we convert to 256, and then do this?
|
| 124 |
+
self.weight = jax.random.normal(self.make_rng("noise"), means.shape)
|
| 125 |
+
|
| 126 |
+
@nn.compact
|
| 127 |
+
def __call__(self, f):
|
| 128 |
+
#this is probabl ycahnnels lastz
|
| 129 |
+
f = 2 * math.pi * input @ self.weight.T
|
| 130 |
+
return torch.cat([f.cos(), f.sin()], dim = -1)
|
| 131 |
+
|
| 132 |
+
from einops import rearrange
|
| 133 |
+
class LinearEncoder(nn.Module):
|
| 134 |
+
|
| 135 |
+
config: ml_collections.ConfigDict
|
| 136 |
+
|
| 137 |
+
#So in this setup, we don't carea bout anything
|
| 138 |
+
@nn.compact
|
| 139 |
+
def __call__(self, x):
|
| 140 |
+
print("init encoder")
|
| 141 |
+
print("x shape", x.shape)
|
| 142 |
+
x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=8, b2=8)
|
| 143 |
+
x = nn.Dense(4)(x)#We just put to 4 for now
|
| 144 |
+
print(x.shape)
|
| 145 |
+
return x
|
| 146 |
+
#k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
|
| 147 |
+
#1x1 conv, uplift from 3 to like..... 64
|
| 148 |
+
#That gives us 256x256x64
|
| 149 |
+
#Then pixelshuffle to
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Encoder(nn.Module):
|
| 153 |
+
"""From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
|
| 154 |
+
config: ml_collections.ConfigDict
|
| 155 |
+
|
| 156 |
+
def setup(self):
|
| 157 |
+
self.filters = self.config.filters#filters is the original setup
|
| 158 |
+
self.num_res_blocks = self.config.num_res_blocks
|
| 159 |
+
self.channel_multipliers = self.config.channel_multipliers
|
| 160 |
+
self.embedding_dim = self.config.embedding_dim
|
| 161 |
+
self.norm_type = self.config.norm_type
|
| 162 |
+
self.activation_fn = nn.swish
|
| 163 |
+
self.kernel_init = initializers.he_normal()
|
| 164 |
+
|
| 165 |
+
@nn.compact
|
| 166 |
+
def __call__(self, x):
|
| 167 |
+
print("Initializing encoder.")
|
| 168 |
+
norm_fn = get_norm_layer(norm_type=self.norm_type)
|
| 169 |
+
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
|
| 170 |
+
print("Incoming encoder shape", x.shape)
|
| 171 |
+
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
|
| 172 |
+
print('Encoder layer', x.shape)
|
| 173 |
+
num_blocks = len(self.channel_multipliers)
|
| 174 |
+
|
| 175 |
+
#The way SD works, is it does 2x resnet, not changing anything, then downsample
|
| 176 |
+
#It does this 3 times, leading to 8x downsample
|
| 177 |
+
#Then it has an extra resnet block, and THEN from 512 to 8 / 4
|
| 178 |
+
|
| 179 |
+
for i in range(num_blocks):
|
| 180 |
+
filters = self.filters * self.channel_multipliers[i]
|
| 181 |
+
for _ in range(self.num_res_blocks):
|
| 182 |
+
x = ResBlock(filters, **block_args)(x)
|
| 183 |
+
if i < num_blocks - 1:#For each block *except end* do downsample
|
| 184 |
+
print("doing downsample")
|
| 185 |
+
x = dsample(x)
|
| 186 |
+
print('Encoder layer', x.shape)
|
| 187 |
+
|
| 188 |
+
#After we are done downsampling, we do the 2 resnet, and down below here, we have the 2 midblock?
|
| 189 |
+
|
| 190 |
+
for _ in range(self.num_res_blocks):
|
| 191 |
+
x = ResBlock(filters, **block_args)(x)
|
| 192 |
+
print('Encoder layer final', x.shape)
|
| 193 |
+
|
| 194 |
+
x = norm_fn()(x)
|
| 195 |
+
x = self.activation_fn(x)
|
| 196 |
+
last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
|
| 197 |
+
x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
|
| 198 |
+
print("Final embeddings are size", x.shape)
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
class Decoder(nn.Module):
|
| 202 |
+
"""From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
|
| 203 |
+
|
| 204 |
+
config: ml_collections.ConfigDict
|
| 205 |
+
|
| 206 |
+
def setup(self):
|
| 207 |
+
self.filters = self.config.filters
|
| 208 |
+
self.num_res_blocks = self.config.num_res_blocks
|
| 209 |
+
self.channel_multipliers = self.config.channel_multipliers
|
| 210 |
+
self.norm_type = self.config.norm_type
|
| 211 |
+
self.image_channels = self.config.image_channels
|
| 212 |
+
self.activation_fn = nn.swish
|
| 213 |
+
self.kernel_init = initializers.he_normal()
|
| 214 |
+
|
| 215 |
+
@nn.compact
|
| 216 |
+
def __call__(self, x):
|
| 217 |
+
norm_fn = get_norm_layer(norm_type=self.norm_type)
|
| 218 |
+
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
|
| 219 |
+
num_blocks = len(self.channel_multipliers)
|
| 220 |
+
filters = self.filters * self.channel_multipliers[-1]
|
| 221 |
+
print("Decoder incoming shape", x.shape)
|
| 222 |
+
|
| 223 |
+
#We don't need to do anything here because it'll put it back to 512
|
| 224 |
+
|
| 225 |
+
x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
|
| 226 |
+
print("Decoder input", x.shape)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
#This is the mid block
|
| 230 |
+
for _ in range(self.num_res_blocks):
|
| 231 |
+
x = ResBlock(filters, **block_args)(x)
|
| 232 |
+
print('Mid Block Decoder layer', x.shape)
|
| 233 |
+
|
| 234 |
+
#First two SET of blocks is just 3 resnet, no channel changes, we are already at 4x = 512
|
| 235 |
+
|
| 236 |
+
for i in reversed(range(num_blocks)):
|
| 237 |
+
filters = self.filters * self.channel_multipliers[i]
|
| 238 |
+
for _ in range(self.num_res_blocks):#sym
|
| 239 |
+
x = ResBlock(filters, **block_args)(x)
|
| 240 |
+
if i > 0:
|
| 241 |
+
x = upsample(x, 2)
|
| 242 |
+
x = nn.Conv(filters, kernel_size=(3, 3))(x)
|
| 243 |
+
print('Decoder layer', x.shape)
|
| 244 |
+
x = norm_fn()(x)
|
| 245 |
+
x = self.activation_fn(x)
|
| 246 |
+
x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
class VectorQuantizer(nn.Module):
|
| 250 |
+
"""Basic vector quantizer."""
|
| 251 |
+
config: ml_collections.ConfigDict
|
| 252 |
+
train: bool
|
| 253 |
+
|
| 254 |
+
@nn.compact
|
| 255 |
+
def __call__(self, x):
|
| 256 |
+
codebook_size = self.config.codebook_size
|
| 257 |
+
emb_dim = x.shape[-1]
|
| 258 |
+
codebook = self.param(
|
| 259 |
+
"codebook",
|
| 260 |
+
jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
|
| 261 |
+
(codebook_size, emb_dim))
|
| 262 |
+
codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
|
| 263 |
+
distances = jnp.reshape(
|
| 264 |
+
squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
|
| 265 |
+
x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
|
| 266 |
+
encoding_indices = jnp.argmin(distances, axis=-1)
|
| 267 |
+
encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
|
| 268 |
+
quantized = self.quantize(encoding_onehot)
|
| 269 |
+
result_dict = dict()
|
| 270 |
+
if self.train:
|
| 271 |
+
e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
|
| 272 |
+
q_latent_loss = jnp.mean((quantized - sg(x))**2)
|
| 273 |
+
entropy_loss = 0.0
|
| 274 |
+
if self.config.entropy_loss_ratio != 0:
|
| 275 |
+
entropy_loss = entropy_loss_fn(
|
| 276 |
+
-distances,
|
| 277 |
+
loss_type=self.config.entropy_loss_type,
|
| 278 |
+
temperature=self.config.entropy_temperature
|
| 279 |
+
) * self.config.entropy_loss_ratio
|
| 280 |
+
e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
|
| 281 |
+
q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
|
| 282 |
+
entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
|
| 283 |
+
loss = e_latent_loss + q_latent_loss + entropy_loss
|
| 284 |
+
result_dict = dict(
|
| 285 |
+
quantizer_loss=loss,
|
| 286 |
+
e_latent_loss=e_latent_loss,
|
| 287 |
+
q_latent_loss=q_latent_loss,
|
| 288 |
+
entropy_loss=entropy_loss)
|
| 289 |
+
quantized = x + jax.lax.stop_gradient(quantized - x)
|
| 290 |
+
|
| 291 |
+
result_dict.update({
|
| 292 |
+
"z_ids": encoding_indices,
|
| 293 |
+
})
|
| 294 |
+
return quantized, result_dict
|
| 295 |
+
|
| 296 |
+
def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
|
| 297 |
+
codebook = jnp.asarray(self.variables["params"]["codebook"])
|
| 298 |
+
return jnp.dot(encoding_onehot, codebook)
|
| 299 |
+
|
| 300 |
+
def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
|
| 301 |
+
codebook = self.variables["params"]["codebook"]
|
| 302 |
+
return jnp.take(codebook, ids, axis=0)
|
| 303 |
+
|
| 304 |
+
class KLQuantizer(nn.Module):
|
| 305 |
+
config: ml_collections.ConfigDict
|
| 306 |
+
train: bool
|
| 307 |
+
|
| 308 |
+
@nn.compact
|
| 309 |
+
def __call__(self, x):
|
| 310 |
+
emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
|
| 311 |
+
means = x[..., :emb_dim]
|
| 312 |
+
logvars = x[..., emb_dim:]
|
| 313 |
+
if not self.train:
|
| 314 |
+
result_dict = dict()
|
| 315 |
+
result_dict["std"] = jnp.exp(0.5 * logvars)
|
| 316 |
+
return means, result_dict
|
| 317 |
+
else:
|
| 318 |
+
noise = jax.random.normal(self.make_rng("noise"), means.shape)
|
| 319 |
+
stds = jnp.exp(0.5 * logvars)
|
| 320 |
+
z = means + stds * noise
|
| 321 |
+
#kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
|
| 322 |
+
|
| 323 |
+
#New kl
|
| 324 |
+
kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(means) - jnp.exp(logvars),axis=tuple(range(1, means.ndim)))
|
| 325 |
+
kl_loss = jnp.mean(kl_loss)
|
| 326 |
+
|
| 327 |
+
result_dict = dict(quantizer_loss=kl_loss)
|
| 328 |
+
result_dict["std"] = jnp.exp(0.5 * logvars)
|
| 329 |
+
return z, result_dict
|
| 330 |
+
|
| 331 |
+
class AEQuantizer(nn.Module): #cooking
|
| 332 |
+
config: ml_collections.ConfigDict
|
| 333 |
+
train: bool
|
| 334 |
+
|
| 335 |
+
@nn.compact
|
| 336 |
+
def __call__(self, x):
|
| 337 |
+
result_dict = dict()
|
| 338 |
+
result_dict["std"] = 0.0
|
| 339 |
+
return x, result_dict
|
| 340 |
+
|
| 341 |
+
import jax
|
| 342 |
+
import jax.numpy as jnp
|
| 343 |
+
from jax import random
|
| 344 |
+
|
| 345 |
+
def imq_kernel(X: jnp.ndarray, Y: jnp.ndarray, h_dim: int):
|
| 346 |
+
batch_size = X.shape[0]
|
| 347 |
+
|
| 348 |
+
norms_x = jnp.sum(X**2, axis=1, keepdims=True) # batch_size x 1
|
| 349 |
+
prods_x = jnp.dot(X, X.T) # batch_size x batch_size
|
| 350 |
+
dists_x = norms_x + norms_x.T - 2 * prods_x
|
| 351 |
+
|
| 352 |
+
norms_y = jnp.sum(Y**2, axis=1, keepdims=True) # batch_size x 1
|
| 353 |
+
prods_y = jnp.dot(Y, Y.T) # batch_size x batch_size
|
| 354 |
+
dists_y = norms_y + norms_y.T - 2 * prods_y
|
| 355 |
+
|
| 356 |
+
dot_prd = jnp.dot(X, Y.T)
|
| 357 |
+
dists_c = norms_x + norms_y.T - 2 * dot_prd
|
| 358 |
+
|
| 359 |
+
stats = 0
|
| 360 |
+
for scale in [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]:
|
| 361 |
+
C = 2 * h_dim * 1.0 * scale
|
| 362 |
+
res1 = C / (C + dists_x)
|
| 363 |
+
res1 += C / (C + dists_y)
|
| 364 |
+
|
| 365 |
+
res1 = (1 - jnp.eye(batch_size)) * res1
|
| 366 |
+
res1 = jnp.sum(res1) / (batch_size - 1)
|
| 367 |
+
|
| 368 |
+
res2 = C / (C + dists_c)
|
| 369 |
+
res2 = jnp.sum(res2) * 2.0 / batch_size
|
| 370 |
+
stats += res1 - res2
|
| 371 |
+
|
| 372 |
+
return stats
|
| 373 |
+
|
| 374 |
+
class MMDQuantizer(nn.Module): #cooking
|
| 375 |
+
config: ml_collections.ConfigDict
|
| 376 |
+
train: bool
|
| 377 |
+
|
| 378 |
+
@nn.compact
|
| 379 |
+
def __call__(self, x):
|
| 380 |
+
if not self.train:
|
| 381 |
+
result_dict = dict()
|
| 382 |
+
return x, result_dict
|
| 383 |
+
else:
|
| 384 |
+
print("mmd quantizer")
|
| 385 |
+
batch_size, height, width, latent_channels = x.shape
|
| 386 |
+
z_flat = x.reshape(batch_size, -1)
|
| 387 |
+
print(z_flat.shape)
|
| 388 |
+
z_fake_flat = jax.random.normal(self.make_rng("noise"), z_flat.shape) * self.config["MMD_weight"]
|
| 389 |
+
print(z_fake_flat.shape)
|
| 390 |
+
mmd_loss = imq_kernel(z_flat, z_fake_flat, z_flat.shape[1])
|
| 391 |
+
print(mmd_loss.shape)
|
| 392 |
+
print(mmd_loss)
|
| 393 |
+
result_dict = dict(quantizer_loss=mmd_loss)
|
| 394 |
+
return x, result_dict
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class KLQuantizerTwo(nn.Module):
|
| 399 |
+
config: ml_collections.ConfigDict
|
| 400 |
+
train: bool
|
| 401 |
+
|
| 402 |
+
@nn.compact
|
| 403 |
+
def __call__(self, x):
|
| 404 |
+
#emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
|
| 405 |
+
#means = x[..., :emb_dim]
|
| 406 |
+
#logvars = x[..., emb_dim:]
|
| 407 |
+
|
| 408 |
+
#Wwe actually wanna do mean and STD on the batch axis?
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
#we start as b hw 8, go to b hw 4, with mean and std over those.
|
| 412 |
+
|
| 413 |
+
if not self.train:
|
| 414 |
+
result_dict = dict()
|
| 415 |
+
result_dict["std"] = 1.0
|
| 416 |
+
return x, result_dict
|
| 417 |
+
else:
|
| 418 |
+
stds = jnp.std(x, axis = [1,2,3])
|
| 419 |
+
|
| 420 |
+
noise = jax.random.normal(self.make_rng("noise"), x.shape)
|
| 421 |
+
|
| 422 |
+
logvars = .5 * jnp.log(stds)
|
| 423 |
+
logvars = logvars.reshape(-1,1,1,1)
|
| 424 |
+
if True:#This is true for special KL where we set sigma to 1 manually
|
| 425 |
+
logvars = 0.0
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if False:#dinossl
|
| 429 |
+
x_2 = x.reshape(x.shape[0], -1, x.shape[-1])#Linear with channel size
|
| 430 |
+
x_2 = jnp.swapaxes(x_2,0,1)
|
| 431 |
+
#then/ get the covariance
|
| 432 |
+
cov = jnp.swapaxes(x_2,1,2) @ x_2 / x.shape[0]
|
| 433 |
+
#Not sure about this, we also have regular cov
|
| 434 |
+
I_d = jnp.identity(x.shape[-1])
|
| 435 |
+
R_eps = jnp.log(jnp.linalg.det(jnp.expand_dims(I_d, axis = 0) + x.shape[-1]/ (.0001 ** 2) * cov))
|
| 436 |
+
|
| 437 |
+
#So something here *does* depend on the -1 shape, but I need to math it out.
|
| 438 |
+
kl_loss = R_eps.mean()
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
#This is the denoising version
|
| 442 |
+
kl_loss = - 0.5 * jnp.sum(1 + logvars - jnp.square(x) - jnp.exp(logvars),axis=tuple(range(1, x.ndim)))
|
| 443 |
+
kl_loss = jnp.mean(kl_loss)
|
| 444 |
+
|
| 445 |
+
result_dict = dict(quantizer_loss=kl_loss)
|
| 446 |
+
result_dict["std"] = 1.0
|
| 447 |
+
|
| 448 |
+
#For proper kl two, we need to return noise + mean.
|
| 449 |
+
return x + noise, result_dict
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class FSQuantizer(nn.Module):
|
| 453 |
+
config: ml_collections.ConfigDict
|
| 454 |
+
train: bool
|
| 455 |
+
|
| 456 |
+
@nn.compact
|
| 457 |
+
def __call__(self, x):
|
| 458 |
+
assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
|
| 459 |
+
z = jnp.tanh(x) # [-1, 1]
|
| 460 |
+
z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
|
| 461 |
+
zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
|
| 462 |
+
quantized = z + jax.lax.stop_gradient(zhat - z)
|
| 463 |
+
quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
|
| 464 |
+
result_dict = dict()
|
| 465 |
+
|
| 466 |
+
# Diagnostics for codebook usage.
|
| 467 |
+
zhat_scaled = zhat + self.config['fsq_levels'] // 2
|
| 468 |
+
basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
|
| 469 |
+
idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
|
| 470 |
+
idx_flat = idx.reshape(-1)
|
| 471 |
+
usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
|
| 472 |
+
|
| 473 |
+
result_dict.update({
|
| 474 |
+
"z_ids": zhat,
|
| 475 |
+
'usage': usage
|
| 476 |
+
})
|
| 477 |
+
return quantized, result_dict
|
| 478 |
+
|
| 479 |
+
class VQVAE(nn.Module):
|
| 480 |
+
"""VQVAE model."""
|
| 481 |
+
config: ml_collections.ConfigDict
|
| 482 |
+
train: bool
|
| 483 |
+
|
| 484 |
+
def setup(self):
|
| 485 |
+
"""VQVAE setup."""
|
| 486 |
+
if self.config['quantizer_type'] == 'vq':
|
| 487 |
+
self.quantizer = VectorQuantizer(config=self.config, train=self.train)
|
| 488 |
+
elif self.config['quantizer_type'] == 'kl':
|
| 489 |
+
self.quantizer = KLQuantizer(config=self.config, train=self.train)
|
| 490 |
+
elif self.config['quantizer_type'] == 'fsq':
|
| 491 |
+
self.quantizer = FSQuantizer(config=self.config, train=self.train)
|
| 492 |
+
elif self.config['quantizer_type'] == 'ae':
|
| 493 |
+
self.quantizer = AEQuantizer(config=self.config, train=self.train)
|
| 494 |
+
elif self.config["quantizer_type"] == "kl_two":
|
| 495 |
+
self.quantizer = KLQuantizerTwo(config=self.config, train=self.train)
|
| 496 |
+
self.encoder = Encoder(config=self.config)
|
| 497 |
+
self.decoder = Decoder(config=self.config)
|
| 498 |
+
|
| 499 |
+
def encode(self, image):
|
| 500 |
+
encoded_feature = self.encoder(image)
|
| 501 |
+
quantized, result_dict = self.quantizer(encoded_feature)
|
| 502 |
+
print("After quant", quantized.shape)
|
| 503 |
+
return quantized, result_dict
|
| 504 |
+
|
| 505 |
+
def decode(self, z_vectors):
|
| 506 |
+
print("z_vectors shape", z_vectors.shape)
|
| 507 |
+
reconstructed = self.decoder(z_vectors)
|
| 508 |
+
return reconstructed
|
| 509 |
+
|
| 510 |
+
def decode_from_indices(self, z_ids):
|
| 511 |
+
z_vectors = self.quantizer.decode_ids(z_ids)
|
| 512 |
+
reconstructed_image = self.decode(z_vectors)
|
| 513 |
+
return reconstructed_image
|
| 514 |
+
|
| 515 |
+
def encode_to_indices(self, image):
|
| 516 |
+
encoded_feature = self.encoder(image)
|
| 517 |
+
_, result_dict = self.quantizer(encoded_feature)
|
| 518 |
+
ids = result_dict["z_ids"]
|
| 519 |
+
return ids
|
| 520 |
+
|
| 521 |
+
def __call__(self, input_dict):
|
| 522 |
+
quantized, result_dict = self.encode(input_dict)
|
| 523 |
+
#Freezing encoder now
|
| 524 |
+
print("encode finished")
|
| 525 |
+
result_dict["latents"] = quantized
|
| 526 |
+
outputs = self.decoder(quantized)
|
| 527 |
+
return outputs, result_dict
|
f16c16/ppl_images.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 12 |
+
import jax
|
| 13 |
+
import lpips
|
| 14 |
+
|
| 15 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 16 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import flax.linen as nn
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from absl import app, flags
|
| 23 |
+
from functools import partial
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tqdm
|
| 26 |
+
import flax
|
| 27 |
+
import optax
|
| 28 |
+
import wandb
|
| 29 |
+
from ml_collections import config_flags
|
| 30 |
+
#import elements
|
| 31 |
+
import ml_collections
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
tf.config.set_visible_devices([], "GPU")
|
| 35 |
+
tf.config.set_visible_devices([], "TPU")
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
from utils.train_state import TrainState, target_update
|
| 40 |
+
from utils.checkpoint import Checkpoint
|
| 41 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 42 |
+
|
| 43 |
+
from train import VQGANModel
|
| 44 |
+
from models.vqvae import VQVAE
|
| 45 |
+
from models.discriminator import Discriminator
|
| 46 |
+
|
| 47 |
+
from PIL import Image
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 51 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 52 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 53 |
+
|
| 54 |
+
FLAGS = flags.FLAGS
|
| 55 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 56 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 60 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 61 |
+
|
| 62 |
+
import gc
|
| 63 |
+
|
| 64 |
+
def main(_):
|
| 65 |
+
device_count = len(jax.local_devices())
|
| 66 |
+
global_device_count = jax.device_count()
|
| 67 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 68 |
+
|
| 69 |
+
def get_dataset(is_train):
|
| 70 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 71 |
+
def deserialization_fn(data):
|
| 72 |
+
image = data['image']
|
| 73 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 74 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 75 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 76 |
+
image = tf.image.resize(image, (256, 256))
|
| 77 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 78 |
+
image = tf.image.resize(image, (128, 128))
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 81 |
+
if is_train:
|
| 82 |
+
image = tf.image.random_flip_left_right(image)
|
| 83 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 84 |
+
return image
|
| 85 |
+
|
| 86 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 87 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 88 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 89 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 90 |
+
dataset = dataset.batch(local_batch_size)
|
| 91 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 92 |
+
dataset = tfds.as_numpy(dataset)
|
| 93 |
+
dataset = iter(dataset)
|
| 94 |
+
return dataset
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 97 |
+
|
| 98 |
+
dataset = get_dataset(is_train=True)
|
| 99 |
+
dataset_valid = get_dataset(is_train=False)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# image = Image.open("osman.png")
|
| 103 |
+
# image = np.array(image) / 255.0
|
| 104 |
+
# print(image)
|
| 105 |
+
# image = jnp.array(image)
|
| 106 |
+
# image = jnp.expand_dims(image, 0)
|
| 107 |
+
# image = jnp.expand_dims(image, 0)
|
| 108 |
+
|
| 109 |
+
example_obs = next(dataset)[:1]
|
| 110 |
+
|
| 111 |
+
#Reconstruction loop
|
| 112 |
+
# image = model.reconstruction(image)
|
| 113 |
+
# image = image[0,0,:,:,:]
|
| 114 |
+
# image = (image * 255).astype(np.uint8)
|
| 115 |
+
# image = np.array(image)
|
| 116 |
+
# img = Image.fromarray(image)
|
| 117 |
+
# img.save("osman" + str(i) + ".png")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 121 |
+
rng, param_key = jax.random.split(rng)
|
| 122 |
+
print("Total devices", jax.local_devices()[0])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
###################################
|
| 126 |
+
# Creating Model and put on devices.
|
| 127 |
+
###################################
|
| 128 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 129 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 130 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 131 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 132 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 133 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 134 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 135 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 136 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 137 |
+
|
| 138 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 139 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 140 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 141 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 142 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 143 |
+
|
| 144 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 145 |
+
|
| 146 |
+
assert FLAGS.load_dir is not None
|
| 147 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 148 |
+
model = cp.load_model(model)
|
| 149 |
+
print("Loaded model with step", model.vqvae.step)
|
| 150 |
+
|
| 151 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 152 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 153 |
+
#print(model.vqvae)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
####################################
|
| 157 |
+
# Noise stuff
|
| 158 |
+
###################################
|
| 159 |
+
|
| 160 |
+
cpus = jax.devices("cpu")
|
| 161 |
+
|
| 162 |
+
#So there are a few ways to calculate PPL here
|
| 163 |
+
#We could take two images in image space
|
| 164 |
+
#Walk between them and check the LPIPS in the output space
|
| 165 |
+
#...actually that's basically it right?
|
| 166 |
+
#We could also do the walk in latent space, which is the same, but with ?? scaling
|
| 167 |
+
|
| 168 |
+
#Let's see if they are any different.
|
| 169 |
+
i = 0
|
| 170 |
+
lpips_list = []
|
| 171 |
+
means = []
|
| 172 |
+
stds = []
|
| 173 |
+
for valid_images in dataset_valid:
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 177 |
+
#1, 2, 256, 256, 3
|
| 178 |
+
#Given our 2 images, we want to lerp between them...
|
| 179 |
+
#We want to lerp once to point t, and once to point t + eps
|
| 180 |
+
#And then we want to get the LPIPS between those two images
|
| 181 |
+
#And then we calculate LPIPS
|
| 182 |
+
#And then we divide by eps squared, and done.
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
reconstructed_images, decoded, std, latents, std_noisy, latents_noisy = model.reconstruction_ppl_image(valid_images) # [devices, 8, 256, 256, 3]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
means.append(latents.mean())
|
| 190 |
+
stds.append(latents.std())
|
| 191 |
+
|
| 192 |
+
# print("std", std.mean())
|
| 193 |
+
print("latent mean", latents.mean())
|
| 194 |
+
print("actual latent std", latents.std())
|
| 195 |
+
|
| 196 |
+
print("latent mean noisy", latents_noisy.mean())
|
| 197 |
+
print("actual latent std noisy", latents_noisy.std())
|
| 198 |
+
|
| 199 |
+
#Need to change images back to -1,1
|
| 200 |
+
|
| 201 |
+
reconstructed_images = reconstructed_images * 2 - 1
|
| 202 |
+
decoded = decoded * 2 -1
|
| 203 |
+
|
| 204 |
+
#1,2,256,256,3
|
| 205 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
|
| 206 |
+
decoded = jnp.swapaxes(decoded, 0, 4)
|
| 207 |
+
|
| 208 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
|
| 209 |
+
decoded = jnp.swapaxes(decoded, 0, 1)
|
| 210 |
+
|
| 211 |
+
reconstructed_images = jnp.squeeze(reconstructed_images)
|
| 212 |
+
decoded = jnp.squeeze(decoded)
|
| 213 |
+
|
| 214 |
+
#So here, we want to put them on CPU and delete the original
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
image_np = np.asarray(reconstructed_images)
|
| 218 |
+
image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
|
| 219 |
+
|
| 220 |
+
decoded_np = np.asarray(decoded)
|
| 221 |
+
decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
|
| 226 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 227 |
+
lpips_cpu = lpips_cpu / (.0001 ** 2)
|
| 228 |
+
|
| 229 |
+
print(lpips_cpu)
|
| 230 |
+
lpips_list.append(lpips_cpu)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
i += 1
|
| 234 |
+
#
|
| 235 |
+
if i == 500:
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
#1e-4 is 54...
|
| 239 |
+
#1e-5 is 106
|
| 240 |
+
#1e-6 is 126
|
| 241 |
+
|
| 242 |
+
#kl2 is 150?
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
mean_lpips = jnp.mean(jnp.asarray(lpips_list))
|
| 247 |
+
print(mean_lpips)
|
| 248 |
+
print("mean of means", jnp.asarray(means).mean())
|
| 249 |
+
print("stds of means", jnp.asarray(means).std())
|
| 250 |
+
print("mean of stds", jnp.asarray(stds).mean())
|
| 251 |
+
print("std of stds", jnp.asarray(stds).std())
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
if __name__ == '__main__':
|
| 255 |
+
app.run(main)
|
f16c16/ppl_latents.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 12 |
+
import jax
|
| 13 |
+
import lpips
|
| 14 |
+
|
| 15 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 16 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import flax.linen as nn
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from absl import app, flags
|
| 23 |
+
from functools import partial
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tqdm
|
| 26 |
+
import flax
|
| 27 |
+
import optax
|
| 28 |
+
import wandb
|
| 29 |
+
from ml_collections import config_flags
|
| 30 |
+
#import elements
|
| 31 |
+
import ml_collections
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
tf.config.set_visible_devices([], "GPU")
|
| 35 |
+
tf.config.set_visible_devices([], "TPU")
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
from utils.train_state import TrainState, target_update
|
| 40 |
+
from utils.checkpoint import Checkpoint
|
| 41 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 42 |
+
|
| 43 |
+
from train import VQGANModel
|
| 44 |
+
from models.vqvae import VQVAE
|
| 45 |
+
from models.discriminator import Discriminator
|
| 46 |
+
|
| 47 |
+
from PIL import Image
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 51 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 52 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 53 |
+
|
| 54 |
+
FLAGS = flags.FLAGS
|
| 55 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 56 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 60 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 61 |
+
|
| 62 |
+
import gc
|
| 63 |
+
|
| 64 |
+
def main(_):
|
| 65 |
+
device_count = len(jax.local_devices())
|
| 66 |
+
global_device_count = jax.device_count()
|
| 67 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 68 |
+
|
| 69 |
+
def get_dataset(is_train):
|
| 70 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 71 |
+
def deserialization_fn(data):
|
| 72 |
+
image = data['image']
|
| 73 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 74 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 75 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 76 |
+
image = tf.image.resize(image, (256, 256))
|
| 77 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 78 |
+
image = tf.image.resize(image, (128, 128))
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 81 |
+
if is_train:
|
| 82 |
+
image = tf.image.random_flip_left_right(image)
|
| 83 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 84 |
+
return image
|
| 85 |
+
|
| 86 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 87 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 88 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 89 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 90 |
+
dataset = dataset.batch(local_batch_size)
|
| 91 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 92 |
+
dataset = tfds.as_numpy(dataset)
|
| 93 |
+
dataset = iter(dataset)
|
| 94 |
+
return dataset
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 97 |
+
|
| 98 |
+
dataset = get_dataset(is_train=True)
|
| 99 |
+
dataset_valid = get_dataset(is_train=False)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# image = Image.open("osman.png")
|
| 103 |
+
# image = np.array(image) / 255.0
|
| 104 |
+
# print(image)
|
| 105 |
+
# image = jnp.array(image)
|
| 106 |
+
# image = jnp.expand_dims(image, 0)
|
| 107 |
+
# image = jnp.expand_dims(image, 0)
|
| 108 |
+
|
| 109 |
+
example_obs = next(dataset)[:1]
|
| 110 |
+
|
| 111 |
+
#Reconstruction loop
|
| 112 |
+
# image = model.reconstruction(image)
|
| 113 |
+
# image = image[0,0,:,:,:]
|
| 114 |
+
# image = (image * 255).astype(np.uint8)
|
| 115 |
+
# image = np.array(image)
|
| 116 |
+
# img = Image.fromarray(image)
|
| 117 |
+
# img.save("osman" + str(i) + ".png")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 121 |
+
rng, param_key = jax.random.split(rng)
|
| 122 |
+
print("Total devices", jax.local_devices()[0])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
###################################
|
| 126 |
+
# Creating Model and put on devices.
|
| 127 |
+
###################################
|
| 128 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 129 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 130 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 131 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 132 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 133 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 134 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 135 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 136 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 137 |
+
|
| 138 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 139 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 140 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 141 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 142 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 143 |
+
|
| 144 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 145 |
+
|
| 146 |
+
assert FLAGS.load_dir is not None
|
| 147 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 148 |
+
model = cp.load_model(model)
|
| 149 |
+
print("Loaded model with step", model.vqvae.step)
|
| 150 |
+
|
| 151 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 152 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 153 |
+
#print(model.vqvae)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
####################################
|
| 157 |
+
# Noise stuff
|
| 158 |
+
###################################
|
| 159 |
+
|
| 160 |
+
cpus = jax.devices("cpu")
|
| 161 |
+
|
| 162 |
+
#So there are a few ways to calculate PPL here
|
| 163 |
+
#We could take two images in image space
|
| 164 |
+
#Walk between them and check the LPIPS in the output space
|
| 165 |
+
#...actually that's basically it right?
|
| 166 |
+
#We could also do the walk in latent space, which is the same, but with ?? scaling
|
| 167 |
+
|
| 168 |
+
#Let's see if they are any different.
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
#We could also try taking a latent, going X/2 direction, and -X/2 direction, and seeing that.
|
| 172 |
+
i = 0
|
| 173 |
+
lpips_list = []
|
| 174 |
+
means = []
|
| 175 |
+
stds = []
|
| 176 |
+
for valid_images in dataset_valid:
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 180 |
+
#1, 2, 256, 256, 3
|
| 181 |
+
#Given our 2 images, we want to lerp between them...
|
| 182 |
+
#We want to lerp once to point t, and once to point t + eps
|
| 183 |
+
#And then we want to get the LPIPS between those two images
|
| 184 |
+
#And then we calculate LPIPS
|
| 185 |
+
#And then we divide by eps squared, and done.
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
reconstructed_images, decoded, std, latents = model.reconstruction_ppl(valid_images) # [devices, 8, 256, 256, 3]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
means.append(latents.mean())
|
| 192 |
+
stds.append(latents.std())
|
| 193 |
+
print("noise added", std.mean())
|
| 194 |
+
print("latent mean", latents.mean())
|
| 195 |
+
print("actual latent std", latents.std())
|
| 196 |
+
|
| 197 |
+
#Need to change images back to -1,1
|
| 198 |
+
|
| 199 |
+
reconstructed_images = reconstructed_images * 2 - 1
|
| 200 |
+
decoded = decoded * 2 -1
|
| 201 |
+
|
| 202 |
+
#1,2,256,256,3
|
| 203 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
|
| 204 |
+
decoded = jnp.swapaxes(decoded, 0, 4)
|
| 205 |
+
|
| 206 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
|
| 207 |
+
decoded = jnp.swapaxes(decoded, 0, 1)
|
| 208 |
+
|
| 209 |
+
reconstructed_images = jnp.squeeze(reconstructed_images)
|
| 210 |
+
decoded = jnp.squeeze(decoded)
|
| 211 |
+
|
| 212 |
+
#So here, we want to put them on CPU and delete the original
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
image_np = np.asarray(reconstructed_images)
|
| 216 |
+
image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
|
| 217 |
+
|
| 218 |
+
decoded_np = np.asarray(decoded)
|
| 219 |
+
decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
|
| 224 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 225 |
+
lpips_cpu = lpips_cpu / (.0001 ** 2)
|
| 226 |
+
|
| 227 |
+
print(lpips_cpu)
|
| 228 |
+
lpips_list.append(lpips_cpu)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
i += 1
|
| 232 |
+
#
|
| 233 |
+
if i == 500:
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
mean_lpips = jnp.mean(jnp.asarray(lpips_list))
|
| 238 |
+
std_lpips = jnp.std(jnp.asarray(lpips_list))
|
| 239 |
+
print("PPL", mean_lpips)
|
| 240 |
+
print("C std", std_lpips)
|
| 241 |
+
|
| 242 |
+
print("mean of means", jnp.asarray(means).mean())
|
| 243 |
+
print("stds of means", jnp.asarray(means).std())
|
| 244 |
+
print("mean of stds", jnp.asarray(stds).mean())
|
| 245 |
+
print("std of stds", jnp.asarray(stds).std())
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
#ae sym
|
| 249 |
+
# mean of means 0.35234922
|
| 250 |
+
# stds of means 0.4036692
|
| 251 |
+
# mean of stds 2.6363409
|
| 252 |
+
# std of stds 0.30666474
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
#1e-6:
|
| 256 |
+
#mean of means -0.018107202
|
| 257 |
+
# stds of means 0.11694455
|
| 258 |
+
# mean of stds 1.0860059
|
| 259 |
+
# std of stds 0.09732369
|
| 260 |
+
#average noise added around .03
|
| 261 |
+
|
| 262 |
+
#1e-5:
|
| 263 |
+
# mean of means 0.0065166513
|
| 264 |
+
# stds of means 0.06983645
|
| 265 |
+
# mean of stds 0.9855982
|
| 266 |
+
# std of stds 0.05810356
|
| 267 |
+
|
| 268 |
+
#1e-4:
|
| 269 |
+
# PPL 8.167942
|
| 270 |
+
# C std 1.7576017
|
| 271 |
+
# mean of means 0.0065882676
|
| 272 |
+
# stds of means 0.042861093
|
| 273 |
+
# mean of stds 0.7608507
|
| 274 |
+
# std of stds 0.05846726
|
| 275 |
+
#Average noise added???
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
#pl300
|
| 280 |
+
#PPL 3.5399284
|
| 281 |
+
#C std 0.45380986
|
| 282 |
+
# mean of means 0.090131655
|
| 283 |
+
# stds of means 0.69894844
|
| 284 |
+
# mean of stds 5.5634923
|
| 285 |
+
# std of stds 0.6767279
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
#pl100
|
| 289 |
+
# PPL 3.6192155
|
| 290 |
+
# C std 0.47185272
|
| 291 |
+
# mean of means 0.16227543
|
| 292 |
+
# stds of means 0.53616405
|
| 293 |
+
# mean of stds 4.4914503
|
| 294 |
+
# std of stds 0.6015057
|
| 295 |
+
|
| 296 |
+
#kl2 noise thing
|
| 297 |
+
# PPL 1.2598925
|
| 298 |
+
# C std 0.26455516
|
| 299 |
+
# mean of means -0.013443217
|
| 300 |
+
# stds of means 1.5238239
|
| 301 |
+
# mean of stds 40.043938
|
| 302 |
+
# std of stds 1.7931403
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == '__main__':
|
| 307 |
+
app.run(main)
|
f16c16/ppl_latents2.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 12 |
+
import jax
|
| 13 |
+
import lpips
|
| 14 |
+
|
| 15 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 16 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import flax.linen as nn
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from absl import app, flags
|
| 23 |
+
from functools import partial
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tqdm
|
| 26 |
+
import flax
|
| 27 |
+
import optax
|
| 28 |
+
import wandb
|
| 29 |
+
from ml_collections import config_flags
|
| 30 |
+
#import elements
|
| 31 |
+
import ml_collections
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
tf.config.set_visible_devices([], "GPU")
|
| 35 |
+
tf.config.set_visible_devices([], "TPU")
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
from utils.train_state import TrainState, target_update
|
| 40 |
+
from utils.checkpoint import Checkpoint
|
| 41 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 42 |
+
|
| 43 |
+
from train import VQGANModel
|
| 44 |
+
from models.vqvae import VQVAE
|
| 45 |
+
from models.discriminator import Discriminator
|
| 46 |
+
|
| 47 |
+
from PIL import Image
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 51 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 52 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 53 |
+
|
| 54 |
+
FLAGS = flags.FLAGS
|
| 55 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 56 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 60 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 61 |
+
|
| 62 |
+
import gc
|
| 63 |
+
|
| 64 |
+
def main(_):
|
| 65 |
+
device_count = len(jax.local_devices())
|
| 66 |
+
global_device_count = jax.device_count()
|
| 67 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 68 |
+
|
| 69 |
+
def get_dataset(is_train):
|
| 70 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 71 |
+
def deserialization_fn(data):
|
| 72 |
+
image = data['image']
|
| 73 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 74 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 75 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 76 |
+
image = tf.image.resize(image, (256, 256))
|
| 77 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 78 |
+
image = tf.image.resize(image, (128, 128))
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 81 |
+
if is_train:
|
| 82 |
+
image = tf.image.random_flip_left_right(image)
|
| 83 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 84 |
+
return image
|
| 85 |
+
|
| 86 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 87 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 88 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 89 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 90 |
+
dataset = dataset.batch(local_batch_size)
|
| 91 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 92 |
+
dataset = tfds.as_numpy(dataset)
|
| 93 |
+
dataset = iter(dataset)
|
| 94 |
+
return dataset
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 97 |
+
|
| 98 |
+
dataset = get_dataset(is_train=True)
|
| 99 |
+
dataset_valid = get_dataset(is_train=False)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# image = Image.open("osman.png")
|
| 103 |
+
# image = np.array(image) / 255.0
|
| 104 |
+
# print(image)
|
| 105 |
+
# image = jnp.array(image)
|
| 106 |
+
# image = jnp.expand_dims(image, 0)
|
| 107 |
+
# image = jnp.expand_dims(image, 0)
|
| 108 |
+
|
| 109 |
+
example_obs = next(dataset)[:1]
|
| 110 |
+
|
| 111 |
+
#Reconstruction loop
|
| 112 |
+
# image = model.reconstruction(image)
|
| 113 |
+
# image = image[0,0,:,:,:]
|
| 114 |
+
# image = (image * 255).astype(np.uint8)
|
| 115 |
+
# image = np.array(image)
|
| 116 |
+
# img = Image.fromarray(image)
|
| 117 |
+
# img.save("osman" + str(i) + ".png")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 121 |
+
rng, param_key = jax.random.split(rng)
|
| 122 |
+
print("Total devices", jax.local_devices()[0])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
###################################
|
| 126 |
+
# Creating Model and put on devices.
|
| 127 |
+
###################################
|
| 128 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 129 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 130 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 131 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 132 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 133 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 134 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 135 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 136 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 137 |
+
|
| 138 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 139 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 140 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 141 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 142 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 143 |
+
|
| 144 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 145 |
+
|
| 146 |
+
assert FLAGS.load_dir is not None
|
| 147 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 148 |
+
model = cp.load_model(model)
|
| 149 |
+
print("Loaded model with step", model.vqvae.step)
|
| 150 |
+
|
| 151 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 152 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 153 |
+
#print(model.vqvae)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
####################################
|
| 157 |
+
# Noise stuff
|
| 158 |
+
###################################
|
| 159 |
+
|
| 160 |
+
cpus = jax.devices("cpu")
|
| 161 |
+
|
| 162 |
+
#So there are a few ways to calculate PPL here
|
| 163 |
+
#We could take two images in image space
|
| 164 |
+
#Walk between them and check the LPIPS in the output space
|
| 165 |
+
#...actually that's basically it right?
|
| 166 |
+
#We could also do the walk in latent space, which is the same, but with ?? scaling
|
| 167 |
+
|
| 168 |
+
#Let's see if they are any different.
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
#We could also try taking a latent, going X/2 direction, and -X/2 direction, and seeing that.
|
| 172 |
+
i = 0
|
| 173 |
+
lpips_list = []
|
| 174 |
+
means = []
|
| 175 |
+
stds = []
|
| 176 |
+
for valid_images in dataset_valid:
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 180 |
+
#1, 2, 256, 256, 3
|
| 181 |
+
#Given our 2 images, we want to lerp between them...
|
| 182 |
+
#We want to lerp once to point t, and once to point t + eps
|
| 183 |
+
#And then we want to get the LPIPS between those two images
|
| 184 |
+
#And then we calculate LPIPS
|
| 185 |
+
#And then we divide by eps squared, and done.
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
reconstructed_images, decoded, std, latents, decoded_2 = model.reconstruction_ppl_two(valid_images) # [devices, 8, 256, 256, 3]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
means.append(latents.mean())
|
| 192 |
+
stds.append(latents.std())
|
| 193 |
+
# print("std", std.mean())
|
| 194 |
+
print("latent mean", latents.mean())
|
| 195 |
+
print("actual latent std", latents.std())
|
| 196 |
+
|
| 197 |
+
#Need to change images back to -1,1
|
| 198 |
+
#Why are the images so similar? It's different noises...
|
| 199 |
+
|
| 200 |
+
reconstructed_images = decoded_2 * 2 - 1
|
| 201 |
+
decoded = decoded * 2 -1
|
| 202 |
+
|
| 203 |
+
#1,2,256,256,3
|
| 204 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 4)
|
| 205 |
+
decoded = jnp.swapaxes(decoded, 0, 4)
|
| 206 |
+
|
| 207 |
+
reconstructed_images = jnp.swapaxes(reconstructed_images, 0, 1)
|
| 208 |
+
decoded = jnp.swapaxes(decoded, 0, 1)
|
| 209 |
+
|
| 210 |
+
reconstructed_images = jnp.squeeze(reconstructed_images)
|
| 211 |
+
decoded = jnp.squeeze(decoded)
|
| 212 |
+
|
| 213 |
+
#So here, we want to put them on CPU and delete the original
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
image_np = np.asarray(reconstructed_images)
|
| 217 |
+
image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
|
| 218 |
+
|
| 219 |
+
decoded_np = np.asarray(decoded)
|
| 220 |
+
decoded_np_2 = torch.from_numpy(np.copy(decoded_np)).cuda()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
lpips_loss = loss_fn_alex(image_np_2, decoded_np_2)
|
| 225 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 226 |
+
lpips_cpu = lpips_cpu / (.0001 ** 2)
|
| 227 |
+
|
| 228 |
+
print(lpips_cpu)
|
| 229 |
+
lpips_list.append(lpips_cpu)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
i += 1
|
| 233 |
+
#
|
| 234 |
+
if i == 500:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
mean_lpips = jnp.mean(jnp.asarray(lpips_list))
|
| 241 |
+
print(mean_lpips)
|
| 242 |
+
print("mean of means", jnp.asarray(means).mean())
|
| 243 |
+
print("stds of means", jnp.asarray(means).std())
|
| 244 |
+
print("mean of stds", jnp.asarray(stds).mean())
|
| 245 |
+
print("std of stds", jnp.asarray(stds).std())
|
| 246 |
+
|
| 247 |
+
#1e-4? 8.1371
|
| 248 |
+
#1e-5 9.0486
|
| 249 |
+
#1e-6 9.7
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
#ae is a 5.85.....
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
#1e-4 kl2 1.26
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
#1e-6 is 9.8
|
| 262 |
+
#1e-5 is 9.09
|
| 263 |
+
#2e-5 is ..... between these. hopefully. 8.83
|
| 264 |
+
#1e-4 is 8.16
|
| 265 |
+
#ae (sym) is 5.87 right now, somehow.
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
#basicallly ae 5.56, then 4.95?
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
#PL100 is 3.6
|
| 272 |
+
#Pl300 is 3.53
|
| 273 |
+
#Pl600 is... 3.97
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
#So the kl level barely matters it seems.
|
| 277 |
+
#We might want to try MMD + noise, but it also barely matters I think
|
| 278 |
+
#1e-4 was 1.25
|
| 279 |
+
#5e-5 was 1.225
|
| 280 |
+
#kl2 was like super duper low, forgot to save it lol. 1.17 maybe?
|
| 281 |
+
|
| 282 |
+
if __name__ == '__main__':
|
| 283 |
+
app.run(main)
|
f16c16/stats.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#import jax
|
| 9 |
+
#jax.config.update('jax_platform_name', 'cpu')
|
| 10 |
+
import os
|
| 11 |
+
# os.environ["JAX_PLATFORMS"] = 'cpu'
|
| 12 |
+
import jax
|
| 13 |
+
import lpips
|
| 14 |
+
|
| 15 |
+
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
|
| 16 |
+
loss_fn_alex = loss_fn_alex.cuda()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import flax.linen as nn
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
from absl import app, flags
|
| 23 |
+
from functools import partial
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tqdm
|
| 26 |
+
import flax
|
| 27 |
+
import optax
|
| 28 |
+
import wandb
|
| 29 |
+
from ml_collections import config_flags
|
| 30 |
+
#import elements
|
| 31 |
+
import ml_collections
|
| 32 |
+
import tensorflow_datasets as tfds
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
tf.config.set_visible_devices([], "GPU")
|
| 35 |
+
tf.config.set_visible_devices([], "TPU")
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
from utils.train_state import TrainState, target_update
|
| 40 |
+
from utils.checkpoint import Checkpoint
|
| 41 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 42 |
+
|
| 43 |
+
from train import VQGANModel
|
| 44 |
+
from models.vqvae import VQVAE
|
| 45 |
+
from models.discriminator import Discriminator
|
| 46 |
+
|
| 47 |
+
from PIL import Image
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
delattr(flags.FLAGS, 'dataset_name')
|
| 51 |
+
delattr(flags.FLAGS, 'load_dir')
|
| 52 |
+
delattr(flags.FLAGS, 'batch_size')
|
| 53 |
+
|
| 54 |
+
FLAGS = flags.FLAGS
|
| 55 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 56 |
+
flags.DEFINE_string('load_dir', "/home/dkaplan/Downloads/Models/checkpoint(1).tmp", 'Load dir (if not None, load params from here).')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
flags.DEFINE_integer('batch_size', 2, 'Total Batch size.')
|
| 60 |
+
# Flags are inhereited from train.py, so pass your model parameters again here to evaluate.
|
| 61 |
+
|
| 62 |
+
import gc
|
| 63 |
+
|
| 64 |
+
def main(_):
|
| 65 |
+
|
| 66 |
+
device_count = len(jax.local_devices())
|
| 67 |
+
global_device_count = jax.device_count()
|
| 68 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 69 |
+
|
| 70 |
+
def get_dataset(is_train):
|
| 71 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 72 |
+
def deserialization_fn(data):
|
| 73 |
+
image = data['image']
|
| 74 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 75 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 76 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 77 |
+
image = tf.image.resize(image, (256, 256))
|
| 78 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 79 |
+
image = tf.image.resize(image, (128, 128))
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 82 |
+
if is_train:
|
| 83 |
+
image = tf.image.random_flip_left_right(image)
|
| 84 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 85 |
+
return image
|
| 86 |
+
|
| 87 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 88 |
+
dataset = tfds.load('imagenet2012', data_dir="/data/inet", split=split)
|
| 89 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 90 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 91 |
+
dataset = dataset.batch(local_batch_size)
|
| 92 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 93 |
+
dataset = tfds.as_numpy(dataset)
|
| 94 |
+
dataset = iter(dataset)
|
| 95 |
+
return dataset
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 98 |
+
|
| 99 |
+
dataset = get_dataset(is_train=True)
|
| 100 |
+
dataset_valid = get_dataset(is_train=False)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# image = Image.open("osman.png")
|
| 104 |
+
# image = np.array(image) / 255.0
|
| 105 |
+
# print(image)
|
| 106 |
+
# image = jnp.array(image)
|
| 107 |
+
# image = jnp.expand_dims(image, 0)
|
| 108 |
+
# image = jnp.expand_dims(image, 0)
|
| 109 |
+
|
| 110 |
+
example_obs = next(dataset)[:1]
|
| 111 |
+
|
| 112 |
+
#Reconstruction loop
|
| 113 |
+
# image = model.reconstruction(image)
|
| 114 |
+
# image = image[0,0,:,:,:]
|
| 115 |
+
# image = (image * 255).astype(np.uint8)
|
| 116 |
+
# image = np.array(image)
|
| 117 |
+
# img = Image.fromarray(image)
|
| 118 |
+
# img.save("osman" + str(i) + ".png")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 122 |
+
rng, param_key = jax.random.split(rng)
|
| 123 |
+
print("Total devices", jax.local_devices()[0])
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
###################################
|
| 127 |
+
# Creating Model and put on devices.
|
| 128 |
+
###################################
|
| 129 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 130 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 131 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 132 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 133 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 134 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params)#, tx=tx) #Turning off tx because we don't need it...
|
| 135 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 136 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 137 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 138 |
+
|
| 139 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 140 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 141 |
+
# tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 142 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params)#, tx=tx)#No tx again
|
| 143 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 144 |
+
|
| 145 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 146 |
+
|
| 147 |
+
assert FLAGS.load_dir is not None
|
| 148 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 149 |
+
model = cp.load_model(model)
|
| 150 |
+
print("Loaded model with step", model.vqvae.step)
|
| 151 |
+
|
| 152 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 153 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 154 |
+
#print(model.vqvae)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
####################################
|
| 158 |
+
# Noise stuff
|
| 159 |
+
###################################
|
| 160 |
+
|
| 161 |
+
#on the other end also.
|
| 162 |
+
noises = []
|
| 163 |
+
|
| 164 |
+
numbers = np.arange(0.00, 1.0, 0.01)
|
| 165 |
+
|
| 166 |
+
for number in numbers:
|
| 167 |
+
noises.append(float(number))
|
| 168 |
+
|
| 169 |
+
# numbers = np.arange(.4, 3, .5)
|
| 170 |
+
# for number in numbers:
|
| 171 |
+
# noises.append(float(number))
|
| 172 |
+
|
| 173 |
+
i = 0
|
| 174 |
+
l2_dict = {noise: [] for noise in noises}
|
| 175 |
+
lpips_dict = {noise: [] for noise in noises}
|
| 176 |
+
snr_dict = {noise: [] for noise in noises}
|
| 177 |
+
|
| 178 |
+
cpus = jax.devices("cpu")
|
| 179 |
+
print(noises)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
for valid_images in dataset_valid:
|
| 183 |
+
print(i)
|
| 184 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 185 |
+
|
| 186 |
+
# valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
|
| 187 |
+
|
| 188 |
+
valid_reconstructed_images, noisy_reconstructed_images, std = model.reconstruction_noisy(valid_images)
|
| 189 |
+
print(std.mean())
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# valid_reconstructed_images, noisy_reconstructed_images = model.reconstruction_sampling(valid_images) # [devices, 8, 256, 256, 3]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# print(latents)
|
| 197 |
+
#Calculate MSE between valid and noisy.
|
| 198 |
+
if True:
|
| 199 |
+
for noise, decoded in zip(noises, noisy_reconstructed_images):
|
| 200 |
+
image, snr = decoded
|
| 201 |
+
snr = snr.mean()#So this gives us the snr for a given noise level. need to mean it..
|
| 202 |
+
snr_dict[noise].append(snr)
|
| 203 |
+
#So we put it into the noise list.
|
| 204 |
+
|
| 205 |
+
# print("snr", snr)
|
| 206 |
+
l2 = jnp.mean((valid_reconstructed_images - image) ** 2)
|
| 207 |
+
l2_cpu = jax.device_put(l2, cpus[0])
|
| 208 |
+
l2_dict[noise].append(l2_cpu)
|
| 209 |
+
|
| 210 |
+
#Need to change images back to -1,1
|
| 211 |
+
|
| 212 |
+
image = image * 2 - 1
|
| 213 |
+
valid_rescaled = valid_reconstructed_images * 2 -1
|
| 214 |
+
|
| 215 |
+
#1,2,256,256,3
|
| 216 |
+
image = jnp.swapaxes(image, 0, 4)
|
| 217 |
+
valid_rescaled = jnp.swapaxes(valid_rescaled, 0, 4)
|
| 218 |
+
|
| 219 |
+
image = jnp.swapaxes(image, 0, 1)
|
| 220 |
+
valid_rescaled = jnp.swapaxes(valid_rescaled, 0, 1)
|
| 221 |
+
|
| 222 |
+
image = jnp.squeeze(image)
|
| 223 |
+
valid_rescaled = jnp.squeeze(valid_rescaled)
|
| 224 |
+
|
| 225 |
+
#So here, we want to put them on CPU and delete the original
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
image_np = np.asarray(image)
|
| 229 |
+
image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
#Can be run only once if needd
|
| 233 |
+
valid_rescaled_np = np.asarray(valid_rescaled)
|
| 234 |
+
valid_rescaled_np_2 = torch.from_numpy(np.copy(valid_rescaled_np)).cuda()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
lpips_loss = loss_fn_alex(valid_rescaled_np_2, image_np_2)
|
| 239 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 240 |
+
lpips_dict[noise].append(lpips_cpu)
|
| 241 |
+
elif False:#Check l2 and lpips on our 2 images..
|
| 242 |
+
|
| 243 |
+
l2 = jnp.mean((valid_reconstructed_images - noisy_reconstructed_images) ** 2)
|
| 244 |
+
l2_cpu = jax.device_put(l2, cpus[0])
|
| 245 |
+
print("L2", l2_cpu)
|
| 246 |
+
|
| 247 |
+
#Need to change images back to -1,1
|
| 248 |
+
valid_reconstructed_images = valid_reconstructed_images * 2 - 1
|
| 249 |
+
noisy_reconstructed_images = noisy_reconstructed_images * 2 -1
|
| 250 |
+
|
| 251 |
+
#1,2,256,256,3
|
| 252 |
+
valid_reconstructed_images = jnp.swapaxes(valid_reconstructed_images, 0, 4)
|
| 253 |
+
noisy_reconstructed_images = jnp.swapaxes(noisy_reconstructed_images, 0, 4)
|
| 254 |
+
|
| 255 |
+
valid_reconstructed_images = jnp.swapaxes(valid_reconstructed_images, 0, 1)
|
| 256 |
+
noisy_reconstructed_images = jnp.swapaxes(noisy_reconstructed_images, 0, 1)
|
| 257 |
+
|
| 258 |
+
valid_reconstructed_images = jnp.squeeze(valid_reconstructed_images)
|
| 259 |
+
noisy_reconstructed_images = jnp.squeeze(noisy_reconstructed_images)
|
| 260 |
+
|
| 261 |
+
#So here, we want to put them on CPU and delete the original
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
image_np = np.asarray(valid_reconstructed_images)
|
| 265 |
+
image_np_2 = torch.from_numpy(np.copy(image_np)).cuda()
|
| 266 |
+
|
| 267 |
+
valid_rescaled_np = np.asarray(noisy_reconstructed_images)
|
| 268 |
+
valid_rescaled_np_2 = torch.from_numpy(np.copy(valid_rescaled_np)).cuda()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
lpips_loss = loss_fn_alex(valid_rescaled_np_2, image_np_2)
|
| 273 |
+
lpips_cpu = lpips_loss.detach().cpu().squeeze().mean()
|
| 274 |
+
print("Lpips", lpips_cpu)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if False:
|
| 278 |
+
image = valid_images[0,0,:,:,:]
|
| 279 |
+
image = (image * 255).astype(np.uint8)
|
| 280 |
+
img = Image.fromarray(image)
|
| 281 |
+
img.save("original" + str(i) + ".png")
|
| 282 |
+
|
| 283 |
+
image2 = valid_reconstructed_images[0,0,:,:,:]
|
| 284 |
+
image2 = (image2 * 255).astype(np.uint8)
|
| 285 |
+
image2 = np.array(image2)
|
| 286 |
+
image2 = Image.fromarray(image2)
|
| 287 |
+
image2.save("recon" + str(i) + ".png")
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
#Needs [0] if list
|
| 291 |
+
# image3 = noisy_reconstructed_images[0][0,0,:,:,:]
|
| 292 |
+
|
| 293 |
+
image3 = noisy_reconstructed_images[2][0,0,:,:,:]
|
| 294 |
+
image3 = (image3 * 255).astype(np.uint8)
|
| 295 |
+
image3 = np.array(image3)
|
| 296 |
+
image3 = Image.fromarray(image3)
|
| 297 |
+
image3.save("noisy_recon_0_" + str(i) + ".png")
|
| 298 |
+
|
| 299 |
+
image4 = noisy_reconstructed_images[-1][0,0,:,:,:]
|
| 300 |
+
image4 = (image4 * 255).astype(np.uint8)
|
| 301 |
+
image4 = np.array(image4)
|
| 302 |
+
image4 = Image.fromarray(image4)
|
| 303 |
+
image4.save("noisy_recon_last_" + str(i) + ".png")
|
| 304 |
+
|
| 305 |
+
# del valid_images
|
| 306 |
+
# del valid_reconstructed_images
|
| 307 |
+
# del noisy_reconstructed_images
|
| 308 |
+
|
| 309 |
+
# gc.collect()
|
| 310 |
+
# torch.cuda.empty_cache()
|
| 311 |
+
i += 1
|
| 312 |
+
#
|
| 313 |
+
if i == 50:
|
| 314 |
+
break
|
| 315 |
+
#Now we have our l2 set.
|
| 316 |
+
|
| 317 |
+
mean_l2_dict = {noise: jnp.mean(jnp.asarray(l2_values)) for noise, l2_values in l2_dict.items()}
|
| 318 |
+
std_l2_dict = {noise: jnp.std(jnp.asarray(l2_values)) for noise, l2_values in l2_dict.items()}
|
| 319 |
+
for noise, mean_l2 in mean_l2_dict.items():
|
| 320 |
+
print(f"Mean L2 for noise {noise}: {mean_l2}")
|
| 321 |
+
|
| 322 |
+
mean_lpips_dict = {noise: torch.mean(torch.tensor(lpips_values)) for noise, lpips_values in lpips_dict.items()}
|
| 323 |
+
std_lpips_dict = {noise: torch.std(torch.tensor(lpips_values)) for noise, lpips_values in lpips_dict.items()}
|
| 324 |
+
for noise, mean_lpips in mean_lpips_dict.items():
|
| 325 |
+
print(f"Mean Lpips for noise {noise}: {mean_lpips}")
|
| 326 |
+
|
| 327 |
+
mean_snr_dict = {noise: jnp.mean(jnp.asarray(snr_values)) for noise, snr_values in snr_dict.items()}
|
| 328 |
+
std_snr_dict = {noise: jnp.std(jnp.asarray(snr_values)) for noise, snr_values in snr_dict.items()}
|
| 329 |
+
for noise, mean_snr in mean_snr_dict.items():
|
| 330 |
+
print(f"Mean SNR for noise {noise}: {mean_snr}")
|
| 331 |
+
|
| 332 |
+
array = []
|
| 333 |
+
for noise, std in std_lpips_dict.items():
|
| 334 |
+
array.append(np.asarray(std).tolist())
|
| 335 |
+
|
| 336 |
+
print(array)
|
| 337 |
+
print(std_lpips_dict)
|
| 338 |
+
print(std_snr_dict)#This tells us the range of SNR for a given image/noise level, which... should be lower...?
|
| 339 |
+
|
| 340 |
+
#pl300
|
| 341 |
+
#it's noise to std of the lpips at that noise, but we need....
|
| 342 |
+
#So our points are mean of the lpips at a noise level
|
| 343 |
+
#Mean of the
|
| 344 |
+
''' PL300
|
| 345 |
+
{0.0: tensor(0.), 0.01: tensor(1.2151e-05), 0.02: tensor(4.2352e-05), 0.03: tensor(8.5722e-05), 0.04: tensor(0.0001), 0.05: tensor(0.0002), 0.06: tensor(0.0003), 0.07: tensor(0.0003), 0.08: tensor(0.0004), 0.09: tensor(0.0005), 0.1: tensor(0.0006), 0.11: tensor(0.0007), 0.12: tensor(0.0008), 0.13: tensor(0.0009), 0.14: tensor(0.0011), 0.15: tensor(0.0012), 0.16: tensor(0.0013), 0.17: tensor(0.0015), 0.18: tensor(0.0016), 0.19: tensor(0.0017), 0.2: tensor(0.0019), 0.21: tensor(0.0020), 0.22: tensor(0.0022), 0.23: tensor(0.0023), 0.24: tensor(0.0025), 0.25: tensor(0.0027), 0.26: tensor(0.0028), 0.27: tensor(0.0030), 0.28: tensor(0.0032), 0.29: tensor(0.0034), 0.3: tensor(0.0036), 0.31: tensor(0.0037), 0.32: tensor(0.0039), 0.33: tensor(0.0041), 0.34: tensor(0.0043), 0.35000000000000003: tensor(0.0045), 0.36: tensor(0.0047), 0.37: tensor(0.0050), 0.38: tensor(0.0052), 0.39: tensor(0.0054), 0.4: tensor(0.0056), 0.41000000000000003: tensor(0.0059), 0.42: tensor(0.0061), 0.43: tensor(0.0063), 0.44: tensor(0.0066), 0.45: tensor(0.0068), 0.46: tensor(0.0070), 0.47000000000000003: tensor(0.0073), 0.48: tensor(0.0075), 0.49: tensor(0.0078), 0.5: tensor(0.0080), 0.51: tensor(0.0083), 0.52: tensor(0.0086), 0.53: tensor(0.0088), 0.54: tensor(0.0091), 0.55: tensor(0.0094), 0.56: tensor(0.0097), 0.5700000000000001: tensor(0.0100), 0.58: tensor(0.0102), 0.59: tensor(0.0105), 0.6: tensor(0.0108), 0.61: tensor(0.0111), 0.62: tensor(0.0114), 0.63: tensor(0.0118), 0.64: tensor(0.0121), 0.65: tensor(0.0124), 0.66: tensor(0.0127), 0.67: tensor(0.0130), 0.68: tensor(0.0133), 0.6900000000000001: tensor(0.0136), 0.7000000000000001: tensor(0.0140), 0.71: tensor(0.0143), 0.72: tensor(0.0146), 0.73: tensor(0.0149), 0.74: tensor(0.0152), 0.75: tensor(0.0156), 0.76: tensor(0.0159), 0.77: tensor(0.0162), 0.78: tensor(0.0166), 0.79: tensor(0.0169), 0.8: tensor(0.0172), 0.81: tensor(0.0176), 0.8200000000000001: tensor(0.0179), 0.8300000000000001: tensor(0.0183), 0.84: tensor(0.0186), 0.85: tensor(0.0190), 0.86: tensor(0.0193), 0.87: tensor(0.0197), 0.88: tensor(0.0200), 0.89: tensor(0.0204), 0.9: tensor(0.0208), 0.91: tensor(0.0211), 0.92: tensor(0.0215), 0.93: tensor(0.0218), 0.9400000000000001: tensor(0.0222), 0.9500000000000001: tensor(0.0226), 0.96: tensor(0.0229), 0.97: tensor(0.0233), 0.98: tensor(0.0236), 0.99: tensor(0.0240)}
|
| 346 |
+
1e-4
|
| 347 |
+
{0.0: tensor(0.), 0.01: tensor(7.1912e-05), 0.02: tensor(0.0003), 0.03: tensor(0.0006), 0.04: tensor(0.0009), 0.05: tensor(0.0014), 0.06: tensor(0.0018), 0.07: tensor(0.0023), 0.08: tensor(0.0029), 0.09: tensor(0.0034), 0.1: tensor(0.0039), 0.11: tensor(0.0044), 0.12: tensor(0.0049), 0.13: tensor(0.0054), 0.14: tensor(0.0059), 0.15: tensor(0.0064), 0.16: tensor(0.0070), 0.17: tensor(0.0075), 0.18: tensor(0.0080), 0.19: tensor(0.0085), 0.2: tensor(0.0090), 0.21: tensor(0.0096), 0.22: tensor(0.0101), 0.23: tensor(0.0107), 0.24: tensor(0.0112), 0.25: tensor(0.0118), 0.26: tensor(0.0123), 0.27: tensor(0.0129), 0.28: tensor(0.0135), 0.29: tensor(0.0141), 0.3: tensor(0.0147), 0.31: tensor(0.0153), 0.32: tensor(0.0159), 0.33: tensor(0.0166), 0.34: tensor(0.0173), 0.35000000000000003: tensor(0.0180), 0.36: tensor(0.0187), 0.37: tensor(0.0194), 0.38: tensor(0.0201), 0.39: tensor(0.0207), 0.4: tensor(0.0214), 0.41000000000000003: tensor(0.0221), 0.42: tensor(0.0228), 0.43: tensor(0.0236), 0.44: tensor(0.0243), 0.45: tensor(0.0250), 0.46: tensor(0.0258), 0.47000000000000003: tensor(0.0266), 0.48: tensor(0.0274), 0.49: tensor(0.0282), 0.5: tensor(0.0290), 0.51: tensor(0.0298), 0.52: tensor(0.0305), 0.53: tensor(0.0313), 0.54: tensor(0.0321), 0.55: tensor(0.0328), 0.56: tensor(0.0336), 0.5700000000000001: tensor(0.0344), 0.58: tensor(0.0353), 0.59: tensor(0.0361), 0.6: tensor(0.0370), 0.61: tensor(0.0378), 0.62: tensor(0.0386), 0.63: tensor(0.0395), 0.64: tensor(0.0403), 0.65: tensor(0.0410), 0.66: tensor(0.0417), 0.67: tensor(0.0424), 0.68: tensor(0.0430), 0.6900000000000001: tensor(0.0436), 0.7000000000000001: tensor(0.0442), 0.71: tensor(0.0448), 0.72: tensor(0.0454), 0.73: tensor(0.0459), 0.74: tensor(0.0464), 0.75: tensor(0.0468), 0.76: tensor(0.0472), 0.77: tensor(0.0477), 0.78: tensor(0.0480), 0.79: tensor(0.0484), 0.8: tensor(0.0488), 0.81: tensor(0.0493), 0.8200000000000001: tensor(0.0497), 0.8300000000000001: tensor(0.0501), 0.84: tensor(0.0506), 0.85: tensor(0.0510), 0.86: tensor(0.0513), 0.87: tensor(0.0516), 0.88: tensor(0.0519), 0.89: tensor(0.0521), 0.9: tensor(0.0522), 0.91: tensor(0.0524), 0.92: tensor(0.0525), 0.93: tensor(0.0526), 0.9400000000000001: tensor(0.0526), 0.9500000000000001: tensor(0.0526), 0.96: tensor(0.0526), 0.97: tensor(0.0526), 0.98: tensor(0.0525), 0.99: tensor(0.0525)}
|
| 348 |
+
|
| 349 |
+
'''
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# for (noise, lpips), (noise_2, snr) in zip(mean_lpips_dict.items(), mean_snr_dict.items()):
|
| 353 |
+
# print(noise, snr)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
#So here we want to print out our x, which is the mean_snr, and our y, which is the mean noise
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# images.append((valid_reconstructed_images*255).astype(np.uint8))
|
| 360 |
+
|
| 361 |
+
if __name__ == '__main__':
|
| 362 |
+
app.run(main)
|
f16c16/train.py
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # For debugging
|
| 2 |
+
from localutils.debugger import enable_debug
|
| 3 |
+
enable_debug()
|
| 4 |
+
except ImportError:
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
import flax.linen as nn
|
| 8 |
+
import jax.numpy as jnp
|
| 9 |
+
from absl import app, flags
|
| 10 |
+
from functools import partial
|
| 11 |
+
import numpy as np
|
| 12 |
+
import tqdm
|
| 13 |
+
import jax
|
| 14 |
+
import jax.numpy as jnp
|
| 15 |
+
import flax
|
| 16 |
+
import optax
|
| 17 |
+
import wandb
|
| 18 |
+
from ml_collections import config_flags
|
| 19 |
+
import ml_collections
|
| 20 |
+
import tensorflow_datasets as tfds
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
tf.config.set_visible_devices([], "GPU")
|
| 23 |
+
tf.config.set_visible_devices([], "TPU")
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
from typing import Any
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
from utils.wandb import setup_wandb, default_wandb_config
|
| 29 |
+
from utils.train_state import TrainState, target_update
|
| 30 |
+
from utils.checkpoint import Checkpoint
|
| 31 |
+
from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model
|
| 32 |
+
from utils.fid import get_fid_network, fid_from_stats
|
| 33 |
+
from models.vqvae import VQVAE
|
| 34 |
+
from models.discriminator import Discriminator
|
| 35 |
+
|
| 36 |
+
FLAGS = flags.FLAGS
|
| 37 |
+
flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
|
| 38 |
+
flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).')
|
| 39 |
+
flags.DEFINE_string('load_dir', "./checkpointbest.tmp.tmp" , 'Load dir (if not None, load params from here).')
|
| 40 |
+
flags.DEFINE_integer('seed', 0, 'Random seed.')
|
| 41 |
+
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
|
| 42 |
+
flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.')
|
| 43 |
+
flags.DEFINE_integer('save_interval', 1000, 'Save interval.')
|
| 44 |
+
flags.DEFINE_integer('batch_size', 64, 'Total Batch size.')
|
| 45 |
+
flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.')
|
| 46 |
+
|
| 47 |
+
model_config = ml_collections.ConfigDict({
|
| 48 |
+
# VQVAE
|
| 49 |
+
'lr': 0.0001,
|
| 50 |
+
'beta1': 0.0,#.5
|
| 51 |
+
'beta2': 0.99,#.9
|
| 52 |
+
'lr_warmup_steps': 4000,
|
| 53 |
+
'lr_decay_steps': 1_000_000, #They use 'lambdalr'
|
| 54 |
+
'filters': 128,
|
| 55 |
+
'num_res_blocks': 2,
|
| 56 |
+
'channel_multipliers': (1, 1, 2, 2, 4),
|
| 57 |
+
'embedding_dim': 16,
|
| 58 |
+
'norm_type': 'GN',
|
| 59 |
+
'weight_decay': 0.05,#None maybe?
|
| 60 |
+
'clip_gradient': 1.0,
|
| 61 |
+
'l2_loss_weight': 1.0,#They use L1 actually
|
| 62 |
+
'eps_update_rate': 0.9999,
|
| 63 |
+
# Quantizer
|
| 64 |
+
'quantizer_type': 'ae', # or 'fsq', 'kl'
|
| 65 |
+
# Quantizer (VQ)
|
| 66 |
+
'quantizer_loss_ratio': 1,
|
| 67 |
+
'codebook_size': 1024,
|
| 68 |
+
'entropy_loss_ratio': 0.1,
|
| 69 |
+
'entropy_loss_type': 'softmax',
|
| 70 |
+
'entropy_temperature': 0.01,
|
| 71 |
+
'commitment_cost': 0.25,
|
| 72 |
+
# Quantizer (FSQ)
|
| 73 |
+
'fsq_levels': 5, # Bins per dimension.
|
| 74 |
+
# Quantizer (KL)
|
| 75 |
+
'kl_weight': 0.000001,#They use 1e-6 on their stuff LUL. .001 is the default
|
| 76 |
+
# GAN
|
| 77 |
+
'g_adversarial_loss_weight': 0.5,
|
| 78 |
+
'g_grad_penalty_cost': 10,
|
| 79 |
+
'perceptual_loss_weight': 0.5,
|
| 80 |
+
'gan_warmup_steps': 100000,#50000, #Temporary extra time
|
| 81 |
+
"pl_decay": 0.01,
|
| 82 |
+
"pl_weight": -1,
|
| 83 |
+
'MMD_weight': 1.0
|
| 84 |
+
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
wandb_config = default_wandb_config()
|
| 88 |
+
wandb_config.update({
|
| 89 |
+
'project': 'vqvae',
|
| 90 |
+
'name': 'vqvae_{dataset_name}',
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
|
| 94 |
+
config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
|
| 95 |
+
|
| 96 |
+
##############################################
|
| 97 |
+
## Model Definitions.
|
| 98 |
+
##############################################
|
| 99 |
+
|
| 100 |
+
@jax.vmap
|
| 101 |
+
def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
|
| 102 |
+
"""https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py
|
| 103 |
+
"""
|
| 104 |
+
zeros = jnp.zeros_like(logits, dtype=logits.dtype)
|
| 105 |
+
condition = (logits >= zeros)
|
| 106 |
+
relu_logits = jnp.where(condition, logits, zeros)
|
| 107 |
+
neg_abs_logits = jnp.where(condition, -logits, logits)
|
| 108 |
+
return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
|
| 109 |
+
|
| 110 |
+
class VQGANModel(flax.struct.PyTreeNode):
|
| 111 |
+
rng: Any
|
| 112 |
+
config: dict = flax.struct.field(pytree_node=False)
|
| 113 |
+
vqvae: TrainState
|
| 114 |
+
vqvae_eps: TrainState
|
| 115 |
+
discriminator: TrainState
|
| 116 |
+
|
| 117 |
+
# Train G and D.
|
| 118 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 119 |
+
def update(self, images, pmap_axis='data'):
|
| 120 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 121 |
+
|
| 122 |
+
resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy')
|
| 123 |
+
|
| 124 |
+
is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32)
|
| 125 |
+
#Maybe only start GAN way later on?
|
| 126 |
+
|
| 127 |
+
def loss_fn(params_vqvae, params_disc):
|
| 128 |
+
|
| 129 |
+
def path_reg_loss(latents, targets):#let's have pl_mean be in our self.config
|
| 130 |
+
#1/2 should be our spatial dimensions.
|
| 131 |
+
|
| 132 |
+
latents = latents[0:2, :, :, :]
|
| 133 |
+
targets = targets[0:2, :, :, :]
|
| 134 |
+
pl_noise = jax.random.normal(new_rng, shape = targets.shape) / jnp.sqrt(targets.shape[1] * targets.shape[2])
|
| 135 |
+
def grad_sum(latents, pl_noise):#So we don't have access to the actual decode method
|
| 136 |
+
#return jnp.sum(self.vqvae.decode(latents))
|
| 137 |
+
|
| 138 |
+
#I am not sure if this makes any sense whatsoever tbh
|
| 139 |
+
my_sum = self.vqvae(latents, params=params_vqvae, method="decode", rngs={'noise': curr_key})*pl_noise
|
| 140 |
+
print("Decode shape", my_sum.shape)
|
| 141 |
+
return jnp.sum(my_sum)
|
| 142 |
+
|
| 143 |
+
decode_grad_fn = jax.grad(grad_sum)
|
| 144 |
+
pl_grads = decode_grad_fn(latents, pl_noise)
|
| 145 |
+
pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis = [2,3]), axis = 1))
|
| 146 |
+
#pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=3))
|
| 147 |
+
|
| 148 |
+
pl_mean = self.vqvae.pl_mean + self.config.pl_decay * (jnp.mean(pl_lengths) - self.vqvae.pl_mean)
|
| 149 |
+
pl_penalty = jnp.square(pl_lengths - pl_mean)
|
| 150 |
+
loss = jnp.mean(pl_penalty)
|
| 151 |
+
return loss, pl_mean
|
| 152 |
+
|
| 153 |
+
if self.config.pl_weight != -1:
|
| 154 |
+
smooth_loss, pl_mean = path_reg_loss(result_dict["latents"], reconstructed_images)
|
| 155 |
+
# self.vqvae.replace(pl_mean = pl_mean)
|
| 156 |
+
#We need to update pl mean in self.vqvae
|
| 157 |
+
|
| 158 |
+
# Reconstruct image
|
| 159 |
+
reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key})
|
| 160 |
+
print("Reconstructed images shape", reconstructed_images.shape)
|
| 161 |
+
print("Input images shape", images.shape)
|
| 162 |
+
assert reconstructed_images.shape == images.shape
|
| 163 |
+
|
| 164 |
+
# GAN loss on VQVAE output.
|
| 165 |
+
discriminator_fn = lambda x: self.discriminator(x, params=params_disc)
|
| 166 |
+
real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False)
|
| 167 |
+
gradient = vjp_fn(jnp.ones_like(real_logit))[0] # Gradient of discriminator output wrt. real images.
|
| 168 |
+
gradient = gradient.reshape((images.shape[0], -1))
|
| 169 |
+
gradient = jnp.asarray(gradient, jnp.float32)
|
| 170 |
+
penalty = jnp.sum(jnp.square(gradient), axis=-1)
|
| 171 |
+
penalty = jnp.mean(penalty) # Gradient penalty for training D.
|
| 172 |
+
fake_logit = discriminator_fn(reconstructed_images)
|
| 173 |
+
d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean()
|
| 174 |
+
d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean()
|
| 175 |
+
loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost'])
|
| 176 |
+
|
| 177 |
+
d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean()
|
| 178 |
+
d_loss_for_vae = d_loss_for_vae * is_gan_training
|
| 179 |
+
|
| 180 |
+
real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images)
|
| 181 |
+
fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images)
|
| 182 |
+
perceptual_loss = jnp.mean((real_pools - fake_pools)**2)
|
| 183 |
+
|
| 184 |
+
l2_loss = jnp.mean((reconstructed_images - images) ** 2)
|
| 185 |
+
quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0
|
| 186 |
+
if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two":
|
| 187 |
+
quantizer_loss = quantizer_loss * self.config['kl_weight']
|
| 188 |
+
elif self.config["quantizer_type"] == "MMD":
|
| 189 |
+
quantizer_loss = quantizer_loss * self.config['MMD_weight']
|
| 190 |
+
loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \
|
| 191 |
+
+ (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
|
| 192 |
+
+ (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
|
| 193 |
+
+ (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
|
| 194 |
+
#+ (smooth_loss * FLAGS.model['pl_weight'] )
|
| 195 |
+
codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
|
| 196 |
+
|
| 197 |
+
return_dict = {
|
| 198 |
+
'loss_vae': loss_vae,
|
| 199 |
+
'loss_d': loss_d,
|
| 200 |
+
'l2_loss': l2_loss,
|
| 201 |
+
'd_loss_for_vae': d_loss_for_vae,
|
| 202 |
+
'perceptual_loss': perceptual_loss,
|
| 203 |
+
'quantizer_loss': quantizer_loss,
|
| 204 |
+
'codebook_usage': codebook_usage,
|
| 205 |
+
#'pl_loss': smooth_loss,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
if self.config["pl_weight"] != -1:
|
| 209 |
+
loss_vae += (smooth_loss * FLAGS.model["pl_weight"])
|
| 210 |
+
return_dict["pl_mean"] = pl_mean
|
| 211 |
+
return_dict["smooth_loss"] = smooth_loss
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
return (loss_vae, loss_d), return_dict
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# This is a fancy way to do 'jax.grad' so (loss_vae, params_vqvae) and (loss_d, params_disc) are differentiated.
|
| 218 |
+
_, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True)
|
| 219 |
+
vae_grads, _ = grad_fn((1., 0.))
|
| 220 |
+
_, d_grads = grad_fn((0., 1.))
|
| 221 |
+
|
| 222 |
+
vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis)
|
| 223 |
+
d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis)
|
| 224 |
+
d_grads = jax.tree.map(lambda x: x * is_gan_training, d_grads)
|
| 225 |
+
|
| 226 |
+
info = jax.lax.pmean(info, axis_name=pmap_axis)
|
| 227 |
+
if self.config['quantizer_type'] == 'fsq':
|
| 228 |
+
info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1]
|
| 229 |
+
|
| 230 |
+
updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params)
|
| 231 |
+
new_params = optax.apply_updates(self.vqvae.params, updates)
|
| 232 |
+
|
| 233 |
+
if self.config["pl_weight"] != -1:
|
| 234 |
+
new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state, pl_mean=info["pl_mean"])
|
| 235 |
+
else:
|
| 236 |
+
new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state)
|
| 237 |
+
|
| 238 |
+
updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params)
|
| 239 |
+
new_params = optax.apply_updates(self.discriminator.params, updates)
|
| 240 |
+
new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state)
|
| 241 |
+
|
| 242 |
+
info['grad_norm_vae'] = optax.global_norm(vae_grads)
|
| 243 |
+
info['grad_norm_d'] = optax.global_norm(d_grads)
|
| 244 |
+
info['update_norm'] = optax.global_norm(updates)
|
| 245 |
+
info['param_norm'] = optax.global_norm(new_params)
|
| 246 |
+
info['is_gan_training'] = is_gan_training
|
| 247 |
+
|
| 248 |
+
new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate'])
|
| 249 |
+
|
| 250 |
+
new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator)
|
| 251 |
+
return new_model, info
|
| 252 |
+
|
| 253 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 254 |
+
def reconstruction(self, images, pmap_axis='data', sampling = True):
|
| 255 |
+
if not sampling:
|
| 256 |
+
reconstructed_images, _ = self.vqvae_eps(images)
|
| 257 |
+
else:#Not sure what our theoretical sampling mode does
|
| 258 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 259 |
+
reconstructed_images, _ = self.vqvae_eps(images, rngs={'noise': curr_key})
|
| 260 |
+
|
| 261 |
+
reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
|
| 262 |
+
return reconstructed_images
|
| 263 |
+
|
| 264 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 265 |
+
def reconstruction_sampling(self, images, pmap_axis='data'):
|
| 266 |
+
|
| 267 |
+
reconstructed_images_determistic, _ = self.vqvae_eps(images)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 271 |
+
reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
|
| 272 |
+
|
| 273 |
+
#We don't need to return the result dict.
|
| 274 |
+
reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
|
| 275 |
+
reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
|
| 276 |
+
|
| 277 |
+
return reconstructed_images_determistic, reconstructed_images_sample
|
| 278 |
+
|
| 279 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 280 |
+
def reconstruction_interpolation(self, images, pmap_axis='data'):
|
| 281 |
+
|
| 282 |
+
#So we *have* our two images. We are going to linearly interpolate between them in... latent space
|
| 283 |
+
#But also in image space?
|
| 284 |
+
#Sure, why not
|
| 285 |
+
reconstructed_images_determistic, _ = self.vqvae_eps(images)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 289 |
+
reconstructed_images_sample, result_dict = self.vqvae(images, rngs={'noise': curr_key})
|
| 290 |
+
|
| 291 |
+
#We don't need to return the result dict.
|
| 292 |
+
reconstructed_images_determistic = jnp.clip(reconstructed_images_determistic, 0, 1)
|
| 293 |
+
reconstructed_images_sample = jnp.clip(reconstructed_images_sample, 0, 1)
|
| 294 |
+
|
| 295 |
+
return reconstructed_images_determistic, reconstructed_images_sample
|
| 296 |
+
|
| 297 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 298 |
+
def get_latent(self, images, pmap_axis='data'):
|
| 299 |
+
|
| 300 |
+
#We do *not* add the noise ourselves, just save it.
|
| 301 |
+
latents, result_dict = self.vqvae_eps(images, params=self.vqvae_eps.params, method="encode")
|
| 302 |
+
|
| 303 |
+
# reconstructed_images, result_dict_two = self.vqvae_eps(images)
|
| 304 |
+
# reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
|
| 305 |
+
#
|
| 306 |
+
#
|
| 307 |
+
# decoded = self.vqvae_eps(latents, params=self.vqvae_eps.params, method="decode")
|
| 308 |
+
# decoded = jnp.clip(decoded, 0, 1)
|
| 309 |
+
|
| 310 |
+
#reconstructed images should be correct
|
| 311 |
+
return latents, result_dict#, result_dict_two, reconstructed_images, decoded
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 315 |
+
def reconstruction_noisy(self, images, pmap_axis='data'):
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
noises = []
|
| 319 |
+
numbers = np.arange(0.00, 1.0, 0.01)
|
| 320 |
+
|
| 321 |
+
for number in numbers:
|
| 322 |
+
noises.append(float(number))
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
#So 3 things to try out.
|
| 326 |
+
#One is normalize variance of the latents before adding noise, start there
|
| 327 |
+
#The second is plot snr instead.
|
| 328 |
+
#snr = var(latent)/var(noise)
|
| 329 |
+
#var is std^2
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
#This return the full reconstruction, but *also* the latents.
|
| 333 |
+
reconstructed_images, result_dict = self.vqvae_eps(images)
|
| 334 |
+
latents = result_dict["latents"]
|
| 335 |
+
std = result_dict["std"]
|
| 336 |
+
#We need to check the latnes std
|
| 337 |
+
|
| 338 |
+
#Get rng for creating noise.
|
| 339 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 340 |
+
|
| 341 |
+
decode = []
|
| 342 |
+
latent_std = latents.std(axis = [1,2,3]).reshape(-1,1,1,1)
|
| 343 |
+
|
| 344 |
+
for mult in noises:
|
| 345 |
+
|
| 346 |
+
noise = jax.random.normal(curr_key, latents.shape)
|
| 347 |
+
#Combine noise with latents
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if True:
|
| 351 |
+
latent_var = latent_std ** 2
|
| 352 |
+
noise_std = mult*noise.std()#noise std should be around 1
|
| 353 |
+
noise_var = mult ** 2
|
| 354 |
+
if noise_var == 0:#If noise is zero, then instead denominator is it's variance
|
| 355 |
+
snr = 0
|
| 356 |
+
else:
|
| 357 |
+
snr = latent_var/noise_var
|
| 358 |
+
|
| 359 |
+
temp_latents = latents + noise*mult
|
| 360 |
+
|
| 361 |
+
#vae_eps is the determinstic one.
|
| 362 |
+
decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
|
| 363 |
+
decoded = jnp.clip(decoded, 0, 1)
|
| 364 |
+
if True:
|
| 365 |
+
decode.append((decoded, snr))
|
| 366 |
+
|
| 367 |
+
reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
|
| 368 |
+
return reconstructed_images, decode, std
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 372 |
+
def reconstruction_ppl(self, images, pmap_axis='data'):
|
| 373 |
+
|
| 374 |
+
epsilon = .0001
|
| 375 |
+
reconstructed_images, result_dict = self.vqvae_eps(images)
|
| 376 |
+
latents = result_dict["latents"]
|
| 377 |
+
std = result_dict["std"]
|
| 378 |
+
|
| 379 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 380 |
+
|
| 381 |
+
noise = jax.random.normal(curr_key, latents.shape)
|
| 382 |
+
#Combine noise with latents
|
| 383 |
+
|
| 384 |
+
temp_latents = latents + noise * epsilon
|
| 385 |
+
# print(temp_latents.shape)#Probably should be like, bs, 32,32,4
|
| 386 |
+
# exit()
|
| 387 |
+
decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
|
| 388 |
+
decoded = jnp.clip(decoded, 0, 1)
|
| 389 |
+
|
| 390 |
+
reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
|
| 391 |
+
return reconstructed_images, decoded, std, latents
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
#So this method simply will return the gradient/jacobian
|
| 395 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 396 |
+
def reconstruction_grad_distance(self, images, pmap_axis='data'):
|
| 397 |
+
#We want to try and identify C.
|
| 398 |
+
#C means that when we change our latents by a specific and small number X, our outputs change by C*X also.
|
| 399 |
+
#We want to capture all of the C, and see what their STD is.
|
| 400 |
+
pass
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 404 |
+
def reconstruction_ppl_two(self, images, pmap_axis='data'):
|
| 405 |
+
|
| 406 |
+
epsilon = .0001
|
| 407 |
+
reconstructed_images, result_dict = self.vqvae_eps(images)
|
| 408 |
+
latents = result_dict["latents"]
|
| 409 |
+
std = result_dict["std"]
|
| 410 |
+
|
| 411 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 412 |
+
|
| 413 |
+
noise = jax.random.normal(curr_key, latents.shape)
|
| 414 |
+
#Combine noise with latents
|
| 415 |
+
|
| 416 |
+
temp_latents = latents + noise/2 * epsilon
|
| 417 |
+
|
| 418 |
+
decoded = self.vqvae_eps(temp_latents, params=self.vqvae_eps.params, method="decode")
|
| 419 |
+
decoded = jnp.clip(decoded, 0, 1)
|
| 420 |
+
|
| 421 |
+
temp_latents_2 = latents + -1 * noise/2 * epsilon
|
| 422 |
+
|
| 423 |
+
decoded_2 = self.vqvae_eps(temp_latents_2, params=self.vqvae_eps.params, method="decode")
|
| 424 |
+
decoded_2 = jnp.clip(decoded_2, 0, 1)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
|
| 428 |
+
return reconstructed_images, decoded, std, latents, decoded_2
|
| 429 |
+
|
| 430 |
+
@partial(jax.pmap, axis_name='data', in_axes=(0, 0))
|
| 431 |
+
def reconstruction_ppl_image(self, images, pmap_axis='data'):
|
| 432 |
+
|
| 433 |
+
epsilon = .0001
|
| 434 |
+
new_rng, curr_key = jax.random.split(self.rng, 2)
|
| 435 |
+
|
| 436 |
+
reconstructed_images, result_dict = self.vqvae_eps(images)
|
| 437 |
+
latents = result_dict["latents"]
|
| 438 |
+
std = result_dict["std"]
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
noise = jax.random.normal(curr_key, images.shape)
|
| 442 |
+
images = images + noise * epsilon
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
decoded, result_dict_2 = self.vqvae_eps(images)
|
| 446 |
+
decoded = jnp.clip(decoded, 0, 1)
|
| 447 |
+
|
| 448 |
+
latents_noisy = result_dict_2["latents"]
|
| 449 |
+
std_noisy = result_dict_2["std"]
|
| 450 |
+
|
| 451 |
+
reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
|
| 452 |
+
return reconstructed_images, decoded, std, latents, std_noisy, latents_noisy
|
| 453 |
+
|
| 454 |
+
##############################################
|
| 455 |
+
## Training Code.
|
| 456 |
+
##############################################
|
| 457 |
+
def main(_):
|
| 458 |
+
np.random.seed(FLAGS.seed)
|
| 459 |
+
print("Using devices", jax.local_devices())
|
| 460 |
+
device_count = len(jax.local_devices())
|
| 461 |
+
global_device_count = jax.device_count()
|
| 462 |
+
local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
|
| 463 |
+
print("Device count", device_count)
|
| 464 |
+
print("Global device count", global_device_count)
|
| 465 |
+
print("Global Batch: ", FLAGS.batch_size)
|
| 466 |
+
print("Node Batch: ", local_batch_size)
|
| 467 |
+
print("Device Batch:", local_batch_size // device_count)
|
| 468 |
+
|
| 469 |
+
# Create wandb logger
|
| 470 |
+
if jax.process_index() == 0:
|
| 471 |
+
setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
|
| 472 |
+
|
| 473 |
+
def get_dataset(is_train):
|
| 474 |
+
if 'imagenet' in FLAGS.dataset_name:
|
| 475 |
+
def deserialization_fn(data):
|
| 476 |
+
image = data['image']
|
| 477 |
+
min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 478 |
+
image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
|
| 479 |
+
if 'imagenet256' in FLAGS.dataset_name:
|
| 480 |
+
image = tf.image.resize(image, (256, 256))
|
| 481 |
+
elif 'imagenet128' in FLAGS.dataset_name:
|
| 482 |
+
image = tf.image.resize(image, (128, 128))
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 485 |
+
if is_train:
|
| 486 |
+
image = tf.image.random_flip_left_right(image)
|
| 487 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 488 |
+
return image
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
|
| 492 |
+
print(split)
|
| 493 |
+
dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm")
|
| 494 |
+
dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
| 495 |
+
dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
|
| 496 |
+
dataset = dataset.repeat()
|
| 497 |
+
dataset = dataset.batch(local_batch_size)
|
| 498 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 499 |
+
dataset = tfds.as_numpy(dataset)
|
| 500 |
+
dataset = iter(dataset)
|
| 501 |
+
return dataset
|
| 502 |
+
else:
|
| 503 |
+
raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
|
| 504 |
+
|
| 505 |
+
dataset = get_dataset(is_train=True)
|
| 506 |
+
dataset_valid = get_dataset(is_train=False)
|
| 507 |
+
example_obs = next(dataset)[:1]
|
| 508 |
+
|
| 509 |
+
get_fid_activations = get_fid_network()
|
| 510 |
+
if not os.path.exists('./data/imagenet256_fidstats_openai.npz'):
|
| 511 |
+
raise ValueError("Please download the FID stats file! See the README.")
|
| 512 |
+
truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
|
| 513 |
+
#truth_fid_stats = np.load("./base_stats.npz")
|
| 514 |
+
|
| 515 |
+
rng = jax.random.PRNGKey(FLAGS.seed)
|
| 516 |
+
rng, param_key = jax.random.split(rng)
|
| 517 |
+
print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
|
| 518 |
+
|
| 519 |
+
###################################
|
| 520 |
+
# Creating Model and put on devices.
|
| 521 |
+
###################################
|
| 522 |
+
FLAGS.model.image_channels = example_obs.shape[-1]
|
| 523 |
+
FLAGS.model.image_size = example_obs.shape[1]
|
| 524 |
+
vqvae_def = VQVAE(FLAGS.model, train=True)
|
| 525 |
+
vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
|
| 526 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 527 |
+
vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
|
| 528 |
+
vqvae_def_eps = VQVAE(FLAGS.model, train=False)
|
| 529 |
+
vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
|
| 530 |
+
print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
|
| 531 |
+
|
| 532 |
+
discriminator_def = Discriminator(FLAGS.model)
|
| 533 |
+
discriminator_params = discriminator_def.init(param_key, example_obs)['params']
|
| 534 |
+
tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
|
| 535 |
+
discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
|
| 536 |
+
print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
|
| 537 |
+
|
| 538 |
+
model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
|
| 539 |
+
|
| 540 |
+
if FLAGS.load_dir is not None:
|
| 541 |
+
try:
|
| 542 |
+
cp = Checkpoint(FLAGS.load_dir)
|
| 543 |
+
model = cp.load_model(model)
|
| 544 |
+
print("Loaded model with step", model.vqvae.step)
|
| 545 |
+
except:
|
| 546 |
+
print("Random init")
|
| 547 |
+
else:
|
| 548 |
+
print("Random init")
|
| 549 |
+
|
| 550 |
+
model = flax.jax_utils.replicate(model, devices=jax.local_devices())
|
| 551 |
+
jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
|
| 552 |
+
|
| 553 |
+
###################################
|
| 554 |
+
# Train Loop
|
| 555 |
+
###################################
|
| 556 |
+
|
| 557 |
+
best_fid = 100000
|
| 558 |
+
|
| 559 |
+
for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
|
| 560 |
+
smoothing=0.1,
|
| 561 |
+
dynamic_ncols=True):
|
| 562 |
+
|
| 563 |
+
batch_images = next(dataset)
|
| 564 |
+
batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 565 |
+
|
| 566 |
+
model, update_info = model.update(batch_images)
|
| 567 |
+
|
| 568 |
+
if i % FLAGS.log_interval == 0:
|
| 569 |
+
update_info = jax.tree.map(lambda x: x.mean(), update_info)
|
| 570 |
+
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
|
| 571 |
+
if jax.process_index() == 0:
|
| 572 |
+
wandb.log(train_metrics, step=i)
|
| 573 |
+
|
| 574 |
+
if i % FLAGS.eval_interval == 0:
|
| 575 |
+
# Print some images
|
| 576 |
+
reconstructed_images = model.reconstruction(batch_images) # [devices, 8, 256, 256, 3]
|
| 577 |
+
valid_images = next(dataset_valid)
|
| 578 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 579 |
+
valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
|
| 580 |
+
|
| 581 |
+
if jax.process_index() == 0:
|
| 582 |
+
wandb.log({'batch_image_mean': batch_images.mean()}, step=i)
|
| 583 |
+
wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i)
|
| 584 |
+
wandb.log({'batch_image_std': batch_images.std()}, step=i)
|
| 585 |
+
wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i)
|
| 586 |
+
|
| 587 |
+
# plot comparison witah matplotlib. put each reconstruction side by side.
|
| 588 |
+
fig, axs = plt.subplots(2, 8, figsize=(30, 15))
|
| 589 |
+
#print("batch shape", batch_images.shape)#batch shape (4, 32, 256, 256, 3) #THE FIRST SHAPE IS DEVICES
|
| 590 |
+
#print("recon shape", reconstructed_images.shape)#it's all the same lol
|
| 591 |
+
#print("valid shape", valid_images.shape)
|
| 592 |
+
#it seems to be made for 8 device, aka tpuv3 instead
|
| 593 |
+
for j in range(4):#fuck it
|
| 594 |
+
axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1)
|
| 595 |
+
axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1)
|
| 596 |
+
wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
|
| 597 |
+
plt.close(fig)
|
| 598 |
+
fig, axs = plt.subplots(2, 8, figsize=(30, 15))
|
| 599 |
+
for j in range(4):
|
| 600 |
+
axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
|
| 601 |
+
axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
|
| 602 |
+
wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i)
|
| 603 |
+
plt.close(fig)
|
| 604 |
+
|
| 605 |
+
# Validation Losses
|
| 606 |
+
_, valid_update_info = model.update(valid_images)
|
| 607 |
+
valid_update_info = jax.tree.map(lambda x: x.mean(), valid_update_info)
|
| 608 |
+
valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()}
|
| 609 |
+
if jax.process_index() == 0:
|
| 610 |
+
wandb.log(valid_metrics, step=i)
|
| 611 |
+
|
| 612 |
+
# FID measurement.
|
| 613 |
+
activations = []
|
| 614 |
+
activations2 = []
|
| 615 |
+
for _ in range(780):#This is apprximately 40k
|
| 616 |
+
valid_images = next(dataset_valid)
|
| 617 |
+
valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
|
| 618 |
+
valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
|
| 619 |
+
|
| 620 |
+
valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
|
| 621 |
+
method='bilinear', antialias=False)
|
| 622 |
+
valid_reconstructed_images = 2 * valid_reconstructed_images - 1
|
| 623 |
+
activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
#Only needed when we save
|
| 627 |
+
#valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
|
| 628 |
+
#method='bilinear', antialias=False)
|
| 629 |
+
#valid_reconstructed_images = 2 * valid_reconstructed_images - 1
|
| 630 |
+
#activations2 += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
# TODO: use all_gather to get activations from all devices.
|
| 634 |
+
#This seems to be FID with only 64 images?
|
| 635 |
+
activations = np.concatenate(activations, axis=0)
|
| 636 |
+
activations = activations.reshape((-1, activations.shape[-1]))
|
| 637 |
+
|
| 638 |
+
# activations2 = np.concatenate(activations2, axis = 0)
|
| 639 |
+
# activations2 = activations2.reshape((-1, activations2.shape[-1]))
|
| 640 |
+
|
| 641 |
+
print("doing this much FID", activations.shape)#8192, 2048 should be 2048 items then I guess
|
| 642 |
+
mu1 = np.mean(activations, axis=0)
|
| 643 |
+
sigma1 = np.cov(activations, rowvar=False)
|
| 644 |
+
fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
|
| 645 |
+
|
| 646 |
+
# mu2 = np.mean(activations2, axis = 0)
|
| 647 |
+
# sigma2 = np.cov(activations2, rowvar = False)
|
| 648 |
+
|
| 649 |
+
#save mu2 and sigma2
|
| 650 |
+
#And then exit for now
|
| 651 |
+
# np.savez("base.npz", mu = mu2, sigma = sigma2)
|
| 652 |
+
# exit()
|
| 653 |
+
|
| 654 |
+
#Used with loading base
|
| 655 |
+
#fid = fid_from_stats(mu1, sigma1, mu2, sigma2)
|
| 656 |
+
|
| 657 |
+
if jax.process_index() == 0:
|
| 658 |
+
wandb.log({'validation/fid': fid}, step=i)
|
| 659 |
+
print("validation FID at step", i, fid)
|
| 660 |
+
#Then if fid is smaller than previous best FID, save new FID
|
| 661 |
+
if fid < best_fid:
|
| 662 |
+
model_single = flax.jax_utils.unreplicate(model)
|
| 663 |
+
cp = Checkpoint(FLAGS.save_dir + "best.tmp")
|
| 664 |
+
cp.set_model(model_single)
|
| 665 |
+
cp.save()
|
| 666 |
+
best_fid = fid
|
| 667 |
+
|
| 668 |
+
if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None):
|
| 669 |
+
if jax.process_index() == 0:
|
| 670 |
+
model_single = flax.jax_utils.unreplicate(model)
|
| 671 |
+
cp = Checkpoint(FLAGS.save_dir)
|
| 672 |
+
cp.set_model(model_single)
|
| 673 |
+
cp.save()
|
| 674 |
+
|
| 675 |
+
if __name__ == '__main__':
|
| 676 |
+
app.run(main)
|