Upload folder using huggingface_hub
Browse files- learned_cfg/model.py +309 -0
- learned_cfg/targets_shortcut.py +125 -0
learned_cfg/model.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Callable, Optional, Tuple, Type, Sequence, Union
|
| 3 |
+
import flax.linen as nn
|
| 4 |
+
import jax
|
| 5 |
+
import jax.numpy as jnp
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
Array = Any
|
| 9 |
+
PRNGKey = Any
|
| 10 |
+
Shape = Tuple[int]
|
| 11 |
+
Dtype = Any
|
| 12 |
+
|
| 13 |
+
from math_utils import get_2d_sincos_pos_embed, modulate
|
| 14 |
+
from jax._src import core
|
| 15 |
+
from jax._src import dtypes
|
| 16 |
+
from jax._src.nn.initializers import _compute_fans
|
| 17 |
+
|
| 18 |
+
def xavier_uniform_pytorchlike():
|
| 19 |
+
def init(key, shape, dtype):
|
| 20 |
+
dtype = dtypes.canonicalize_dtype(dtype)
|
| 21 |
+
#named_shape = core.as_named_shape(shape)
|
| 22 |
+
if len(shape) == 2: # Dense, [in, out]
|
| 23 |
+
fan_in = shape[0]
|
| 24 |
+
fan_out = shape[1]
|
| 25 |
+
elif len(shape) == 4: # Conv, [k, k, in, out]. Assumes patch-embed style conv.
|
| 26 |
+
fan_in = shape[0] * shape[1] * shape[2]
|
| 27 |
+
fan_out = shape[3]
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Invalid shape {shape}")
|
| 30 |
+
|
| 31 |
+
variance = 2 / (fan_in + fan_out)
|
| 32 |
+
scale = jnp.sqrt(3 * variance)
|
| 33 |
+
param = jax.random.uniform(key, shape, dtype, -1) * scale
|
| 34 |
+
|
| 35 |
+
return param
|
| 36 |
+
return init
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TrainConfig:
|
| 40 |
+
def __init__(self, dtype):
|
| 41 |
+
self.dtype = dtype
|
| 42 |
+
def kern_init(self, name='default', zero=False):
|
| 43 |
+
if zero or 'bias' in name:
|
| 44 |
+
return nn.initializers.constant(0)
|
| 45 |
+
return xavier_uniform_pytorchlike()
|
| 46 |
+
def default_config(self):
|
| 47 |
+
return {
|
| 48 |
+
'kernel_init': self.kern_init(),
|
| 49 |
+
'bias_init': self.kern_init('bias', zero=True),
|
| 50 |
+
'dtype': self.dtype,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
class TimestepEmbedder(nn.Module):
|
| 54 |
+
"""
|
| 55 |
+
Embeds scalar timesteps into vector representations.
|
| 56 |
+
"""
|
| 57 |
+
hidden_size: int
|
| 58 |
+
tc: TrainConfig
|
| 59 |
+
frequency_embedding_size: int = 256
|
| 60 |
+
|
| 61 |
+
@nn.compact
|
| 62 |
+
def __call__(self, t):
|
| 63 |
+
x = self.timestep_embedding(t)
|
| 64 |
+
x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
|
| 65 |
+
bias_init=self.tc.kern_init('time_bias'), dtype=self.tc.dtype)(x)
|
| 66 |
+
x = nn.silu(x)
|
| 67 |
+
x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
|
| 68 |
+
bias_init=self.tc.kern_init('time_bias'))(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
# t is between [0, 1].
|
| 72 |
+
def timestep_embedding(self, t, max_period=10000):
|
| 73 |
+
"""
|
| 74 |
+
Create sinusoidal timestep embeddings.
|
| 75 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 76 |
+
These may be fractional.
|
| 77 |
+
:param dim: the dimension of the output.
|
| 78 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 79 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 80 |
+
"""
|
| 81 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 82 |
+
t = jax.lax.convert_element_type(t, jnp.float32)
|
| 83 |
+
# t = t * max_period
|
| 84 |
+
dim = self.frequency_embedding_size
|
| 85 |
+
half = dim // 2
|
| 86 |
+
freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half)
|
| 87 |
+
args = t[:, None] * freqs[None]
|
| 88 |
+
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
|
| 89 |
+
embedding = embedding.astype(self.tc.dtype)
|
| 90 |
+
return embedding
|
| 91 |
+
|
| 92 |
+
class LabelEmbedder(nn.Module):
|
| 93 |
+
"""
|
| 94 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 95 |
+
"""
|
| 96 |
+
num_classes: int
|
| 97 |
+
hidden_size: int
|
| 98 |
+
tc: TrainConfig
|
| 99 |
+
|
| 100 |
+
@nn.compact
|
| 101 |
+
def __call__(self, labels):
|
| 102 |
+
embedding_table = nn.Embed(self.num_classes + 1, self.hidden_size,
|
| 103 |
+
embedding_init=nn.initializers.normal(0.02), dtype=self.tc.dtype)
|
| 104 |
+
embeddings = embedding_table(labels)
|
| 105 |
+
return embeddings
|
| 106 |
+
|
| 107 |
+
class PatchEmbed(nn.Module):
|
| 108 |
+
""" 2D Image to Patch Embedding """
|
| 109 |
+
patch_size: int
|
| 110 |
+
hidden_size: int
|
| 111 |
+
tc: TrainConfig
|
| 112 |
+
bias: bool = True
|
| 113 |
+
|
| 114 |
+
@nn.compact
|
| 115 |
+
def __call__(self, x):
|
| 116 |
+
B, H, W, C = x.shape
|
| 117 |
+
patch_tuple = (self.patch_size, self.patch_size)
|
| 118 |
+
num_patches = (H // self.patch_size)
|
| 119 |
+
x = nn.Conv(self.hidden_size, patch_tuple, patch_tuple, use_bias=self.bias, padding="VALID",
|
| 120 |
+
kernel_init=self.tc.kern_init('patch'), bias_init=self.tc.kern_init('patch_bias', zero=True),
|
| 121 |
+
dtype=self.tc.dtype)(x) # (B, P, P, hidden_size)
|
| 122 |
+
x = rearrange(x, 'b h w c -> b (h w) c', h=num_patches, w=num_patches)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
class MlpBlock(nn.Module):
|
| 126 |
+
"""Transformer MLP / feed-forward block."""
|
| 127 |
+
mlp_dim: int
|
| 128 |
+
tc: TrainConfig
|
| 129 |
+
out_dim: Optional[int] = None
|
| 130 |
+
dropout_rate: float = None
|
| 131 |
+
train: bool = False
|
| 132 |
+
|
| 133 |
+
@nn.compact
|
| 134 |
+
def __call__(self, inputs):
|
| 135 |
+
"""It's just an MLP, so the input shape is (batch, len, emb)."""
|
| 136 |
+
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
|
| 137 |
+
x = nn.Dense(features=self.mlp_dim, **self.tc.default_config())(inputs)
|
| 138 |
+
x = nn.gelu(x)
|
| 139 |
+
x = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(x)
|
| 140 |
+
output = nn.Dense(features=actual_out_dim, **self.tc.default_config())(x)
|
| 141 |
+
output = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(output)
|
| 142 |
+
return output
|
| 143 |
+
|
| 144 |
+
def modulate(x, shift, scale):
|
| 145 |
+
# scale = jnp.clip(scale, -1, 1)
|
| 146 |
+
return x * (1 + scale[:, None]) + shift[:, None]
|
| 147 |
+
|
| 148 |
+
################################################################################
|
| 149 |
+
# Core DiT Model #
|
| 150 |
+
#################################################################################
|
| 151 |
+
|
| 152 |
+
class DiTBlock(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 155 |
+
"""
|
| 156 |
+
hidden_size: int
|
| 157 |
+
num_heads: int
|
| 158 |
+
tc: TrainConfig
|
| 159 |
+
mlp_ratio: float = 4.0
|
| 160 |
+
dropout: float = 0.0
|
| 161 |
+
train: bool = False
|
| 162 |
+
|
| 163 |
+
# @functools.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
|
| 164 |
+
@nn.compact
|
| 165 |
+
def __call__(self, x, c):
|
| 166 |
+
# Calculate adaLn modulation parameters.
|
| 167 |
+
c = nn.silu(c)
|
| 168 |
+
c = nn.Dense(6 * self.hidden_size, **self.tc.default_config())(c)
|
| 169 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(c, 6, axis=-1)
|
| 170 |
+
|
| 171 |
+
# Attention Residual.
|
| 172 |
+
x_norm = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
|
| 173 |
+
x_modulated = modulate(x_norm, shift_msa, scale_msa)
|
| 174 |
+
channels_per_head = self.hidden_size // self.num_heads
|
| 175 |
+
k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
|
| 176 |
+
q = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
|
| 177 |
+
v = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
|
| 178 |
+
k = jnp.reshape(k, (k.shape[0], k.shape[1], self.num_heads, channels_per_head))
|
| 179 |
+
q = jnp.reshape(q, (q.shape[0], q.shape[1], self.num_heads, channels_per_head))
|
| 180 |
+
v = jnp.reshape(v, (v.shape[0], v.shape[1], self.num_heads, channels_per_head))
|
| 181 |
+
q = q / q.shape[3] # (1/d) scaling.
|
| 182 |
+
w = jnp.einsum('bqhc,bkhc->bhqk', q, k) # [B, HW, HW, num_heads]
|
| 183 |
+
w = w.astype(jnp.float32)
|
| 184 |
+
w = nn.softmax(w, axis=-1)
|
| 185 |
+
y = jnp.einsum('bhqk,bkhc->bqhc', w, v) # [B, HW, num_heads, channels_per_head]
|
| 186 |
+
y = jnp.reshape(y, x.shape) # [B, H, W, C] (C = heads * channels_per_head)
|
| 187 |
+
attn_x = nn.Dense(self.hidden_size, **self.tc.default_config())(y)
|
| 188 |
+
x = x + (gate_msa[:, None] * attn_x)
|
| 189 |
+
|
| 190 |
+
# MLP Residual.
|
| 191 |
+
x_norm2 = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
|
| 192 |
+
x_modulated2 = modulate(x_norm2, shift_mlp, scale_mlp)
|
| 193 |
+
mlp_x = MlpBlock(mlp_dim=int(self.hidden_size * self.mlp_ratio), tc=self.tc,
|
| 194 |
+
dropout_rate=self.dropout, train=self.train)(x_modulated2)
|
| 195 |
+
x = x + (gate_mlp[:, None] * mlp_x)
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
class FinalLayer(nn.Module):
|
| 199 |
+
"""
|
| 200 |
+
The final layer of DiT.
|
| 201 |
+
"""
|
| 202 |
+
patch_size: int
|
| 203 |
+
out_channels: int
|
| 204 |
+
hidden_size: int
|
| 205 |
+
tc: TrainConfig
|
| 206 |
+
|
| 207 |
+
@nn.compact
|
| 208 |
+
def __call__(self, x, c):
|
| 209 |
+
c = nn.silu(c)
|
| 210 |
+
c = nn.Dense(2 * self.hidden_size, kernel_init=self.tc.kern_init(zero=True),
|
| 211 |
+
bias_init=self.tc.kern_init('bias', zero=True), dtype=self.tc.dtype)(c)
|
| 212 |
+
shift, scale = jnp.split(c, 2, axis=-1)
|
| 213 |
+
x = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
|
| 214 |
+
x = modulate(x, shift, scale)
|
| 215 |
+
x = nn.Dense(self.patch_size * self.patch_size * self.out_channels,
|
| 216 |
+
kernel_init=self.tc.kern_init('final', zero=True),
|
| 217 |
+
bias_init=self.tc.kern_init('final_bias', zero=True), dtype=self.tc.dtype)(x)
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
class DiT(nn.Module):
|
| 221 |
+
"""
|
| 222 |
+
Diffusion model with a Transformer backbone.
|
| 223 |
+
"""
|
| 224 |
+
patch_size: int
|
| 225 |
+
hidden_size: int
|
| 226 |
+
depth: int
|
| 227 |
+
num_heads: int
|
| 228 |
+
mlp_ratio: float
|
| 229 |
+
out_channels: int
|
| 230 |
+
class_dropout_prob: float
|
| 231 |
+
num_classes: int
|
| 232 |
+
ignore_dt: bool = False
|
| 233 |
+
dropout: float = 0.0
|
| 234 |
+
dtype: Dtype = jnp.bfloat16
|
| 235 |
+
init_cfg_scale: float = 1.5
|
| 236 |
+
|
| 237 |
+
@nn.compact
|
| 238 |
+
def __call__(self, x, t, dt, y, train=False, return_activations=False, perturbe = False):
|
| 239 |
+
# (x = (B, H, W, C) image, t = (B,) timesteps, y = (B,) class labels)
|
| 240 |
+
print("DiT: Input of shape", x.shape, "dtype", x.dtype)
|
| 241 |
+
activations = {}
|
| 242 |
+
|
| 243 |
+
#cfg weight, only if learned
|
| 244 |
+
"""cfg_weight = self.param('cfg_weight',
|
| 245 |
+
lambda rng, shape: jnp.ones([1]) * self.init_cfg_scale,
|
| 246 |
+
(1,))
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
batch_size = x.shape[0]
|
| 250 |
+
input_size = x.shape[1]
|
| 251 |
+
in_channels = x.shape[-1]
|
| 252 |
+
num_patches = (input_size // self.patch_size) ** 2
|
| 253 |
+
num_patches_side = input_size // self.patch_size
|
| 254 |
+
tc = TrainConfig(dtype=self.dtype)
|
| 255 |
+
|
| 256 |
+
if self.ignore_dt:
|
| 257 |
+
dt = jnp.zeros_like(t)
|
| 258 |
+
|
| 259 |
+
# pos_embed = self.param("pos_embed", get_2d_sincos_pos_embed, self.hidden_size, num_patches)
|
| 260 |
+
# pos_embed = jax.lax.stop_gradient(pos_embed)
|
| 261 |
+
pos_embed = get_2d_sincos_pos_embed(None, self.hidden_size, num_patches)
|
| 262 |
+
x = PatchEmbed(self.patch_size, self.hidden_size, tc=tc)(x) # (B, num_patches, hidden_size)
|
| 263 |
+
print("DiT: After patch embed, shape is", x.shape, "dtype", x.dtype)
|
| 264 |
+
activations['patch_embed'] = x
|
| 265 |
+
|
| 266 |
+
#Pertube
|
| 267 |
+
#result = jnp.array(jnp.logical_not(perturbe), dtype=int)
|
| 268 |
+
#dt = dt * result#So this was effectively cond + dt 0 instead of 7. FID was like 100.
|
| 269 |
+
|
| 270 |
+
#Let's try modifying the label embedding, adding noise?
|
| 271 |
+
x = x + pos_embed
|
| 272 |
+
x = x.astype(self.dtype)
|
| 273 |
+
te = TimestepEmbedder(self.hidden_size, tc=tc)(t) # (B, hidden_size)
|
| 274 |
+
dte = TimestepEmbedder(self.hidden_size, tc=tc)(dt) # (B, hidden_size)
|
| 275 |
+
ye = LabelEmbedder(self.num_classes, self.hidden_size, tc=tc)(y) # (B, hidden_size)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
result = jnp.array(perturbe)
|
| 279 |
+
#Create noise, multiply noise by perturbe, linear interpolation of ye with noise
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
c = te + ye + dte
|
| 283 |
+
|
| 284 |
+
activations['pos_embed'] = pos_embed
|
| 285 |
+
activations['time_embed'] = te
|
| 286 |
+
activations['dt_embed'] = dte
|
| 287 |
+
activations['label_embed'] = ye
|
| 288 |
+
activations['conditioning'] = c
|
| 289 |
+
|
| 290 |
+
print("DiT: Patch Embed of shape", x.shape, "dtype", x.dtype)
|
| 291 |
+
print("DiT: Conditioning of shape", c.shape, "dtype", c.dtype)
|
| 292 |
+
for i in range(self.depth):
|
| 293 |
+
x = DiTBlock(self.hidden_size, self.num_heads, tc, self.mlp_ratio, self.dropout, train)(x, c)
|
| 294 |
+
activations[f'dit_block_{i}'] = x
|
| 295 |
+
x = FinalLayer(self.patch_size, self.out_channels, self.hidden_size, tc)(x, c) # (B, num_patches, p*p*c)
|
| 296 |
+
activations['final_layer'] = x
|
| 297 |
+
# print("DiT: FinalLayer of shape", x.shape, "dtype", x.dtype)
|
| 298 |
+
x = jnp.reshape(x, (batch_size, num_patches_side, num_patches_side,
|
| 299 |
+
self.patch_size, self.patch_size, self.out_channels))
|
| 300 |
+
x = jnp.einsum('bhwpqc->bhpwqc', x)
|
| 301 |
+
x = rearrange(x, 'B H P W Q C -> B (H P) (W Q) C', H=int(num_patches_side), W=int(num_patches_side))
|
| 302 |
+
assert x.shape == (batch_size, input_size, input_size, self.out_channels)
|
| 303 |
+
|
| 304 |
+
t_discrete = jnp.floor(t * 256).astype(jnp.int32)
|
| 305 |
+
logvars = nn.Embed(256, 1, embedding_init=nn.initializers.constant(0))(t_discrete) * 100
|
| 306 |
+
|
| 307 |
+
if return_activations:
|
| 308 |
+
return x, logvars, activations
|
| 309 |
+
return x
|
learned_cfg/targets_shortcut.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def get_targets(FLAGS, key, train_state, images, labels, force_t=-1, force_dt=-1, cfg_scale = None):
|
| 6 |
+
label_key, time_key, noise_key = jax.random.split(key, 3)
|
| 7 |
+
info = {}
|
| 8 |
+
|
| 9 |
+
# 1) =========== Sample dt. ============
|
| 10 |
+
bootstrap_batchsize = FLAGS.batch_size // FLAGS.model['bootstrap_every']
|
| 11 |
+
log2_sections = np.log2(FLAGS.model['denoise_timesteps']).astype(np.int32)
|
| 12 |
+
if FLAGS.model['bootstrap_dt_bias'] == 0:
|
| 13 |
+
dt_base = jnp.repeat(log2_sections - 1 - jnp.arange(log2_sections), bootstrap_batchsize // log2_sections)
|
| 14 |
+
dt_base = jnp.concatenate([dt_base, jnp.zeros(bootstrap_batchsize-dt_base.shape[0],)])
|
| 15 |
+
num_dt_cfg = bootstrap_batchsize // log2_sections
|
| 16 |
+
else:
|
| 17 |
+
dt_base = jnp.repeat(log2_sections - 1 - jnp.arange(log2_sections-2), (bootstrap_batchsize // 2) // log2_sections)
|
| 18 |
+
dt_base = jnp.concatenate([dt_base, jnp.ones(bootstrap_batchsize // 4), jnp.zeros(bootstrap_batchsize // 4)])
|
| 19 |
+
dt_base = jnp.concatenate([dt_base, jnp.zeros(bootstrap_batchsize-dt_base.shape[0],)])
|
| 20 |
+
num_dt_cfg = (bootstrap_batchsize // 2) // log2_sections
|
| 21 |
+
force_dt_vec = jnp.ones(bootstrap_batchsize, dtype=jnp.float32) * force_dt
|
| 22 |
+
dt_base = jnp.where(force_dt_vec != -1, force_dt_vec, dt_base)
|
| 23 |
+
dt = 1 / (2 ** (dt_base)) # [1, 1/2, 1/4, 1/8, 1/16, 1/32]
|
| 24 |
+
dt_base_bootstrap = dt_base + 1
|
| 25 |
+
dt_bootstrap = dt / 2
|
| 26 |
+
|
| 27 |
+
# 2) =========== Sample t. ============
|
| 28 |
+
dt_sections = jnp.power(2, dt_base) # [1, 2, 4, 8, 16, 32]
|
| 29 |
+
t = jax.random.randint(time_key, (bootstrap_batchsize,), minval=0, maxval=dt_sections).astype(jnp.float32)
|
| 30 |
+
t = t / dt_sections # Between 0 and 1.
|
| 31 |
+
force_t_vec = jnp.ones(bootstrap_batchsize, dtype=jnp.float32) * force_t
|
| 32 |
+
t = jnp.where(force_t_vec != -1, force_t_vec, t)
|
| 33 |
+
t_full = t[:, None, None, None]
|
| 34 |
+
|
| 35 |
+
# 3) =========== Generate Bootstrap Targets ============
|
| 36 |
+
x_1 = images[:bootstrap_batchsize]
|
| 37 |
+
x_0 = jax.random.normal(noise_key, x_1.shape)
|
| 38 |
+
x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
|
| 39 |
+
bst_labels = labels[:bootstrap_batchsize]
|
| 40 |
+
call_model_fn = train_state.call_model if FLAGS.model['bootstrap_ema'] == 0 else train_state.call_model_ema
|
| 41 |
+
if not FLAGS.model['bootstrap_cfg']:
|
| 42 |
+
v_b1 = call_model_fn(x_t, t, dt_base_bootstrap, bst_labels, train=False)
|
| 43 |
+
t2 = t + dt_bootstrap
|
| 44 |
+
x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
|
| 45 |
+
x_t2 = jnp.clip(x_t2, -4, 4)
|
| 46 |
+
v_b2 = call_model_fn(x_t2, t2, dt_base_bootstrap, bst_labels, train=False)
|
| 47 |
+
v_target = (v_b1 + v_b2) / 2
|
| 48 |
+
else:
|
| 49 |
+
x_t_extra = jnp.concatenate([x_t, x_t[:num_dt_cfg]], axis=0)
|
| 50 |
+
t_extra = jnp.concatenate([t, t[:num_dt_cfg]], axis=0)
|
| 51 |
+
dt_base_extra = jnp.concatenate([dt_base_bootstrap, dt_base_bootstrap[:num_dt_cfg]], axis=0)
|
| 52 |
+
labels_extra = jnp.concatenate([bst_labels, jnp.ones(num_dt_cfg, dtype=jnp.int32) * FLAGS.model['num_classes']], axis=0)
|
| 53 |
+
v_b1_raw = call_model_fn(x_t_extra, t_extra, dt_base_extra, labels_extra, train=False)
|
| 54 |
+
v_b_cond = v_b1_raw[:x_1.shape[0]]
|
| 55 |
+
v_b_uncond = v_b1_raw[x_1.shape[0]:]
|
| 56 |
+
v_cfg = v_b_uncond + cfg_scale * (v_b_cond[:num_dt_cfg] - v_b_uncond)#CFG scale is now a learned parameter
|
| 57 |
+
v_b1 = jnp.concatenate([v_cfg, v_b_cond[num_dt_cfg:]], axis=0)
|
| 58 |
+
|
| 59 |
+
t2 = t + dt_bootstrap
|
| 60 |
+
x_t2 = x_t + dt_bootstrap[:, None, None, None] * v_b1
|
| 61 |
+
x_t2 = jnp.clip(x_t2, -4, 4)
|
| 62 |
+
x_t2_extra = jnp.concatenate([x_t2, x_t2[:num_dt_cfg]], axis=0)
|
| 63 |
+
t2_extra = jnp.concatenate([t2, t2[:num_dt_cfg]], axis=0)
|
| 64 |
+
v_b2_raw = call_model_fn(x_t2_extra, t2_extra, dt_base_extra, labels_extra, train=False)
|
| 65 |
+
v_b2_cond = v_b2_raw[:x_1.shape[0]]
|
| 66 |
+
v_b2_uncond = v_b2_raw[x_1.shape[0]:]
|
| 67 |
+
if False:#Not doing learned cfg scale right now
|
| 68 |
+
pass
|
| 69 |
+
v_b2_cfg = v_b2_uncond + cfg_scale * (v_b2_cond[:num_dt_cfg] - v_b2_uncond)#cfg scale is once again a learned p
|
| 70 |
+
|
| 71 |
+
v_b2 = jnp.concatenate([v_b2_cfg, v_b2_cond[num_dt_cfg:]], axis=0)
|
| 72 |
+
v_target = (v_b1 + v_b2) / 2
|
| 73 |
+
|
| 74 |
+
v_target = jnp.clip(v_target, -4, 4)
|
| 75 |
+
bst_v = v_target
|
| 76 |
+
bst_dt = dt_base
|
| 77 |
+
bst_t = t
|
| 78 |
+
bst_xt = x_t
|
| 79 |
+
bst_l = bst_labels
|
| 80 |
+
|
| 81 |
+
# 4) =========== Generate Flow-Matching Targets ============
|
| 82 |
+
|
| 83 |
+
labels_dropout = jax.random.bernoulli(label_key, FLAGS.model['class_dropout_prob'], (labels.shape[0],))
|
| 84 |
+
labels_dropped = jnp.where(labels_dropout, FLAGS.model['num_classes'], labels)
|
| 85 |
+
info['dropped_ratio'] = jnp.mean(labels_dropped == FLAGS.model['num_classes'])
|
| 86 |
+
|
| 87 |
+
# Sample t.
|
| 88 |
+
t = jax.random.randint(time_key, (images.shape[0],), minval=0, maxval=FLAGS.model['denoise_timesteps']).astype(jnp.float32)
|
| 89 |
+
t /= FLAGS.model['denoise_timesteps']
|
| 90 |
+
|
| 91 |
+
do_logit = True
|
| 92 |
+
if do_logit:
|
| 93 |
+
#Despite the fact that this actually violates our normal flow timesteps, whatever.
|
| 94 |
+
t = jax.random.normal(time_key, (images.shape[0],)).astype(jnp.float32)
|
| 95 |
+
|
| 96 |
+
t = 1/ (1 + jnp.exp(-t))
|
| 97 |
+
t = jnp.round(t * FLAGS.model["denoise_timesteps"])/FLAGS.model["denoise_timesteps"]
|
| 98 |
+
|
| 99 |
+
force_t_vec = jnp.ones(images.shape[0], dtype=jnp.float32) * force_t
|
| 100 |
+
t = jnp.where(force_t_vec != -1, force_t_vec, t) # If force_t is not -1, then use force_t.
|
| 101 |
+
t_full = t[:, None, None, None] # [batch, 1, 1, 1]
|
| 102 |
+
|
| 103 |
+
# Sample flow pairs x_t, v_t.
|
| 104 |
+
x_0 = jax.random.normal(noise_key, images.shape)
|
| 105 |
+
x_1 = images
|
| 106 |
+
x_t = x_t = (1 - (1 - 1e-5) * t_full) * x_0 + t_full * x_1
|
| 107 |
+
v_t = v_t = x_1 - (1 - 1e-5) * x_0
|
| 108 |
+
dt_flow = np.log2(FLAGS.model['denoise_timesteps']).astype(jnp.int32)
|
| 109 |
+
dt_base = jnp.ones(images.shape[0], dtype=jnp.int32) * dt_flow
|
| 110 |
+
|
| 111 |
+
# ==== 5) Merge Flow+Bootstrap ====
|
| 112 |
+
bst_size = FLAGS.batch_size // FLAGS.model['bootstrap_every']
|
| 113 |
+
bst_size_data = FLAGS.batch_size - bst_size
|
| 114 |
+
x_t = jnp.concatenate([bst_xt, x_t[:bst_size_data]], axis=0)
|
| 115 |
+
t = jnp.concatenate([bst_t, t[:bst_size_data]], axis=0)
|
| 116 |
+
dt_base = jnp.concatenate([bst_dt, dt_base[:bst_size_data]], axis=0)
|
| 117 |
+
v_t = jnp.concatenate([bst_v, v_t[:bst_size_data]], axis=0)
|
| 118 |
+
labels_dropped = jnp.concatenate([bst_l, labels_dropped[:bst_size_data]], axis=0)
|
| 119 |
+
info['bootstrap_ratio'] = jnp.mean(dt_base != dt_flow)
|
| 120 |
+
|
| 121 |
+
info['v_magnitude_bootstrap'] = jnp.sqrt(jnp.mean(jnp.square(bst_v)))
|
| 122 |
+
info['v_magnitude_b1'] = jnp.sqrt(jnp.mean(jnp.square(v_b1)))
|
| 123 |
+
info['v_magnitude_b2'] = jnp.sqrt(jnp.mean(jnp.square(v_b2)))
|
| 124 |
+
|
| 125 |
+
return x_t, v_t, t, dt_base, labels_dropped, info
|