KublaiKhan1 commited on
Commit
7ad2969
·
verified ·
1 Parent(s): 1703a82

Upload folder using huggingface_hub

Browse files
1e-6_kl_naive_globalscale_channelmean/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
1e-6_kl_naive_globalscale_channelmean/train.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import jax.numpy as jnp
3
+ from absl import app, flags
4
+ from functools import partial
5
+ import numpy as np
6
+ import tqdm
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import flax
10
+ import optax
11
+ import wandb
12
+ from ml_collections import config_flags
13
+ import ml_collections
14
+
15
+ from PIL import Image
16
+
17
+ from utils.wandb import setup_wandb, default_wandb_config
18
+ from utils.train_state import TrainStateEma
19
+ from utils.checkpoint import Checkpoint
20
+ from utils.stable_vae import StableVAE
21
+ from utils.my_vae import MyVAE
22
+ from utils.sharding import create_sharding, all_gather
23
+ from utils.datasets import get_dataset
24
+ from model import DiT
25
+ from helper_eval import eval_model
26
+ from helper_inference import do_inference
27
+
28
+ FLAGS = flags.FLAGS
29
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
30
+ flags.DEFINE_string('load_dir', None, 'Logging dir (if not None, save params).')
31
+ flags.DEFINE_string('save_dir', './checkpoints/', 'Logging dir (if not None, save params).')
32
+ flags.DEFINE_string('fid_stats', None, 'FID stats file.')
33
+ flags.DEFINE_integer('seed', 10, 'Random seed.') # Must be the same across all processes.
34
+ flags.DEFINE_integer('log_interval', 100, 'Logging interval.')
35
+ flags.DEFINE_integer('eval_interval', 1000000, 'Eval interval.')
36
+ flags.DEFINE_integer('save_interval', 10000, 'Save interval.')
37
+ flags.DEFINE_integer('batch_size', 512, 'Mini batch size.')
38
+ flags.DEFINE_integer('max_steps', int(500_000), 'Number of training steps.')
39
+ flags.DEFINE_integer('debug_overfit', 0, 'Debug overfitting.')
40
+ flags.DEFINE_string('mode', 'train', 'train or inference.')
41
+
42
+ model_config = ml_collections.ConfigDict({
43
+ 'lr': 0.0001,
44
+ 'beta1': 0.9,
45
+ 'beta2': 0.999,
46
+ 'weight_decay': 0.1,
47
+ 'use_cosine': 0,
48
+ 'warmup': 0,
49
+ 'dropout': 0.0,
50
+ 'hidden_size': 64, # change this!
51
+ 'patch_size': 8, # change this!
52
+ 'depth': 2, # change this!
53
+ 'num_heads': 2, # change this!
54
+ 'mlp_ratio': 1, # change this!
55
+ 'class_dropout_prob': 0.1,
56
+ 'num_classes': 1000,
57
+ 'denoise_timesteps': 128,
58
+ 'cfg_scale': 4.0,
59
+ 'target_update_rate': 0.999,
60
+ 'use_ema': 0,
61
+ 'use_stable_vae': 1,
62
+ 'sharding': 'dp', # dp or fsdp.
63
+ 't_sampling': 'discrete-dt',
64
+ 'dt_sampling': 'uniform',
65
+ 'bootstrap_cfg': 0,
66
+ 'bootstrap_every': 8, # Make sure its a divisor of batch size.
67
+ 'bootstrap_ema': 1,
68
+ 'bootstrap_dt_bias': 0,
69
+ 'train_type': 'shortcut' # or naive.
70
+ })
71
+
72
+
73
+ wandb_config = default_wandb_config()
74
+ wandb_config.update({
75
+ 'project': 'shortcut',
76
+ 'name': 'shortcut_{dataset_name}',
77
+ })
78
+
79
+ config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
80
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
81
+
82
+ ##############################################
83
+ ## Training Code.
84
+ ##############################################
85
+ def main(_):
86
+
87
+ np.random.seed(FLAGS.seed)
88
+ print("Using devices", jax.local_devices())
89
+ device_count = len(jax.local_devices())
90
+ global_device_count = jax.device_count()
91
+ print("Device count", device_count)
92
+ print("Global device count", global_device_count)
93
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
94
+ print("Global Batch: ", FLAGS.batch_size)
95
+ print("Node Batch: ", local_batch_size)
96
+ print("Device Batch:", local_batch_size // device_count)
97
+
98
+ # Create wandb logger
99
+ if jax.process_index() == 0 and FLAGS.mode == 'train':
100
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
101
+
102
+ dataset = get_dataset(FLAGS.dataset_name, local_batch_size, True, FLAGS.debug_overfit)
103
+ dataset_valid = get_dataset(FLAGS.dataset_name, local_batch_size, False, FLAGS.debug_overfit)
104
+ example_obs, example_labels = next(dataset)
105
+ test_data = example_obs[:4]
106
+ example_obs = example_obs[:1]
107
+
108
+ example_obs_shape = example_obs.shape
109
+
110
+ if FLAGS.model.use_stable_vae:
111
+ #vae = StableVAE.create()
112
+
113
+ print("creating model")
114
+ #Create vae with IMAGe shapes
115
+ vae= MyVAE.create(example_obs)
116
+ print("model done")
117
+ if 'latent' in FLAGS.dataset_name:
118
+ example_obs = example_obs[:, :, :, example_obs.shape[-1] // 2:]
119
+ example_obs_shape = example_obs.shape
120
+ else:
121
+ #Need to expand the obs shape and repeat because our thing expects JIT
122
+
123
+ pass
124
+ x = jnp.expand_dims(example_obs, axis=0)
125
+ ## Repeat along the new axis to get shape (4, 3, 32, 32, 4)
126
+ #x = jnp.repeat(x,repeats=4, axis = 1)
127
+ x = jnp.repeat(x, repeats=4, axis=0)
128
+ print("Input to vae", x.shape)
129
+ #print(example_obs.shape)
130
+ example_obs, res = vae.encode(x)
131
+ print("output example shape", example_obs.shape)
132
+ #del x
133
+
134
+ #example_obs_shape = example_obs.shape
135
+ vae_rng = jax.random.PRNGKey(42)
136
+
137
+ #How do we do this?
138
+ # vae_encode = jax.jit(vae.encode)
139
+ # vae_decode = jax.jit(vae.decode)
140
+ vae_encode = vae.encode
141
+ vae_decode = vae.decode
142
+
143
+
144
+ print("Test data shape", test_data.shape)#4,256,256,3?
145
+ #save first image.
146
+ first = test_data[0]
147
+ image = (first * 255).astype(np.uint8)
148
+ image = np.array(image)
149
+ img = Image.fromarray(image)
150
+ img.save("testimg.png")
151
+
152
+ #Needs expansion to 4x
153
+ x = jnp.expand_dims(test_data, axis=0)
154
+ #So now we are 1,4,256,256,3
155
+ x = jnp.swapaxes(x, 0, 1)
156
+ print("x shape", x.shape)
157
+ encoded, res = vae_encode(x)
158
+ print("encoded shape", encoded.shape)
159
+ #It's possible we want to compress this
160
+ decoded = vae_decode(encoded)
161
+ print("image shape", decoded.shape)
162
+ #Encode, decode, log
163
+ decoded_img = decoded[0][0]
164
+ print("decoded img shape", decoded_img.shape)
165
+
166
+ image = (decoded_img * 255).astype(np.uint8)
167
+ image = np.array(image)
168
+ img = Image.fromarray(image)
169
+ img.save("decodedimg.png")
170
+
171
+ #Need example shape here again
172
+ example_obs = example_obs.squeeze()
173
+ print("obs shape", example_obs.shape)
174
+ example_obs_shape = example_obs.shape
175
+ if FLAGS.fid_stats is not None:
176
+ from utils.fid import get_fid_network, fid_from_stats
177
+ get_fid_activations = get_fid_network()
178
+ truth_fid_stats = np.load(FLAGS.fid_stats)
179
+ else:
180
+ get_fid_activations = None
181
+ truth_fid_stats = None
182
+
183
+ ###################################
184
+ # Creating Model and put on devices.
185
+ ###################################
186
+ FLAGS.model.image_channels = example_obs_shape[-1]
187
+ FLAGS.model.image_size = example_obs_shape[1]
188
+ dit_args = {
189
+ 'patch_size': FLAGS.model['patch_size'],
190
+ 'hidden_size': FLAGS.model['hidden_size'],
191
+ 'depth': FLAGS.model['depth'],
192
+ 'num_heads': FLAGS.model['num_heads'],
193
+ 'mlp_ratio': FLAGS.model['mlp_ratio'],
194
+ 'out_channels': example_obs_shape[-1],
195
+ 'class_dropout_prob': FLAGS.model['class_dropout_prob'],
196
+ 'num_classes': FLAGS.model['num_classes'],
197
+ 'dropout': FLAGS.model['dropout'],
198
+ 'ignore_dt': False if (FLAGS.model['train_type'] in ('shortcut', 'livereflow')) else True,
199
+ }
200
+ model_def = DiT(**dit_args)
201
+ tabulate_fn = flax.linen.tabulate(model_def, jax.random.PRNGKey(0))
202
+ print(tabulate_fn(example_obs, jnp.zeros((1,)), jnp.zeros((1,)), jnp.zeros((1,), dtype=jnp.int32)))
203
+
204
+ if FLAGS.model.use_cosine:
205
+ lr_schedule = optax.warmup_cosine_decay_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'], FLAGS.max_steps)
206
+ elif FLAGS.model.warmup > 0:
207
+ lr_schedule = optax.linear_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'])
208
+ else:
209
+ lr_schedule = lambda x: FLAGS.model['lr']
210
+ adam = optax.adamw(learning_rate=lr_schedule, b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'], weight_decay=FLAGS.model['weight_decay'])
211
+ tx = optax.chain(adam)
212
+ start_step = 1
213
+
214
+ def log_param_shapes(params, label=""):
215
+ flat = flax.traverse_util.flatten_dict(params)
216
+
217
+ squeezed_flat = {k: jnp.squeeze(v, axis = 0) for k, v in flat.items() if v.shape[0] == 1}
218
+ print(f"\n{label} parameter shapes:")
219
+ for k, v in flat.items():
220
+ print(f"{k}: {v.shape}")
221
+ return flax.traverse_util.unflatten_dict(squeezed_flat)
222
+
223
+ def init(rng):
224
+ param_key, dropout_key, dropout2_key = jax.random.split(rng, 3)
225
+ example_t = jnp.zeros((1,))
226
+ example_dt = jnp.zeros((1,))
227
+ example_label = jnp.zeros((1,), dtype=jnp.int32)
228
+ example_obs = jnp.zeros(example_obs_shape)
229
+ model_rngs = {'params': param_key, 'label_dropout': dropout_key, 'dropout': dropout2_key}
230
+ params = model_def.init(model_rngs, example_obs, example_t, example_dt, example_label)['params']
231
+ opt_state = tx.init(params)
232
+ ts = TrainStateEma.create(model_def, params, rng=rng, tx=tx, opt_state=opt_state)
233
+
234
+ if FLAGS.load_dir is not None:
235
+
236
+ cp = Checkpoint(FLAGS.load_dir)
237
+ train_state_load = cp.load_as_dict()["train_state"]
238
+ start_step = train_state_load["step"]
239
+
240
+ log_param_shapes(ts.params)
241
+ flat = log_param_shapes(train_state_load["params"])
242
+ flat_ema = log_param_shapes(train_state_load["params_ema"])
243
+ flat_mu = log_param_shapes(train_state_load["opt_state"][0][0].mu)
244
+ flat_nu = log_param_shapes(train_state_load["opt_state"][0][0].nu)
245
+
246
+ from optax import ScaleByAdamState
247
+ opt_state = train_state_load["opt_state"]
248
+ new_state = ScaleByAdamState(
249
+ opt_state[0][0].count,
250
+ mu=flat_mu,
251
+ nu=flat_nu
252
+ )
253
+ opt_state = list(opt_state)
254
+ opt_state[0] = list(opt_state[0])
255
+ opt_state[0][0] = new_state
256
+
257
+ opt_state[0] = tuple(opt_state[0])
258
+ opt_state = tuple(opt_state)
259
+
260
+ train_state_load = TrainStateEma.create(model_def, params = flat, rng = rng, tx = tx, opt_state=opt_state)
261
+ train_state_load = train_state_load.replace(step=start_step)
262
+ #Need to replace EMA because we have a separate ema
263
+ log_param_shapes(train_state_load.params)
264
+ train_state_load = train_state_load.replace(params_ema = flat_ema)
265
+ ts = train_state_load
266
+
267
+ return ts
268
+
269
+ rng = jax.random.PRNGKey(FLAGS.seed)
270
+ train_state_shape = jax.eval_shape(init, rng)
271
+
272
+
273
+ data_sharding, train_state_sharding, no_shard, shard_data, global_to_local = create_sharding(FLAGS.model.sharding, train_state_shape)
274
+ train_state = jax.jit(init, out_shardings=train_state_sharding)(rng)
275
+
276
+ #So we can only visualize here if we are squeezed.... which might cause errors later?
277
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
278
+ jax.debug.visualize_array_sharding(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
279
+ jax.experimental.multihost_utils.assert_equal(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
280
+
281
+
282
+ if False:#FLAGS.load_dir is not None:
283
+ cp = Checkpoint(FLAGS.load_dir)
284
+ replace_dict = cp.load_as_dict()['train_state']
285
+
286
+ def log_param_shapes(params, label=""):
287
+ flat = flax.traverse_util.flatten_dict(params)
288
+
289
+ squeezed_flat = {k: jnp.squeeze(v, axis = 0) for k, v in flat.items() if v.shape[0] == 1}
290
+ print(f"\n{label} parameter shapes:")
291
+ for k, v in flat.items():
292
+ print(f"{k}: {v.shape}")
293
+ return flax.traverse_util.unflatten_dict(squeezed_flat)
294
+
295
+ log_param_shapes(train_state.params, "Before load")
296
+ train_state = train_state.replace(**replace_dict)
297
+
298
+ flat = log_param_shapes(train_state.params, "Before squeeze")
299
+ train_state = train_state.replace(params=flat)
300
+ log_param_shapes(flat, "after squeeze")
301
+
302
+ flat_ema = log_param_shapes(train_state.params_ema, "before ema")
303
+ train_state = train_state.replace(params_ema=flat_ema)
304
+ log_param_shapes(flat_ema, "after squeeze")
305
+ print(train_state.step)
306
+ exit()
307
+
308
+ #log_param_shapes(train_state.params, "After squeeze")
309
+
310
+ if FLAGS.wandb.run_id != "None": # If we are continuing a run.
311
+ start_step = train_state.step
312
+ train_state = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
313
+ print("Loaded model with step", train_state.step)
314
+
315
+ #log_param_shapes(train_state.params, "after jit shard")
316
+
317
+ train_state = train_state.replace(step=0)
318
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
319
+ del cp
320
+
321
+ if FLAGS.model.train_type == 'progressive' or FLAGS.model.train_type == 'consistency-distillation':
322
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
323
+ else:
324
+ train_state_teacher = None
325
+
326
+ visualize_labels = example_labels
327
+ visualize_labels = shard_data(visualize_labels)
328
+ visualize_labels = jax.experimental.multihost_utils.process_allgather(visualize_labels)
329
+ imagenet_labels = open('data/imagenet_labels.txt').read().splitlines()
330
+
331
+ ###################################
332
+ # Update Function
333
+ ###################################
334
+
335
+ @partial(jax.jit, out_shardings=(train_state_sharding, no_shard))
336
+ def update(train_state, train_state_teacher, images, labels, force_t=-1, force_dt=-1):
337
+ new_rng, targets_key, dropout_key, perm_key = jax.random.split(train_state.rng, 4)
338
+ info = {}
339
+
340
+ id_perm = jax.random.permutation(perm_key, images.shape[0])
341
+ images = images[id_perm]
342
+ labels = labels[id_perm]
343
+ images = jax.lax.with_sharding_constraint(images, data_sharding)
344
+ labels = jax.lax.with_sharding_constraint(labels, data_sharding)
345
+
346
+ if FLAGS.model['cfg_scale'] == 0: # For unconditional generation.
347
+ labels = jnp.ones(labels.shape[0], dtype=jnp.int32) * FLAGS.model['num_classes']
348
+
349
+ if FLAGS.model['train_type'] == 'naive':
350
+ from baselines.targets_naive import get_targets
351
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
352
+ elif FLAGS.model['train_type'] == 'shortcut':
353
+ from targets_shortcut import get_targets
354
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
355
+ elif FLAGS.model['train_type'] == 'progressive':
356
+ from baselines.targets_progressive import get_targets
357
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
358
+ elif FLAGS.model['train_type'] == 'consistency-distillation':
359
+ from baselines.targets_consistency_distillation import get_targets
360
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
361
+ elif FLAGS.model['train_type'] == 'consistency':
362
+ from baselines.targets_consistency_training import get_targets
363
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
364
+ elif FLAGS.model['train_type'] == 'livereflow':
365
+ from baselines.targets_livereflow import get_targets
366
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
367
+
368
+ def loss_fn(grad_params):
369
+ v_prime, logvars, activations = train_state.call_model(x_t, t, dt_base, labels, train=True, rngs={'dropout': dropout_key}, params=grad_params, return_activations=True)
370
+ mse_v = jnp.mean((v_prime - v_t) ** 2, axis=(1, 2, 3))
371
+ loss = jnp.mean(mse_v)
372
+
373
+ info = {
374
+ 'loss': loss,
375
+ 'v_magnitude_prime': jnp.sqrt(jnp.mean(jnp.square(v_prime))),
376
+ **{'activations/' + k : jnp.sqrt(jnp.mean(jnp.square(v))) for k, v in activations.items()},
377
+ }
378
+
379
+ if FLAGS.model['train_type'] == 'shortcut' or FLAGS.model['train_type'] == 'livereflow':
380
+ bootstrap_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
381
+ info['loss_flow'] = jnp.mean(mse_v[bootstrap_size:])
382
+ info['loss_bootstrap'] = jnp.mean(mse_v[:bootstrap_size])
383
+
384
+ return loss, info
385
+
386
+ grads, new_info = jax.grad(loss_fn, has_aux=True)(train_state.params)
387
+ info = {**info, **new_info}
388
+ updates, new_opt_state = train_state.tx.update(grads, train_state.opt_state, train_state.params)
389
+ new_params = optax.apply_updates(train_state.params, updates)
390
+
391
+ info['grad_norm'] = optax.global_norm(grads)
392
+ info['update_norm'] = optax.global_norm(updates)
393
+ info['param_norm'] = optax.global_norm(new_params)
394
+ info['lr'] = lr_schedule(train_state.step)
395
+
396
+ train_state = train_state.replace(rng=new_rng, step=train_state.step + 1, params=new_params, opt_state=new_opt_state)
397
+ train_state = train_state.update_ema(FLAGS.model['target_update_rate'])
398
+ return train_state, info
399
+
400
+ if FLAGS.mode != 'train':
401
+ #do_inference(FLAGS, train_state, None, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
402
+ # get_fid_activations, imagenet_labels, visualize_labels,
403
+ # fid_from_stats, truth_fid_stats)
404
+ print("doing the else")
405
+ cfgs = [1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]#Basically only no 1.5, since it's already done.
406
+ steps = [128,64,32,16,8,4,2,1]
407
+ #steps = [64,32,16,8,2,1]
408
+ if True:
409
+ for cfg in cfgs:
410
+ for step in steps:
411
+ FLAGS.inference_timesteps = step
412
+ FLAGS.inference_cfg_scale = cfg
413
+ do_inference(FLAGS, train_state, None, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
414
+ get_fid_activations, imagenet_labels, visualize_labels,
415
+ fid_from_stats, truth_fid_stats)
416
+ exit()
417
+
418
+ return
419
+
420
+ ###################################
421
+ # Train Loop
422
+ ###################################
423
+
424
+ for i in tqdm.tqdm(range(1 + start_step, FLAGS.max_steps + 1 + start_step),
425
+ smoothing=0.1,
426
+ dynamic_ncols=True):
427
+
428
+ # Sample data.
429
+ if not FLAGS.debug_overfit or i == 1:
430
+ batch_images, batch_labels = shard_data(*next(dataset))
431
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
432
+ vae_rng, vae_key = jax.random.split(vae_rng)
433
+ #batch_images = vae_encode(vae_key, batch_images)
434
+ batch_images_reshaped = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:]))#Shoudl split over devices
435
+ batch_images_reshaped, result_dict = vae_encode(batch_images_reshaped)#like (4, 128, 32, 32, 4)
436
+ #print(batch_images_reshaped.shape)
437
+ batch_images = batch_images_reshaped.reshape(-1, batch_images_reshaped.shape[2], batch_images_reshaped.shape[3], batch_images_reshaped.shape[-1])
438
+ #print("after encode", batch_images.shape)
439
+ #We don't sample right now, and we don't use a key because we don't sample.
440
+
441
+
442
+ #Normalize global
443
+ mean = jnp.array([ 0.04621413, 0.00622245, -0.03867066, -0.12760854])
444
+ std = jnp.array([1.1124766, 1.1514145, 1.1221403, 1.0895475])
445
+
446
+
447
+ batch_images = (batch_images - mean) / std.mean()
448
+
449
+
450
+ # Train update.
451
+ train_state, update_info = update(train_state, train_state_teacher, batch_images, batch_labels)
452
+
453
+ if i % FLAGS.log_interval == 0 or i == 1:
454
+ update_info = jax.device_get(update_info)
455
+ update_info = jax.tree_map(lambda x: np.array(x), update_info)
456
+ update_info = jax.tree_map(lambda x: x.mean(), update_info)
457
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
458
+
459
+ valid_images, valid_labels = shard_data(*next(dataset_valid))
460
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
461
+ #valid_images = vae_encode(vae_rng, valid_images)
462
+ valid_images_reshaped = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:]))#Shoudl split over devices
463
+ valid_images_reshaped, result_dict = vae_encode(valid_images_reshaped)#like (4, 128, 32, 32, 4)
464
+ #print(batch_images_reshaped.shape)
465
+ valid_images = valid_images_reshaped.reshape(-1, valid_images_reshaped.shape[2], valid_images_reshaped.shape[3], valid_images_reshaped.shape[-1])
466
+ #We forgot valid/2
467
+
468
+ mean = jnp.array([ 0.04621413, 0.00622245, -0.03867066, -0.12760854])
469
+ std = jnp.array([1.1124766, 1.1514145, 1.1221403, 1.0895475])
470
+
471
+
472
+ valid_images = (valid_images - mean) / std.mean()
473
+
474
+
475
+ _, valid_update_info = update(train_state, train_state_teacher, valid_images, valid_labels)
476
+ valid_update_info = jax.device_get(valid_update_info)
477
+ valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info)
478
+ train_metrics['training/loss_valid'] = valid_update_info['loss']
479
+
480
+ if jax.process_index() == 0:
481
+ print(train_metrics)
482
+ wandb.log(train_metrics, step=i)
483
+
484
+ if FLAGS.model['train_type'] == 'progressive':
485
+ num_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
486
+ if i % (FLAGS.max_steps // num_sections) == 0:
487
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
488
+
489
+ if i % FLAGS.eval_interval == 0:
490
+ eval_model(FLAGS, train_state, train_state_teacher, i, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
491
+ get_fid_activations, imagenet_labels, visualize_labels,
492
+ fid_from_stats, truth_fid_stats)
493
+
494
+ if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
495
+ train_state_gather = jax.experimental.multihost_utils.process_allgather(train_state)
496
+ if jax.process_index() == 0:
497
+ cp = Checkpoint(FLAGS.save_dir+str(train_state_gather.step+1), parallel=False)
498
+ cp.train_state = train_state_gather
499
+ cp.save()
500
+ del cp
501
+ del train_state_gather
502
+
503
+ if __name__ == '__main__':
504
+ app.run(main)