Shortcuts / class_mean_0.05 /helper_inference.py
KublaiKhan1's picture
Upload folder using huggingface_hub
95f1da1 verified
import jax
import jax.experimental
import wandb
import jax.numpy as jnp
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import os
from functools import partial
from absl import app, flags
flags.DEFINE_integer('inference_timesteps', 128, 'Number of timesteps for inference.')
flags.DEFINE_integer('inference_generations', 50000, 'Number of generations for inference.')
flags.DEFINE_float('inference_cfg_scale', 1.5, 'CFG scale for inference.')
classes = np.load("classes.npz")
global_mean = jnp.load("global_mean.npy")
#print(type(classes))#npz shit
classes = {key: classes[key] for key in classes.files}
classes["1000"] = global_mean
classes_array = jnp.array([classes[str(i)] for i in range(len(classes))])
def do_inference(
FLAGS,
train_state,
step,
dataset,
dataset_valid,
shard_data,
vae_encode,
vae_decode,
update,
get_fid_activations,
imagenet_labels,
visualize_labels,
fid_from_stats,
truth_fid_stats,
):
with jax.spmd_mode('allow_all'):
global_device_count = jax.device_count()
key = jax.random.PRNGKey(42 + jax.process_index())
batch_images, batch_labels = next(dataset)
valid_images, valid_labels = next(dataset_valid)
if FLAGS.model.use_stable_vae:
batch_images = vae_encode(key, batch_images)
valid_images = vae_encode(key, valid_images)
batch_labels_sharded, valid_labels_sharded = shard_data(batch_labels, valid_labels)
labels_uncond = shard_data(jnp.ones(batch_labels.shape, dtype=jnp.int32) * FLAGS.model['num_classes']) # Null token
eps = jax.random.normal(key, batch_images.shape)
def process_img(img):
if FLAGS.model.use_stable_vae:
img = vae_decode(img[None])[0]
img = img * 0.5 + 0.5
img = jnp.clip(img, 0, 1)
img = np.array(img)
return img
@partial(jax.jit, static_argnums=(5,))
def call_model(train_state, images, t, dt, labels, use_ema=True, perturbe = False):
if use_ema and FLAGS.model.use_ema:
call_fn = train_state.call_model_ema
else:
call_fn = train_state.call_model
output = call_fn(images, t, dt, labels, train=False)#, perturbe = perturbe)
return output
if FLAGS.mode == 'interpolate':
seed = 5
eps0 = jax.random.normal(jax.random.PRNGKey(seed), batch_images[0].shape)
eps1 = jax.random.normal(jax.random.PRNGKey(seed+1), batch_images[0].shape)
labels = jnp.ones(FLAGS.batch_size,).astype(jnp.int32) * 555
i = jnp.linspace(0, 1, FLAGS.batch_size)
i_neg = np.sqrt(1-i**2)
x = eps0[None] * i_neg[:, None, None, None] + eps1[None] * i[:, None, None, None]
t_vector = jnp.full((FLAGS.batch_size, ), 0)
dt_vector = jnp.zeros_like(t_vector)
cfg_scale = FLAGS.inference_cfg_scale
v = call_model(train_state, x, t_vector, dt_vector, labels)
x = x + v * 1.0
x = vae_decode(x) # Image is in [-1, 1] space.
x_render = np.array(jax.experimental.multihost_utils.process_allgather(x))
os.makedirs(FLAGS.save_dir, exist_ok=True)
np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
breakpoint()
denoise_timesteps = FLAGS.inference_timesteps
num_generations = FLAGS.inference_generations
cfg_scale = FLAGS.inference_cfg_scale
x0 = []
x1 = []
lab = []
x_render = []
activations = []
images_shape = batch_images.shape
print(f"Calc FID for CFG {cfg_scale} and denoise_timesteps {denoise_timesteps}")
print("should do x", num_generations // FLAGS.batch_size)
for fid_it in tqdm.tqdm(range(num_generations // FLAGS.batch_size)):
key = jax.random.PRNGKey(42)
key = jax.random.fold_in(key, fid_it)
key = jax.random.fold_in(key, jax.process_index())
eps_key, label_key = jax.random.split(key)
x = jax.random.normal(eps_key, images_shape)
labels = jax.random.randint(label_key, (images_shape[0],), 0, FLAGS.model.num_classes)
#Recalculate X
e = 0.30
from baselines.targets_naive import map_labels_to_classes
x_cond = map_labels_to_classes(classes_array, labels) * (1-e) + e * x
x_uncond = map_labels_to_classes(classes_array, labels_uncond) * (1-e) + e * x
# print("first xcond", x_cond[0])
x_cond, labels = shard_data(x_cond, labels)
# print("sharded xcond", x_cond[0])
x_uncond, _ = shard_data(x_uncond, labels)
x0.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
if False:
print(x.shape)#256,32,32,4
print(x_cond.shape)
print(labels)
if False:
x = vae_decode(x[0:5])
x_cond = vae_decode(x_cond[0:5])
x_uncond = vae_decode(x_uncond[0:5])
#They are all 0 to 255
x = ((x + 1) * 127.5).clip(0, 255)
x_cond = ((x_cond + 1) * 127.5).clip(0, 255)
x_uncond = ((x_uncond + 1) * 127.5).clip(0, 255)
noise_levels = [0,.01,.05,.1,.2,.33,.66,1.0]
x = x[0:5]
for noise_level in noise_levels:
x_1 = batch_images[0:5]
x_0 = x[0:5]
e = 0.05
labels = labels[0:5]
#what...?
print("noise level", noise_level)
print("noise shape", x_0.shape)#batch, 256, 256, 4
x_0 = map_labels_to_classes(classes_array, labels)*(1-e) + e * x_0#So this is just full noise right? noise level starts at 0, which means we are full noise.
#print("classes mapped shape", x_0.shape)
#exit()
x_t = (1 - (1 - 1e-5) * noise_level) * x_0 + noise_level * x_1
v_t = x_1 - (1 - 1e-5) * x_0
#print("v_t is", v_t)
#x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt, classes)
dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
dt_base = jnp.ones(x_0.shape[0], dtype=jnp.int32) * dt_flow # Smallest dt.
#Noise level needs to be the shape shape as stuff
noise_level = jnp.ones(x_0.shape[0], dtype=jnp.int32) * noise_level
#Call using the noisy data lol...
v = call_model(train_state, x_t[0:5], noise_level, dt_base, labels)
diff = (v_t - v) ** 2
print("first loss", diff.mean())
#These are wrong because the velocity calculation uses x_1 and x_0, which is images and classes
image = x_0[0] + v_t[0]
image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
image = ((image + 1) * 127.5).clip(0, 255)
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("denoised_image_real_v" + str(noise_level) + ".png")
image = x_0[0] + v[0]
image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
image = ((image + 1) * 127.5).clip(0, 255)
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("denoised_image_" + str(noise_level) + ".png")
image = x_1[0]
image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
image = ((image + 1) * 127.5).clip(0, 255)
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("actual_image_" + str(noise_level) + ".png")
image = x_t[0]
image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
image = ((image + 1) * 127.5).clip(0, 255)
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("noised_image_" + str(noise_level) + ".png")
"""
print("first dtbase", dt_base)
from baselines.targets_naive import get_targets
x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, eps_key, train_state, batch_images[0:5], labels[0:5], -1, -1, classes_array)
#print("v_t2", v_t)
#This uses random ts, so it doesn't tell us shit.
v = call_model(train_state, x_t[0:5], noise_level, dt_base, labels)
print("second dtbase", dt_base)
print("second loss", ((v_t - v) ** 2).mean())
#Noise level 1.0 should be loss around 0.03...
#get mse v, vt_t
#if needed.
"""
exit()
break
print("doing some decoding stuff")
for i in range(0,5):
image = x[i]
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("noisestuff" + str(i) + ".png")
for i in range(0,5):
image = x_cond[i]
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("condstuff" + str(i) + ".png")
image = x_uncond[0]
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("uncondtuff" + str(i) + ".png")
#exit()
delta_t = 1.0 / denoise_timesteps
for ti in range(denoise_timesteps):
t = ti / denoise_timesteps # From x_0 (noise) to x_1 (data)
t_vector = jnp.full((images_shape[0], ), t)
if FLAGS.model.train_type == 'naive':
dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow # Smallest dt.
else: # shortcut
dt_flow = np.log2(denoise_timesteps).astype(jnp.int32)#[128,64,32,16,8,4,2,1] = [7,6,5,4,3,2,1,0]
dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow #For 128 steps, distance = 7, maximum distance.
t_vector, dt_base = shard_data(t_vector, dt_base)
if cfg_scale == 1:
v = call_model(train_state, x, t_vector, dt_base, labels)
elif cfg_scale == 0:
v = call_model(train_state, x, t_vector, dt_base, labels_uncond)
else:
v_pred_uncond = call_model(train_state, x_uncond, t_vector, dt_base, labels_uncond)
v_pred_label = call_model(train_state, x_cond, t_vector, dt_base, labels)
v = v_pred_uncond + cfg_scale * (v_pred_label - v_pred_uncond)
if FLAGS.model.train_type == 'consistency':
eps = shard_data(jax.random.normal(jax.random.fold_in(eps_key, ti), images_shape))
x1pred = x + v * (1-t)
x = x1pred * (t+delta_t) + eps * (1-t-delta_t)
elif True:
x = x + v * delta_t # Euler sampling.
elif False:#special predictor. So with special. If we do a natural prediction of step 4, distance = 2... we do a step same x, but longer distance. so as if we were doing 2 steps
if ti + 1 == denoise_timesteps:
x = x + v * delta_t
else:
dt_flow = np.log2(denoise_timesteps/2).astype(jnp.int32)#[128,64,32,16,8,4,2,1] = [7,6,5,4,3,2,1,0]
dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow
v_2_c = call_model(train_state, x, t_vector, dt_base, labels)
v_2_u = call_model(train_state, x, t_vector, dt_base, labels_uncond)
v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u)
#We might be able to skip doing CFG in the future
v_prime = (v + v_2) / 2
x = x + v_prime * delta_t
elif False:#midpiont
#print("ts", t)
if ti + 1 == denoise_timesteps:# or ti == 0:
x = x + v * delta_t
else:
pass
elif True:#heun 3
if ti + 1 == denoise_timesteps:
x = x + v * delta_t # Final Euler step
else:
# Stage 1
k1 = v # already computed
t1 = t
# Stage 2
x2 = x + (delta_t / 3) * k1
t_vector_2 = jnp.full((images_shape[0],), t1 + delta_t / 3)
t_vector_2 = shard_data(t_vector_2)
k2_c = call_model(train_state, x2, t_vector_2, dt_base, labels)
k2_u = call_model(train_state, x2, t_vector_2, dt_base, labels_uncond)
k2 = k2_u + cfg_scale * (k2_c - k2_u)
# Stage 3
x3 = x + (2 * delta_t / 3) * k2
t_vector_3 = jnp.full((images_shape[0],), t1 + 2 * delta_t / 3)
t_vector_3 = shard_data(t_vector_3)
k3_c = call_model(train_state, x3, t_vector_3, dt_base, labels)
k3_u = call_model(train_state, x3, t_vector_3, dt_base, labels_uncond)
k3 = k3_u + cfg_scale * (k3_c - k3_u)
# Combine stages
v_prime = (1/4) * k1 + (3/4) * k3
x = x + v_prime * delta_t
elif True:#Third order RK
if ti + 1 == denoise_timesteps:
x = x + v * delta_t # Final Euler step
else:
x1 = x
t1 = t
v1 = v
# Stage 2
x2 = x1 + v1 * delta_t / 2
t_vector_2 = jnp.full((images_shape[0],), t1 + delta_t / 2)
t_vector_2 = shard_data(t_vector_2)
v2_c = call_model(train_state, x2, t_vector_2, dt_base, labels)
v2_u = call_model(train_state, x2, t_vector_2, dt_base, labels_uncond)
v2 = v2_u + cfg_scale * (v2_c - v2_u)
# Stage 3
x3 = x1 - v1 * delta_t + 2 * v2 * delta_t
t_vector_3 = jnp.full((images_shape[0],), t1 + delta_t)
t_vector_3 = shard_data(t_vector_3)
v3_c = call_model(train_state, x3, t_vector_3, dt_base, labels)
v3_u = call_model(train_state, x3, t_vector_3, dt_base, labels_uncond)
v3 = v3_u + cfg_scale * (v3_c - v3_u)
# Weighted sum of stages
v_prime = (v1 + 4 * v2 + v3) / 6
x = x + v_prime * delta_t
elif True:#heun
#Last time euler
if ti + 1 == denoise_timesteps:# or ti == 0:
x = x + v * delta_t
else:
x_2 = x + v * delta_t
#print("original t", t_vector)
t_vector_2 = jnp.full((images_shape[0], ), t + delta_t)
t_vector_2 = shard_data(t_vector_2)
#print("second t", t_vector_2)
v_2_c = call_model(train_state, x_2, t_vector_2, dt_base, labels)
v_2_u = call_model(train_state, x_2, t_vector_2, dt_base, labels_uncond)
v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u)
# print(jnp.linalg.norm(v))
# print(jnp.linalg.norm(v_2))
v_prime = (v + v_2) / 2
x = x + v_prime * delta_t
elif False:#DPM++2M maybe?
if ti + 1 == denoise_timesteps:
x = x + v * delta_t
continue
sigma_hat = t#Current timestep for me
#we already have v here, v = d
#Should just be the next timestep?
sigma_i_1 = t + delta_t
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigma_i_1 ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigma_i_1 - sigma_hat
x_2 = x + v * dt_1
t_vector_2 = jnp.full((images_shape[0], ), sigma_mid)
v_2_c = call_model(train_state, x_2, t_vector_2, dt_base, labels)
v_2_u = call_model(train_state, x_2, t_vector_2, dt_base, labels_uncond)
v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u)
x = x + v_2 * dt_2
elif False:#RF-solver solution #tcurr and tprev are... 0,0 1,1, 1,2, 2,2, 3,3 3,4, 4,4, 4,5....
img_mid = x + (t_prev - t_curr)/2 * v
t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
v_2 = model(img_mid, t_vec_mid)
first_order = (v_2 - v) / ((t_prev - t_curr) / 2)
x = x = (t_prev - t_curr) * v + .5 * (t_prev - t_curr) ** 2 * first_order
x1.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
lab.append(np.array(jax.experimental.multihost_utils.process_allgather(labels)))
if FLAGS.model.use_stable_vae:
x = vae_decode(x) # Image is in [-1, 1] space.
if num_generations < 10000:
x_render.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
#This happens EVERY LOOP
print("decode n shit", x.shape)
if False:
for i in range(0,5):
image = x[i]
image = ((image + 1) * 127.5).clip(0, 255)
from PIL import Image
image = np.array(image).astype(np.uint8)
image = Image.fromarray(image)
image.save("stuff" + str(i) + ".png")
print("done")
# exit()
x = jax.image.resize(x, (x.shape[0], 299, 299, 3), method='bilinear', antialias=False)
x = jnp.clip(x, -1, 1)
acts = get_fid_activations(x)[..., 0, 0, :] # [devices, batch//devices, 2048]
acts = jax.experimental.multihost_utils.process_allgather(acts)
acts = np.array(acts)
activations.append(acts)
if jax.process_index() == 0:
activations = np.concatenate(activations, axis=0)
activations = activations.reshape((-1, activations.shape[-1]))
mu1 = np.mean(activations, axis=0)
sigma1 = np.cov(activations, rowvar=False)
fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
print(f"FID is {fid}")
return
if FLAGS.save_dir is not None:
os.makedirs(FLAGS.save_dir, exist_ok=True)
x_render = np.concatenate(x_render, axis=0)
np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
# x0 = np.concatenate(x0, axis=0)
# x1 = np.concatenate(x1, axis=0)
# lab = np.concatenate(lab, axis=0)
# os.makedirs(FLAGS.save_dir, exist_ok=True)
# np.save(FLAGS.save_dir + f'/x0.npy', x0)
# np.save(FLAGS.save_dir + f'/x1.npy', x1)
# np.save(FLAGS.save_dir + f'/lab.npy', lab)