KublaiKhan1 commited on
Commit
464344f
·
verified ·
1 Parent(s): ea03941

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -77,3 +77,5 @@ heun3_dt01/810001.tmp filter=lfs diff=lfs merge=lfs -text
77
  1e-6_kl_naive_globalscale_channelmean_sampling/810000.tmp filter=lfs diff=lfs merge=lfs -text
78
  heun3_dt01/810001/810001.tmp filter=lfs diff=lfs merge=lfs -text
79
  meanflow/810001.tmp filter=lfs diff=lfs merge=lfs -text
 
 
 
77
  1e-6_kl_naive_globalscale_channelmean_sampling/810000.tmp filter=lfs diff=lfs merge=lfs -text
78
  heun3_dt01/810001/810001.tmp filter=lfs diff=lfs merge=lfs -text
79
  meanflow/810001.tmp filter=lfs diff=lfs merge=lfs -text
80
+ sharpness/final.tmp filter=lfs diff=lfs merge=lfs -text
81
+ sharpness/final_810001.tmp filter=lfs diff=lfs merge=lfs -text
sharpness/final.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43b76a05291b4f9715b131d73ce450f6e59a18bcb2b84f4f6c916140b71e5e74
3
+ size 2110113717
sharpness/final_810001.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3c4c0a239f06a5dbdab8892ab1cbc2cef0dc7ada7ade9d94261064aa188cbf0
3
+ size 2110113717
sharpness/gen_images.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import jax.numpy as jnp
3
+ from absl import app, flags
4
+ from functools import partial
5
+ import numpy as np
6
+ import tqdm
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import flax
10
+ import optax
11
+ import wandb
12
+ from ml_collections import config_flags
13
+ import ml_collections
14
+
15
+ from utils.wandb import setup_wandb, default_wandb_config
16
+ from utils.train_state import TrainStateEma
17
+ from utils.checkpoint import Checkpoint
18
+ from utils.stable_vae import StableVAE
19
+ from utils.sharding import create_sharding, all_gather
20
+ from utils.datasets import get_dataset
21
+ from model import DiT
22
+ from helper_eval import eval_model
23
+ from helper_inference import do_inference
24
+
25
+ FLAGS = flags.FLAGS
26
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
27
+ flags.DEFINE_string('load_dir', './sharpness/final.tmp', 'Logging dir (if not None, save params).')
28
+ flags.DEFINE_string('save_dir', './checkpoints/', 'Logging dir (if not None, save params).')
29
+ flags.DEFINE_string('fid_stats', None, 'FID stats file.')
30
+ flags.DEFINE_integer('seed', 10, 'Random seed.') # Must be the same across all processes.
31
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
32
+ flags.DEFINE_integer('eval_interval', 1000000, 'Eval interval.')
33
+ flags.DEFINE_integer('save_interval', 10000, 'Save interval.')
34
+ flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
35
+ flags.DEFINE_integer('max_steps', int(500_000), 'Number of training steps.')
36
+ flags.DEFINE_integer('debug_overfit', 0, 'Debug overfitting.')
37
+ flags.DEFINE_string('mode', 'train', 'train or inference.')
38
+
39
+ model_config = ml_collections.ConfigDict({
40
+ 'lr': 0.0001,
41
+ 'beta1': 0.9,
42
+ 'beta2': 0.999,
43
+ 'weight_decay': 0.1,
44
+ 'use_cosine': 0,
45
+ 'warmup': 0,
46
+ 'dropout': 0.0,
47
+ 'hidden_size': 64, # change this!
48
+ 'patch_size': 8, # change this!
49
+ 'depth': 2, # change this!
50
+ 'num_heads': 2, # change this!
51
+ 'mlp_ratio': 1, # change this!
52
+ 'class_dropout_prob': 0.1,
53
+ 'num_classes': 1000,
54
+ 'denoise_timesteps': 128,
55
+ 'cfg_scale': 4.0,
56
+ 'target_update_rate': 0.999,
57
+ 'use_ema': 0,
58
+ 'use_stable_vae': 1,
59
+ 'sharding': 'dp', # dp or fsdp.
60
+ 't_sampling': 'discrete-dt',
61
+ 'dt_sampling': 'uniform',
62
+ 'bootstrap_cfg': 0,
63
+ 'bootstrap_every': 8, # Make sure its a divisor of batch size.
64
+ 'bootstrap_ema': 1,
65
+ 'bootstrap_dt_bias': 0,
66
+ 'train_type': 'shortcut' # or naive.
67
+ })
68
+
69
+
70
+ #config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
71
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
72
+
73
+ ##############################################
74
+ ## Training Code.
75
+ ##############################################
76
+ def main(_):
77
+
78
+ np.random.seed(FLAGS.seed)
79
+ print("Using devices", jax.local_devices())
80
+ device_count = len(jax.local_devices())
81
+ global_device_count = jax.device_count()
82
+ print("Device count", device_count)
83
+ print("Global device count", global_device_count)
84
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
85
+ print("Global Batch: ", FLAGS.batch_size)
86
+ print("Node Batch: ", local_batch_size)
87
+ print("Device Batch:", local_batch_size // device_count)
88
+
89
+ # Create wandb logger
90
+ if jax.process_index() == 0 and FLAGS.mode == 'train':
91
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
92
+
93
+ dataset = get_dataset(FLAGS.dataset_name, local_batch_size, True, FLAGS.debug_overfit)
94
+ dataset_valid = get_dataset(FLAGS.dataset_name, local_batch_size, False, FLAGS.debug_overfit)
95
+ example_obs, example_labels = next(dataset)
96
+ example_obs = example_obs[:1]
97
+ example_obs_shape = example_obs.shape
98
+
99
+ if FLAGS.model.use_stable_vae:
100
+ vae = StableVAE.create()
101
+ if 'latent' in FLAGS.dataset_name:
102
+ example_obs = example_obs[:, :, :, example_obs.shape[-1] // 2:]
103
+ example_obs_shape = example_obs.shape
104
+ else:
105
+ example_obs = vae.encode(jax.random.PRNGKey(0), example_obs)
106
+ example_obs_shape = example_obs.shape
107
+ vae_rng = jax.random.PRNGKey(42)
108
+ vae_encode = jax.jit(vae.encode)
109
+ vae_decode = jax.jit(vae.decode)
110
+
111
+ if FLAGS.fid_stats is not None:
112
+ from utils.fid import get_fid_network, fid_from_stats
113
+ get_fid_activations = get_fid_network()
114
+ truth_fid_stats = np.load(FLAGS.fid_stats)
115
+ else:
116
+ get_fid_activations = None
117
+ truth_fid_stats = None
118
+
119
+ ###################################
120
+ # Creating Model and put on devices.
121
+ ###################################
122
+ FLAGS.model.image_channels = example_obs_shape[-1]
123
+ FLAGS.model.image_size = example_obs_shape[1]
124
+ dit_args = {
125
+ 'patch_size': FLAGS.model['patch_size'],
126
+ 'hidden_size': FLAGS.model['hidden_size'],
127
+ 'depth': FLAGS.model['depth'],
128
+ 'num_heads': FLAGS.model['num_heads'],
129
+ 'mlp_ratio': FLAGS.model['mlp_ratio'],
130
+ 'out_channels': example_obs_shape[-1],
131
+ 'class_dropout_prob': FLAGS.model['class_dropout_prob'],
132
+ 'num_classes': FLAGS.model['num_classes'],
133
+ 'dropout': FLAGS.model['dropout'],
134
+ 'ignore_dt': False if (FLAGS.model['train_type'] in ('shortcut', 'livereflow')) else True,
135
+ }
136
+ model_def = DiT(**dit_args)
137
+ # tabulate_fn = flax.linen.tabulate(model_def, jax.random.PRNGKey(0))
138
+ tabulate_fn = flax.linen.tabulate(model_def, rngs={"params": jax.random.PRNGKey(0), "label":jax.random.PRNGKey(0)})
139
+ print(tabulate_fn(example_obs, jnp.zeros((1,)), jnp.zeros((1,)), jnp.zeros((1,), dtype=jnp.int32)))
140
+
141
+ if FLAGS.model.use_cosine:
142
+ lr_schedule = optax.warmup_cosine_decay_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'], FLAGS.max_steps)
143
+ elif FLAGS.model.warmup > 0:
144
+ lr_schedule = optax.linear_schedule(0.0, FLAGS.model['lr'], FLAGS.model['warmup'])
145
+ else:
146
+ lr_schedule = lambda x: FLAGS.model['lr']
147
+ adam = optax.adamw(learning_rate=lr_schedule, b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'], weight_decay=FLAGS.model['weight_decay'])
148
+ tx = optax.chain(adam)
149
+
150
+ def log_param_shapes(params, label=""):
151
+ flat = flax.traverse_util.flatten_dict(params)
152
+
153
+ squeezed_flat = {k: jnp.squeeze(v, axis = 0) for k, v in flat.items() if v.shape[0] == 1}
154
+ print(f"\n{label} parameter shapes:")
155
+ for k, v in flat.items():
156
+ print(f"{k}: {v.shape}")
157
+ return flax.traverse_util.unflatten_dict(squeezed_flat)
158
+
159
+
160
+ def init(rng):
161
+ param_key, dropout_key, dropout2_key = jax.random.split(rng, 3)
162
+ example_t = jnp.zeros((1,))
163
+ example_dt = jnp.zeros((1,))
164
+ example_label = jnp.zeros((1,), dtype=jnp.int32)
165
+ example_obs = jnp.zeros(example_obs_shape)
166
+ model_rngs = {'params': param_key, 'label_dropout': dropout_key, 'dropout': dropout2_key}
167
+ params = model_def.init(model_rngs, example_obs, example_t, example_dt, example_label)['params']
168
+ opt_state = tx.init(params)
169
+ ts = TrainStateEma.create(model_def, params, rng=rng, tx=tx, opt_state=opt_state)
170
+
171
+ if FLAGS.load_dir is not None:
172
+
173
+ cp = Checkpoint(FLAGS.load_dir)
174
+ train_state_load = cp.load_as_dict()["train_state"]
175
+
176
+ log_param_shapes(ts.params)
177
+ flat = log_param_shapes(train_state_load["params"])
178
+ flat_ema = log_param_shapes(train_state_load["params_ema"])
179
+ flat_mu = log_param_shapes(train_state_load["opt_state"][0][0].mu)
180
+ flat_nu = log_param_shapes(train_state_load["opt_state"][0][0].nu)
181
+
182
+ from optax import ScaleByAdamState
183
+ opt_state = train_state_load["opt_state"]
184
+ new_state = ScaleByAdamState(
185
+ opt_state[0][0].count,
186
+ mu=flat_mu,
187
+ nu=flat_nu
188
+ )
189
+ opt_state = list(opt_state)
190
+ opt_state[0] = list(opt_state[0])
191
+ opt_state[0][0] = new_state
192
+
193
+ opt_state[0] = tuple(opt_state[0])
194
+ opt_state = tuple(opt_state)
195
+
196
+ train_state_load = TrainStateEma.create(model_def, params = flat, rng = rng, tx = tx, opt_state=opt_state)
197
+
198
+ #Need to replace EMA because we have a separate ema
199
+ log_param_shapes(train_state_load.params)
200
+ train_state_load.replace(params_ema = flat_ema)
201
+
202
+ start_step = train_state_load.step
203
+
204
+ ts = train_state_load
205
+
206
+
207
+ return ts
208
+
209
+ rng = jax.random.PRNGKey(FLAGS.seed)
210
+ train_state_shape = jax.eval_shape(init, rng)
211
+
212
+ data_sharding, train_state_sharding, no_shard, shard_data, global_to_local = create_sharding(FLAGS.model.sharding, train_state_shape)
213
+ train_state = jax.jit(init, out_shardings=train_state_sharding)(rng)
214
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
215
+ jax.debug.visualize_array_sharding(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
216
+ jax.experimental.multihost_utils.assert_equal(train_state.params['TimestepEmbedder_1']['Dense_0']['kernel'])
217
+ start_step = 1
218
+
219
+ if False:#FLAGS.load_dir is not None:
220
+ cp = Checkpoint(FLAGS.load_dir)
221
+ replace_dict = cp.load_as_dict()['train_state']
222
+ del replace_dict['opt_state'] # Debug
223
+ train_state = train_state.replace(**replace_dict)
224
+ if FLAGS.wandb.run_id != "None": # If we are continuing a run.
225
+ start_step = train_state.step
226
+ train_state = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
227
+ print("Loaded model with step", train_state.step)
228
+ train_state = train_state.replace(step=0)
229
+ jax.debug.visualize_array_sharding(train_state.params['FinalLayer_0']['Dense_0']['kernel'])
230
+ del cp
231
+
232
+ if FLAGS.model.train_type == 'progressive' or FLAGS.model.train_type == 'consistency-distillation':
233
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
234
+ else:
235
+ train_state_teacher = None
236
+
237
+ visualize_labels = example_labels
238
+ visualize_labels = shard_data(visualize_labels)
239
+ visualize_labels = jax.experimental.multihost_utils.process_allgather(visualize_labels)
240
+ imagenet_labels = open('data/imagenet_labels.txt').read().splitlines()
241
+
242
+ ###################################
243
+ # Update Function
244
+ ###################################
245
+
246
+ @partial(jax.jit, out_shardings=(train_state_sharding, no_shard))
247
+ def update(train_state, train_state_teacher, images, labels, force_t=-1, force_dt=-1):
248
+ new_rng, targets_key, dropout_key, perm_key = jax.random.split(train_state.rng, 4)
249
+ info = {}
250
+
251
+ id_perm = jax.random.permutation(perm_key, images.shape[0])
252
+ images = images[id_perm]
253
+ labels = labels[id_perm]
254
+ images = jax.lax.with_sharding_constraint(images, data_sharding)
255
+ labels = jax.lax.with_sharding_constraint(labels, data_sharding)
256
+
257
+ if FLAGS.model['cfg_scale'] == 0: # For unconditional generation.
258
+ labels = jnp.ones(labels.shape[0], dtype=jnp.int32) * FLAGS.model['num_classes']
259
+
260
+ if FLAGS.model['train_type'] == 'naive':
261
+ from baselines.targets_naive import get_targets
262
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
263
+ elif FLAGS.model['train_type'] == 'shortcut':
264
+ from targets_shortcut import get_targets
265
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
266
+ elif FLAGS.model['train_type'] == 'progressive':
267
+ from baselines.targets_progressive import get_targets
268
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
269
+ elif FLAGS.model['train_type'] == 'consistency-distillation':
270
+ from baselines.targets_consistency_distillation import get_targets
271
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, train_state_teacher, images, labels, force_t, force_dt)
272
+ elif FLAGS.model['train_type'] == 'consistency':
273
+ from baselines.targets_consistency_training import get_targets
274
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
275
+ elif FLAGS.model['train_type'] == 'livereflow':
276
+ from baselines.targets_livereflow import get_targets
277
+ x_t, v_t, t, dt_base, labels, info = get_targets(FLAGS, targets_key, train_state, images, labels, force_t, force_dt)
278
+
279
+ def loss_fn(grad_params):
280
+ v_prime, logvars, activations = train_state.call_model(x_t, t, dt_base, labels, train=True, rngs={'dropout': dropout_key}, params=grad_params, return_activations=True)
281
+ mse_v = jnp.mean((v_prime - v_t) ** 2, axis=(1, 2, 3))
282
+ loss = jnp.mean(mse_v)
283
+
284
+ info = {
285
+ 'loss': loss,
286
+ 'v_magnitude_prime': jnp.sqrt(jnp.mean(jnp.square(v_prime))),
287
+ **{'activations/' + k : jnp.sqrt(jnp.mean(jnp.square(v))) for k, v in activations.items()},
288
+ }
289
+
290
+ if FLAGS.model['train_type'] == 'shortcut' or FLAGS.model['train_type'] == 'livereflow':
291
+ bootstrap_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
292
+ info['loss_flow'] = jnp.mean(mse_v[bootstrap_size:])
293
+ info['loss_bootstrap'] = jnp.mean(mse_v[:bootstrap_size])
294
+
295
+ return loss, info
296
+
297
+ grads, new_info = jax.grad(loss_fn, has_aux=True)(train_state.params)
298
+ info = {**info, **new_info}
299
+ updates, new_opt_state = train_state.tx.update(grads, train_state.opt_state, train_state.params)
300
+ new_params = optax.apply_updates(train_state.params, updates)
301
+
302
+ info['grad_norm'] = optax.global_norm(grads)
303
+ info['update_norm'] = optax.global_norm(updates)
304
+ info['param_norm'] = optax.global_norm(new_params)
305
+ info['lr'] = lr_schedule(train_state.step)
306
+
307
+ train_state = train_state.replace(rng=new_rng, step=train_state.step + 1, params=new_params, opt_state=new_opt_state)
308
+ train_state = train_state.update_ema(FLAGS.model['target_update_rate'])
309
+ return train_state, info
310
+
311
+ if FLAGS.mode != 'train':
312
+ do_inference(FLAGS, train_state, None, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
313
+ get_fid_activations, imagenet_labels, visualize_labels,
314
+ fid_from_stats, truth_fid_stats)
315
+ return
316
+
317
+ ###################################
318
+ # Train Loop
319
+ ###################################
320
+
321
+ for i in tqdm.tqdm(range(1 + start_step, FLAGS.max_steps + 1 + start_step),
322
+ smoothing=0.1,
323
+ dynamic_ncols=True):
324
+
325
+ # Sample data.
326
+ if not FLAGS.debug_overfit or i == 1:
327
+ batch_images, batch_labels = shard_data(*next(dataset))
328
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
329
+ vae_rng, vae_key = jax.random.split(vae_rng)
330
+ batch_images = vae_encode(vae_key, batch_images)
331
+
332
+ # Train update.
333
+ train_state, update_info = update(train_state, train_state_teacher, batch_images, batch_labels)
334
+
335
+ if i % FLAGS.log_interval == 0 or i == 1:
336
+ update_info = jax.device_get(update_info)
337
+ update_info = jax.tree_map(lambda x: np.array(x), update_info)
338
+ update_info = jax.tree_map(lambda x: x.mean(), update_info)
339
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
340
+
341
+ valid_images, valid_labels = shard_data(*next(dataset_valid))
342
+ if FLAGS.model.use_stable_vae and 'latent' not in FLAGS.dataset_name:
343
+ valid_images = vae_encode(vae_rng, valid_images)
344
+ _, valid_update_info = update(train_state, train_state_teacher, valid_images, valid_labels)
345
+ valid_update_info = jax.device_get(valid_update_info)
346
+ valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info)
347
+ train_metrics['training/loss_valid'] = valid_update_info['loss']
348
+
349
+ if jax.process_index() == 0:
350
+ wandb.log(train_metrics, step=i)
351
+
352
+ if FLAGS.model['train_type'] == 'progressive':
353
+ num_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
354
+ if i % (FLAGS.max_steps // num_sections) == 0:
355
+ train_state_teacher = jax.jit(lambda x : x, out_shardings=train_state_sharding)(train_state)
356
+
357
+ if i % FLAGS.eval_interval == 0:
358
+ eval_model(FLAGS, train_state, train_state_teacher, i, dataset, dataset_valid, shard_data, vae_encode, vae_decode, update,
359
+ get_fid_activations, imagenet_labels, visualize_labels,
360
+ fid_from_stats, truth_fid_stats)
361
+
362
+ if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
363
+ train_state_gather = jax.experimental.multihost_utils.process_allgather(train_state)
364
+ #This all gather might be parto f the reason the shape is odd
365
+ if jax.process_index() == 0:
366
+ cp = Checkpoint(FLAGS.save_dir+str(train_state_gather.step+1), parallel=False)
367
+ cp.train_state = train_state_gather
368
+ cp.save()
369
+ del cp
370
+ del train_state_gather
371
+
372
+ if __name__ == '__main__':
373
+ app.run(main)
374
+
sharpness/helper_inference.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.experimental
3
+ import wandb
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ import tqdm
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ from functools import partial
10
+ from absl import app, flags
11
+
12
+ flags.DEFINE_integer('inference_timesteps', 128, 'Number of timesteps for inference.')
13
+ flags.DEFINE_integer('inference_generations', 50000, 'Number of generations for inference.')
14
+ flags.DEFINE_float('inference_cfg_scale', 1.0, 'CFG scale for inference.')
15
+ #So although we do a CFG sanity check, we don't really train properly with CFG for this to actually work.
16
+
17
+ if False:
18
+ classes = np.load("classes.npz")
19
+ global_mean = jnp.load("global_mean.npy")
20
+ #print(type(classes))#npz shit
21
+ classes = {key: classes[key] for key in classes.files}
22
+ classes["1000"] = global_mean
23
+ classes_array = jnp.array([classes[str(i)] for i in range(len(classes))])
24
+
25
+
26
+
27
+ def do_inference(
28
+ FLAGS,
29
+ train_state,
30
+ step,
31
+ dataset,
32
+ dataset_valid,
33
+ shard_data,
34
+ vae_encode,
35
+ vae_decode,
36
+ update,
37
+ get_fid_activations,
38
+ imagenet_labels,
39
+ visualize_labels,
40
+ fid_from_stats,
41
+ truth_fid_stats,
42
+ ):
43
+ with jax.spmd_mode('allow_all'):
44
+ global_device_count = jax.device_count()
45
+ key = jax.random.PRNGKey(42 + jax.process_index())
46
+ batch_images, batch_labels = next(dataset)
47
+ valid_images, valid_labels = next(dataset_valid)
48
+ if FLAGS.model.use_stable_vae:
49
+ batch_images = vae_encode(key, batch_images)
50
+ valid_images = vae_encode(key, valid_images)
51
+ batch_labels_sharded, valid_labels_sharded = shard_data(batch_labels, valid_labels)
52
+ labels_uncond = shard_data(jnp.ones(batch_labels.shape, dtype=jnp.int32) * FLAGS.model['num_classes']) # Null token
53
+ eps = jax.random.normal(key, batch_images.shape)
54
+
55
+ def process_img(img):
56
+ if FLAGS.model.use_stable_vae:
57
+ img = vae_decode(img[None])[0]
58
+ img = img * 0.5 + 0.5
59
+ img = jnp.clip(img, 0, 1)
60
+ img = np.array(img)
61
+ return img
62
+
63
+ # @partial(jax.jit, static_argnums=(5,))
64
+ def call_model(train_state, images, t, dt, labels, use_ema=True, perturbe = False):
65
+ if use_ema and FLAGS.model.use_ema:
66
+ call_fn = train_state.call_model_ema
67
+ else:
68
+ call_fn = train_state.call_model
69
+
70
+ key2 = jax.random.PRNGKey(0)
71
+ output = call_fn(images, t, dt, labels, train=False, rngs={"label": key2}, perturbe = perturbe)
72
+
73
+ return output
74
+
75
+ if FLAGS.mode == 'interpolate':
76
+ seed = 5
77
+ eps0 = jax.random.normal(jax.random.PRNGKey(seed), batch_images[0].shape)
78
+ eps1 = jax.random.normal(jax.random.PRNGKey(seed+1), batch_images[0].shape)
79
+ labels = jnp.ones(FLAGS.batch_size,).astype(jnp.int32) * 555
80
+ i = jnp.linspace(0, 1, FLAGS.batch_size)
81
+ i_neg = np.sqrt(1-i**2)
82
+ x = eps0[None] * i_neg[:, None, None, None] + eps1[None] * i[:, None, None, None]
83
+ t_vector = jnp.full((FLAGS.batch_size, ), 0)
84
+ dt_vector = jnp.zeros_like(t_vector)
85
+ cfg_scale = FLAGS.inference_cfg_scale
86
+ v = call_model(train_state, x, t_vector, dt_vector, labels)
87
+ x = x + v * 1.0
88
+ x = vae_decode(x) # Image is in [-1, 1] space.
89
+ x_render = np.array(jax.experimental.multihost_utils.process_allgather(x))
90
+ os.makedirs(FLAGS.save_dir, exist_ok=True)
91
+ np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
92
+ breakpoint()
93
+
94
+ denoise_timesteps = FLAGS.inference_timesteps
95
+ num_generations = FLAGS.inference_generations
96
+ cfg_scale = FLAGS.inference_cfg_scale
97
+ x0 = []
98
+ x1 = []
99
+ lab = []
100
+ x_render = []
101
+ activations = []
102
+ images_shape = batch_images.shape
103
+ print(f"Calc FID for CFG {cfg_scale} and denoise_timesteps {denoise_timesteps}")
104
+ for fid_it in tqdm.tqdm(range(num_generations // FLAGS.batch_size)):
105
+ key = jax.random.PRNGKey(42)
106
+ key = jax.random.fold_in(key, fid_it)
107
+ key = jax.random.fold_in(key, jax.process_index())
108
+ eps_key, label_key = jax.random.split(key)
109
+ x = jax.random.normal(eps_key, images_shape)
110
+
111
+ e = 0.30
112
+
113
+ labels = jax.random.randint(label_key, (images_shape[0],), 0, FLAGS.model.num_classes)
114
+
115
+ #from baselines.targets_naive import map_labels_to_classes
116
+ #x_cond = map_labels_to_classes(classes_array, labels) * (1-e) + e * x
117
+ #x_uncond = map_labels_to_classes(classes_array, labels_uncond) * (1-e) + e * x
118
+
119
+
120
+ x, labels = shard_data(x, labels)
121
+ x0.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
122
+ delta_t = 1.0 / denoise_timesteps
123
+ sigmas = []
124
+ for ti in range(denoise_timesteps + 1):
125
+ t = ti / denoise_timesteps # From x_0 (noise) to x_1 (data)
126
+ sigmas.append(t)
127
+ #So this gives us n + 1 steps, because we start at n
128
+ i = 0
129
+ for ti in range(denoise_timesteps):
130
+ t = ti / denoise_timesteps # From x_0 (noise) to x_1 (data)
131
+ t_vector = jnp.full((images_shape[0], ), t)
132
+ if FLAGS.model.train_type == 'naive':
133
+ dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
134
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow # Smallest dt.
135
+ else: # shortcut
136
+ dt_flow = np.log2(denoise_timesteps).astype(jnp.int32)
137
+ dt_base = jnp.ones(images_shape[0], dtype=jnp.int32) * dt_flow
138
+ # print(dt_base)
139
+ t_vector, dt_base = shard_data(t_vector, dt_base)
140
+ if cfg_scale == 1:
141
+ v = call_model(train_state, x, t_vector, dt_base, labels, perturbe = True)#True really just means (conditional)
142
+ elif cfg_scale == 0:
143
+ v = call_model(train_state, x, t_vector, dt_base, labels_uncond)
144
+ else:
145
+ v_pred_uncond = call_model(train_state, x, t_vector, dt_base, labels_uncond)
146
+ v_pred_label = call_model(train_state, x, t_vector, dt_base, labels)
147
+ v = v_pred_uncond + cfg_scale * (v_pred_label - v_pred_uncond)
148
+
149
+ if FLAGS.model.train_type == 'consistency':
150
+ eps = shard_data(jax.random.normal(jax.random.fold_in(eps_key, ti), images_shape))
151
+ x1pred = x + v * (1-t)
152
+ x = x1pred * (t+delta_t) + eps * (1-t-delta_t)
153
+ elif True:
154
+ x = x + v * delta_t # Euler sampling.
155
+ elif False:
156
+
157
+ def get_ancestral_step(t0, t1):
158
+ sigma_up = None
159
+ return 1 / (1 + ((t0 ** 2 * (t1 - 1) ** 4) / ((t0 - 1) ** 2 * t1 ** 4)) ** 0.5), sigma_up
160
+ # def flow_sample_sde_3(model, x, ts):
161
+ #for s, t in tqdm(zip(ts[:-1], ts[1:]), total=len(ts) - 1):
162
+ # dx = model(x, s)
163
+ # denoised = x + dx * (1 - s)
164
+ # noise = torch.randn_like(x)
165
+ # fac_1 = (s * (1 - t) ** 2) / ((1 - s) ** 2 * t)
166
+ # fac_2 = (t ** 2 - 2 * s * t ** 2 + s ** 2 * (2 * t - 1)) / ((1 - s) ** 2 * t)
167
+ # fac_3 = (1 - t) * (fac_2 / t) ** 0.5
168
+ # x = fac_1 * x + fac_2 * denoised + fac_3 * noise
169
+ #return x
170
+ #So our timesteps looks like 0, 1/128..
171
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
172
+ # Euler method
173
+ dt = sigma_down - sigmas[i]
174
+ #Naive up
175
+ sigma_up = sigmas[i+1] - dt
176
+
177
+ x = x + v * dt
178
+ if sigmas[i + 1] != 1.0:
179
+ x = x + jax.random.normal(eps_key, images_shape) * sigma_up * v
180
+
181
+ i += 1
182
+ x1.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
183
+ lab.append(np.array(jax.experimental.multihost_utils.process_allgather(labels)))
184
+ if FLAGS.model.use_stable_vae:
185
+ x = vae_decode(x) # Image is in [-1, 1] space.
186
+ if num_generations < 10000:
187
+ x_render.append(np.array(jax.experimental.multihost_utils.process_allgather(x)))
188
+ #save some number of x
189
+ #What is x shape?
190
+ x = jax.image.resize(x, (x.shape[0], 299, 299, 3), method='bilinear', antialias=False)
191
+ x = jnp.clip(x, -1, 1)
192
+ acts = get_fid_activations(x)[..., 0, 0, :] # [devices, batch//devices, 2048]
193
+ acts = jax.experimental.multihost_utils.process_allgather(acts)
194
+ acts = np.array(acts)
195
+ activations.append(acts)
196
+
197
+ if jax.process_index() == 0:
198
+ activations = np.concatenate(activations, axis=0)
199
+ activations = activations.reshape((-1, activations.shape[-1]))
200
+ mu1 = np.mean(activations, axis=0)
201
+ sigma1 = np.cov(activations, rowvar=False)
202
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
203
+ print(f"FID is {fid}")
204
+ return
205
+
206
+ if FLAGS.save_dir is not None:
207
+ os.makedirs(FLAGS.save_dir, exist_ok=True)
208
+ x_render = np.concatenate(x_render, axis=0)
209
+ np.save(FLAGS.save_dir + f'/x_render.npy', x_render)
210
+
sharpness/model.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Callable, Optional, Tuple, Type, Sequence, Union
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from einops import rearrange
7
+
8
+ Array = Any
9
+ PRNGKey = Any
10
+ Shape = Tuple[int]
11
+ Dtype = Any
12
+
13
+ from math_utils import get_2d_sincos_pos_embed, modulate
14
+ from jax._src import core
15
+ from jax._src import dtypes
16
+ from jax._src.nn.initializers import _compute_fans
17
+
18
+ def xavier_uniform_pytorchlike():
19
+ def init(key, shape, dtype):
20
+ dtype = dtypes.canonicalize_dtype(dtype)
21
+ #named_shape = core.as_named_shape(shape)
22
+ if len(shape) == 2: # Dense, [in, out]
23
+ fan_in = shape[0]
24
+ fan_out = shape[1]
25
+ elif len(shape) == 4: # Conv, [k, k, in, out]. Assumes patch-embed style conv.
26
+ fan_in = shape[0] * shape[1] * shape[2]
27
+ fan_out = shape[3]
28
+ else:
29
+ raise ValueError(f"Invalid shape {shape}")
30
+
31
+ variance = 2 / (fan_in + fan_out)
32
+ scale = jnp.sqrt(3 * variance)
33
+ param = jax.random.uniform(key, shape, dtype, -1) * scale
34
+
35
+ return param
36
+ return init
37
+
38
+
39
+ class TrainConfig:
40
+ def __init__(self, dtype):
41
+ self.dtype = dtype
42
+ def kern_init(self, name='default', zero=False):
43
+ if zero or 'bias' in name:
44
+ return nn.initializers.constant(0)
45
+ return xavier_uniform_pytorchlike()
46
+ def default_config(self):
47
+ return {
48
+ 'kernel_init': self.kern_init(),
49
+ 'bias_init': self.kern_init('bias', zero=True),
50
+ 'dtype': self.dtype,
51
+ }
52
+
53
+ class TimestepEmbedder(nn.Module):
54
+ """
55
+ Embeds scalar timesteps into vector representations.
56
+ """
57
+ hidden_size: int
58
+ tc: TrainConfig
59
+ frequency_embedding_size: int = 256
60
+
61
+ @nn.compact
62
+ def __call__(self, t):
63
+ x = self.timestep_embedding(t)
64
+ x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
65
+ bias_init=self.tc.kern_init('time_bias'), dtype=self.tc.dtype)(x)
66
+ x = nn.silu(x)
67
+ x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
68
+ bias_init=self.tc.kern_init('time_bias'))(x)
69
+ return x
70
+
71
+ # t is between [0, 1].
72
+ def timestep_embedding(self, t, max_period=10000):
73
+ """
74
+ Create sinusoidal timestep embeddings.
75
+ :param t: a 1-D Tensor of N indices, one per batch element.
76
+ These may be fractional.
77
+ :param dim: the dimension of the output.
78
+ :param max_period: controls the minimum frequency of the embeddings.
79
+ :return: an (N, D) Tensor of positional embeddings.
80
+ """
81
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
82
+ t = jax.lax.convert_element_type(t, jnp.float32)
83
+ # t = t * max_period
84
+ dim = self.frequency_embedding_size
85
+ half = dim // 2
86
+ freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half)
87
+ args = t[:, None] * freqs[None]
88
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
89
+ embedding = embedding.astype(self.tc.dtype)
90
+ return embedding
91
+
92
+ class LabelEmbedder(nn.Module):
93
+ """
94
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
95
+ """
96
+ num_classes: int
97
+ hidden_size: int
98
+ tc: TrainConfig
99
+
100
+ @nn.compact
101
+ def __call__(self, labels):
102
+ embedding_table = nn.Embed(self.num_classes + 1, self.hidden_size,
103
+ embedding_init=nn.initializers.normal(0.02), dtype=self.tc.dtype)
104
+ embeddings = embedding_table(labels)
105
+ return embeddings
106
+
107
+ class PatchEmbed(nn.Module):
108
+ """ 2D Image to Patch Embedding """
109
+ patch_size: int
110
+ hidden_size: int
111
+ tc: TrainConfig
112
+ bias: bool = True
113
+
114
+ @nn.compact
115
+ def __call__(self, x):
116
+ B, H, W, C = x.shape
117
+ patch_tuple = (self.patch_size, self.patch_size)
118
+ num_patches = (H // self.patch_size)
119
+ x = nn.Conv(self.hidden_size, patch_tuple, patch_tuple, use_bias=self.bias, padding="VALID",
120
+ kernel_init=self.tc.kern_init('patch'), bias_init=self.tc.kern_init('patch_bias', zero=True),
121
+ dtype=self.tc.dtype)(x) # (B, P, P, hidden_size)
122
+ x = rearrange(x, 'b h w c -> b (h w) c', h=num_patches, w=num_patches)
123
+ return x
124
+
125
+ class MlpBlock(nn.Module):
126
+ """Transformer MLP / feed-forward block."""
127
+ mlp_dim: int
128
+ tc: TrainConfig
129
+ out_dim: Optional[int] = None
130
+ dropout_rate: float = None
131
+ train: bool = False
132
+
133
+ @nn.compact
134
+ def __call__(self, inputs):
135
+ """It's just an MLP, so the input shape is (batch, len, emb)."""
136
+ actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
137
+ x = nn.Dense(features=self.mlp_dim, **self.tc.default_config())(inputs)
138
+ x = nn.gelu(x)
139
+ x = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(x)
140
+ output = nn.Dense(features=actual_out_dim, **self.tc.default_config())(x)
141
+ output = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(output)
142
+ return output
143
+
144
+ def modulate(x, shift, scale):
145
+ # scale = jnp.clip(scale, -1, 1)
146
+ return x * (1 + scale[:, None]) + shift[:, None]
147
+
148
+ ################################################################################
149
+ # Core DiT Model #
150
+ #################################################################################
151
+
152
+ class DiTBlock(nn.Module):
153
+ """
154
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
155
+ """
156
+ hidden_size: int
157
+ num_heads: int
158
+ tc: TrainConfig
159
+ mlp_ratio: float = 4.0
160
+ dropout: float = 0.0
161
+ train: bool = False
162
+
163
+ # @functools.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
164
+ @nn.compact
165
+ def __call__(self, x, c):
166
+ # Calculate adaLn modulation parameters.
167
+ c = nn.silu(c)
168
+ c = nn.Dense(6 * self.hidden_size, **self.tc.default_config())(c)
169
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(c, 6, axis=-1)
170
+
171
+ # Attention Residual.
172
+ x_norm = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
173
+ x_modulated = modulate(x_norm, shift_msa, scale_msa)
174
+ channels_per_head = self.hidden_size // self.num_heads
175
+ k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
176
+ q = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
177
+ v = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
178
+ k = jnp.reshape(k, (k.shape[0], k.shape[1], self.num_heads, channels_per_head))
179
+ q = jnp.reshape(q, (q.shape[0], q.shape[1], self.num_heads, channels_per_head))
180
+ v = jnp.reshape(v, (v.shape[0], v.shape[1], self.num_heads, channels_per_head))
181
+ q = q / q.shape[3] # (1/d) scaling.
182
+ w = jnp.einsum('bqhc,bkhc->bhqk', q, k) # [B, HW, HW, num_heads]
183
+ w = w.astype(jnp.float32)
184
+ w = nn.softmax(w, axis=-1)
185
+ y = jnp.einsum('bhqk,bkhc->bqhc', w, v) # [B, HW, num_heads, channels_per_head]
186
+ y = jnp.reshape(y, x.shape) # [B, H, W, C] (C = heads * channels_per_head)
187
+ attn_x = nn.Dense(self.hidden_size, **self.tc.default_config())(y)
188
+ x = x + (gate_msa[:, None] * attn_x)
189
+
190
+ # MLP Residual.
191
+ x_norm2 = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
192
+ x_modulated2 = modulate(x_norm2, shift_mlp, scale_mlp)
193
+ mlp_x = MlpBlock(mlp_dim=int(self.hidden_size * self.mlp_ratio), tc=self.tc,
194
+ dropout_rate=self.dropout, train=self.train)(x_modulated2)
195
+ x = x + (gate_mlp[:, None] * mlp_x)
196
+ return x
197
+
198
+ class FinalLayer(nn.Module):
199
+ """
200
+ The final layer of DiT.
201
+ """
202
+ patch_size: int
203
+ out_channels: int
204
+ hidden_size: int
205
+ tc: TrainConfig
206
+
207
+ @nn.compact
208
+ def __call__(self, x, c):
209
+ c = nn.silu(c)
210
+ c = nn.Dense(2 * self.hidden_size, kernel_init=self.tc.kern_init(zero=True),
211
+ bias_init=self.tc.kern_init('bias', zero=True), dtype=self.tc.dtype)(c)
212
+ shift, scale = jnp.split(c, 2, axis=-1)
213
+ x = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
214
+ x = modulate(x, shift, scale)
215
+ x = nn.Dense(self.patch_size * self.patch_size * self.out_channels,
216
+ kernel_init=self.tc.kern_init('final', zero=True),
217
+ bias_init=self.tc.kern_init('final_bias', zero=True), dtype=self.tc.dtype)(x)
218
+ return x
219
+
220
+
221
+ import jax
222
+ import jax.numpy as jnp
223
+
224
+ def apply_label_embedding_noise(key, label_embeddings):
225
+ """
226
+ Applies Gaussian noise to label embeddings based on specified probabilities.
227
+
228
+ Args:
229
+ key: A JAX random key.
230
+ label_embeddings: A JAX array of shape (batch_size, embedding_dim),
231
+ representing the label embeddings.
232
+
233
+ Returns:
234
+ A tuple containing:
235
+ - noisy_label_embeddings: The label embeddings with noise applied.
236
+ - noise_levels: A JAX array of shape (batch_size,), indicating
237
+ the alpha value used for each sample (1.0 for no noise,
238
+ 0.0 for 100% noise, or a uniform sample for partial noise).
239
+ """
240
+ batch_size, embedding_dim = label_embeddings.shape
241
+
242
+ # Split key for different random operations
243
+ key, noise_type_key, alpha_key, normal_key = jax.random.split(key, 4)
244
+
245
+ # Determine noise application type for each sample
246
+ # 0: 100% noise (alpha = 0)
247
+ # 1: Partial noise (alpha uniformly 0-1)
248
+ # 2: No noise (do nothing)
249
+ noise_type_choices = jax.random.choice(
250
+ noise_type_key,
251
+ a=jnp.array([0, 1, 2]),
252
+ shape=(batch_size,),
253
+ p=jnp.array([0.00, 0.10, 0.90])
254
+ )
255
+
256
+ # Initialize noise_levels to 1.0 (no noise)
257
+ noise_levels = jnp.ones(batch_size, dtype=label_embeddings.dtype)
258
+
259
+ # Generate alpha values for partial noise
260
+ sampled_alphas = jax.random.uniform(alpha_key, shape=(batch_size,), minval=0.0, maxval=1.0)
261
+
262
+ # Generate Gaussian noise for the entire batch
263
+ # We assume a standard deviation of 1 for the noise, you might want to adjust this.
264
+ gaussian_noise = jax.random.normal(normal_key, shape=label_embeddings.shape)
265
+
266
+ # Initialize noisy_label_embeddings
267
+ noisy_label_embeddings = label_embeddings
268
+
269
+ # Apply 100% noise
270
+ cond_100_percent_noise = (noise_type_choices == 0)
271
+ noisy_label_embeddings = jnp.where(
272
+ cond_100_percent_noise[:, None], # Expand dim for broadcasting
273
+ gaussian_noise,
274
+ noisy_label_embeddings
275
+ )
276
+ noise_levels = jnp.where(cond_100_percent_noise, 0.0, noise_levels)
277
+
278
+ # Apply partial noise
279
+ cond_partial_noise = (noise_type_choices == 1)
280
+ # Reshape sampled_alphas for broadcasting
281
+ alpha_reshaped = sampled_alphas[:, None]
282
+ noisy_label_embeddings = jnp.where(
283
+ cond_partial_noise[:, None],
284
+ label_embeddings * alpha_reshaped + gaussian_noise * (1.0 - alpha_reshaped),
285
+ noisy_label_embeddings
286
+ )
287
+ noise_levels = jnp.where(cond_partial_noise, sampled_alphas, noise_levels)
288
+
289
+ # For cond_no_noise (noise_type_choices == 2), noisy_label_embeddings remains
290
+ # label_embeddings and noise_levels remains 1.0, so no specific action needed.
291
+ return noisy_label_embeddings, noise_levels, key
292
+
293
+ class DiT(nn.Module):
294
+ """
295
+ Diffusion model with a Transformer backbone.
296
+ """
297
+ patch_size: int
298
+ hidden_size: int
299
+ depth: int
300
+ num_heads: int
301
+ mlp_ratio: float
302
+ out_channels: int
303
+ class_dropout_prob: float
304
+ num_classes: int
305
+ ignore_dt: bool = False
306
+ dropout: float = 0.0
307
+ dtype: Dtype = jnp.bfloat16
308
+
309
+ @nn.compact
310
+ def __call__(self, x, t, dt, y, train=False, return_activations=False, perturbe = True):
311
+ # (x = (B, H, W, C) image, t = (B,) timesteps, y = (B,) class labels)
312
+ print("DiT: Input of shape", x.shape, "dtype", x.dtype)
313
+ activations = {}
314
+
315
+ key = self.make_rng("label")
316
+
317
+ batch_size = x.shape[0]
318
+ input_size = x.shape[1]
319
+ in_channels = x.shape[-1]
320
+ num_patches = (input_size // self.patch_size) ** 2
321
+ num_patches_side = input_size // self.patch_size
322
+ tc = TrainConfig(dtype=self.dtype)
323
+
324
+ if self.ignore_dt:
325
+ dt = jnp.zeros_like(t)
326
+
327
+ # pos_embed = self.param("pos_embed", get_2d_sincos_pos_embed, self.hidden_size, num_patches)
328
+ # pos_embed = jax.lax.stop_gradient(pos_embed)
329
+ pos_embed = get_2d_sincos_pos_embed(None, self.hidden_size, num_patches)
330
+ x = PatchEmbed(self.patch_size, self.hidden_size, tc=tc)(x) # (B, num_patches, hidden_size)
331
+ print("DiT: After patch embed, shape is", x.shape, "dtype", x.dtype)
332
+ activations['patch_embed'] = x
333
+
334
+ x = x + pos_embed
335
+ x = x.astype(self.dtype)
336
+ te = TimestepEmbedder(self.hidden_size, tc=tc)(t) # (B, hidden_size)
337
+ dte = TimestepEmbedder(self.hidden_size, tc=tc)(dt) # (B, hidden_size)
338
+ ye = LabelEmbedder(self.num_classes, self.hidden_size, tc=tc)(y) # (B, hidden_size)
339
+
340
+
341
+
342
+ # ye_g = TimestepEmbedder(self.hidden_size,tc=tc)
343
+ #CFG free, here!
344
+ #So we set CFG % to 0 during training
345
+ #Instead, we will apply gaussian noise to the label embeddings, and condition... somewhere, on that.
346
+
347
+
348
+ #So the perturbed version uses cfg between conditional and conditional, except the second one uses condition_amount = ones
349
+ #So we use condition_amount = zeros, then condition_amount = ones.
350
+ #Not sure how we indicate training mode. Maybe -1?
351
+ #x = int(x == 'true')
352
+
353
+ #Now we need a way to condition the forward pass..
354
+
355
+ def adjust_condition_amount(train, peturbe, condition_amount):
356
+ def true_fn(_):
357
+ return jnp.ones_like(condition_amount) # peturbe is True → ones
358
+
359
+ def false_fn(_):
360
+ return jnp.zeros_like(condition_amount) # peturbe is False → zeros
361
+
362
+ def train_false_branch(_):
363
+ return jax.lax.cond(peturbe, true_fn, false_fn, operand=None)
364
+
365
+ def train_true_branch(_):
366
+ return condition_amount # leave it unchanged during training
367
+
368
+ return jax.lax.cond(train, train_true_branch, train_false_branch, operand=None)
369
+
370
+ #When perturbe is true, we return ones = no noise
371
+ #When false, return zeros = full noise.
372
+ #For NON training, we don't want to actually modify the labels, only the conditioning.
373
+ #So default during training is apply
374
+ def apply_fn(key, ye, train):
375
+ def true_branch(args):
376
+ key, ye = args
377
+ ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye)
378
+ return ye_new.astype(jnp.float32), condition_amount, key_new
379
+
380
+ def false_branch(args):
381
+ key, ye = args
382
+ ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye)
383
+ return ye.astype(jnp.float32), condition_amount, key_new
384
+
385
+ return jax.lax.cond(train, true_branch, false_branch, (key, ye))
386
+
387
+ print("train is", train)#False
388
+ print("perturbe is", perturbe)#False right now (it's getting passed properly)
389
+ print("initial ye", ye[0][0:10])
390
+ ye, condition_amount, key = apply_fn(key, ye, train)
391
+ print("new ye", ye[0][0:10])
392
+ print("condition amount", condition_amount)
393
+ condition_amount = adjust_condition_amount(train, perturbe, condition_amount)
394
+ print("adjusted", condition_amount)
395
+
396
+
397
+ ye_g = TimestepEmbedder(self.hidden_size, tc=tc)(condition_amount)
398
+
399
+ c = te + ye + dte + ye_g
400
+
401
+
402
+ activations['pos_embed'] = pos_embed
403
+ activations['time_embed'] = te
404
+ activations['dt_embed'] = dte
405
+ activations['label_embed'] = ye
406
+ activations['conditioning'] = c
407
+
408
+ print("DiT: Patch Embed of shape", x.shape, "dtype", x.dtype)
409
+ print("DiT: Conditioning of shape", c.shape, "dtype", c.dtype)
410
+ for i in range(self.depth):
411
+ x = DiTBlock(self.hidden_size, self.num_heads, tc, self.mlp_ratio, self.dropout, train)(x, c)
412
+ activations[f'dit_block_{i}'] = x
413
+ x = FinalLayer(self.patch_size, self.out_channels, self.hidden_size, tc)(x, c) # (B, num_patches, p*p*c)
414
+ activations['final_layer'] = x
415
+ # print("DiT: FinalLayer of shape", x.shape, "dtype", x.dtype)
416
+ x = jnp.reshape(x, (batch_size, num_patches_side, num_patches_side,
417
+ self.patch_size, self.patch_size, self.out_channels))
418
+ x = jnp.einsum('bhwpqc->bhpwqc', x)
419
+ x = rearrange(x, 'B H P W Q C -> B (H P) (W Q) C', H=int(num_patches_side), W=int(num_patches_side))
420
+ assert x.shape == (batch_size, input_size, input_size, self.out_channels)
421
+
422
+ t_discrete = jnp.floor(t * 256).astype(jnp.int32)
423
+ logvars = nn.Embed(256, 1, embedding_init=nn.initializers.constant(0))(t_discrete) * 100
424
+
425
+ if return_activations:
426
+ return x, logvars, activations
427
+ return x#, dte, te