| import jax |
| import jax.experimental |
| import wandb |
| import jax.numpy as jnp |
| import numpy as np |
| import tqdm |
| import matplotlib.pyplot as plt |
| import os |
| from functools import partial |
| from absl import app, flags |
|
|
| flags.DEFINE_integer('inference_timesteps', 128, 'Number of timesteps for inference.') |
| flags.DEFINE_integer('inference_generations', 50000, 'Number of generations for inference.') |
| flags.DEFINE_float('inference_cfg_scale', 1.5, 'CFG scale for inference.') |
|
|
| classes = np.load("classes.npz") |
| global_mean = jnp.load("global_mean.npy") |
| |
| classes = {key: classes[key] for key in classes.files} |
| classes["1000"] = global_mean |
| classes_array = jnp.array([classes[str(i)] for i in range(len(classes))]) |
|
|
| def do_inference( |
| FLAGS, |
| train_state, |
| step, |
| dataset, |
| dataset_valid, |
| shard_data, |
| vae_encode, |
| vae_decode, |
| update, |
| get_fid_activations, |
| imagenet_labels, |
| visualize_labels, |
| fid_from_stats, |
| truth_fid_stats, |
| ): |
| with jax.spmd_mode('allow_all'): |
| global_device_count = jax.device_count() |
| key = jax.random.PRNGKey(42 + jax.process_index()) |
| batch_images, batch_labels = next(dataset) |
| valid_images, valid_labels = next(dataset_valid) |
| if FLAGS.model.use_stable_vae: |
| batch_images = vae_encode(key, batch_images) |
| valid_images = vae_encode(key, valid_images) |
| batch_labels_sharded, valid_labels_sharded = shard_data(batch_labels, valid_labels) |
| labels_uncond = shard_data(jnp.ones(batch_labels.shape, dtype=jnp.int32) * FLAGS.model['num_classes']) |
| eps = jax.random.normal(key, batch_images.shape) |
|
|
| def process_img(img): |
| if FLAGS.model.use_stable_vae: |
| img = vae_decode(img[None])[0] |
| img = img * 0.5 + 0.5 |
| img = jnp.clip(img, 0, 1) |
| img = np.array(img) |
| return img |
| |
| @partial(jax.jit, static_argnums=(5,)) |
| def call_model(train_state, images, t, dt, labels, use_ema=True, perturbe = False): |
| if use_ema and FLAGS.model.use_ema: |
| call_fn = train_state.call_model_ema |
| else: |
| call_fn = train_state.call_model |
| output = call_fn(images, t, dt, labels, train=False) |
| return output |
| |
| if FLAGS.mode == 'interpolate': |
| seed = 5 |
| eps0 = jax.random.normal(jax.random.PRNGKey(seed), batch_images[0].shape) |
| eps1 = jax.random.normal(jax.random.PRNGKey(seed+1), batch_images[0].shape) |
| labels = jnp.ones(FLAGS.batch_size,).astype(jnp.int32) * 555 |
| i = jnp.linspace(0, 1, FLAGS.batch_size) |
| i_neg = np.sqrt(1-i**2) |
| x = eps0[None] * i_neg[:, None, None, None] + eps1[None] * i[:, None, None, None] |
| t_vector = jnp.full((FLAGS.batch_size, ), 0) |
| dt_vector = jnp.zeros_like(t_vector) |
| cfg_scale = FLAGS.inference_cfg_scale |
| v = call_model(train_state, x, t_vector, dt_vector, labels) |
| x = x + v * 1.0 |
| x = vae_decode(x) |
| x_render = np.array(jax.experimental.multihost_utils.process_allgather(x)) |
| os.makedirs(FLAGS.save_dir, exist_ok=True) |
| np.save(FLAGS.save_dir + f'/x_render.npy', x_render) |
| breakpoint() |
|
|
| denoise_timesteps = FLAGS.inference_timesteps |
| num_generations = FLAGS.inference_generations |
| cfg_scale = FLAGS.inference_cfg_scale |
| x0 = [] |
| x1 = [] |
| lab = [] |
| x_render = [] |
| activations = [] |
| images_shape = batch_images.shape |
| print(f"Calc FID for CFG {cfg_scale} and denoise_timesteps {denoise_timesteps}") |
| print("should do x", num_generations // FLAGS.batch_size) |
| for fid_it in tqdm.tqdm(range(num_generations // FLAGS.batch_size)): |
| key = jax.random.PRNGKey(42) |
| key = jax.random.fold_in(key, fid_it) |
| key = jax.random.fold_in(key, jax.process_index()) |
| eps_key, label_key = jax.random.split(key) |
| x = jax.random.normal(eps_key, images_shape) |
| labels = jax.random.randint(label_key, (images_shape[0],), 0, FLAGS.model.num_classes) |
| |
| e = 0.30 |
|
|
| from baselines.targets_naive import map_labels_to_classes |
| x_cond = map_labels_to_classes(classes_array, labels) * (1-e) + e * x |
| x_uncond = map_labels_to_classes(classes_array, labels_uncond) * (1-e) + e * x |
| |
|
|
|
|
| x_cond, labels = shard_data(x_cond, labels) |
| |
| x_uncond, _ = shard_data(x_uncond, labels) |
|
|
| x0.append(np.array(jax.experimental.multihost_utils.process_allgather(x))) |
|
|
| if False: |
| print(x.shape) |
| print(x_cond.shape) |
| print(labels) |
| if False: |
| x = vae_decode(x[0:5]) |
| x_cond = vae_decode(x_cond[0:5]) |
| x_uncond = vae_decode(x_uncond[0:5]) |
| |
| x = ((x + 1) * 127.5).clip(0, 255) |
| x_cond = ((x_cond + 1) * 127.5).clip(0, 255) |
| x_uncond = ((x_uncond + 1) * 127.5).clip(0, 255) |
| noise_levels = [0,.01,.05,.1,.2,.33,.66,1.0] |
|
|
| |
| x = x[0:5] |
|
|
|
|
| for noise_level in noise_levels: |
|
|
| x_1 = batch_images[0:5] |
| x_0 = x[0:5] |
| e = 0.05 |
| labels = labels[0:5] |
| |
| print("noise level", noise_level) |
| print("noise shape", x_0.shape) |
| x_0 = map_labels_to_classes(classes_array, labels)*(1-e) + e * x_0 |
| |
| |
| x_t = (1 - (1 - 1e-5) * noise_level) * x_0 + noise_level * x_1 |
|
|
| v_t = x_1 - (1 - 1e-5) * x_0 |
| |
| |
|
|
|
|
| dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32) |
| dt_base = jnp.ones(x_0.shape[0], dtype=jnp.int32) * dt_flow |
| |
|
|
| noise_level = jnp.ones(x_0.shape[0], dtype=jnp.int32) * noise_level |
| |
| |
| v = call_model(train_state, x_t[0:5], noise_level, dt_base, labels) |
| diff = (v_t - v) ** 2 |
| print("first loss", diff.mean()) |
| |
|
|
| |
| image = x_0[0] + v_t[0] |
| image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze() |
| image = ((image + 1) * 127.5).clip(0, 255) |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("denoised_image_real_v" + str(noise_level) + ".png") |
|
|
|
|
| image = x_0[0] + v[0] |
| image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze() |
| image = ((image + 1) * 127.5).clip(0, 255) |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("denoised_image_" + str(noise_level) + ".png") |
| |
| image = x_1[0] |
| image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze() |
| image = ((image + 1) * 127.5).clip(0, 255) |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("actual_image_" + str(noise_level) + ".png") |
| |
| image = x_t[0] |
| image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze() |
| image = ((image + 1) * 127.5).clip(0, 255) |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("noised_image_" + str(noise_level) + ".png") |
|
|
|
|
|
|
| """ |
| print("first dtbase", dt_base) |
| from baselines.targets_naive import get_targets |
| x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, eps_key, train_state, batch_images[0:5], labels[0:5], -1, -1, classes_array) |
| #print("v_t2", v_t) |
| #This uses random ts, so it doesn't tell us shit. |
| v = call_model(train_state, x_t[0:5], noise_level, dt_base, labels) |
| print("second dtbase", dt_base) |
| print("second loss", ((v_t - v) ** 2).mean()) |
| #Noise level 1.0 should be loss around 0.03... |
| #get mse v, vt_t |
| #if needed. |
| """ |
| exit() |
| break |
|
|
| print("doing some decoding stuff") |
| for i in range(0,5): |
| image = x[i] |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("noisestuff" + str(i) + ".png") |
| for i in range(0,5): |
| image = x_cond[i] |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("condstuff" + str(i) + ".png") |
| image = x_uncond[0] |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("uncondtuff" + str(i) + ".png") |
| |
|
|
|
|
| delta_t = 1.0 / denoise_timesteps |
| for ti in range(denoise_timesteps): |
| t = ti / denoise_timesteps |
| t_vector = jnp.full((images_shape[0], ), t) |
| if FLAGS.model.train_type == 'naive': |
| dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32) |
| dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow |
| else: |
| dt_flow = np.log2(denoise_timesteps).astype(jnp.int32) |
| dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow |
|
|
| t_vector, dt_base = shard_data(t_vector, dt_base) |
| if cfg_scale == 1: |
| v = call_model(train_state, x, t_vector, dt_base, labels) |
| elif cfg_scale == 0: |
| v = call_model(train_state, x, t_vector, dt_base, labels_uncond) |
| else: |
| v_pred_uncond = call_model(train_state, x_uncond, t_vector, dt_base, labels_uncond) |
| v_pred_label = call_model(train_state, x_cond, t_vector, dt_base, labels) |
| v = v_pred_uncond + cfg_scale * (v_pred_label - v_pred_uncond) |
|
|
| if FLAGS.model.train_type == 'consistency': |
| eps = shard_data(jax.random.normal(jax.random.fold_in(eps_key, ti), images_shape)) |
| x1pred = x + v * (1-t) |
| x = x1pred * (t+delta_t) + eps * (1-t-delta_t) |
|
|
| elif True: |
| x = x + v * delta_t |
| elif False: |
| if ti + 1 == denoise_timesteps: |
| x = x + v * delta_t |
| else: |
| dt_flow = np.log2(denoise_timesteps/2).astype(jnp.int32) |
| dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow |
| |
| v_2_c = call_model(train_state, x, t_vector, dt_base, labels) |
| v_2_u = call_model(train_state, x, t_vector, dt_base, labels_uncond) |
| v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u) |
| |
| |
| |
| v_prime = (v + v_2) / 2 |
| x = x + v_prime * delta_t |
| elif False: |
| |
| |
|
|
| if ti + 1 == denoise_timesteps: |
| x = x + v * delta_t |
| else: |
| pass |
| elif True: |
|
|
| if ti + 1 == denoise_timesteps: |
| x = x + v * delta_t |
| else: |
| |
| k1 = v |
| t1 = t |
|
|
| |
| x2 = x + (delta_t / 3) * k1 |
| t_vector_2 = jnp.full((images_shape[0],), t1 + delta_t / 3) |
| t_vector_2 = shard_data(t_vector_2) |
| k2_c = call_model(train_state, x2, t_vector_2, dt_base, labels) |
| k2_u = call_model(train_state, x2, t_vector_2, dt_base, labels_uncond) |
| k2 = k2_u + cfg_scale * (k2_c - k2_u) |
|
|
| |
| x3 = x + (2 * delta_t / 3) * k2 |
| t_vector_3 = jnp.full((images_shape[0],), t1 + 2 * delta_t / 3) |
| t_vector_3 = shard_data(t_vector_3) |
| k3_c = call_model(train_state, x3, t_vector_3, dt_base, labels) |
| k3_u = call_model(train_state, x3, t_vector_3, dt_base, labels_uncond) |
| k3 = k3_u + cfg_scale * (k3_c - k3_u) |
|
|
| |
| v_prime = (1/4) * k1 + (3/4) * k3 |
| x = x + v_prime * delta_t |
| elif True: |
|
|
| if ti + 1 == denoise_timesteps: |
| x = x + v * delta_t |
| else: |
| x1 = x |
| t1 = t |
| v1 = v |
|
|
| |
| x2 = x1 + v1 * delta_t / 2 |
| t_vector_2 = jnp.full((images_shape[0],), t1 + delta_t / 2) |
| t_vector_2 = shard_data(t_vector_2) |
| v2_c = call_model(train_state, x2, t_vector_2, dt_base, labels) |
| v2_u = call_model(train_state, x2, t_vector_2, dt_base, labels_uncond) |
| v2 = v2_u + cfg_scale * (v2_c - v2_u) |
|
|
| |
| x3 = x1 - v1 * delta_t + 2 * v2 * delta_t |
| t_vector_3 = jnp.full((images_shape[0],), t1 + delta_t) |
| t_vector_3 = shard_data(t_vector_3) |
| v3_c = call_model(train_state, x3, t_vector_3, dt_base, labels) |
| v3_u = call_model(train_state, x3, t_vector_3, dt_base, labels_uncond) |
| v3 = v3_u + cfg_scale * (v3_c - v3_u) |
|
|
| |
| v_prime = (v1 + 4 * v2 + v3) / 6 |
| x = x + v_prime * delta_t |
|
|
| elif True: |
| |
| if ti + 1 == denoise_timesteps: |
| x = x + v * delta_t |
| else: |
| x_2 = x + v * delta_t |
| |
| t_vector_2 = jnp.full((images_shape[0], ), t + delta_t) |
| t_vector_2 = shard_data(t_vector_2) |
| |
| v_2_c = call_model(train_state, x_2, t_vector_2, dt_base, labels) |
| v_2_u = call_model(train_state, x_2, t_vector_2, dt_base, labels_uncond) |
| v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u) |
| |
| |
| |
| |
| v_prime = (v + v_2) / 2 |
| x = x + v_prime * delta_t |
| |
| |
| elif False: |
| |
| if ti + 1 == denoise_timesteps: |
| x = x + v * delta_t |
| continue |
| sigma_hat = t |
| |
| |
| |
| |
| sigma_i_1 = t + delta_t |
| |
| sigma_mid = ((sigma_hat ** (1 / 3) + sigma_i_1 ** (1 / 3)) / 2) ** 3 |
| dt_1 = sigma_mid - sigma_hat |
| dt_2 = sigma_i_1 - sigma_hat |
|
|
| x_2 = x + v * dt_1 |
|
|
| t_vector_2 = jnp.full((images_shape[0], ), sigma_mid) |
|
|
| v_2_c = call_model(train_state, x_2, t_vector_2, dt_base, labels) |
| v_2_u = call_model(train_state, x_2, t_vector_2, dt_base, labels_uncond) |
| v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u) |
|
|
| x = x + v_2 * dt_2 |
|
|
| elif False: |
| img_mid = x + (t_prev - t_curr)/2 * v |
| t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device) |
| v_2 = model(img_mid, t_vec_mid) |
|
|
| first_order = (v_2 - v) / ((t_prev - t_curr) / 2) |
| x = x = (t_prev - t_curr) * v + .5 * (t_prev - t_curr) ** 2 * first_order |
|
|
| |
| x1.append(np.array(jax.experimental.multihost_utils.process_allgather(x))) |
| lab.append(np.array(jax.experimental.multihost_utils.process_allgather(labels))) |
| if FLAGS.model.use_stable_vae: |
| x = vae_decode(x) |
| if num_generations < 10000: |
| x_render.append(np.array(jax.experimental.multihost_utils.process_allgather(x))) |
|
|
| |
| print("decode n shit", x.shape) |
| if False: |
| for i in range(0,5): |
| image = x[i] |
| image = ((image + 1) * 127.5).clip(0, 255) |
| from PIL import Image |
| image = np.array(image).astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save("stuff" + str(i) + ".png") |
| print("done") |
| |
|
|
| x = jax.image.resize(x, (x.shape[0], 299, 299, 3), method='bilinear', antialias=False) |
| x = jnp.clip(x, -1, 1) |
| acts = get_fid_activations(x)[..., 0, 0, :] |
| acts = jax.experimental.multihost_utils.process_allgather(acts) |
| acts = np.array(acts) |
| activations.append(acts) |
| |
| if jax.process_index() == 0: |
| activations = np.concatenate(activations, axis=0) |
| activations = activations.reshape((-1, activations.shape[-1])) |
| mu1 = np.mean(activations, axis=0) |
| sigma1 = np.cov(activations, rowvar=False) |
| fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma']) |
| print(f"FID is {fid}") |
| return |
|
|
| if FLAGS.save_dir is not None: |
| os.makedirs(FLAGS.save_dir, exist_ok=True) |
| x_render = np.concatenate(x_render, axis=0) |
| np.save(FLAGS.save_dir + f'/x_render.npy', x_render) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|