KublaiKhan1 commited on
Commit
c449eaa
·
verified ·
1 Parent(s): 7ecf49d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. learned_cfg/train.py +386 -0
learned_cfg/train.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import jax.numpy as jnp
3
+ from absl import app, flags
4
+ from functools import partial
5
+ import numpy as np
6
+ import tqdm
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import flax
10
+ import optax
11
+ import wandb
12
+ from ml_collections import config_flags
13
+ import ml_collections
14
+
15
+ from utils.wandb import setup_wandb, default_wandb_config
16
+ from utils.train_state import TrainStateEma
17
+ from utils.checkpoint import Checkpoint
18
+ from utils.stable_vae import StableVAE
19
+ from utils.sharding import create_sharding, all_gather
20
+ from utils.datasets import get_dataset
21
+ from model import DiT
22
+ from helper_eval import eval_model
23
+ from helper_inference import do_inference
24
+
25
+ FLAGS = flags.FLAGS
26
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
27
+ flags.DEFINE_string('load_dir', None, 'Logging dir (if not None, save params).')
28
+ flags.DEFINE_string('save_dir', './checkpoints/', 'Logging dir (if not None, save params).')
29
+ flags.DEFINE_string('fid_stats', None, 'FID stats file.')
30
+ flags.DEFINE_integer('seed', 10, 'Random seed.') # Must be the same across all processes.
31
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
32
+ flags.DEFINE_integer('eval_interval', 1000000, 'Eval interval.')
33
+ flags.DEFINE_integer('save_interval', 10000, 'Save interval.')
34
+ flags.DEFINE_integer('batch_size', 512, 'Mini batch size.')
35
+ flags.DEFINE_integer('max_steps', int(500_000), 'Number of training steps.')
36
+ flags.DEFINE_integer('debug_overfit', 0, 'Debug overfitting.')
37
+ flags.DEFINE_string('mode', 'train', 'train or inference.')
38
+
39
+ model_config = ml_collections.ConfigDict({
40
+ 'lr': 0.0001,
41
+ 'beta1': 0.9,
42
+ 'beta2': 0.999,
43
+ 'weight_decay': 0.1,
44
+ 'use_cosine': 0,
45
+ 'warmup': 0,
46
+ 'dropout': 0.0,
47
+ 'hidden_size': 64, # change this!
48
+ 'patch_size': 8, # change this!
49
+ 'depth': 2, # change this!
50
+ 'num_heads': 2, # change this!
51
+ 'mlp_ratio': 1, # change this!
52
+ 'class_dropout_prob': 0.1,
53
+ 'num_classes': 1000,
54
+ 'denoise_timesteps': 128,
55
+ 'cfg_scale': 4.0,
56
+ 'target_update_rate': 0.999,
57
+ 'use_ema': 0,
58
+ 'use_stable_vae': 1,
59
+ 'sharding': 'dp', # dp or fsdp.
60
+ 't_sampling': 'discrete-dt',
61
+ 'dt_sampling': 'uniform',
62
+ 'bootstrap_cfg': 0,
63
+ 'bootstrap_every': 8, # Make sure its a divisor of batch size.
64
+ 'bootstrap_ema': 1,
65
+ 'bootstrap_dt_bias': 0,
66
+ 'train_type': 'shortcut' # or naive.
67
+ })
68
+
69
+
70
+ wandb_config = default_wandb_config()
71
+ wandb_config.update({
72
+ 'project': 'shortcut',
73
+ 'name': 'shortcut_{dataset_name}',
74
+ })
75
+
76
+ config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
77
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
78
+
79
+ ##############################################
80
+ ## Training Code.
81
+ ##############################################
82
+ def main(_):
83
+
84
+ np.random.seed(FLAGS.seed)
85
+ print("Using devices", jax.local_devices())
86
+ device_count = len(jax.local_devices())
87
+ global_device_count = jax.device_count()
88
+ print("Device count", device_count)
89
+ print("Global device count", global_device_count)
90
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
91
+ print("Global Batch: ", FLAGS.batch_size)
92
+ print("Node Batch: ", local_batch_size)
93
+ print("Device Batch:", local_batch_size // device_count)
94
+
95
+ # Create wandb logger
96
+ if jax.process_index() == 0 and FLAGS.mode == 'train':
97
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
98
+
99
+ dataset = get_dataset(FLAGS.dataset_name, local_batch_size, True, FLAGS.debug_overfit)
100
+ dataset_valid = get_dataset(FLAGS.dataset_name, int(local_batch_size/8), False, FLAGS.debug_overfit)
101
+ example_obs, example_labels = next(dataset)
102
+ example_obs = example_obs[:1]
103
+ example_obs_shape = example_obs.shape
104
+
105
+ if FLAGS.model.use_stable_vae:
106
+ vae = StableVAE.create()
107
+ if 'latent' in FLAGS.dataset_name:
108
+ example_obs = example_obs[:, :, :, example_obs.shape[-1] // 2:]
109
+ example_obs_shape = example_obs.shape
110
+ else:
111
+ example_obs = vae.encode(jax.random.PRNGKey(0), example_obs)
112
+ example_obs_shape = example_obs.shape
113
+ vae_rng = jax.random.PRNGKey(42)
114
+ vae_encode = jax.jit(vae.encode)
115
+ vae_decode = jax.jit(vae.decode)
116
+
117
+ if FLAGS.fid_stats is not None:
118
+ from utils.fid import get_fid_network, fid_from_stats
119
+ get_fid_activations = get_fid_network()
120
+ truth_fid_stats = np.load(FLAGS.fid_stats)
121
+ else:
122
+ get_fid_activations = None
123
+ truth_fid_stats = None
124
+
125
+ ###################################
126
+ # Creating Model and put on devices.
127
+ ###################################
128
+ FLAGS.model.image_channels = example_obs_shape[-1]
129
+ FLAGS.model.image_size = example_obs_shape[1]
130
+ dit_args = {
131
+ 'patch_size': FLAGS.model['patch_size'],
132
+ 'hidden_size': FLAGS.model['hidden_size'],
133
+ 'depth': FLAGS.model['depth'],
134
+ 'num_heads': FLAGS.model['num_heads'],
135
+ 'mlp_ratio': FLAGS.model['mlp_ratio'],
136
+ 'out_channels': example_obs_shape[-1],
137
+ 'class_dropout_prob': FLAGS.model['class_dropout_prob'],
138
+ 'num_classes': FLAGS.model['num_classes'],
139
+ 'dropout': FLAGS.model['dropout'],
140
+ 'ignore_dt': False if (FLAGS.model['train_type'] in ('shortcut', 'livereflow')) else True,
141
+ }
142
+ model_def = DiT(**dit_args)
143
+ tabulate_fn = flax.linen.tabulate(model_def, jax.random.PRNGKey(0))
144
+ print(tabulate_fn(example_obs, jnp.zeros((1,)), jnp.zeros((1,)), jnp.zeros((1,), dtype=jnp.int32)))
145
+
146
+ if FLAGS.model.use_cosine:
147
+ lr_schedule = optax.warmup_cosine_decay_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'], FLAGS.max_steps)
148
+ elif FLAGS.model.warmup > 0:
149
+ lr_schedule = optax.linear_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'])
150
+ else:
151
+ lr_schedule = lambda x: FLAGS.model['lr']
152
+ adam = optax.adamw(learning_rate=lr_schedule, b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'], weight_decay=FLAGS.model['weight_decay'])
153
+ tx = optax.chain(adam)
154
+
155
+ def log_param_shapes(params, label=""):
156
+ flat = flax.traverse_util.flatten_dict(params)
157
+
158
+ squeezed_flat = {k: jnp.squeeze(v, axis = 0) for k, v in flat.items() if v.shape[0] == 1}
159
+ print(f"\n{label} parameter shapes:")
160
+ for k, v in flat.items():
161
+ print(f"{k}: {v.shape}")
162
+ return flax.traverse_util.unflatten_dict(squeezed_flat)
163
+
164
+
165
+ def init(rng):
166
+ param_key, dropout_key, dropout2_key = jax.random.split(rng, 3)
167
+ example_t = jnp.zeros((1,))
168
+ example_dt = jnp.zeros((1,))
169
+ example_label = jnp.zeros((1,), dtype=jnp.int32)
170
+ example_obs = jnp.zeros(example_obs_shape)
171
+ model_rngs = {'params': param_key, 'label_dropout': dropout_key, 'dropout': dropout2_key}
172
+ params = model_def.init(model_rngs, example_obs, example_t, example_dt, example_label)['params']
173
+ opt_state = tx.init(params)
174
+ ts = TrainStateEma.create(model_def, params, rng=rng, tx=tx, opt_state=opt_state)
175
+
176
+ if FLAGS.load_dir is not None:
177
+
178
+ cp = Checkpoint(FLAGS.load_dir)
179
+ train_state_load = cp.load_as_dict()["train_state"]
180
+
181
+ log_param_shapes(ts.params)
182
+ flat = log_param_shapes(train_state_load["params"])
183
+ flat_ema = log_param_shapes(train_state_load["params_ema"])
184
+ flat_mu = log_param_shapes(train_state_load["opt_state"][0][0].mu)
185
+ flat_nu = log_param_shapes(train_state_load["opt_state"][0][0].nu)
186
+
187
+ from optax import ScaleByAdamState
188
+ opt_state = train_state_load["opt_state"]
189
+ new_state = ScaleByAdamState(
190
+ opt_state[0][0].count,
191
+ mu=flat_mu,
192
+ nu=flat_nu
193
+ )
194
+ opt_state = list(opt_state)
195
+ opt_state[0] = list(opt_state[0])
196
+ opt_state[0][0] = new_state
197
+
198
+ opt_state[0] = tuple(opt_state[0])
199
+ opt_state = tuple(opt_state)
200
+
201
+ train_state_load = TrainStateEma.create(model_def, params = flat, rng = rng, tx = tx, opt_state=opt_state)
202
+
203
+ #Need to replace EMA because we have a separate ema
204
+ log_param_shapes(train_state_load.params)
205
+ train_state_load.replace(params_ema = flat_ema)
206
+
207
+ start_step = train_state_load.step
208
+
209
+ ts = train_state_load
210
+
211
+
212
+ return ts
213
+
214
+ rng = jax.random.PRNGKey(FLAGS.seed)
215
+ train_state_shape = jax.eval_shape(init, rng)
216
+
217
+ data_sharding, train_state_sharding, no_shard, shard_data, global_to_local = create_sharding(FLAGS.model.sharding, train_state_shape)
218
+ train_state = jax.jit(init, out_shardings=train_state_sharding)(rng)
219
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
220
+ jax.debug.visualize_array_sharding(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
221
+ jax.experimental.multihost_utils.assert_equal(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
222
+ start_step = 1
223
+
224
+ if False:#FLAGS.load_dir is not None:
225
+ cp = Checkpoint(FLAGS.load_dir)
226
+ replace_dict = cp.load_as_dict()['train_state']
227
+ del replace_dict['opt_state'] # Debug
228
+ train_state = train_state.replace(**replace_dict)
229
+ if FLAGS.wandb.run_id != "None": # If we are continuing a run.
230
+ start_step = train_state.step
231
+ train_state = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
232
+ print("Loaded model with step", train_state.step)
233
+ train_state = train_state.replace(step=0)
234
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
235
+ del cp
236
+
237
+ if FLAGS.model.train_type == 'progressive' or FLAGS.model.train_type == 'consistency-distillation':
238
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
239
+ else:
240
+ train_state_teacher = None
241
+
242
+ visualize_labels = example_labels
243
+ visualize_labels = shard_data(visualize_labels)
244
+ visualize_labels = jax.experimental.multihost_utils.process_allgather(visualize_labels)
245
+ imagenet_labels = open('data/imagenet_labels.txt').read().splitlines()
246
+
247
+ ###################################
248
+ # Update Function
249
+ ###################################
250
+
251
+ @partial(jax.jit, out_shardings=(train_state_sharding, no_shard))
252
+ def update(train_state, train_state_teacher, images, labels, force_t=-1, force_dt=-1):
253
+ new_rng, targets_key, dropout_key, perm_key = jax.random.split(train_state.rng, 4)
254
+ info = {}
255
+
256
+ id_perm = jax.random.permutation(perm_key, images.shape[0])
257
+ images = images[id_perm]
258
+ labels = labels[id_perm]
259
+ images = jax.lax.with_sharding_constraint(images, data_sharding)
260
+ labels = jax.lax.with_sharding_constraint(labels, data_sharding)
261
+
262
+ #print(train_state.params["cfg_weight"])
263
+ cfg_scale = train_state.params["cfg_weight"]
264
+ #Basically here we grab cfg_scale from the model.
265
+ #And then hope it's changing properly
266
+ if FLAGS.model['cfg_scale'] == 0: # For unconditional generation.
267
+ labels = jnp.ones(labels.shape[0], dtype=jnp.int32) * FLAGS.model['num_classes']
268
+
269
+ if FLAGS.model['train_type'] == 'naive':
270
+ from baselines.targets_naive import get_targets
271
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
272
+ elif FLAGS.model['train_type'] == 'shortcut':
273
+ from targets_shortcut import get_targets
274
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt, cfg_scale)
275
+ elif FLAGS.model['train_type'] == 'progressive':
276
+ from baselines.targets_progressive import get_targets
277
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
278
+ elif FLAGS.model['train_type'] == 'consistency-distillation':
279
+ from baselines.targets_consistency_distillation import get_targets
280
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
281
+ elif FLAGS.model['train_type'] == 'consistency':
282
+ from baselines.targets_consistency_training import get_targets
283
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
284
+ elif FLAGS.model['train_type'] == 'livereflow':
285
+ from baselines.targets_livereflow import get_targets
286
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
287
+
288
+ def loss_fn(grad_params):
289
+ v_prime, logvars, activations = train_state.call_model(x_t, t, dt_base, labels, train=True, rngs={'dropout': dropout_key}, params=grad_params, return_activations=True)
290
+ mse_v = jnp.mean((v_prime - v_t) ** 2, axis=(1, 2, 3))
291
+ loss = jnp.mean(mse_v)
292
+
293
+ info = {
294
+ 'loss': loss,
295
+ 'cfg_scale': cfg_scale,
296
+ 'v_magnitude_prime': jnp.sqrt(jnp.mean(jnp.square(v_prime))),
297
+ **{'activations/' + k : jnp.sqrt(jnp.mean(jnp.square(v))) for k, v in activations.items()},
298
+ }
299
+
300
+ if FLAGS.model['train_type'] == 'shortcut' or FLAGS.model['train_type'] == 'livereflow':
301
+ bootstrap_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
302
+ info['loss_flow'] = jnp.mean(mse_v[bootstrap_size:])
303
+ info['loss_bootstrap'] = jnp.mean(mse_v[:bootstrap_size])
304
+
305
+ return loss, info
306
+
307
+ grads, new_info = jax.grad(loss_fn, has_aux=True)(train_state.params)
308
+ info = {**info, **new_info}
309
+ updates, new_opt_state = train_state.tx.update(grads, train_state.opt_state, train_state.params)
310
+ new_params = optax.apply_updates(train_state.params, updates)
311
+
312
+ info['grad_norm'] = optax.global_norm(grads)
313
+ info['update_norm'] = optax.global_norm(updates)
314
+ info['param_norm'] = optax.global_norm(new_params)
315
+ info['lr'] = lr_schedule(train_state.step)
316
+
317
+ train_state = train_state.replace(rng=new_rng, step=train_state.step + 1, params=new_params, opt_state=new_opt_state)
318
+ train_state = train_state.update_ema(FLAGS.model['target_update_rate'])
319
+ return train_state, info
320
+
321
+ if FLAGS.mode != 'train':
322
+ do_inference(FLAGS, train_state, None, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
323
+ get_fid_activations, imagenet_labels, visualize_labels,
324
+ fid_from_stats, truth_fid_stats)
325
+ return
326
+
327
+ ###################################
328
+ # Train Loop
329
+ ###################################
330
+
331
+ for i in tqdm.tqdm(range(1 + start_step, FLAGS.max_steps + 1 + start_step),
332
+ smoothing=0.1,
333
+ dynamic_ncols=True):
334
+
335
+ # Sample data.
336
+ if not FLAGS.debug_overfit or i == 1:
337
+ batch_images, batch_labels = shard_data(*next(dataset))
338
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
339
+ vae_rng, vae_key = jax.random.split(vae_rng)
340
+ batch_images = vae_encode(vae_key, batch_images)
341
+
342
+ # Train update.
343
+ train_state, update_info = update(train_state, train_state_teacher, batch_images, batch_labels)
344
+
345
+ if i % FLAGS.log_interval == 0 or i == 1:
346
+ print("logging")
347
+ update_info = jax.device_get(update_info)
348
+ update_info = jax.tree_map(lambda x: np.array(x), update_info)
349
+ update_info = jax.tree_map(lambda x: x.mean(), update_info)
350
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
351
+
352
+ #We go oom trying to do valid
353
+ valid_images, valid_labels = shard_data(*next(dataset_valid))
354
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
355
+ valid_images = vae_encode(vae_rng, valid_images)
356
+ _, valid_update_info = update(train_state, train_state_teacher, valid_images, valid_labels)
357
+ valid_update_info = jax.device_get(valid_update_info)
358
+ valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info)
359
+ train_metrics['training/loss_valid'] = valid_update_info['loss']
360
+
361
+ if jax.process_index() == 0:
362
+ wandb.log(train_metrics, step=i)
363
+
364
+ if FLAGS.model['train_type'] == 'progressive':
365
+ num_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
366
+ if i % (FLAGS.max_steps // num_sections) == 0:
367
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
368
+
369
+ if i % FLAGS.eval_interval == 0:
370
+ eval_model(FLAGS, train_state, train_state_teacher, i, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
371
+ get_fid_activations, imagenet_labels, visualize_labels,
372
+ fid_from_stats, truth_fid_stats)
373
+
374
+ if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
375
+ train_state_gather = jax.experimental.multihost_utils.process_allgather(train_state)
376
+ #This all gather might be parto f the reason the shape is odd
377
+ if jax.process_index() == 0:
378
+ cp = Checkpoint(FLAGS.save_dir+str(train_state_gather.step+1), parallel=False)
379
+ cp.train_state = train_state_gather
380
+ cp.save()
381
+ del cp
382
+ del train_state_gather
383
+
384
+ if __name__ == '__main__':
385
+ app.run(main)
386
+