KublaiKhan1 commited on
Commit
95f1da1
·
verified ·
1 Parent(s): 5cb5570

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -65,3 +65,4 @@ dt0_1/final.tmp filter=lfs diff=lfs merge=lfs -text
65
  1e-6kl_0.5std/final.tmp filter=lfs diff=lfs merge=lfs -text
66
  global_over_four_channel_mean/final.tmp filter=lfs diff=lfs merge=lfs -text
67
  class_mean_0.30/810001.tmp filter=lfs diff=lfs merge=lfs -text
 
 
65
  1e-6kl_0.5std/final.tmp filter=lfs diff=lfs merge=lfs -text
66
  global_over_four_channel_mean/final.tmp filter=lfs diff=lfs merge=lfs -text
67
  class_mean_0.30/810001.tmp filter=lfs diff=lfs merge=lfs -text
68
+ class_mean_0.05/810000.tmp filter=lfs diff=lfs merge=lfs -text
class_mean_0.05/810000.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff9cc75c61e5f1b5b0de44fefb7cd11934b88003143c7deec51ef1a5f8a72165
3
+ size 2097505397
class_mean_0.05/helper_inference.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.experimental
3
+ import wandb
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ import tqdm
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ from functools import partial
10
+ from absl import app, flags
11
+
12
+ flags.DEFINE_integer('inference_timesteps', 128, 'Number of timesteps for inference.')
13
+ flags.DEFINE_integer('inference_generations', 50000, 'Number of generations for inference.')
14
+ flags.DEFINE_float('inference_cfg_scale', 1.5, 'CFG scale for inference.')
15
+
16
+ classes = np.load("classes.npz")
17
+ global_mean = jnp.load("global_mean.npy")
18
+ #print(type(classes))#npz shit
19
+ classes = {key: classes[key] for key in classes.files}
20
+ classes["1000"] = global_mean
21
+ classes_array = jnp.array([classes[str(i)] for i in range(len(classes))])
22
+
23
+ def do_inference(
24
+ FLAGS,
25
+ train_state,
26
+ step,
27
+ dataset,
28
+ dataset_valid,
29
+ shard_data,
30
+ vae_encode,
31
+ vae_decode,
32
+ update,
33
+ get_fid_activations,
34
+ imagenet_labels,
35
+ visualize_labels,
36
+ fid_from_stats,
37
+ truth_fid_stats,
38
+ ):
39
+ with jax.spmd_mode('allow_all'):
40
+ global_device_count = jax.device_count()
41
+ key = jax.random.PRNGKey(42 + jax.process_index())
42
+ batch_images, batch_labels = next(dataset)
43
+ valid_images, valid_labels = next(dataset_valid)
44
+ if FLAGS.model.use_stable_vae:
45
+ batch_images = vae_encode(key, batch_images)
46
+ valid_images = vae_encode(key, valid_images)
47
+ batch_labels_sharded, valid_labels_sharded = shard_data(batch_labels, valid_labels)
48
+ labels_uncond = shard_data(jnp.ones(batch_labels.shape, dtype=jnp.int32) * FLAGS.model['num_classes']) # Null token
49
+ eps = jax.random.normal(key, batch_images.shape)
50
+
51
+ def process_img(img):
52
+ if FLAGS.model.use_stable_vae:
53
+ img = vae_decode(img[None])[0]
54
+ img = img * 0.5 + 0.5
55
+ img = jnp.clip(img, 0, 1)
56
+ img = np.array(img)
57
+ return img
58
+
59
+ @partial(jax.jit, static_argnums=(5,))
60
+ def call_model(train_state, images, t, dt, labels, use_ema=True, perturbe = False):
61
+ if use_ema and FLAGS.model.use_ema:
62
+ call_fn = train_state.call_model_ema
63
+ else:
64
+ call_fn = train_state.call_model
65
+ output = call_fn(images, t, dt, labels, train=False)#, perturbe = perturbe)
66
+ return output
67
+
68
+ if FLAGS.mode == 'interpolate':
69
+ seed = 5
70
+ eps0 = jax.random.normal(jax.random.PRNGKey(seed), batch_images[0].shape)
71
+ eps1 = jax.random.normal(jax.random.PRNGKey(seed+1), batch_images[0].shape)
72
+ labels = jnp.ones(FLAGS.batch_size,).astype(jnp.int32) * 555
73
+ i = jnp.linspace(0, 1, FLAGS.batch_size)
74
+ i_neg = np.sqrt(1-i**2)
75
+ x = eps0[None] * i_neg[:, None, None, None] + eps1[None] * i[:, None, None, None]
76
+ t_vector = jnp.full((FLAGS.batch_size, ), 0)
77
+ dt_vector = jnp.zeros_like(t_vector)
78
+ cfg_scale = FLAGS.inference_cfg_scale
79
+ v = call_model(train_state, x, t_vector, dt_vector, labels)
80
+ x = x + v * 1.0
81
+ x = vae_decode(x) # Image is in [-1, 1] space.
82
+ x_render = np.array(jax.experimental.multihost_utils.process_allgather(x))
83
+ os.makedirs(FLAGS.save_dir, exist_ok=True)
84
+ np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
85
+ breakpoint()
86
+
87
+ denoise_timesteps = FLAGS.inference_timesteps
88
+ num_generations = FLAGS.inference_generations
89
+ cfg_scale = FLAGS.inference_cfg_scale
90
+ x0 = []
91
+ x1 = []
92
+ lab = []
93
+ x_render = []
94
+ activations = []
95
+ images_shape = batch_images.shape
96
+ print(f"Calc FID for CFG {cfg_scale} and denoise_timesteps {denoise_timesteps}")
97
+ print("should do x", num_generations // FLAGS.batch_size)
98
+ for fid_it in tqdm.tqdm(range(num_generations // FLAGS.batch_size)):
99
+ key = jax.random.PRNGKey(42)
100
+ key = jax.random.fold_in(key, fid_it)
101
+ key = jax.random.fold_in(key, jax.process_index())
102
+ eps_key, label_key = jax.random.split(key)
103
+ x = jax.random.normal(eps_key, images_shape)
104
+ labels = jax.random.randint(label_key, (images_shape[0],), 0, FLAGS.model.num_classes)
105
+ #Recalculate X
106
+ e = 0.30
107
+
108
+ from baselines.targets_naive import map_labels_to_classes
109
+ x_cond = map_labels_to_classes(classes_array, labels) * (1-e) + e * x
110
+ x_uncond = map_labels_to_classes(classes_array, labels_uncond) * (1-e) + e * x
111
+ # print("first xcond", x_cond[0])
112
+
113
+
114
+ x_cond, labels = shard_data(x_cond, labels)
115
+ # print("sharded xcond", x_cond[0])
116
+ x_uncond, _ = shard_data(x_uncond, labels)
117
+
118
+ x0.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
119
+
120
+ if False:
121
+ print(x.shape)#256,32,32,4
122
+ print(x_cond.shape)
123
+ print(labels)
124
+ if False:
125
+ x = vae_decode(x[0:5])
126
+ x_cond = vae_decode(x_cond[0:5])
127
+ x_uncond = vae_decode(x_uncond[0:5])
128
+ #They are all 0 to 255
129
+ x = ((x + 1) * 127.5).clip(0, 255)
130
+ x_cond = ((x_cond + 1) * 127.5).clip(0, 255)
131
+ x_uncond = ((x_uncond + 1) * 127.5).clip(0, 255)
132
+ noise_levels = [0,.01,.05,.1,.2,.33,.66,1.0]
133
+
134
+
135
+ x = x[0:5]
136
+
137
+
138
+ for noise_level in noise_levels:
139
+
140
+ x_1 = batch_images[0:5]
141
+ x_0 = x[0:5]
142
+ e = 0.05
143
+ labels = labels[0:5]
144
+ #what...?
145
+ print("noise level", noise_level)
146
+ print("noise shape", x_0.shape)#batch, 256, 256, 4
147
+ x_0 = map_labels_to_classes(classes_array, labels)*(1-e) + e * x_0#So this is just full noise right? noise level starts at 0, which means we are full noise.
148
+ #print("classes mapped shape", x_0.shape)
149
+ #exit()
150
+ x_t = (1 - (1 - 1e-5) * noise_level) * x_0 + noise_level * x_1
151
+
152
+ v_t = x_1 - (1 - 1e-5) * x_0
153
+ #print("v_t is", v_t)
154
+ #x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt, classes)
155
+
156
+
157
+ dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
158
+ dt_base = jnp.ones(x_0.shape[0], dtype=jnp.int32) * dt_flow # Smallest dt.
159
+ #Noise level needs to be the shape shape as stuff
160
+
161
+ noise_level = jnp.ones(x_0.shape[0], dtype=jnp.int32) * noise_level
162
+
163
+ #Call using the noisy data lol...
164
+ v = call_model(train_state, x_t[0:5], noise_level, dt_base, labels)
165
+ diff = (v_t - v) ** 2
166
+ print("first loss", diff.mean())
167
+
168
+
169
+ #These are wrong because the velocity calculation uses x_1 and x_0, which is images and classes
170
+ image = x_0[0] + v_t[0]
171
+ image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
172
+ image = ((image + 1) * 127.5).clip(0, 255)
173
+ from PIL import Image
174
+ image = np.array(image).astype(np.uint8)
175
+ image = Image.fromarray(image)
176
+ image.save("denoised_image_real_v" + str(noise_level) + ".png")
177
+
178
+
179
+ image = x_0[0] + v[0]
180
+ image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
181
+ image = ((image + 1) * 127.5).clip(0, 255)
182
+ from PIL import Image
183
+ image = np.array(image).astype(np.uint8)
184
+ image = Image.fromarray(image)
185
+ image.save("denoised_image_" + str(noise_level) + ".png")
186
+
187
+ image = x_1[0]
188
+ image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
189
+ image = ((image + 1) * 127.5).clip(0, 255)
190
+ from PIL import Image
191
+ image = np.array(image).astype(np.uint8)
192
+ image = Image.fromarray(image)
193
+ image.save("actual_image_" + str(noise_level) + ".png")
194
+
195
+ image = x_t[0]
196
+ image = vae_decode(jnp.expand_dims(image, axis = 0)).squeeze()
197
+ image = ((image + 1) * 127.5).clip(0, 255)
198
+ from PIL import Image
199
+ image = np.array(image).astype(np.uint8)
200
+ image = Image.fromarray(image)
201
+ image.save("noised_image_" + str(noise_level) + ".png")
202
+
203
+
204
+
205
+ """
206
+ print("first dtbase", dt_base)
207
+ from baselines.targets_naive import get_targets
208
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, eps_key, train_state, batch_images[0:5], labels[0:5], -1, -1, classes_array)
209
+ #print("v_t2", v_t)
210
+ #This uses random ts, so it doesn't tell us shit.
211
+ v = call_model(train_state, x_t[0:5], noise_level, dt_base, labels)
212
+ print("second dtbase", dt_base)
213
+ print("second loss", ((v_t - v) ** 2).mean())
214
+ #Noise level 1.0 should be loss around 0.03...
215
+ #get mse v, vt_t
216
+ #if needed.
217
+ """
218
+ exit()
219
+ break
220
+
221
+ print("doing some decoding stuff")
222
+ for i in range(0,5):
223
+ image = x[i]
224
+ from PIL import Image
225
+ image = np.array(image).astype(np.uint8)
226
+ image = Image.fromarray(image)
227
+ image.save("noisestuff" + str(i) + ".png")
228
+ for i in range(0,5):
229
+ image = x_cond[i]
230
+ from PIL import Image
231
+ image = np.array(image).astype(np.uint8)
232
+ image = Image.fromarray(image)
233
+ image.save("condstuff" + str(i) + ".png")
234
+ image = x_uncond[0]
235
+ image = np.array(image).astype(np.uint8)
236
+ image = Image.fromarray(image)
237
+ image.save("uncondtuff" + str(i) + ".png")
238
+ #exit()
239
+
240
+
241
+ delta_t = 1.0 / denoise_timesteps
242
+ for ti in range(denoise_timesteps):
243
+ t = ti / denoise_timesteps # From x_0 (noise) to x_1 (data)
244
+ t_vector = jnp.full((images_shape[0], ), t)
245
+ if FLAGS.model.train_type == 'naive':
246
+ dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
247
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow # Smallest dt.
248
+ else: # shortcut
249
+ dt_flow = np.log2(denoise_timesteps).astype(jnp.int32)#[128,64,32,16,8,4,2,1] = [7,6,5,4,3,2,1,0]
250
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow #For 128 steps, distance = 7, maximum distance.
251
+
252
+ t_vector, dt_base = shard_data(t_vector, dt_base)
253
+ if cfg_scale == 1:
254
+ v = call_model(train_state, x, t_vector, dt_base, labels)
255
+ elif cfg_scale == 0:
256
+ v = call_model(train_state, x, t_vector, dt_base, labels_uncond)
257
+ else:
258
+ v_pred_uncond = call_model(train_state, x_uncond, t_vector, dt_base, labels_uncond)
259
+ v_pred_label = call_model(train_state, x_cond, t_vector, dt_base, labels)
260
+ v = v_pred_uncond + cfg_scale * (v_pred_label - v_pred_uncond)
261
+
262
+ if FLAGS.model.train_type == 'consistency':
263
+ eps = shard_data(jax.random.normal(jax.random.fold_in(eps_key, ti), images_shape))
264
+ x1pred = x + v * (1-t)
265
+ x = x1pred * (t+delta_t) + eps * (1-t-delta_t)
266
+
267
+ elif True:
268
+ x = x + v * delta_t # Euler sampling.
269
+ elif False:#special predictor. So with special. If we do a natural prediction of step 4, distance = 2... we do a step same x, but longer distance. so as if we were doing 2 steps
270
+ if ti + 1 == denoise_timesteps:
271
+ x = x + v * delta_t
272
+ else:
273
+ dt_flow = np.log2(denoise_timesteps/2).astype(jnp.int32)#[128,64,32,16,8,4,2,1] = [7,6,5,4,3,2,1,0]
274
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow
275
+
276
+ v_2_c = call_model(train_state, x, t_vector, dt_base, labels)
277
+ v_2_u = call_model(train_state, x, t_vector, dt_base, labels_uncond)
278
+ v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u)
279
+
280
+ #We might be able to skip doing CFG in the future
281
+
282
+ v_prime = (v + v_2) / 2
283
+ x = x + v_prime * delta_t
284
+ elif False:#midpiont
285
+
286
+ #print("ts", t)
287
+
288
+ if ti + 1 == denoise_timesteps:# or ti == 0:
289
+ x = x + v * delta_t
290
+ else:
291
+ pass
292
+ elif True:#heun 3
293
+
294
+ if ti + 1 == denoise_timesteps:
295
+ x = x + v * delta_t # Final Euler step
296
+ else:
297
+ # Stage 1
298
+ k1 = v # already computed
299
+ t1 = t
300
+
301
+ # Stage 2
302
+ x2 = x + (delta_t / 3) * k1
303
+ t_vector_2 = jnp.full((images_shape[0],), t1 + delta_t / 3)
304
+ t_vector_2 = shard_data(t_vector_2)
305
+ k2_c = call_model(train_state, x2, t_vector_2, dt_base, labels)
306
+ k2_u = call_model(train_state, x2, t_vector_2, dt_base, labels_uncond)
307
+ k2 = k2_u + cfg_scale * (k2_c - k2_u)
308
+
309
+ # Stage 3
310
+ x3 = x + (2 * delta_t / 3) * k2
311
+ t_vector_3 = jnp.full((images_shape[0],), t1 + 2 * delta_t / 3)
312
+ t_vector_3 = shard_data(t_vector_3)
313
+ k3_c = call_model(train_state, x3, t_vector_3, dt_base, labels)
314
+ k3_u = call_model(train_state, x3, t_vector_3, dt_base, labels_uncond)
315
+ k3 = k3_u + cfg_scale * (k3_c - k3_u)
316
+
317
+ # Combine stages
318
+ v_prime = (1/4) * k1 + (3/4) * k3
319
+ x = x + v_prime * delta_t
320
+ elif True:#Third order RK
321
+
322
+ if ti + 1 == denoise_timesteps:
323
+ x = x + v * delta_t # Final Euler step
324
+ else:
325
+ x1 = x
326
+ t1 = t
327
+ v1 = v
328
+
329
+ # Stage 2
330
+ x2 = x1 + v1 * delta_t / 2
331
+ t_vector_2 = jnp.full((images_shape[0],), t1 + delta_t / 2)
332
+ t_vector_2 = shard_data(t_vector_2)
333
+ v2_c = call_model(train_state, x2, t_vector_2, dt_base, labels)
334
+ v2_u = call_model(train_state, x2, t_vector_2, dt_base, labels_uncond)
335
+ v2 = v2_u + cfg_scale * (v2_c - v2_u)
336
+
337
+ # Stage 3
338
+ x3 = x1 - v1 * delta_t + 2 * v2 * delta_t
339
+ t_vector_3 = jnp.full((images_shape[0],), t1 + delta_t)
340
+ t_vector_3 = shard_data(t_vector_3)
341
+ v3_c = call_model(train_state, x3, t_vector_3, dt_base, labels)
342
+ v3_u = call_model(train_state, x3, t_vector_3, dt_base, labels_uncond)
343
+ v3 = v3_u + cfg_scale * (v3_c - v3_u)
344
+
345
+ # Weighted sum of stages
346
+ v_prime = (v1 + 4 * v2 + v3) / 6
347
+ x = x + v_prime * delta_t
348
+
349
+ elif True:#heun
350
+ #Last time euler
351
+ if ti + 1 == denoise_timesteps:# or ti == 0:
352
+ x = x + v * delta_t
353
+ else:
354
+ x_2 = x + v * delta_t
355
+ #print("original t", t_vector)
356
+ t_vector_2 = jnp.full((images_shape[0], ), t + delta_t)
357
+ t_vector_2 = shard_data(t_vector_2)
358
+ #print("second t", t_vector_2)
359
+ v_2_c = call_model(train_state, x_2, t_vector_2, dt_base, labels)
360
+ v_2_u = call_model(train_state, x_2, t_vector_2, dt_base, labels_uncond)
361
+ v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u)
362
+
363
+ # print(jnp.linalg.norm(v))
364
+ # print(jnp.linalg.norm(v_2))
365
+
366
+ v_prime = (v + v_2) / 2
367
+ x = x + v_prime * delta_t
368
+
369
+
370
+ elif False:#DPM++2M maybe?
371
+
372
+ if ti + 1 == denoise_timesteps:
373
+ x = x + v * delta_t
374
+ continue
375
+ sigma_hat = t#Current timestep for me
376
+
377
+ #we already have v here, v = d
378
+
379
+ #Should just be the next timestep?
380
+ sigma_i_1 = t + delta_t
381
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
382
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigma_i_1 ** (1 / 3)) / 2) ** 3
383
+ dt_1 = sigma_mid - sigma_hat
384
+ dt_2 = sigma_i_1 - sigma_hat
385
+
386
+ x_2 = x + v * dt_1
387
+
388
+ t_vector_2 = jnp.full((images_shape[0], ), sigma_mid)
389
+
390
+ v_2_c = call_model(train_state, x_2, t_vector_2, dt_base, labels)
391
+ v_2_u = call_model(train_state, x_2, t_vector_2, dt_base, labels_uncond)
392
+ v_2 = v_2_u + cfg_scale * (v_2_c - v_2_u)
393
+
394
+ x = x + v_2 * dt_2
395
+
396
+ elif False:#RF-solver solution #tcurr and tprev are... 0,0 1,1, 1,2, 2,2, 3,3 3,4, 4,4, 4,5....
397
+ img_mid = x + (t_prev - t_curr)/2 * v
398
+ t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
399
+ v_2 = model(img_mid, t_vec_mid)
400
+
401
+ first_order = (v_2 - v) / ((t_prev - t_curr) / 2)
402
+ x = x = (t_prev - t_curr) * v + .5 * (t_prev - t_curr) ** 2 * first_order
403
+
404
+
405
+ x1.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
406
+ lab.append(np.array(jax.experimental.multihost_utils.process_allgather(labels)))
407
+ if FLAGS.model.use_stable_vae:
408
+ x = vae_decode(x) # Image is in [-1, 1] space.
409
+ if num_generations < 10000:
410
+ x_render.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
411
+
412
+ #This happens EVERY LOOP
413
+ print("decode n shit", x.shape)
414
+ if False:
415
+ for i in range(0,5):
416
+ image = x[i]
417
+ image = ((image + 1) * 127.5).clip(0, 255)
418
+ from PIL import Image
419
+ image = np.array(image).astype(np.uint8)
420
+ image = Image.fromarray(image)
421
+ image.save("stuff" + str(i) + ".png")
422
+ print("done")
423
+ # exit()
424
+
425
+ x = jax.image.resize(x, (x.shape[0], 299, 299, 3), method='bilinear', antialias=False)
426
+ x = jnp.clip(x, -1, 1)
427
+ acts = get_fid_activations(x)[..., 0, 0, :] # [devices, batch//devices, 2048]
428
+ acts = jax.experimental.multihost_utils.process_allgather(acts)
429
+ acts = np.array(acts)
430
+ activations.append(acts)
431
+
432
+ if jax.process_index() == 0:
433
+ activations = np.concatenate(activations, axis=0)
434
+ activations = activations.reshape((-1, activations.shape[-1]))
435
+ mu1 = np.mean(activations, axis=0)
436
+ sigma1 = np.cov(activations, rowvar=False)
437
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
438
+ print(f"FID is {fid}")
439
+ return
440
+
441
+ if FLAGS.save_dir is not None:
442
+ os.makedirs(FLAGS.save_dir, exist_ok=True)
443
+ x_render = np.concatenate(x_render, axis=0)
444
+ np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
445
+
446
+ # x0 = np.concatenate(x0, axis=0)
447
+ # x1 = np.concatenate(x1, axis=0)
448
+ # lab = np.concatenate(lab, axis=0)
449
+ # os.makedirs(FLAGS.save_dir, exist_ok=True)
450
+ # np.save(FLAGS.save_dir + f'/x0.npy', x0)
451
+ # np.save(FLAGS.save_dir + f'/x1.npy', x1)
452
+ # np.save(FLAGS.save_dir + f'/lab.npy', lab)
453
+