KublaiKhan1 commited on
Commit
7ecf49d
·
verified ·
1 Parent(s): 51e86ae

Upload folder using huggingface_hub

Browse files
learned_cfg/model.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Callable, Optional, Tuple, Type, Sequence, Union
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from einops import rearrange
7
+
8
+ Array = Any
9
+ PRNGKey = Any
10
+ Shape = Tuple[int]
11
+ Dtype = Any
12
+
13
+ from math_utils import get_2d_sincos_pos_embed, modulate
14
+ from jax._src import core
15
+ from jax._src import dtypes
16
+ from jax._src.nn.initializers import _compute_fans
17
+
18
+ def xavier_uniform_pytorchlike():
19
+ def init(key, shape, dtype):
20
+ dtype = dtypes.canonicalize_dtype(dtype)
21
+ #named_shape = core.as_named_shape(shape)
22
+ if len(shape) == 2: # Dense, [in, out]
23
+ fan_in = shape[0]
24
+ fan_out = shape[1]
25
+ elif len(shape) == 4: # Conv, [k, k, in, out]. Assumes patch-embed style conv.
26
+ fan_in = shape[0] * shape[1] * shape[2]
27
+ fan_out = shape[3]
28
+ else:
29
+ raise ValueError(f"Invalid shape {shape}")
30
+
31
+ variance = 2 / (fan_in + fan_out)
32
+ scale = jnp.sqrt(3 * variance)
33
+ param = jax.random.uniform(key, shape, dtype, -1) * scale
34
+
35
+ return param
36
+ return init
37
+
38
+
39
+ class TrainConfig:
40
+ def __init__(self, dtype):
41
+ self.dtype = dtype
42
+ def kern_init(self, name='default', zero=False):
43
+ if zero or 'bias' in name:
44
+ return nn.initializers.constant(0)
45
+ return xavier_uniform_pytorchlike()
46
+ def default_config(self):
47
+ return {
48
+ 'kernel_init': self.kern_init(),
49
+ 'bias_init': self.kern_init('bias', zero=True),
50
+ 'dtype': self.dtype,
51
+ }
52
+
53
+ class TimestepEmbedder(nn.Module):
54
+ """
55
+ Embeds scalar timesteps into vector representations.
56
+ """
57
+ hidden_size: int
58
+ tc: TrainConfig
59
+ frequency_embedding_size: int = 256
60
+
61
+ @nn.compact
62
+ def __call__(self, t):
63
+ x = self.timestep_embedding(t)
64
+ x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
65
+ bias_init=self.tc.kern_init('time_bias'), dtype=self.tc.dtype)(x)
66
+ x = nn.silu(x)
67
+ x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
68
+ bias_init=self.tc.kern_init('time_bias'))(x)
69
+ return x
70
+
71
+ # t is between [0, 1].
72
+ def timestep_embedding(self, t, max_period=10000):
73
+ """
74
+ Create sinusoidal timestep embeddings.
75
+ :param t: a 1-D Tensor of N indices, one per batch element.
76
+ These may be fractional.
77
+ :param dim: the dimension of the output.
78
+ :param max_period: controls the minimum frequency of the embeddings.
79
+ :return: an (N, D) Tensor of positional embeddings.
80
+ """
81
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
82
+ t = jax.lax.convert_element_type(t, jnp.float32)
83
+ # t = t * max_period
84
+ dim = self.frequency_embedding_size
85
+ half = dim // 2
86
+ freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half)
87
+ args = t[:, None] * freqs[None]
88
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
89
+ embedding = embedding.astype(self.tc.dtype)
90
+ return embedding
91
+
92
+ class LabelEmbedder(nn.Module):
93
+ """
94
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
95
+ """
96
+ num_classes: int
97
+ hidden_size: int
98
+ tc: TrainConfig
99
+
100
+ @nn.compact
101
+ def __call__(self, labels):
102
+ embedding_table = nn.Embed(self.num_classes + 1, self.hidden_size,
103
+ embedding_init=nn.initializers.normal(0.02), dtype=self.tc.dtype)
104
+ embeddings = embedding_table(labels)
105
+ return embeddings
106
+
107
+ class PatchEmbed(nn.Module):
108
+ """ 2D Image to Patch Embedding """
109
+ patch_size: int
110
+ hidden_size: int
111
+ tc: TrainConfig
112
+ bias: bool = True
113
+
114
+ @nn.compact
115
+ def __call__(self, x):
116
+ B, H, W, C = x.shape
117
+ patch_tuple = (self.patch_size, self.patch_size)
118
+ num_patches = (H // self.patch_size)
119
+ x = nn.Conv(self.hidden_size, patch_tuple, patch_tuple, use_bias=self.bias, padding="VALID",
120
+ kernel_init=self.tc.kern_init('patch'), bias_init=self.tc.kern_init('patch_bias', zero=True),
121
+ dtype=self.tc.dtype)(x) # (B, P, P, hidden_size)
122
+ x = rearrange(x, 'b h w c -> b (h w) c', h=num_patches, w=num_patches)
123
+ return x
124
+
125
+ class MlpBlock(nn.Module):
126
+ """Transformer MLP / feed-forward block."""
127
+ mlp_dim: int
128
+ tc: TrainConfig
129
+ out_dim: Optional[int] = None
130
+ dropout_rate: float = None
131
+ train: bool = False
132
+
133
+ @nn.compact
134
+ def __call__(self, inputs):
135
+ """It's just an MLP, so the input shape is (batch, len, emb)."""
136
+ actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
137
+ x = nn.Dense(features=self.mlp_dim, **self.tc.default_config())(inputs)
138
+ x = nn.gelu(x)
139
+ x = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(x)
140
+ output = nn.Dense(features=actual_out_dim, **self.tc.default_config())(x)
141
+ output = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(output)
142
+ return output
143
+
144
+ def modulate(x, shift, scale):
145
+ # scale = jnp.clip(scale, -1, 1)
146
+ return x * (1 + scale[:, None]) + shift[:, None]
147
+
148
+ ################################################################################
149
+ # Core DiT Model #
150
+ #################################################################################
151
+
152
+ class DiTBlock(nn.Module):
153
+ """
154
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
155
+ """
156
+ hidden_size: int
157
+ num_heads: int
158
+ tc: TrainConfig
159
+ mlp_ratio: float = 4.0
160
+ dropout: float = 0.0
161
+ train: bool = False
162
+
163
+ # @functools.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
164
+ @nn.compact
165
+ def __call__(self, x, c):
166
+ # Calculate adaLn modulation parameters.
167
+ c = nn.silu(c)
168
+ c = nn.Dense(6 * self.hidden_size, **self.tc.default_config())(c)
169
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(c, 6, axis=-1)
170
+
171
+ # Attention Residual.
172
+ x_norm = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
173
+ x_modulated = modulate(x_norm, shift_msa, scale_msa)
174
+ channels_per_head = self.hidden_size // self.num_heads
175
+ k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
176
+ q = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
177
+ v = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
178
+ k = jnp.reshape(k, (k.shape[0], k.shape[1], self.num_heads, channels_per_head))
179
+ q = jnp.reshape(q, (q.shape[0], q.shape[1], self.num_heads, channels_per_head))
180
+ v = jnp.reshape(v, (v.shape[0], v.shape[1], self.num_heads, channels_per_head))
181
+ q = q / q.shape[3] # (1/d) scaling.
182
+ w = jnp.einsum('bqhc,bkhc->bhqk', q, k) # [B, HW, HW, num_heads]
183
+ w = w.astype(jnp.float32)
184
+ w = nn.softmax(w, axis=-1)
185
+ y = jnp.einsum('bhqk,bkhc->bqhc', w, v) # [B, HW, num_heads, channels_per_head]
186
+ y = jnp.reshape(y, x.shape) # [B, H, W, C] (C = heads * channels_per_head)
187
+ attn_x = nn.Dense(self.hidden_size, **self.tc.default_config())(y)
188
+ x = x + (gate_msa[:, None] * attn_x)
189
+
190
+ # MLP Residual.
191
+ x_norm2 = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
192
+ x_modulated2 = modulate(x_norm2, shift_mlp, scale_mlp)
193
+ mlp_x = MlpBlock(mlp_dim=int(self.hidden_size * self.mlp_ratio), tc=self.tc,
194
+ dropout_rate=self.dropout, train=self.train)(x_modulated2)
195
+ x = x + (gate_mlp[:, None] * mlp_x)
196
+ return x
197
+
198
+ class FinalLayer(nn.Module):
199
+ """
200
+ The final layer of DiT.
201
+ """
202
+ patch_size: int
203
+ out_channels: int
204
+ hidden_size: int
205
+ tc: TrainConfig
206
+
207
+ @nn.compact
208
+ def __call__(self, x, c):
209
+ c = nn.silu(c)
210
+ c = nn.Dense(2 * self.hidden_size, kernel_init=self.tc.kern_init(zero=True),
211
+ bias_init=self.tc.kern_init('bias', zero=True), dtype=self.tc.dtype)(c)
212
+ shift, scale = jnp.split(c, 2, axis=-1)
213
+ x = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
214
+ x = modulate(x, shift, scale)
215
+ x = nn.Dense(self.patch_size * self.patch_size * self.out_channels,
216
+ kernel_init=self.tc.kern_init('final', zero=True),
217
+ bias_init=self.tc.kern_init('final_bias', zero=True), dtype=self.tc.dtype)(x)
218
+ return x
219
+
220
+ class DiT(nn.Module):
221
+ """
222
+ Diffusion model with a Transformer backbone.
223
+ """
224
+ patch_size: int
225
+ hidden_size: int
226
+ depth: int
227
+ num_heads: int
228
+ mlp_ratio: float
229
+ out_channels: int
230
+ class_dropout_prob: float
231
+ num_classes: int
232
+ ignore_dt: bool = False
233
+ dropout: float = 0.0
234
+ dtype: Dtype = jnp.bfloat16
235
+ init_cfg_scale: float = 1.5
236
+
237
+ @nn.compact
238
+ def __call__(self, x, t, dt, y, train=False, return_activations=False, perturbe = False):
239
+ # (x = (B, H, W, C) image, t = (B,) timesteps, y = (B,) class labels)
240
+ print("DiT: Input of shape", x.shape, "dtype", x.dtype)
241
+ activations = {}
242
+
243
+ #cfg weight, only if learned
244
+ """cfg_weight = self.param('cfg_weight',
245
+ lambda rng, shape: jnp.ones([1]) * self.init_cfg_scale,
246
+ (1,))
247
+ """
248
+
249
+ batch_size = x.shape[0]
250
+ input_size = x.shape[1]
251
+ in_channels = x.shape[-1]
252
+ num_patches = (input_size // self.patch_size) ** 2
253
+ num_patches_side = input_size // self.patch_size
254
+ tc = TrainConfig(dtype=self.dtype)
255
+
256
+ if self.ignore_dt:
257
+ dt = jnp.zeros_like(t)
258
+
259
+ # pos_embed = self.param("pos_embed", get_2d_sincos_pos_embed, self.hidden_size, num_patches)
260
+ # pos_embed = jax.lax.stop_gradient(pos_embed)
261
+ pos_embed = get_2d_sincos_pos_embed(None, self.hidden_size, num_patches)
262
+ x = PatchEmbed(self.patch_size, self.hidden_size, tc=tc)(x) # (B, num_patches, hidden_size)
263
+ print("DiT: After patch embed, shape is", x.shape, "dtype", x.dtype)
264
+ activations['patch_embed'] = x
265
+
266
+ #Pertube
267
+ #result = jnp.array(jnp.logical_not(perturbe), dtype=int)
268
+ #dt = dt * result#So this was effectively cond + dt 0 instead of 7. FID was like 100.
269
+
270
+ #Let's try modifying the label embedding, adding noise?
271
+ x = x + pos_embed
272
+ x = x.astype(self.dtype)
273
+ te = TimestepEmbedder(self.hidden_size, tc=tc)(t) # (B, hidden_size)
274
+ dte = TimestepEmbedder(self.hidden_size, tc=tc)(dt) # (B, hidden_size)
275
+ ye = LabelEmbedder(self.num_classes, self.hidden_size, tc=tc)(y) # (B, hidden_size)
276
+
277
+
278
+ result = jnp.array(perturbe)
279
+ #Create noise, multiply noise by perturbe, linear interpolation of ye with noise
280
+
281
+
282
+ c = te + ye + dte
283
+
284
+ activations['pos_embed'] = pos_embed
285
+ activations['time_embed'] = te
286
+ activations['dt_embed'] = dte
287
+ activations['label_embed'] = ye
288
+ activations['conditioning'] = c
289
+
290
+ print("DiT: Patch Embed of shape", x.shape, "dtype", x.dtype)
291
+ print("DiT: Conditioning of shape", c.shape, "dtype", c.dtype)
292
+ for i in range(self.depth):
293
+ x = DiTBlock(self.hidden_size, self.num_heads, tc, self.mlp_ratio, self.dropout, train)(x, c)
294
+ activations[f'dit_block_{i}'] = x
295
+ x = FinalLayer(self.patch_size, self.out_channels, self.hidden_size, tc)(x, c) # (B, num_patches, p*p*c)
296
+ activations['final_layer'] = x
297
+ # print("DiT: FinalLayer of shape", x.shape, "dtype", x.dtype)
298
+ x = jnp.reshape(x, (batch_size, num_patches_side, num_patches_side,
299
+ self.patch_size, self.patch_size, self.out_channels))
300
+ x = jnp.einsum('bhwpqc->bhpwqc', x)
301
+ x = rearrange(x, 'B H P W Q C -> B (H P) (W Q) C', H=int(num_patches_side), W=int(num_patches_side))
302
+ assert x.shape == (batch_size, input_size, input_size, self.out_channels)
303
+
304
+ t_discrete = jnp.floor(t * 256).astype(jnp.int32)
305
+ logvars = nn.Embed(256, 1, embedding_init=nn.initializers.constant(0))(t_discrete) * 100
306
+
307
+ if return_activations:
308
+ return x, logvars, activations
309
+ return x
learned_cfg/targets_shortcut.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+ def get_targets(FLAGS, key, train_state, images, labels, force_t=-1, force_dt=-1, cfg_scale = None):
6
+ label_key, time_key, noise_key = jax.random.split(key, 3)
7
+ info = {}
8
+
9
+ # 1) =========== Sample dt. ============
10
+ bootstrap_batchsize = FLAGS.batch_size // FLAGS.model['bootstrap_every']
11
+ log2_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(np.int32)
12
+ if FLAGS.model['bootstrap_dt_bias'] == 0:
13
+ dt_base = jnp.repeat(log2_sections - 1 - jnp.arange(log2_sections), bootstrap_batchsize // log2_sections)
14
+ dt_base = jnp.concatenate([dt_base, jnp.zeros(bootstrap_batchsize-dt_base.shape[0],)])
15
+ num_dt_cfg = bootstrap_batchsize // log2_sections
16
+ else:
17
+ dt_base = jnp.repeat(log2_sections - 1 - jnp.arange(log2_sections-2), (bootstrap_batchsize // 2) // log2_sections)
18
+ dt_base = jnp.concatenate([dt_base, jnp.ones(bootstrap_batchsize // 4), jnp.zeros(bootstrap_batchsize // 4)])
19
+ dt_base = jnp.concatenate([dt_base, jnp.zeros(bootstrap_batchsize-dt_base.shape[0],)])
20
+ num_dt_cfg = (bootstrap_batchsize // 2) // log2_sections
21
+ force_dt_vec = jnp.ones(bootstrap_batchsize, dtype=jnp.float32) * force_dt
22
+ dt_base = jnp.where(force_dt_vec != -1, force_dt_vec, dt_base)
23
+ dt = 1 / (2 ** (dt_base)) # [1, 1/2, 1/4, 1/8, 1/16, 1/32]
24
+ dt_base_bootstrap = dt_base + 1
25
+ dt_bootstrap = dt / 2
26
+
27
+ # 2) =========== Sample t. ============
28
+ dt_sections = jnp.power(2, dt_base) # [1, 2, 4, 8, 16, 32]
29
+ t = jax.random.randint(time_key, (bootstrap_batchsize,), minval=0, maxval=dt_sections).astype(jnp.float32)
30
+ t = t / dt_sections # Between 0 and 1.
31
+ force_t_vec = jnp.ones(bootstrap_batchsize, dtype=jnp.float32) * force_t
32
+ t = jnp.where(force_t_vec != -1, force_t_vec, t)
33
+ t_full = t[:, None, None, None]
34
+
35
+ # 3) =========== Generate Bootstrap Targets ============
36
+ x_1 = images[:bootstrap_batchsize]
37
+ x_0 = jax.random.normal(noise_key, x_1.shape)
38
+ x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
39
+ bst_labels = labels[:bootstrap_batchsize]
40
+ call_model_fn = train_state.call_model if FLAGS.model['bootstrap_ema'] == 0 else train_state.call_model_ema
41
+ if not FLAGS.model['bootstrap_cfg']:
42
+ v_b1 = call_model_fn(x_t, t, dt_base_bootstrap, bst_labels, train=False)
43
+ t2 = t + dt_bootstrap
44
+ x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
45
+ x_t2 = jnp.clip(x_t2, -4, 4)
46
+ v_b2 = call_model_fn(x_t2, t2, dt_base_bootstrap, bst_labels, train=False)
47
+ v_target = (v_b1 + v_b2) / 2
48
+ else:
49
+ x_t_extra = jnp.concatenate([x_t, x_t[:num_dt_cfg]], axis=0)
50
+ t_extra = jnp.concatenate([t, t[:num_dt_cfg]], axis=0)
51
+ dt_base_extra = jnp.concatenate([dt_base_bootstrap, dt_base_bootstrap[:num_dt_cfg]], axis=0)
52
+ labels_extra = jnp.concatenate([bst_labels, jnp.ones(num_dt_cfg, dtype=jnp.int32) * FLAGS.model['num_classes']], axis=0)
53
+ v_b1_raw = call_model_fn(x_t_extra, t_extra, dt_base_extra, labels_extra, train=False)
54
+ v_b_cond = v_b1_raw[:x_1.shape[0]]
55
+ v_b_uncond = v_b1_raw[x_1.shape[0]:]
56
+ v_cfg = v_b_uncond + cfg_scale * (v_b_cond[:num_dt_cfg] - v_b_uncond)#CFG scale is now a learned parameter
57
+ v_b1 = jnp.concatenate([v_cfg, v_b_cond[num_dt_cfg:]], axis=0)
58
+
59
+ t2 = t + dt_bootstrap
60
+ x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
61
+ x_t2 = jnp.clip(x_t2, -4, 4)
62
+ x_t2_extra = jnp.concatenate([x_t2, x_t2[:num_dt_cfg]], axis=0)
63
+ t2_extra = jnp.concatenate([t2, t2[:num_dt_cfg]], axis=0)
64
+ v_b2_raw = call_model_fn(x_t2_extra, t2_extra, dt_base_extra, labels_extra, train=False)
65
+ v_b2_cond = v_b2_raw[:x_1.shape[0]]
66
+ v_b2_uncond = v_b2_raw[x_1.shape[0]:]
67
+ if False:#Not doing learned cfg scale right now
68
+ pass
69
+ v_b2_cfg = v_b2_uncond + cfg_scale * (v_b2_cond[:num_dt_cfg] - v_b2_uncond)#cfg scale is once again a learned p
70
+
71
+ v_b2 = jnp.concatenate([v_b2_cfg, v_b2_cond[num_dt_cfg:]], axis=0)
72
+ v_target = (v_b1 + v_b2) / 2
73
+
74
+ v_target = jnp.clip(v_target, -4, 4)
75
+ bst_v = v_target
76
+ bst_dt = dt_base
77
+ bst_t = t
78
+ bst_xt = x_t
79
+ bst_l = bst_labels
80
+
81
+ # 4) =========== Generate Flow-Matching Targets ============
82
+
83
+ labels_dropout = jax.random.bernoulli(label_key, FLAGS.model['class_dropout_prob'], (labels.shape[0],))
84
+ labels_dropped = jnp.where(labels_dropout, FLAGS.model['num_classes'], labels)
85
+ info['dropped_ratio'] = jnp.mean(labels_dropped == FLAGS.model['num_classes'])
86
+
87
+ # Sample t.
88
+ t = jax.random.randint(time_key, (images.shape[0],), minval=0, maxval=FLAGS.model['denoise_timesteps']).astype(jnp.float32)
89
+ t /= FLAGS.model['denoise_timesteps']
90
+
91
+ do_logit = True
92
+ if do_logit:
93
+ #Despite the fact that this actually violates our normal flow timesteps, whatever.
94
+ t = jax.random.normal(time_key, (images.shape[0],)).astype(jnp.float32)
95
+
96
+ t = 1/ (1 + jnp.exp(-t))
97
+ t = jnp.round(t * FLAGS.model["denoise_timesteps"])/FLAGS.model["denoise_timesteps"]
98
+
99
+ force_t_vec = jnp.ones(images.shape[0], dtype=jnp.float32) * force_t
100
+ t = jnp.where(force_t_vec != -1, force_t_vec, t) # If force_t is not -1, then use force_t.
101
+ t_full = t[:, None, None, None] # [batch, 1, 1, 1]
102
+
103
+ # Sample flow pairs x_t, v_t.
104
+ x_0 = jax.random.normal(noise_key, images.shape)
105
+ x_1 = images
106
+ x_t = x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
107
+ v_t = v_t = x_1 - (1 - 1e-5) * x_0
108
+ dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
109
+ dt_base = jnp.ones(images.shape[0], dtype=jnp.int32) * dt_flow
110
+
111
+ # ==== 5) Merge Flow+Bootstrap ====
112
+ bst_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
113
+ bst_size_data = FLAGS.batch_size - bst_size
114
+ x_t = jnp.concatenate([bst_xt, x_t[:bst_size_data]], axis=0)
115
+ t = jnp.concatenate([bst_t, t[:bst_size_data]], axis=0)
116
+ dt_base = jnp.concatenate([bst_dt, dt_base[:bst_size_data]], axis=0)
117
+ v_t = jnp.concatenate([bst_v, v_t[:bst_size_data]], axis=0)
118
+ labels_dropped = jnp.concatenate([bst_l, labels_dropped[:bst_size_data]], axis=0)
119
+ info['bootstrap_ratio'] = jnp.mean(dt_base != dt_flow)
120
+
121
+ info['v_magnitude_bootstrap'] = jnp.sqrt(jnp.mean(jnp.square(bst_v)))
122
+ info['v_magnitude_b1'] = jnp.sqrt(jnp.mean(jnp.square(v_b1)))
123
+ info['v_magnitude_b2'] = jnp.sqrt(jnp.mean(jnp.square(v_b2)))
124
+
125
+ return x_t, v_t, t, dt_base, labels_dropped, info