KublaiKhan1 commited on
Commit
51e86ae
·
verified ·
1 Parent(s): 5cf25d3

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. whiten/calc_means.py +438 -0
  2. whiten/stable_vae.py +101 -0
  3. whiten/whiten.py +41 -0
whiten/calc_means.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import jax.numpy as jnp
3
+ from absl import app, flags
4
+ from functools import partial
5
+ import numpy as np
6
+ import tqdm
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import flax
10
+ import optax
11
+ import wandb
12
+ from ml_collections import config_flags
13
+ import ml_collections
14
+
15
+ from utils.wandb import setup_wandb, default_wandb_config
16
+ from utils.train_state import TrainStateEma
17
+ from utils.checkpoint import Checkpoint
18
+ from utils.stable_vae import StableVAE
19
+ from utils.sharding import create_sharding, all_gather
20
+ from utils.datasets import get_dataset
21
+ from model import DiT
22
+ from helper_eval import eval_model
23
+ from helper_inference import do_inference
24
+
25
+ FLAGS = flags.FLAGS
26
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
27
+ flags.DEFINE_string('load_dir', None, 'Logging dir (if not None, save params).')
28
+ flags.DEFINE_string('save_dir', './checkpoints/', 'Logging dir (if not None, save params).')
29
+ flags.DEFINE_string('fid_stats', None, 'FID stats file.')
30
+ flags.DEFINE_integer('seed', 10, 'Random seed.') # Must be the same across all processes.
31
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
32
+ flags.DEFINE_integer('eval_interval', 1000000, 'Eval interval.')
33
+ flags.DEFINE_integer('save_interval', 10000, 'Save interval.')
34
+ flags.DEFINE_integer('batch_size', 512, 'Mini batch size.')
35
+ flags.DEFINE_integer('max_steps', int(810_000), 'Number of training steps.')
36
+ flags.DEFINE_integer('debug_overfit', 0, 'Debug overfitting.')
37
+ flags.DEFINE_string('mode', 'train', 'train or inference.')
38
+
39
+ model_config = ml_collections.ConfigDict({
40
+ 'lr': 0.0001,
41
+ 'beta1': 0.9,
42
+ 'beta2': 0.999,
43
+ 'weight_decay': 0.1,
44
+ 'use_cosine': 0,
45
+ 'warmup': 0,
46
+ 'dropout': 0.0,
47
+ 'hidden_size': 768, # change this!
48
+ 'patch_size': 2, # change this!
49
+ 'depth': 12, # change this!
50
+ 'num_heads': 12, # change this!
51
+ 'mlp_ratio': 4, # change this!
52
+ 'class_dropout_prob': 0.1,
53
+ 'num_classes': 1000,
54
+ 'denoise_timesteps': 128,
55
+ 'cfg_scale': 4.0,
56
+ 'target_update_rate': 0.999,
57
+ 'use_ema': 0,
58
+ 'use_stable_vae': 1,
59
+ 'sharding': 'dp', # dp or fsdp.
60
+ 't_sampling': 'discrete-dt',
61
+ 'dt_sampling': 'uniform',
62
+ 'bootstrap_cfg': 1,
63
+ 'bootstrap_every': 8, # Make sure its a divisor of batch size.
64
+ 'bootstrap_ema': 1,
65
+ 'bootstrap_dt_bias': 0,
66
+ 'train_type': 'shortcut' # or naive.
67
+ })
68
+
69
+
70
+ #wandb_config = default_wandb_config()
71
+ #wandb_config.update({
72
+ # 'project': 'shortcut',
73
+ # 'name': 'shortcut_{dataset_name}',
74
+ #})
75
+
76
+ #config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
77
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
78
+
79
+ ##############################################
80
+ ## Training Code.
81
+ ##############################################
82
+ def main(_):
83
+
84
+ np.random.seed(FLAGS.seed)
85
+ print("Using devices", jax.local_devices())
86
+ device_count = len(jax.local_devices())
87
+ global_device_count = jax.device_count()
88
+ print("Device count", device_count)
89
+ print("Global device count", global_device_count)
90
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
91
+ print("Global Batch: ", FLAGS.batch_size)
92
+ print("Node Batch: ", local_batch_size)
93
+ print("Device Batch:", local_batch_size // device_count)
94
+
95
+
96
+ dataset = get_dataset(FLAGS.dataset_name, local_batch_size, True, FLAGS.debug_overfit)
97
+ dataset_valid = get_dataset(FLAGS.dataset_name, local_batch_size, False, FLAGS.debug_overfit)
98
+ example_obs, example_labels = next(dataset)
99
+ example_obs = example_obs[:1]
100
+ example_obs_shape = example_obs.shape
101
+
102
+ if FLAGS.model.use_stable_vae:
103
+ vae = StableVAE.create()
104
+ if 'latent' in FLAGS.dataset_name:
105
+ example_obs = example_obs[:, :, :, example_obs.shape[-1] // 2:]
106
+ example_obs_shape = example_obs.shape
107
+ else:
108
+ example_obs = vae.encode(jax.random.PRNGKey(0), example_obs)
109
+ example_obs_shape = example_obs.shape
110
+ vae_rng = jax.random.PRNGKey(42)
111
+ vae_encode = jax.jit(vae.encode)
112
+ vae_decode = jax.jit(vae.decode)
113
+
114
+ if FLAGS.fid_stats is not None:
115
+ from utils.fid import get_fid_network, fid_from_stats
116
+ get_fid_activations = get_fid_network()
117
+ truth_fid_stats = np.load(FLAGS.fid_stats)
118
+ else:
119
+ get_fid_activations = None
120
+ truth_fid_stats = None
121
+
122
+ ###################################
123
+ # Creating Model and put on devices.
124
+ ###################################
125
+ FLAGS.model.image_channels = example_obs_shape[-1]
126
+ FLAGS.model.image_size = example_obs_shape[1]
127
+ dit_args = {
128
+ 'patch_size': FLAGS.model['patch_size'],
129
+ 'hidden_size': FLAGS.model['hidden_size'],
130
+ 'depth': FLAGS.model['depth'],
131
+ 'num_heads': FLAGS.model['num_heads'],
132
+ 'mlp_ratio': FLAGS.model['mlp_ratio'],
133
+ 'out_channels': example_obs_shape[-1],
134
+ 'class_dropout_prob': FLAGS.model['class_dropout_prob'],
135
+ 'num_classes': FLAGS.model['num_classes'],
136
+ 'dropout': FLAGS.model['dropout'],
137
+ 'ignore_dt': False if (FLAGS.model['train_type'] in ('shortcut', 'livereflow')) else True,
138
+ }
139
+ model_def = DiT(**dit_args)
140
+ tabulate_fn = flax.linen.tabulate(model_def, jax.random.PRNGKey(0))
141
+ print(tabulate_fn(example_obs, jnp.zeros((1,)), jnp.zeros((1,)), jnp.zeros((1,), dtype=jnp.int32)))
142
+
143
+ if FLAGS.model.use_cosine:
144
+ lr_schedule = optax.warmup_cosine_decay_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'], FLAGS.max_steps)
145
+ elif FLAGS.model.warmup > 0:
146
+ lr_schedule = optax.linear_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'])
147
+ else:
148
+ lr_schedule = lambda x: FLAGS.model['lr']
149
+ adam = optax.adamw(learning_rate=lr_schedule, b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'], weight_decay=FLAGS.model['weight_decay'])
150
+ tx = optax.chain(adam)
151
+
152
+ start_step = 1
153
+
154
+ def log_param_shapes(params, label=""):
155
+ flat = flax.traverse_util.flatten_dict(params)
156
+
157
+ squeezed_flat = {k: jnp.squeeze(v, axis = 0) for k, v in flat.items() if v.shape[0] == 1}
158
+ print(f"\n{label} parameter shapes:")
159
+ for k, v in flat.items():
160
+ print(f"{k}: {v.shape}")
161
+ return flax.traverse_util.unflatten_dict(squeezed_flat)
162
+
163
+
164
+ def init(rng):
165
+ param_key, dropout_key, dropout2_key = jax.random.split(rng, 3)
166
+ example_t = jnp.zeros((1,))
167
+ example_dt = jnp.zeros((1,))
168
+ example_label = jnp.zeros((1,), dtype=jnp.int32)
169
+ example_obs = jnp.zeros(example_obs_shape)
170
+ model_rngs = {'params': param_key, 'label_dropout': dropout_key, 'dropout': dropout2_key}
171
+ params = model_def.init(model_rngs, example_obs, example_t, example_dt, example_label)['params']
172
+ opt_state = tx.init(params)
173
+
174
+ ts = TrainStateEma.create(model_def, params, rng=rng, tx=tx, opt_state=opt_state)
175
+
176
+
177
+ if FLAGS.load_dir is not None:
178
+
179
+ cp = Checkpoint(FLAGS.load_dir)
180
+ train_state_load = cp.load_as_dict()["train_state"]
181
+
182
+ log_param_shapes(ts.params)
183
+ flat = log_param_shapes(train_state_load["params"])
184
+ flat_ema = log_param_shapes(train_state_load["params_ema"])
185
+ flat_mu = log_param_shapes(train_state_load["opt_state"][0][0].mu)
186
+ flat_nu = log_param_shapes(train_state_load["opt_state"][0][0].nu)
187
+
188
+ from optax import ScaleByAdamState
189
+ opt_state = train_state_load["opt_state"]
190
+ new_state = ScaleByAdamState(
191
+ opt_state[0][0].count,
192
+ mu=flat_mu,
193
+ nu=flat_nu
194
+ )
195
+ opt_state = list(opt_state)
196
+ opt_state[0] = list(opt_state[0])
197
+ opt_state[0][0] = new_state
198
+
199
+ opt_state[0] = tuple(opt_state[0])
200
+ opt_state = tuple(opt_state)
201
+
202
+ train_state_load = TrainStateEma.create(model_def, params = flat, rng = rng, tx = tx, opt_state=opt_state)
203
+
204
+
205
+ #Need to replace EMA because we have a separate ema
206
+ log_param_shapes(train_state_load.params)
207
+
208
+ train_state_load = train_state_load.replace(params_ema = flat_ema)
209
+ start_step = train_state_load.step
210
+
211
+ ts = train_state_load
212
+
213
+ return ts
214
+
215
+ rng = jax.random.PRNGKey(FLAGS.seed)
216
+ train_state_shape = jax.eval_shape(init, rng)
217
+
218
+ data_sharding, train_state_sharding, no_shard, shard_data, global_to_local = create_sharding(FLAGS.model.sharding, train_state_shape)
219
+ train_state = jax.jit(init, out_shardings=train_state_sharding)(rng)
220
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
221
+ jax.debug.visualize_array_sharding(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
222
+ jax.experimental.multihost_utils.assert_equal(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
223
+
224
+
225
+ if FLAGS.model.train_type == 'progressive' or FLAGS.model.train_type == 'consistency-distillation':
226
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
227
+ else:
228
+ train_state_teacher = None
229
+
230
+ visualize_labels = example_labels
231
+ visualize_labels = shard_data(visualize_labels)
232
+ visualize_labels = jax.experimental.multihost_utils.process_allgather(visualize_labels)
233
+ imagenet_labels = open('data/imagenet_labels.txt').read().splitlines()
234
+
235
+ ###################################
236
+ # Update Function
237
+ ###################################
238
+
239
+ @partial(jax.jit, out_shardings=(train_state_sharding, no_shard))
240
+ def update(train_state, train_state_teacher, images, labels, force_t=-1, force_dt=-1):
241
+ new_rng, targets_key, dropout_key, perm_key = jax.random.split(train_state.rng, 4)
242
+ info = {}
243
+
244
+ id_perm = jax.random.permutation(perm_key, images.shape[0])
245
+ images = images[id_perm]
246
+ labels = labels[id_perm]
247
+ images = jax.lax.with_sharding_constraint(images, data_sharding)
248
+ labels = jax.lax.with_sharding_constraint(labels, data_sharding)
249
+
250
+ if FLAGS.model['cfg_scale'] == 0: # For unconditional generation.
251
+ labels = jnp.ones(labels.shape[0], dtype=jnp.int32) * FLAGS.model['num_classes']
252
+
253
+ if FLAGS.model['train_type'] == 'naive':
254
+ from baselines.targets_naive import get_targets
255
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
256
+ elif FLAGS.model['train_type'] == 'shortcut':
257
+ from targets_shortcut import get_targets
258
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
259
+ elif FLAGS.model['train_type'] == 'progressive':
260
+ from baselines.targets_progressive import get_targets
261
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
262
+ elif FLAGS.model['train_type'] == 'consistency-distillation':
263
+ from baselines.targets_consistency_distillation import get_targets
264
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
265
+ elif FLAGS.model['train_type'] == 'consistency':
266
+ from baselines.targets_consistency_training import get_targets
267
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
268
+ elif FLAGS.model['train_type'] == 'livereflow':
269
+ from baselines.targets_livereflow import get_targets
270
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
271
+
272
+ def loss_fn(grad_params):
273
+ v_prime, logvars, activations = train_state.call_model(x_t, t, dt_base, labels, train=True, rngs={'dropout': dropout_key}, params=grad_params, return_activations=True)
274
+ mse_v = jnp.mean((v_prime - v_t) ** 2, axis=(1, 2, 3))
275
+ loss = jnp.mean(mse_v)
276
+
277
+ if True:#cosine direction velocity
278
+ cos_loss = 1-optax.cosine_distance(v_prime, v_t, axis = 3, epsilon = 1e-5)
279
+ cos_v = jnp.mean(cos_loss, axis = [1,2])
280
+ cos_loss = cos_v.mean()
281
+
282
+ info = {
283
+ 'loss': loss,
284
+ 'v_magnitude_prime': jnp.sqrt(jnp.mean(jnp.square(v_prime))),
285
+ **{'activations/' + k : jnp.sqrt(jnp.mean(jnp.square(v))) for k, v in activations.items()},
286
+ 'cosine_loss': cos_loss,
287
+ }
288
+
289
+ if FLAGS.model['train_type'] == 'shortcut' or FLAGS.model['train_type'] == 'livereflow':
290
+ bootstrap_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
291
+ info['loss_flow'] = jnp.mean(mse_v[bootstrap_size:])
292
+ info['loss_bootstrap'] = jnp.mean(mse_v[:bootstrap_size])
293
+ info['cosine_loss_flow'] = jnp.mean(cos_v[bootstrap_size:])
294
+ info['cosine_loss_boostrap'] = jnp.mean(cos_v[:bootstrap_size])
295
+ if True:
296
+ loss = loss + cos_loss
297
+
298
+ return loss, info
299
+
300
+ grads, new_info = jax.grad(loss_fn, has_aux=True)(train_state.params)
301
+ info = {**info, **new_info}
302
+ updates, new_opt_state = train_state.tx.update(grads, train_state.opt_state, train_state.params)
303
+ new_params = optax.apply_updates(train_state.params, updates)
304
+
305
+ info['grad_norm'] = optax.global_norm(grads)
306
+ info['update_norm'] = optax.global_norm(updates)
307
+ info['param_norm'] = optax.global_norm(new_params)
308
+ info['lr'] = lr_schedule(train_state.step)
309
+
310
+ train_state = train_state.replace(rng=new_rng, step=train_state.step + 1, params=new_params, opt_state=new_opt_state)
311
+ train_state = train_state.update_ema(FLAGS.model['target_update_rate'])
312
+ return train_state, info
313
+
314
+
315
+ ###################################
316
+ # Train Loop
317
+ ###################################
318
+ global_mean = None
319
+ class_means = {}
320
+ total = 1281167
321
+ #Do we need to do global means more often?
322
+ print("starting this shit")
323
+ i = 0
324
+ cpus = jax.devices("cpu")
325
+ images = []
326
+ for i in range(0, int(total/512)):
327
+ print(i)
328
+ i += 1
329
+
330
+ batch_images, batch_labels = shard_data(*next(dataset))
331
+ vae_rng, vae_key = jax.random.split(vae_rng)
332
+ batch_images = vae_encode(vae_key, batch_images)
333
+ #print(batch_images.shape)#512x32x32x4
334
+ if global_mean == None:
335
+ global_mean = batch_images.mean(axis = 0)/total
336
+ else:
337
+ global_mean += batch_images.mean(axis = 0)/total
338
+
339
+ for key, bimage in zip(batch_labels, batch_images):
340
+ key = str(int(key))
341
+ if key in class_means.keys():
342
+ class_means[key] = class_means[key] + bimage/total
343
+ else:
344
+ class_means[key] = np.asarray(bimage/total)
345
+
346
+ # z = jax.device_put(batch_images, cpus[0])
347
+ images.append(batch_images)
348
+
349
+ images = jnp.asarray(images)
350
+ #maybe just save images and exit?
351
+ np.savez("images.npz", images)
352
+ exit()
353
+
354
+ """
355
+ #Get per channel stats.
356
+ batch_shape = images.shape[0] * images.shape[1]
357
+ H, W = images.shape[2], images.shape[3]
358
+ images_white = jnp.zeros(images.shape)
359
+ stats = []
360
+
361
+ for c in range(images.shape[-1]):
362
+
363
+ x = images[:,:,:,:,c].reshape(batch_shape, -1)#Get h*w by batch
364
+ mean = x.mean(axis = 0, keepdims = True)
365
+ x_centered = x - mean
366
+ cov = x_centered.T @ x_centered / (batch_shape - 1) # shape: (H*W, H*W)
367
+ U, S, _ = jnp.linalg.svd(cov, full_matrices=False)
368
+ S_inv_root = jnp.diag(1.0 / jnp.sqrt(S + 1e-5))
369
+ zca = U @ S_inv_root @ U.T
370
+
371
+ x_whitened = (zca @ x_centered.T).T # shape: (B, H*W)
372
+ images_whitened[:, :, :, :, c] = x_whitened.view(B, H, W)
373
+
374
+ stats.append((mean, zca)) # Save stats for unwhitening
375
+
376
+ #Now we need to save stats?
377
+ np.savez("stats.npz", stats)
378
+ """
379
+
380
+ # jnp.save("global_mean", global_mean)
381
+ # np.savez("classes.npz", **class_means)
382
+ exit()
383
+ for i in tqdm.tqdm(range(1 + start_step, FLAGS.max_steps + 1 + start_step),
384
+ smoothing=0.1,
385
+ dynamic_ncols=True):
386
+
387
+ # Sample data.
388
+ if not FLAGS.debug_overfit or i == 1:
389
+ batch_images, batch_labels = shard_data(*next(dataset))
390
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
391
+ vae_rng, vae_key = jax.random.split(vae_rng)
392
+ batch_images = vae_encode(vae_key, batch_images)
393
+
394
+
395
+
396
+ # Train update.
397
+ train_state, update_info = update(train_state, train_state_teacher, batch_images, batch_labels)
398
+
399
+ if i % FLAGS.log_interval == 0 or i == 1:
400
+ update_info = jax.device_get(update_info)
401
+ update_info = jax.tree_map(lambda x: np.array(x), update_info)
402
+ update_info = jax.tree_map(lambda x: x.mean(), update_info)
403
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
404
+
405
+ valid_images, valid_labels = shard_data(*next(dataset_valid))
406
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
407
+ valid_images = vae_encode(vae_rng, valid_images)
408
+ _, valid_update_info = update(train_state, train_state_teacher, valid_images, valid_labels)
409
+ valid_update_info = jax.device_get(valid_update_info)
410
+ valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info)
411
+ train_metrics['training/loss_valid'] = valid_update_info['loss']
412
+ train_metrics['training/loss_cosine'] = valid_update_info['cosine_loss']
413
+
414
+ if jax.process_index() == 0:
415
+ wandb.log(train_metrics, step=i)
416
+
417
+ if FLAGS.model['train_type'] == 'progressive':
418
+ num_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
419
+ if i % (FLAGS.max_steps // num_sections) == 0:
420
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
421
+
422
+ if i % FLAGS.eval_interval == 0:
423
+ eval_model(FLAGS, train_state, train_state_teacher, i, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
424
+ get_fid_activations, imagenet_labels, visualize_labels,
425
+ fid_from_stats, truth_fid_stats)
426
+
427
+ if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
428
+ train_state_gather = jax.experimental.multihost_utils.process_allgather(train_state)
429
+ if jax.process_index() == 0:
430
+ cp = Checkpoint(FLAGS.save_dir+str(train_state_gather.step+1), parallel=False)
431
+ cp.train_state = train_state_gather
432
+ cp.save()
433
+ del cp
434
+ del train_state_gather
435
+
436
+ if __name__ == '__main__':
437
+ app.run(main)
438
+
whiten/stable_vae.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import partial, cached_property
3
+
4
+ import jax
5
+ from diffusers import FlaxAutoencoderKL
6
+ from einops import rearrange
7
+ from flax import struct
8
+
9
+ from jaxtyping import Array, PyTree, Key, Float, Shaped, Int, UInt8, jaxtyped
10
+ from typeguard import typechecked
11
+ from functools import partial
12
+ typecheck = partial(jaxtyped, typechecker=typechecked)
13
+
14
+ import jax.numpy as jnp
15
+
16
+ import pickle
17
+ def load_stats(path="stats.pkl"):
18
+ with open(path, "rb") as f:
19
+ return pickle.load(f)
20
+
21
+ try:
22
+ stats = load_stats()#mean, zca
23
+ except:
24
+ pass
25
+ @struct.dataclass
26
+ class StableVAE:
27
+ params: PyTree[Float[Array, "..."]]
28
+ module: FlaxAutoencoderKL = struct.field(pytree_node=False)
29
+
30
+ @classmethod
31
+ def create(cls) -> "VAE":
32
+ # module, params = FlaxAutoencoderKL.from_pretrained(
33
+ # "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae"
34
+ # )
35
+ module, params = FlaxAutoencoderKL.from_pretrained(
36
+ "pcuenq/sd-vae-ft-mse-flax"
37
+ )
38
+ params = jax.device_get(params)
39
+ return cls(
40
+ params=params,
41
+ module=module,
42
+ )
43
+
44
+ @partial(jax.jit, static_argnames="scale")
45
+ def encode(
46
+ self, key: Key[Array, ""], images: Float[Array, "b h w 3"], scale: bool = True
47
+ ) -> Float[Array, "b lh lw 4"]:
48
+ images = rearrange(images, "b h w c -> b c h w")
49
+ latents = self.module.apply(
50
+ {"params": self.params}, images, method=self.module.encode
51
+ ).latent_dist.sample(key)
52
+
53
+ # return latents
54
+ B, H, W, C = latents.shape
55
+ latents_whitened = jnp.zeros(latents.shape)
56
+ for c in range(C):
57
+ x = latents[:, :, :, c].reshape(B, -1)#We are channels last probably
58
+ mean, zca = stats[c]
59
+
60
+ x_centered = x - mean
61
+ x_whitened = (zca @ x_centered.T).T
62
+ latents_whitened = latents_whitened.at[:, :, :, c].set(x_whitened.reshape(B, H, W))
63
+
64
+ # if scale:
65
+ # latents *= self.module.config.scaling_factor
66
+ return latents_whitened
67
+
68
+ @partial(jax.jit, static_argnames="scale")
69
+ def decode(
70
+ self, latents: Float[Array, "b lh lw 4"], scale: bool = True
71
+ ) -> Float[Array, "b h w 3"]:
72
+ #if scale:
73
+ # latents /= self.module.config.scaling_factor
74
+
75
+ # latents = latents.reshape(1)#256x32x32x4
76
+ #Not sure these latents are correct shape, but whatever
77
+ B, H, W, C = latents.shape
78
+ latents_unwhitened = jnp.zeros(latents.shape)
79
+
80
+ for c in range(C):
81
+ x = latents[:, :, :, c].reshape(B, -1)
82
+ mean, zca = stats[c]
83
+ zca_inv = jnp.linalg.inv(zca)
84
+
85
+ x_unwhitened = (zca_inv @ x.T).T + mean
86
+ latents_unwhitened = latents_unwhitened.at[:, : ,: ,c].set(x_unwhitened.reshape(B,H,W))
87
+
88
+ latents = latents_unwhitened
89
+ #I don't think you need to sample to encode and sample to decode.
90
+ images = self.module.apply(
91
+ {"params": self.params}, latents, method=self.module.decode
92
+ ).sample
93
+
94
+ # convert to channels-last
95
+ #This actually just converts to channels FIRST, which is needed to convert to image
96
+ images = rearrange(images, "b c h w -> b h w c")
97
+ return images
98
+
99
+ @cached_property
100
+ def downscale_factor(self) -> int:
101
+ return 2 ** (len(self.module.block_out_channels) - 1)
whiten/whiten.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import gc
5
+ images = np.load("images.npz")["arr_0"]
6
+ print(images.shape)
7
+ if True:
8
+ batch_shape = images.shape[0] * images.shape[1]
9
+ H, W = images.shape[2], images.shape[3]
10
+ images_white = jnp.zeros(images.shape)
11
+ stats = []
12
+
13
+ for c in range(images.shape[-1]):
14
+ print(c)
15
+ x = images[:,:,:,:,c].reshape(batch_shape, -1)#Get h*w by batch
16
+ print(x.shape)
17
+ mean = x.mean(axis = 0, keepdims = True)
18
+ print(mean.shape)#It's like 1024, because it's reshaped.
19
+ x = x - mean
20
+ cov = x.T @ x / (batch_shape - 1) # shape: (H*W, H*W)
21
+ U, S, _ = jnp.linalg.svd(cov, full_matrices=False)
22
+ S_inv_root = jnp.diag(1.0 / jnp.sqrt(S + 1e-5))
23
+ zca = U @ S_inv_root @ U.T
24
+ del cov
25
+ del U
26
+ del S
27
+ del _
28
+ del S_inv_root
29
+ x = (zca @ x.T).T # shape: (B, H*W)
30
+ gc.collect()
31
+ #images_whitened[:, :, :, :, c] = x.reshape(images.shape[0], images.shape[1],images.shape[2], images.shape[3])
32
+
33
+ #only need mean and zca..
34
+ stats.append((mean, zca)) # Save stats for unwhitening
35
+
36
+ #Now we need to save stats?
37
+ # np.savez("stats.npz", stats)
38
+ import pickle
39
+ with open("stats.pkl","wb") as f:
40
+ pickle.dump(stats, f)
41
+