Shortcuts / sharpness /model.py
KublaiKhan1's picture
Upload folder using huggingface_hub
464344f verified
raw
history blame
17.3 kB
import math
from typing import Any, Callable, Optional, Tuple, Type, Sequence, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
Array = Any
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
from math_utils import get_2d_sincos_pos_embed, modulate
from jax._src import core
from jax._src import dtypes
from jax._src.nn.initializers import _compute_fans
def xavier_uniform_pytorchlike():
def init(key, shape, dtype):
dtype = dtypes.canonicalize_dtype(dtype)
#named_shape = core.as_named_shape(shape)
if len(shape) == 2: # Dense, [in, out]
fan_in = shape[0]
fan_out = shape[1]
elif len(shape) == 4: # Conv, [k, k, in, out]. Assumes patch-embed style conv.
fan_in = shape[0] * shape[1] * shape[2]
fan_out = shape[3]
else:
raise ValueError(f"Invalid shape {shape}")
variance = 2 / (fan_in + fan_out)
scale = jnp.sqrt(3 * variance)
param = jax.random.uniform(key, shape, dtype, -1) * scale
return param
return init
class TrainConfig:
def __init__(self, dtype):
self.dtype = dtype
def kern_init(self, name='default', zero=False):
if zero or 'bias' in name:
return nn.initializers.constant(0)
return xavier_uniform_pytorchlike()
def default_config(self):
return {
'kernel_init': self.kern_init(),
'bias_init': self.kern_init('bias', zero=True),
'dtype': self.dtype,
}
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
hidden_size: int
tc: TrainConfig
frequency_embedding_size: int = 256
@nn.compact
def __call__(self, t):
x = self.timestep_embedding(t)
x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
bias_init=self.tc.kern_init('time_bias'), dtype=self.tc.dtype)(x)
x = nn.silu(x)
x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
bias_init=self.tc.kern_init('time_bias'))(x)
return x
# t is between [0, 1].
def timestep_embedding(self, t, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
t = jax.lax.convert_element_type(t, jnp.float32)
# t = t * max_period
dim = self.frequency_embedding_size
half = dim // 2
freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half)
args = t[:, None] * freqs[None]
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
embedding = embedding.astype(self.tc.dtype)
return embedding
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
num_classes: int
hidden_size: int
tc: TrainConfig
@nn.compact
def __call__(self, labels):
embedding_table = nn.Embed(self.num_classes + 1, self.hidden_size,
embedding_init=nn.initializers.normal(0.02), dtype=self.tc.dtype)
embeddings = embedding_table(labels)
return embeddings
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """
patch_size: int
hidden_size: int
tc: TrainConfig
bias: bool = True
@nn.compact
def __call__(self, x):
B, H, W, C = x.shape
patch_tuple = (self.patch_size, self.patch_size)
num_patches = (H // self.patch_size)
x = nn.Conv(self.hidden_size, patch_tuple, patch_tuple, use_bias=self.bias, padding="VALID",
kernel_init=self.tc.kern_init('patch'), bias_init=self.tc.kern_init('patch_bias', zero=True),
dtype=self.tc.dtype)(x) # (B, P, P, hidden_size)
x = rearrange(x, 'b h w c -> b (h w) c', h=num_patches, w=num_patches)
return x
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
tc: TrainConfig
out_dim: Optional[int] = None
dropout_rate: float = None
train: bool = False
@nn.compact
def __call__(self, inputs):
"""It's just an MLP, so the input shape is (batch, len, emb)."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(features=self.mlp_dim, **self.tc.default_config())(inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(x)
output = nn.Dense(features=actual_out_dim, **self.tc.default_config())(x)
output = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(output)
return output
def modulate(x, shift, scale):
# scale = jnp.clip(scale, -1, 1)
return x * (1 + scale[:, None]) + shift[:, None]
################################################################################
# Core DiT Model #
#################################################################################
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
hidden_size: int
num_heads: int
tc: TrainConfig
mlp_ratio: float = 4.0
dropout: float = 0.0
train: bool = False
# @functools.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
@nn.compact
def __call__(self, x, c):
# Calculate adaLn modulation parameters.
c = nn.silu(c)
c = nn.Dense(6 * self.hidden_size, **self.tc.default_config())(c)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(c, 6, axis=-1)
# Attention Residual.
x_norm = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
x_modulated = modulate(x_norm, shift_msa, scale_msa)
channels_per_head = self.hidden_size // self.num_heads
k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
q = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
v = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
k = jnp.reshape(k, (k.shape[0], k.shape[1], self.num_heads, channels_per_head))
q = jnp.reshape(q, (q.shape[0], q.shape[1], self.num_heads, channels_per_head))
v = jnp.reshape(v, (v.shape[0], v.shape[1], self.num_heads, channels_per_head))
q = q / q.shape[3] # (1/d) scaling.
w = jnp.einsum('bqhc,bkhc->bhqk', q, k) # [B, HW, HW, num_heads]
w = w.astype(jnp.float32)
w = nn.softmax(w, axis=-1)
y = jnp.einsum('bhqk,bkhc->bqhc', w, v) # [B, HW, num_heads, channels_per_head]
y = jnp.reshape(y, x.shape) # [B, H, W, C] (C = heads * channels_per_head)
attn_x = nn.Dense(self.hidden_size, **self.tc.default_config())(y)
x = x + (gate_msa[:, None] * attn_x)
# MLP Residual.
x_norm2 = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
x_modulated2 = modulate(x_norm2, shift_mlp, scale_mlp)
mlp_x = MlpBlock(mlp_dim=int(self.hidden_size * self.mlp_ratio), tc=self.tc,
dropout_rate=self.dropout, train=self.train)(x_modulated2)
x = x + (gate_mlp[:, None] * mlp_x)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
patch_size: int
out_channels: int
hidden_size: int
tc: TrainConfig
@nn.compact
def __call__(self, x, c):
c = nn.silu(c)
c = nn.Dense(2 * self.hidden_size, kernel_init=self.tc.kern_init(zero=True),
bias_init=self.tc.kern_init('bias', zero=True), dtype=self.tc.dtype)(c)
shift, scale = jnp.split(c, 2, axis=-1)
x = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
x = modulate(x, shift, scale)
x = nn.Dense(self.patch_size * self.patch_size * self.out_channels,
kernel_init=self.tc.kern_init('final', zero=True),
bias_init=self.tc.kern_init('final_bias', zero=True), dtype=self.tc.dtype)(x)
return x
import jax
import jax.numpy as jnp
def apply_label_embedding_noise(key, label_embeddings):
"""
Applies Gaussian noise to label embeddings based on specified probabilities.
Args:
key: A JAX random key.
label_embeddings: A JAX array of shape (batch_size, embedding_dim),
representing the label embeddings.
Returns:
A tuple containing:
- noisy_label_embeddings: The label embeddings with noise applied.
- noise_levels: A JAX array of shape (batch_size,), indicating
the alpha value used for each sample (1.0 for no noise,
0.0 for 100% noise, or a uniform sample for partial noise).
"""
batch_size, embedding_dim = label_embeddings.shape
# Split key for different random operations
key, noise_type_key, alpha_key, normal_key = jax.random.split(key, 4)
# Determine noise application type for each sample
# 0: 100% noise (alpha = 0)
# 1: Partial noise (alpha uniformly 0-1)
# 2: No noise (do nothing)
noise_type_choices = jax.random.choice(
noise_type_key,
a=jnp.array([0, 1, 2]),
shape=(batch_size,),
p=jnp.array([0.00, 0.10, 0.90])
)
# Initialize noise_levels to 1.0 (no noise)
noise_levels = jnp.ones(batch_size, dtype=label_embeddings.dtype)
# Generate alpha values for partial noise
sampled_alphas = jax.random.uniform(alpha_key, shape=(batch_size,), minval=0.0, maxval=1.0)
# Generate Gaussian noise for the entire batch
# We assume a standard deviation of 1 for the noise, you might want to adjust this.
gaussian_noise = jax.random.normal(normal_key, shape=label_embeddings.shape)
# Initialize noisy_label_embeddings
noisy_label_embeddings = label_embeddings
# Apply 100% noise
cond_100_percent_noise = (noise_type_choices == 0)
noisy_label_embeddings = jnp.where(
cond_100_percent_noise[:, None], # Expand dim for broadcasting
gaussian_noise,
noisy_label_embeddings
)
noise_levels = jnp.where(cond_100_percent_noise, 0.0, noise_levels)
# Apply partial noise
cond_partial_noise = (noise_type_choices == 1)
# Reshape sampled_alphas for broadcasting
alpha_reshaped = sampled_alphas[:, None]
noisy_label_embeddings = jnp.where(
cond_partial_noise[:, None],
label_embeddings * alpha_reshaped + gaussian_noise * (1.0 - alpha_reshaped),
noisy_label_embeddings
)
noise_levels = jnp.where(cond_partial_noise, sampled_alphas, noise_levels)
# For cond_no_noise (noise_type_choices == 2), noisy_label_embeddings remains
# label_embeddings and noise_levels remains 1.0, so no specific action needed.
return noisy_label_embeddings, noise_levels, key
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
patch_size: int
hidden_size: int
depth: int
num_heads: int
mlp_ratio: float
out_channels: int
class_dropout_prob: float
num_classes: int
ignore_dt: bool = False
dropout: float = 0.0
dtype: Dtype = jnp.bfloat16
@nn.compact
def __call__(self, x, t, dt, y, train=False, return_activations=False, perturbe = True):
# (x = (B, H, W, C) image, t = (B,) timesteps, y = (B,) class labels)
print("DiT: Input of shape", x.shape, "dtype", x.dtype)
activations = {}
key = self.make_rng("label")
batch_size = x.shape[0]
input_size = x.shape[1]
in_channels = x.shape[-1]
num_patches = (input_size // self.patch_size) ** 2
num_patches_side = input_size // self.patch_size
tc = TrainConfig(dtype=self.dtype)
if self.ignore_dt:
dt = jnp.zeros_like(t)
# pos_embed = self.param("pos_embed", get_2d_sincos_pos_embed, self.hidden_size, num_patches)
# pos_embed = jax.lax.stop_gradient(pos_embed)
pos_embed = get_2d_sincos_pos_embed(None, self.hidden_size, num_patches)
x = PatchEmbed(self.patch_size, self.hidden_size, tc=tc)(x) # (B, num_patches, hidden_size)
print("DiT: After patch embed, shape is", x.shape, "dtype", x.dtype)
activations['patch_embed'] = x
x = x + pos_embed
x = x.astype(self.dtype)
te = TimestepEmbedder(self.hidden_size, tc=tc)(t) # (B, hidden_size)
dte = TimestepEmbedder(self.hidden_size, tc=tc)(dt) # (B, hidden_size)
ye = LabelEmbedder(self.num_classes, self.hidden_size, tc=tc)(y) # (B, hidden_size)
# ye_g = TimestepEmbedder(self.hidden_size,tc=tc)
#CFG free, here!
#So we set CFG % to 0 during training
#Instead, we will apply gaussian noise to the label embeddings, and condition... somewhere, on that.
#So the perturbed version uses cfg between conditional and conditional, except the second one uses condition_amount = ones
#So we use condition_amount = zeros, then condition_amount = ones.
#Not sure how we indicate training mode. Maybe -1?
#x = int(x == 'true')
#Now we need a way to condition the forward pass..
def adjust_condition_amount(train, peturbe, condition_amount):
def true_fn(_):
return jnp.ones_like(condition_amount) # peturbe is True → ones
def false_fn(_):
return jnp.zeros_like(condition_amount) # peturbe is False → zeros
def train_false_branch(_):
return jax.lax.cond(peturbe, true_fn, false_fn, operand=None)
def train_true_branch(_):
return condition_amount # leave it unchanged during training
return jax.lax.cond(train, train_true_branch, train_false_branch, operand=None)
#When perturbe is true, we return ones = no noise
#When false, return zeros = full noise.
#For NON training, we don't want to actually modify the labels, only the conditioning.
#So default during training is apply
def apply_fn(key, ye, train):
def true_branch(args):
key, ye = args
ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye)
return ye_new.astype(jnp.float32), condition_amount, key_new
def false_branch(args):
key, ye = args
ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye)
return ye.astype(jnp.float32), condition_amount, key_new
return jax.lax.cond(train, true_branch, false_branch, (key, ye))
print("train is", train)#False
print("perturbe is", perturbe)#False right now (it's getting passed properly)
print("initial ye", ye[0][0:10])
ye, condition_amount, key = apply_fn(key, ye, train)
print("new ye", ye[0][0:10])
print("condition amount", condition_amount)
condition_amount = adjust_condition_amount(train, perturbe, condition_amount)
print("adjusted", condition_amount)
ye_g = TimestepEmbedder(self.hidden_size, tc=tc)(condition_amount)
c = te + ye + dte + ye_g
activations['pos_embed'] = pos_embed
activations['time_embed'] = te
activations['dt_embed'] = dte
activations['label_embed'] = ye
activations['conditioning'] = c
print("DiT: Patch Embed of shape", x.shape, "dtype", x.dtype)
print("DiT: Conditioning of shape", c.shape, "dtype", c.dtype)
for i in range(self.depth):
x = DiTBlock(self.hidden_size, self.num_heads, tc, self.mlp_ratio, self.dropout, train)(x, c)
activations[f'dit_block_{i}'] = x
x = FinalLayer(self.patch_size, self.out_channels, self.hidden_size, tc)(x, c) # (B, num_patches, p*p*c)
activations['final_layer'] = x
# print("DiT: FinalLayer of shape", x.shape, "dtype", x.dtype)
x = jnp.reshape(x, (batch_size, num_patches_side, num_patches_side,
self.patch_size, self.patch_size, self.out_channels))
x = jnp.einsum('bhwpqc->bhpwqc', x)
x = rearrange(x, 'B H P W Q C -> B (H P) (W Q) C', H=int(num_patches_side), W=int(num_patches_side))
assert x.shape == (batch_size, input_size, input_size, self.out_channels)
t_discrete = jnp.floor(t * 256).astype(jnp.int32)
logvars = nn.Embed(256, 1, embedding_init=nn.initializers.constant(0))(t_discrete) * 100
if return_activations:
return x, logvars, activations
return x#, dte, te