KublaiKhan1 commited on
Commit
6bf8496
·
verified ·
1 Parent(s): 464344f

Upload folder using huggingface_hub

Browse files
meanflow/helper_inference.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.experimental
3
+ import wandb
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ import tqdm
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ from functools import partial
10
+ from absl import app, flags
11
+
12
+ flags.DEFINE_integer('inference_timesteps', 1, 'Number of timesteps for inference.')
13
+ flags.DEFINE_integer('inference_generations', 50000, 'Number of generations for inference.')
14
+ flags.DEFINE_float('inference_cfg_scale', 1.0, 'CFG scale for inference.')
15
+
16
+ def do_inference(
17
+ FLAGS,
18
+ train_state,
19
+ step,
20
+ dataset,
21
+ dataset_valid,
22
+ shard_data,
23
+ vae_encode,
24
+ vae_decode,
25
+ update,
26
+ get_fid_activations,
27
+ imagenet_labels,
28
+ visualize_labels,
29
+ fid_from_stats,
30
+ truth_fid_stats,
31
+ ):
32
+ with jax.spmd_mode('allow_all'):
33
+ global_device_count = jax.device_count()
34
+ key = jax.random.PRNGKey(42 + jax.process_index())
35
+ batch_images, batch_labels = next(dataset)
36
+ valid_images, valid_labels = next(dataset_valid)
37
+ if FLAGS.model.use_stable_vae:
38
+ batch_images = vae_encode(key, batch_images)
39
+ valid_images = vae_encode(key, valid_images)
40
+ batch_labels_sharded, valid_labels_sharded = shard_data(batch_labels, valid_labels)
41
+ labels_uncond = shard_data(jnp.ones(batch_labels.shape, dtype=jnp.int32) * FLAGS.model['num_classes']) # Null token
42
+ eps = jax.random.normal(key, batch_images.shape)
43
+
44
+ def process_img(img):
45
+ if FLAGS.model.use_stable_vae:
46
+ img = vae_decode(img[None])[0]
47
+ img = img * 0.5 + 0.5
48
+ img = jnp.clip(img, 0, 1)
49
+ img = np.array(img)
50
+ return img
51
+
52
+ @partial(jax.jit, static_argnums=(5,))
53
+ def call_model(train_state, images, t, dt, labels, use_ema=True):
54
+ if use_ema and FLAGS.model.use_ema:
55
+ call_fn = train_state.call_model_ema
56
+ else:
57
+ call_fn = train_state.call_model
58
+ output = call_fn(images, t, dt, labels, train=False)
59
+ return output
60
+
61
+ if FLAGS.mode == 'interpolate':
62
+ seed = 5
63
+ eps0 = jax.random.normal(jax.random.PRNGKey(seed), batch_images[0].shape)
64
+ eps1 = jax.random.normal(jax.random.PRNGKey(seed+1), batch_images[0].shape)
65
+ labels = jnp.ones(FLAGS.batch_size,).astype(jnp.int32) * 555
66
+ i = jnp.linspace(0, 1, FLAGS.batch_size)
67
+ i_neg = np.sqrt(1-i**2)
68
+ x = eps0[None] * i_neg[:, None, None, None] + eps1[None] * i[:, None, None, None]
69
+ t_vector = jnp.full((FLAGS.batch_size, ), 0)
70
+ dt_vector = jnp.zeros_like(t_vector)
71
+ cfg_scale = FLAGS.inference_cfg_scale
72
+ v = call_model(train_state, x, t_vector, dt_vector, labels)
73
+ x = x + v * 1.0
74
+ x = vae_decode(x) # Image is in [-1, 1] space.
75
+ x_render = np.array(jax.experimental.multihost_utils.process_allgather(x))
76
+ os.makedirs(FLAGS.save_dir, exist_ok=True)
77
+ np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
78
+ breakpoint()
79
+
80
+ denoise_timesteps = FLAGS.inference_timesteps
81
+ num_generations = FLAGS.inference_generations
82
+ cfg_scale = FLAGS.inference_cfg_scale
83
+ x0 = []
84
+ x1 = []
85
+ lab = []
86
+ x_render = []
87
+ activations = []
88
+ images_shape = batch_images.shape
89
+ print(f"Calc FID for CFG {cfg_scale} and denoise_timesteps {denoise_timesteps}")
90
+ for fid_it in tqdm.tqdm(range(num_generations // FLAGS.batch_size)):
91
+ key = jax.random.PRNGKey(42)
92
+ key = jax.random.fold_in(key, fid_it)
93
+ key = jax.random.fold_in(key, jax.process_index())
94
+ eps_key, label_key = jax.random.split(key)
95
+ x = jax.random.normal(eps_key, images_shape)
96
+ labels = jax.random.randint(label_key, (images_shape[0],), 0, FLAGS.model.num_classes)
97
+ x, labels = shard_data(x, labels)
98
+ x0.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
99
+ delta_t = 1.0 / denoise_timesteps
100
+ sigmas = []
101
+ for ti in range(denoise_timesteps + 1):
102
+ t = ti / denoise_timesteps # From x_0 (noise) to x_1 (data)
103
+ sigmas.append(t)
104
+ #So this gives us n + 1 steps, because we start at n
105
+ i = 0
106
+ for ti in range(denoise_timesteps):
107
+ t = ti / denoise_timesteps # From x_0 (noise) to x_1 (data)
108
+ meanflow = True#testing regular
109
+ if meanflow:
110
+ t = 1
111
+ t_vector = jnp.full((images_shape[0], ), t)
112
+ if FLAGS.model.train_type == 'naive':
113
+ dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
114
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow # Smallest dt.
115
+ else: # shortcut
116
+ dt_flow = np.log2(denoise_timesteps).astype(jnp.int32)
117
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow
118
+ # print(dt_base)
119
+ if meanflow:
120
+ dt_base = dt_base * 0
121
+
122
+ #dt_base = t
123
+ #Need to make sure these look right..
124
+ #I think we want to make sure r = t for this part.
125
+ #And we do t normally.
126
+
127
+
128
+ t_vector, dt_base = shard_data(t_vector, dt_base)
129
+ if cfg_scale == 1:
130
+ v = call_model(train_state, x, t_vector, dt_base, labels)
131
+ elif cfg_scale == 0:
132
+ v = call_model(train_state, x, t_vector, dt_base, labels_uncond)
133
+ else:
134
+ v_pred_uncond = call_model(train_state, x, t_vector, dt_base, labels_uncond)
135
+ v_pred_label = call_model(train_state, x, t_vector, dt_base, labels)
136
+ v = v_pred_uncond + cfg_scale * (v_pred_label - v_pred_uncond)
137
+
138
+ if FLAGS.model.train_type == 'consistency':
139
+ eps = shard_data(jax.random.normal(jax.random.fold_in(eps_key, ti), images_shape))
140
+ x1pred = x + v * (1-t)
141
+ x = x1pred * (t+delta_t) + eps * (1-t-delta_t)
142
+ elif True:#Needs to be CORRECT SAMPLING FOR THIS MODEL
143
+ #x = x + v * delta_t # Euler sampling.
144
+ x = x - v * delta_t
145
+ elif False:
146
+
147
+ def get_ancestral_step(t0, t1):
148
+ sigma_up = None
149
+ return 1 / (1 + ((t0 ** 2 * (t1 - 1) ** 4) / ((t0 - 1) ** 2 * t1 ** 4)) ** 0.5), sigma_up
150
+ # def flow_sample_sde_3(model, x, ts):
151
+ #for s, t in tqdm(zip(ts[:-1], ts[1:]), total=len(ts) - 1):
152
+ # dx = model(x, s)
153
+ # denoised = x + dx * (1 - s)
154
+ # noise = torch.randn_like(x)
155
+ # fac_1 = (s * (1 - t) ** 2) / ((1 - s) ** 2 * t)
156
+ # fac_2 = (t ** 2 - 2 * s * t ** 2 + s ** 2 * (2 * t - 1)) / ((1 - s) ** 2 * t)
157
+ # fac_3 = (1 - t) * (fac_2 / t) ** 0.5
158
+ # x = fac_1 * x + fac_2 * denoised + fac_3 * noise
159
+ #return x
160
+ #So our timesteps looks like 0, 1/128..
161
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
162
+ # Euler method
163
+ dt = sigma_down - sigmas[i]
164
+ #Naive up
165
+ sigma_up = sigmas[i+1] - dt
166
+
167
+ x = x + v * dt
168
+ if sigmas[i + 1] != 1.0:
169
+ x = x + jax.random.normal(eps_key, images_shape) * sigma_up * v
170
+
171
+ i += 1
172
+ x1.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
173
+ lab.append(np.array(jax.experimental.multihost_utils.process_allgather(labels)))
174
+ if FLAGS.model.use_stable_vae:
175
+ x = vae_decode(x) # Image is in [-1, 1] space.
176
+ if num_generations < 10000:
177
+ x_render.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
178
+ #save some number of x
179
+ #What is x shape?
180
+ x = jax.image.resize(x, (x.shape[0], 299, 299, 3), method='bilinear', antialias=False)
181
+ x = jnp.clip(x, -1, 1)
182
+ acts = get_fid_activations(x)[..., 0, 0, :] # [devices, batch//devices, 2048]
183
+ acts = jax.experimental.multihost_utils.process_allgather(acts)
184
+ acts = np.array(acts)
185
+ activations.append(acts)
186
+
187
+ if jax.process_index() == 0:
188
+ activations = np.concatenate(activations, axis=0)
189
+ activations = activations.reshape((-1, activations.shape[-1]))
190
+ mu1 = np.mean(activations, axis=0)
191
+ sigma1 = np.cov(activations, rowvar=False)
192
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
193
+ print(f"FID is {fid}")
194
+ print(f"FID is {fid}")
195
+ print(f"FID is {fid}")
196
+
197
+
198
+ if FLAGS.save_dir is not None:
199
+ os.makedirs(FLAGS.save_dir, exist_ok=True)
200
+ x_render = np.concatenate(x_render, axis=0)
201
+ np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
202
+
203
+ # x0 = np.concatenate(x0, axis=0)
204
+ # x1 = np.concatenate(x1, axis=0)
205
+ # lab = np.concatenate(lab, axis=0)
206
+ # os.makedirs(FLAGS.save_dir, exist_ok=True)
207
+ # np.save(FLAGS.save_dir + f'/x0.npy', x0)
208
+ # np.save(FLAGS.save_dir + f'/x1.npy', x1)
209
+ # np.save(FLAGS.save_dir + f'/lab.npy', lab)
meanflow/notes.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Katherines Reverse time:
2
+
3
+ noise 0 clean 1
4
+
5
+ sample t >= r
6
+
7
+ z = (1 - r) * e + r * x, v = x - e
8
+
9
+ jvp = (v, 1, 0) (v, r, t)
10
+ u_gt = v + (t-r) * stopgrad
11
+
12
+ sample = z = z + (t - r) * model(z,r,t)
13
+ Although its actually model(z,r,t-r)
14
+
15
+ Sa,pling is r=0, t=1
meanflow/targets_naive.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+ def get_targets(FLAGS, key, train_state, images, labels, force_t=-1, force_dt=-1):
6
+ label_key, time_key, noise_key = jax.random.split(key, 3)
7
+ info = {}
8
+
9
+ labels_dropout = jax.random.bernoulli(label_key, FLAGS.model['class_dropout_prob'], (labels.shape[0],))
10
+ labels_dropped = jnp.where(labels_dropout, FLAGS.model['num_classes'], labels)
11
+ info['dropped_ratio'] = jnp.mean(labels_dropped == FLAGS.model['num_classes'])
12
+
13
+ # Sample t.
14
+ t = jax.random.randint(time_key, (images.shape[0],), minval=0, maxval=FLAGS.model['denoise_timesteps']).astype(jnp.float32)
15
+ t /= FLAGS.model['denoise_timesteps']
16
+ force_t_vec = jnp.ones(images.shape[0], dtype=jnp.float32) * force_t
17
+ t = jnp.where(force_t_vec != -1, force_t_vec, t) # If force_t is not -1, then use force_t.
18
+ t_full = t[:, None, None, None] # [batch, 1, 1, 1]
19
+
20
+ # Sample flow pairs x_t, v_t.
21
+ if 'latent' in FLAGS.dataset_name:
22
+ x_0 = images[..., :images.shape[-1] // 2]
23
+ x_1 = images[..., images.shape[-1] // 2:]
24
+ x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
25
+ v_t = x_1 - (1 - 1e-5) * x_0
26
+ else:
27
+ x_1 = images
28
+ x_0 = jax.random.normal(noise_key, images.shape)
29
+ x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
30
+ v_t = x_1 - (1 - 1e-5) * x_0
31
+
32
+ dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
33
+ dt_base = jnp.ones(images.shape[0], dtype=jnp.int32) * dt_flow
34
+
35
+ return x_t, v_t, t, dt_base, labels_dropped, info