baguette / src /models.py
nbagel's picture
Initial upload: Paris MoE inference code and weights
4dec1ca verified
# src/models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import UNet2DModel
from transformers import ViTForImageClassification, ViTConfig
import math
from typing import Optional, List
import numpy as np
# =============================================================================
# TIME EMBEDDING (shared utility)
# =============================================================================
class TimeEmbedding(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
device = t.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = t[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class DiTTimestepEmbedder(nn.Module):
def __init__(self, hidden_size, freq_dim=128, max_period=10000):
super().__init__()
self.freq_dim = freq_dim
self.max_period = max_period
self.mlp = nn.Sequential(
nn.Linear(2*freq_dim, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t): # t: [B] integers (float tensor ok)
# standard "timestep_embedding" (like ADM/DiT)
half = self.freq_dim
device = t.device
# positions in radians
freqs = torch.exp(
-torch.arange(half, device=device).float() * np.log(self.max_period) / half
)
args = t.float()[:, None] * freqs[None] # [B, half]
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, 2*half]
return self.mlp(emb)
# =============================================================================
# OUTPUT CONVERTER (for heterogeneous objectives)
# =============================================================================
class OutputConverter(nn.Module):
def __init__(self, schedule_type: str = 'linear_interp', use_latents: bool = False, derivative_eps: float = 1e-4):
super().__init__()
from schedules import NoiseSchedule
self.schedule = NoiseSchedule(schedule_type)
self.schedule_type = schedule_type
self.use_latents = use_latents
self.derivative_eps = derivative_eps # For finite difference derivatives
# Set clamping range based on data type
# VAE latents have larger range than pixel-space images
self.clamp_range = 20.0 if use_latents else 5.0
def _get_schedule_with_derivatives(self, t: torch.Tensor):
"""
Compute schedule coefficients and their derivatives.
Essential for correct velocity computation with any schedule.
"""
# Get coefficients at current time
alpha_t, sigma_t = self.schedule.get_schedule(t)
# Compute derivatives using finite differences
h = torch.full_like(t, self.derivative_eps)
t_plus = (t + h).clamp(0.0, 1.0)
t_minus = (t - h).clamp(0.0, 1.0)
alpha_plus, sigma_plus = self.schedule.get_schedule(t_plus)
alpha_minus, sigma_minus = self.schedule.get_schedule(t_minus)
# Derivatives
dt = (t_plus - t_minus).clamp(min=1e-6)
d_alpha_dt = (alpha_plus - alpha_minus) / dt
d_sigma_dt = (sigma_plus - sigma_minus) / dt
return alpha_t, sigma_t, d_alpha_dt, d_sigma_dt
def epsilon_to_velocity(self, epsilon_pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Correct ε→v conversion for ANY schedule using proper derivatives.
From ODE: dx_t/dt = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε
This is the TRUE velocity for the schedule!
"""
# Get schedule coefficients AND their derivatives
alpha_t, sigma_t, d_alpha_dt, d_sigma_dt = self._get_schedule_with_derivatives(t)
# Reshape for broadcasting
alpha_t = alpha_t.view(-1, 1, 1, 1)
sigma_t = sigma_t.view(-1, 1, 1, 1)
d_alpha_dt = d_alpha_dt.view(-1, 1, 1, 1)
d_sigma_dt = d_sigma_dt.view(-1, 1, 1, 1)
# Numerical stability: handle small alpha_t
alpha_safe = torch.clamp(alpha_t, min=0.01)
# Step 1: Recover x_0 using Tweedie's formula
x_0_pred = (x_t - sigma_t * epsilon_pred) / alpha_safe
# Step 2: Clamp x_0 to reasonable range (prevents blow-up)
# Use adaptive clamping: larger range for VAE latents, tighter for pixel space
x_0_pred = torch.clamp(x_0_pred, -self.clamp_range, self.clamp_range)
# Step 3: Compute velocity based on schedule type
if self.schedule_type == 'linear_interp':
# For linear interpolation: x_t = (1-t)*x_0 + t*ε
# Velocity is simply: v = ε - x_0
v = epsilon_pred - x_0_pred
else:
# For cosine and other schedules: use proper derivatives
# v = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε
v = d_alpha_dt * x_0_pred + d_sigma_dt * epsilon_pred
# Adaptive velocity scaling for cosine schedule
# Derivatives vary dramatically with timestep - need adaptive dampening
if self.schedule_type == 'cosine':
t_val = t[0].item() if t.numel() > 0 else 0.5
if t_val > 0.85:
# Very high noise: derivatives are large, need dampening
scale = 0.88
elif t_val > 0.6:
# Medium-high noise: moderate dampening
scale = 0.93
else:
# Low to medium noise: slight dampening
scale = 0.96
v = v * scale
# Per-channel bias correction to prevent color drift
# The model has inherent channel bias that gets amplified by integration
# Remove per-channel mean to prevent accumulation
# Only apply to color channels (1,2,3), preserve luminance channel (0)
for c in range(1, 4):
v[:, c] = v[:, c] - v[:, c].mean()
return v
def convert(self, prediction: torch.Tensor, objective_type: str, x_t: torch.Tensor, t: torch.Tensor):
"""
Convert any prediction to velocity space.
Args:
prediction: expert output
objective_type: 'ddpm' | 'fm' | 'rf'
x_t: current noisy state
t: current timesteps
Returns:
v: velocity representation
"""
if objective_type == "ddpm":
# Proper ε→v conversion for unified integration
return self.epsilon_to_velocity(prediction, x_t, t)
elif objective_type in ["fm", "rf"]:
return prediction # Already velocity
else:
raise ValueError(f"Unknown objective type: {objective_type}")
# =============================================================================
# EXPERT MODELS
# =============================================================================
class UNetExpert(nn.Module):
"""UNet expert using diffusers"""
def __init__(self, config) -> None:
super().__init__()
# Default UNet params
default_params = {
"sample_size": config.image_size,
"in_channels": config.num_channels,
"out_channels": config.num_channels,
"layers_per_block": 2,
"block_out_channels": [64, 128, 256, 256],
"attention_head_dim": 8,
}
# Override with config params
params = {**default_params, **config.expert_params}
# Store objective type for heterogeneous training (and remove from params)
self.objective_type = params.pop("objective_type", "fm")
# Store and initialize schedule (NEW)
schedule_type = params.pop("schedule_type", "linear_interp")
from schedules import NoiseSchedule
self.schedule = NoiseSchedule(schedule_type)
self.unet = UNet2DModel(**params)
def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
# Scale timesteps for diffusers (expects 0-1000)
# t_scaled = (t * 1000).long()
t_scaled = (t * 999).round().long().clamp(0, 999)
return self.unet(xt, t_scaled).sample
def compute_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""Unified loss computation based on objective type"""
if self.objective_type == "ddpm":
return self.ddpm_loss(x0)
elif self.objective_type == "fm":
return self.flow_matching_loss(x0)
elif self.objective_type == "rf":
return self.rectified_flow_loss(x0)
else:
raise ValueError(f"Unknown objective type: {self.objective_type}")
def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""DDPM: predict noise ε"""
batch_size = x0.shape[0]
device = x0.device
t = torch.rand(batch_size, device=device)
# Use proper schedule (NEW)
alpha_t, sigma_t = self.schedule.get_schedule(t)
noise = torch.randn_like(x0)
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
pred_eps = self.forward(xt, t)
return F.mse_loss(pred_eps, noise)
def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""Rectified Flow: predict velocity v = x_1 - x_0"""
batch_size = x0.shape[0]
device = x0.device
t = torch.rand(batch_size, device=device)
x1 = torch.randn_like(x0)
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
pred_v = self.forward(xt, t)
true_v = x1 - x0
return F.mse_loss(pred_v, true_v)
def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""Flow matching loss for training"""
batch_size = x0.shape[0]
device = x0.device
# Sample random timesteps
t = torch.rand(batch_size, device=device)
# Use proper schedule (NEW)
alpha_t, sigma_t = self.schedule.get_schedule(t)
# Add noise
noise = torch.randn_like(x0)
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
# Predict velocity
pred_v = self.forward(xt, t)
# True velocity for flow matching
# true_v = x0 - xt
true_v = noise - x0
return F.mse_loss(pred_v, true_v)
class SimpleCNNExpert(nn.Module):
"""Simple CNN expert for fast training"""
def __init__(self, config) -> None:
super().__init__()
# Default params
default_params = {
"hidden_dims": [64, 128, 256],
"time_dim": 64,
}
params = {**default_params, **config.expert_params}
# Store objective type for heterogeneous training
self.objective_type = params.get("objective_type", "fm")
# Store and initialize schedule (NEW)
schedule_type = params.get("schedule_type", "linear_interp")
from schedules import NoiseSchedule
self.schedule = NoiseSchedule(schedule_type)
self.time_embedding = TimeEmbedding(params["time_dim"])
self.target_size = config.image_size
# Simple encoder-decoder
self.encoder = self._build_encoder(config.num_channels, params["hidden_dims"])
self.decoder = self._build_decoder(params["hidden_dims"], config.num_channels)
# Time conditioning
self.time_mlp = nn.Sequential(
nn.Linear(params["time_dim"], params["hidden_dims"][-1]),
nn.SiLU(),
nn.Linear(params["hidden_dims"][-1], params["hidden_dims"][-1])
)
def _build_encoder(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential:
layers = []
prev_dim = in_channels
for dim in hidden_dims:
layers.extend([
nn.Conv2d(prev_dim, dim, 3, padding=1),
nn.GroupNorm(8, dim),
nn.SiLU(),
nn.Conv2d(dim, dim, 3, padding=1),
nn.GroupNorm(8, dim),
nn.SiLU(),
nn.MaxPool2d(2)
])
prev_dim = dim
return nn.Sequential(*layers)
def _build_decoder(self, hidden_dims: List[int], out_channels: int) -> nn.Sequential:
layers = []
reversed_dims = list(reversed(hidden_dims))
for i, dim in enumerate(reversed_dims[:-1]):
next_dim = reversed_dims[i + 1]
layers.extend([
nn.ConvTranspose2d(dim, next_dim, 4, stride=2, padding=1),
nn.GroupNorm(8, next_dim),
nn.SiLU(),
nn.Conv2d(next_dim, next_dim, 3, padding=1),
nn.GroupNorm(8, next_dim),
nn.SiLU(),
])
# Final layer
layers.append(nn.Conv2d(reversed_dims[-1], out_channels, 3, padding=1))
return nn.Sequential(*layers)
def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
# Time embedding
time_emb = self.time_embedding(t)
time_features = self.time_mlp(time_emb)
# Encode
encoded = self.encoder(xt)
# Add time conditioning
time_features = time_features.view(time_features.shape[0], -1, 1, 1)
time_features = time_features.expand(-1, -1, encoded.shape[2], encoded.shape[3])
conditioned = encoded + time_features
# Decode
output = self.decoder(conditioned)
# Ensure output matches target size
output = F.interpolate(output, size=xt.shape[-2:], mode='bilinear', align_corners=False)
return output
def compute_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""Unified loss computation based on objective type"""
if self.objective_type == "ddpm":
return self.ddpm_loss(x0)
elif self.objective_type == "fm":
return self.flow_matching_loss(x0)
elif self.objective_type == "rf":
return self.rectified_flow_loss(x0)
else:
raise ValueError(f"Unknown objective type: {self.objective_type}")
def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""DDPM: predict noise ε"""
batch_size = x0.shape[0]
device = x0.device
t = torch.rand(batch_size, device=device)
# Use proper schedule (NEW)
alpha_t, sigma_t = self.schedule.get_schedule(t)
noise = torch.randn_like(x0)
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
pred_eps = self.forward(xt, t)
# Ensure pred_eps matches noise shape
if pred_eps.shape != noise.shape:
pred_eps = F.interpolate(pred_eps, size=noise.shape[-2:], mode='bilinear', align_corners=False)
return F.mse_loss(pred_eps, noise)
def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""Rectified Flow: predict velocity v = x_1 - x_0"""
batch_size = x0.shape[0]
device = x0.device
t = torch.rand(batch_size, device=device)
x1 = torch.randn_like(x0)
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
pred_v = self.forward(xt, t)
true_v = x1 - x0
# Ensure pred_v matches true_v shape
if pred_v.shape != true_v.shape:
pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False)
return F.mse_loss(pred_v, true_v)
def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor:
"""Flow matching loss"""
batch_size = x0.shape[0]
device = x0.device
t = torch.rand(batch_size, device=device)
# Use proper schedule (NEW)
alpha_t, sigma_t = self.schedule.get_schedule(t)
noise = torch.randn_like(x0)
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
pred_v = self.forward(xt, t)
# true_v = x0 - xt
true_v = noise - x0
# Ensure pred_v matches true_v shape
if pred_v.shape != true_v.shape:
pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False)
return F.mse_loss(pred_v, true_v)
# Helper function from original DiT
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
# Fixed sin-cos position embedding from original
def get_2d_sincos_pos_embed(embed_dim, grid_size):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
assert embed_dim % 2 == 0
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
# Timestep Embedder
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t: torch.Tensor) -> torch.Tensor:
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(t_freq)
# DiTBlock with proper AdaLN-Zero
class DiTBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, use_text: bool = False, use_adaln_single: bool = False):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim),
nn.GELU(approximate="tanh"), # Match original
nn.Linear(mlp_hidden_dim, hidden_size),
)
# AdaLN modulation - either per-block MLP or AdaLN-Single embeddings
self.use_adaln_single = use_adaln_single
if use_adaln_single:
# AdaLN-Single: use learnable per-block embeddings instead of MLP
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
self.adaLN_modulation = None # No MLP needed
else:
# Original AdaLN with per-block MLP
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.scale_shift_table = None
# Optional text cross-attention
self.use_text = use_text
if use_text:
# Note: PixArt uses xformers which may handle unnormalized queries differently
# We add a simple norm for stability with PyTorch's MultiheadAttention
self.norm_cross = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.cross_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True)
def forward(self, x: torch.Tensor, c: torch.Tensor, text_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None):
# Get modulation parameters
if self.use_adaln_single:
# AdaLN-Single: combine global time embedding with per-block parameters
# c should be pre-computed from global t_block with shape [B, 6*hidden_size]
B = x.shape[0]
# Chunk and squeeze to get [B, hidden_size] tensors for compatibility with PyTorch's MultiheadAttention
temp = (self.scale_shift_table[None] + c.reshape(B, 6, -1)).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp]
else:
# Original AdaLN: compute modulation from per-block MLP
# Also squeeze after chunk to get [B, hidden_size] for consistency
temp = self.adaLN_modulation(c).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp]
# Self-attention with modulation
# Both paths now use modulate function for consistency
x_norm = modulate(self.norm1(x), shift_msa, scale_msa)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + gate_msa.unsqueeze(1) * attn_out
# Optional cross-attention
if self.use_text and text_emb is not None:
if text_emb.dim() == 2:
text_emb = text_emb.unsqueeze(1)
# Convert attention mask to key_padding_mask format (True = ignore)
# attention_mask: shape [B, T]; either bool (True=keep) or 0/1 numeric (1=keep)
key_padding_mask = None
if attention_mask is not None:
if attention_mask.dtype is not torch.bool:
# Convert 0/1 (or >=1) to bool keep-mask first
keep_mask = attention_mask > 0
else:
keep_mask = attention_mask
# key_padding_mask semantics: True = ignore, False = keep
key_padding_mask = ~keep_mask # logical NOT, not arithmetic subtraction
# Normalize queries for stability (PixArt uses xformers which may differ)
x_norm = self.norm_cross(x)
cross_out, _ = self.cross_attn(x_norm, text_emb, text_emb, key_padding_mask=key_padding_mask)
x = x + cross_out
# MLP with modulation
# Both paths now use modulate function for consistency
x_norm = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_out = self.mlp(x_norm)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
# FinalLayer with AdaLN modulation
class FinalLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x: torch.Tensor, c: torch.Tensor):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
# T2IFinalLayer with AdaLN-Single for parameter efficiency
class T2IFinalLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
# AdaLN-Single: use learnable embeddings instead of MLP
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
self.hidden_size = hidden_size
def forward(self, x: torch.Tensor, t: torch.Tensor):
# t should be the original time embedding with shape [B, hidden_size]
# Following PixArt implementation exactly
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
# shift and scale are [B, 1, hidden_size], use t2i_modulate style
x = self.norm_final(x) * (1 + scale) + shift
x = self.linear(x)
return x
# DiTExpert
class DiTExpert(nn.Module):
def __init__(self, config):
super().__init__()
default_params = {
"hidden_size": 768,
"num_layers": 12,
"num_heads": 12,
"patch_size": 2,
"in_channels": 4,
"out_channels": 4,
"use_text_conditioning": False,
"use_class_conditioning": False,
"num_classes": 1000, # ImageNet classes
"mlp_ratio": 4.0,
"text_embed_dim": 768,
"use_dit_time_embed": False,
}
params = {**default_params, **config.expert_params}
self.patch_size = params["patch_size"]
self.in_channels = params["in_channels"]
self.out_channels = params["out_channels"]
self.hidden_size = params["hidden_size"]
self.num_heads = params["num_heads"]
self.use_text = params.get("use_text_conditioning", False)
self.use_class = params.get("use_class_conditioning", False)
self.cfg_dropout_prob = params.get("cfg_dropout_prob", 0.1) # 10% dropout for CFG
self.text_embed_dim = params.get("text_embed_dim", 768)
self.use_adaln_single = params.get("use_adaln_single", False) # AdaLN-Single for parameter efficiency
self.depth = params["num_layers"]
# Store objective type for heterogeneous training
self.objective_type = params.get("objective_type", "fm")
# Store and initialize schedule (NEW)
schedule_type = params.get("schedule_type", "linear_interp")
from schedules import NoiseSchedule
self.schedule = NoiseSchedule(schedule_type)
# Validation: cannot use both text and class conditioning simultaneously
assert not (self.use_text and self.use_class), "Cannot use both text and class conditioning simultaneously"
# Patch embedding
self.patch_embed = nn.Conv2d(self.in_channels, self.hidden_size,
kernel_size=self.patch_size, stride=self.patch_size)
# Fixed sin-cos positional embedding
latent_size = getattr(config, 'image_size', 32)
self.num_patches = (latent_size // self.patch_size) ** 2
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.hidden_size), requires_grad=False)
# Time embedding
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
if self.use_dit_time_embed:
self.time_embed = DiTTimestepEmbedder(self.hidden_size)
else:
self.time_embed = TimestepEmbedder(self.hidden_size)
# Global time block for AdaLN-Single
if self.use_adaln_single:
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)
)
# Optional text conditioning
if self.use_text:
self.text_proj = nn.Linear(self.text_embed_dim, self.hidden_size)
self.text_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6)
# Note: null text embedding will be provided by empty string encoding from CLIP
# This is handled in the training loop, not as a learnable parameter
# Optional class conditioning (ImageNet style)
if self.use_class:
# Add 1 extra embedding for null/unconditional class
self.class_embed = nn.Embedding(params["num_classes"] + 1, self.hidden_size)
self.null_class_id = params["num_classes"] # Use last index as null class
# Transformer blocks
self.layers = nn.ModuleList([
DiTBlock(self.hidden_size, self.num_heads, params.get("mlp_ratio", 4.0),
self.use_text, use_adaln_single=self.use_adaln_single)
for _ in range(self.depth)
])
# Final layer with modulation
if self.use_adaln_single:
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, self.out_channels)
else:
self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels)
# Initialize weights
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize positional embedding with sin-cos
grid_size = int(self.num_patches ** 0.5)
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear
w = self.patch_embed.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
if self.patch_embed.bias is not None:
nn.init.constant_(self.patch_embed.bias, 0)
# Initialize timestep embedding MLP
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks (from DiT paper)
for block in self.layers:
if block.adaLN_modulation is not None:
# Original AdaLN mode
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# AdaLN-Single mode: scale_shift_table is already initialized with randn/sqrt(hidden_size)
# Zero-out cross-attention output projection (from PixArt-Alpha)
if self.use_text and hasattr(block, 'cross_attn'):
nn.init.constant_(block.cross_attn.out_proj.weight, 0)
nn.init.constant_(block.cross_attn.out_proj.bias, 0)
# Initialize text projection layer (analogous to PixArt's caption embedding)
if self.use_text and hasattr(self, 'text_proj'):
nn.init.normal_(self.text_proj.weight, std=0.02)
if self.text_proj.bias is not None:
nn.init.constant_(self.text_proj.bias, 0)
# Initialize class embedding layer (similar to DiT paper)
if self.use_class and hasattr(self, 'class_embed'):
nn.init.normal_(self.class_embed.weight, std=0.02)
# Initialize global t_block for AdaLN-Single
if self.use_adaln_single and hasattr(self, 't_block'):
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Zero-out t_block initially for stability
nn.init.constant_(self.t_block[1].bias, 0)
# Zero-out output layers
if hasattr(self.final_layer, 'adaLN_modulation') and self.final_layer.adaLN_modulation is not None:
# Original FinalLayer
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
# T2IFinalLayer scale_shift_table is already initialized with randn/sqrt(hidden_size)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, xt: torch.Tensor, t: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
B, C, H, W = xt.shape
# Handle timestep scaling - DiT expects timesteps in [0, 999] range
# If t is normalized (in [0, 1]), scale it to [0, 999]
if t.max() <= 1.0 and t.min() >= 0.0:
# Normalized timesteps, scale to DiT range
t = t * 999.0
# Ensure t is in correct range for DiT
t = t.clamp(0, 999)
# Patchify
x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
x = x + self.pos_embed # Add positional embedding
# Prepare conditioning
time_emb = self.time_embed(t) # [B, hidden_size]
# Add class conditioning to time embedding if using class conditioning
if self.use_class and class_labels is not None:
class_emb = self.class_embed(class_labels) # [B, hidden_size]
time_emb = time_emb + class_emb # Additive combination following DiT paper
# Process conditioning based on AdaLN mode
if self.use_adaln_single:
# AdaLN-Single: compute global modulation once
c = self.t_block(time_emb) # [B, 6*hidden_size]
else:
# Original AdaLN: pass time embedding to each block
c = time_emb
# Prepare text tokens for cross-attention (not fused with time)
text_tokens = None
if self.use_text and text_embeds is not None:
if text_embeds.dim() == 3:
text_tokens = self.text_proj(text_embeds) # [B, T, hidden_size]
text_tokens = self.text_norm(text_tokens)
else:
text_tokens = self.text_proj(text_embeds).unsqueeze(1) # [B, 1, hidden_size]
text_tokens = self.text_norm(text_tokens)
if attention_mask is not None:
# cast to bool, clamp shapes to text_tokens length
attention_mask = attention_mask[:, :text_tokens.shape[1]].to(torch.bool)
# safety: avoid all-false rows (would yield NaNs in softmax)
all_false = attention_mask.sum(dim=1) == 0
if all_false.any():
attention_mask[all_false, 0] = True
# Apply transformer blocks
for layer in self.layers:
x = layer(x, c, text_tokens, attention_mask)
# Final projection
if self.use_adaln_single:
# T2IFinalLayer expects original time embedding, not global modulation
x = self.final_layer(x, time_emb) # [B, num_patches, patch_size^2 * out_channels]
else:
# Original FinalLayer expects conditioning
x = self.final_layer(x, c) # [B, num_patches, patch_size^2 * out_channels]
# Unpatchify
patch_h = patch_w = int(self.num_patches ** 0.5)
x = x.view(B, patch_h, patch_w, self.patch_size, self.patch_size, self.out_channels)
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
x = x.view(B, self.out_channels, H, W)
return x
def compute_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Unified loss computation based on objective type"""
if self.objective_type == "ddpm":
return self.ddpm_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
elif self.objective_type == "fm":
return self.flow_matching_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
elif self.objective_type == "rf":
return self.rectified_flow_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
else:
raise ValueError(f"Unknown objective type: {self.objective_type}")
def ddpm_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""DDPM: predict noise ε"""
B = x0.shape[0]
device = x0.device
# Sample time uniformly
t = torch.rand(B, device=device)
# Use proper schedule (NEW)
alpha_t, sigma_t = self.schedule.get_schedule(t)
noise = torch.randn_like(x0)
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
# Apply CFG dropout during training
if self.training and self.cfg_dropout_prob > 0:
if self.use_text and text_embeds is not None:
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
if null_text_embeds is not None:
# Use provided null text embeddings (from empty string CLIP encoding)
if null_text_embeds.shape[0] == 1:
null_text_embeds = null_text_embeds.expand(B, -1, -1)
# Replace dropped samples with null text embeddings
dropped = ~keep
if dropped.any():
text_embeds = text_embeds.clone()
text_embeds[dropped] = null_text_embeds[dropped]
# Use provided null attention mask or create default for empty string
if attention_mask is not None:
attention_mask = attention_mask.clone()
if null_attention_mask is not None:
if null_attention_mask.shape[0] == 1:
null_attention_mask = null_attention_mask.expand(B, -1)
attention_mask[dropped] = null_attention_mask[dropped]
else:
attention_mask[dropped] = 0
attention_mask[dropped, 0] = 1
else:
# Fallback to old zeroing approach if null_text_embeds not provided
if text_embeds.dim() == 3: # [B, T, D]
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
else: # [B, D]
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
if attention_mask is not None:
attention_mask = attention_mask.clone()
dropped = ~keep
if dropped.any():
attention_mask[dropped, 0] = 1
elif self.use_class and class_labels is not None:
# Apply CFG dropout to class labels using null class embedding
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob)
null_class = torch.full_like(class_labels, self.null_class_id)
class_labels = torch.where(keep, class_labels, null_class)
# Predict noise
pred_eps = self.forward(xt, t, text_embeds, attention_mask, class_labels)
return F.mse_loss(pred_eps, noise)
def rectified_flow_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Rectified Flow: predict velocity v = x_1 - x_0 (straight paths)"""
B = x0.shape[0]
device = x0.device
# Sample time uniformly
t = torch.rand(B, device=device)
# Straight-line interpolation
x1 = torch.randn_like(x0) # Gaussian noise as x_1
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
# Apply CFG dropout during training
if self.training and self.cfg_dropout_prob > 0:
if self.use_text and text_embeds is not None:
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
if null_text_embeds is not None:
# Use provided null text embeddings (from empty string CLIP encoding)
if null_text_embeds.shape[0] == 1:
null_text_embeds = null_text_embeds.expand(B, -1, -1)
# Replace dropped samples with null text embeddings
dropped = ~keep
if dropped.any():
text_embeds = text_embeds.clone()
text_embeds[dropped] = null_text_embeds[dropped]
# Use provided null attention mask or create default for empty string
if attention_mask is not None:
attention_mask = attention_mask.clone()
if null_attention_mask is not None:
if null_attention_mask.shape[0] == 1:
null_attention_mask = null_attention_mask.expand(B, -1)
attention_mask[dropped] = null_attention_mask[dropped]
else:
attention_mask[dropped] = 0
attention_mask[dropped, 0] = 1
else:
# Fallback to old zeroing approach if null_text_embeds not provided
if text_embeds.dim() == 3: # [B, T, D]
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
else: # [B, D]
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
if attention_mask is not None:
attention_mask = attention_mask.clone()
dropped = ~keep
if dropped.any():
attention_mask[dropped, 0] = 1
elif self.use_class and class_labels is not None:
# Apply CFG dropout to class labels using null class embedding
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob)
null_class = torch.full_like(class_labels, self.null_class_id)
class_labels = torch.where(keep, class_labels, null_class)
# Predict velocity (x_1 - x_0)
pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels)
true_v = x1 - x0
return F.mse_loss(pred_v, true_v)
def flow_matching_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Flow matching loss for latent space training with CFG dropout."""
B = x0.shape[0]
device = x0.device
# Sample time uniformly
t = torch.rand(B, device=device)
# Use proper schedule (NEW)
alpha_t, sigma_t = self.schedule.get_schedule(t)
noise = torch.randn_like(x0)
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
# Apply CFG dropout during training
if self.training and self.cfg_dropout_prob > 0:
if self.use_text and text_embeds is not None:
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
if null_text_embeds is not None:
# Use provided null text embeddings (from empty string CLIP encoding)
# Ensure null_text_embeds matches the batch size
if null_text_embeds.shape[0] == 1:
null_text_embeds = null_text_embeds.expand(B, -1, -1)
# Replace dropped samples with null text embeddings
dropped = ~keep
if dropped.any():
text_embeds = text_embeds.clone()
text_embeds[dropped] = null_text_embeds[dropped]
# Use provided null attention mask or create default for empty string
if attention_mask is not None:
attention_mask = attention_mask.clone()
if null_attention_mask is not None:
# Ensure null_attention_mask matches batch size
if null_attention_mask.shape[0] == 1:
null_attention_mask = null_attention_mask.expand(B, -1)
attention_mask[dropped] = null_attention_mask[dropped]
else:
# Default: For null text (empty string), typically only the first token is valid
attention_mask[dropped] = 0
attention_mask[dropped, 0] = 1 # Keep only first token for empty string
else:
# Fallback to old zeroing approach if null_text_embeds not provided
if text_embeds.dim() == 3: # [B, T, D]
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
else: # [B, D]
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
# Handle attention mask for fallback approach
if attention_mask is not None:
attention_mask = attention_mask.clone()
dropped = ~keep
if dropped.any():
attention_mask[dropped, 0] = 1
elif self.use_class and class_labels is not None:
# Apply CFG dropout to class labels using null class embedding
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep class
# Use the dedicated null class embedding for unconditional generation
null_class = torch.full_like(class_labels, self.null_class_id)
class_labels = torch.where(keep, class_labels, null_class)
# Predict velocity
pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels)
true_v = noise - x0
return F.mse_loss(pred_v, true_v)
# =============================================================================
# ROUTER MODELS
# =============================================================================
class ViTRouter(nn.Module):
"""ViT-based router for cluster classification"""
def __init__(self, config) -> None:
super().__init__()
# Default params
default_params = {
"hidden_size": 384,
"num_layers": 6,
"num_heads": 6,
"patch_size": 8,
"use_dit_time_embed": False, # Whether to use DiT-style time embedding
}
params = {**default_params, **config.router_params}
if config.router_pretrained:
# Use pretrained ViT and adapt
self.vit = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224"
)
self._adapt_pretrained(config, params)
else:
# Build from scratch
vit_config = ViTConfig(
image_size=config.image_size,
patch_size=params["patch_size"],
num_channels=config.num_channels,
hidden_size=params["hidden_size"],
num_hidden_layers=params["num_layers"],
num_attention_heads=params["num_heads"],
num_labels=config.num_clusters
)
self.vit = ViTForImageClassification(vit_config)
# Time conditioning - support both embedding styles
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
if self.use_dit_time_embed:
# Use DiT-style timestep embedding for consistency
self.time_embedding = DiTTimestepEmbedder(params["hidden_size"])
else:
# Original simple time embedding
self.time_embedding = nn.Sequential(
nn.Linear(1, params["hidden_size"]),
nn.SiLU(),
nn.Linear(params["hidden_size"], params["hidden_size"])
)
# Combined classifier
self.classifier = nn.Sequential(
nn.Linear(params["hidden_size"] * 2, params["hidden_size"]),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(params["hidden_size"], config.num_clusters)
)
def _adapt_pretrained(self, config, params) -> ViTForImageClassification:
"""Adapt pretrained ViT for our task"""
# Modify patch embeddings if needed
if config.image_size != 224 or config.num_channels != 3:
self.vit.vit.embeddings.patch_embeddings.projection = nn.Conv2d(
config.num_channels,
self.vit.config.hidden_size,
kernel_size=params["patch_size"],
stride=params["patch_size"]
)
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# Process image through ViT
vit_outputs = self.vit.vit(xt)
image_features = vit_outputs.last_hidden_state[:, 0] # CLS token
# Time conditioning
if self.use_dit_time_embed:
# DiT embedder expects raw timesteps
time_features = self.time_embedding(t)
else:
# Original embedding needs unsqueeze
time_features = self.time_embedding(t.unsqueeze(-1))
# Combine and classify
combined = torch.cat([image_features, time_features], dim=1)
return self.classifier(combined)
class CNNRouter(nn.Module):
"""Simple CNN router for cluster classification"""
def __init__(self, config) -> None:
super().__init__()
# Default params
default_params = {
"hidden_dims": [64, 128, 256],
"use_dit_time_embed": False, # Whether to use DiT-style time embedding
}
params = {**default_params, **config.router_params}
# CNN backbone
self.backbone = self._build_cnn(config.num_channels, params["hidden_dims"])
# Time embedding - support both styles
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
if self.use_dit_time_embed:
# Use DiT-style timestep embedding, output to 128 dims for CNN
self.time_embedding = DiTTimestepEmbedder(128)
else:
# Original simple time embedding
self.time_embedding = nn.Sequential(
nn.Linear(1, 128),
nn.SiLU(),
nn.Linear(128, 128)
)
# Classifier
self.classifier = nn.Sequential(
nn.Linear(params["hidden_dims"][-1] + 128, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, config.num_clusters)
)
def _build_cnn(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential:
layers = []
prev_dim = in_channels
for dim in hidden_dims:
layers.extend([
nn.Conv2d(prev_dim, dim, 3, padding=1),
nn.ReLU(),
nn.Conv2d(dim, dim, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
])
prev_dim = dim
layers.append(nn.AdaptiveAvgPool2d(1))
layers.append(nn.Flatten())
return nn.Sequential(*layers)
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# CNN features
img_features = self.backbone(xt)
# Time features
if self.use_dit_time_embed:
# DiT embedder expects raw timesteps
time_features = self.time_embedding(t)
else:
# Original embedding needs unsqueeze
time_features = self.time_embedding(t.unsqueeze(-1))
# Combine and classify
combined = torch.cat([img_features, time_features], dim=1)
return self.classifier(combined)
class DiTRouter(nn.Module):
"""DiT B/2 router for cluster classification"""
def __init__(self, config):
super().__init__()
# DiT B/2 specifications
default_params = {
"hidden_size": 768, # DiT-B uses 768
"num_layers": 12, # DiT-B uses 12 layers
"num_heads": 12, # DiT-B uses 12 heads
"patch_size": 2, # For latent space (32x32 -> 16x16 patches)
"in_channels": 4, # VAE latent channels
"mlp_ratio": 4.0,
"use_dit_time_embed": False, # Whether to use DiT-style time embedding
}
params = {**default_params, **config.router_params}
self.patch_size = params["patch_size"]
self.in_channels = params["in_channels"]
self.hidden_size = params["hidden_size"]
self.num_heads = params["num_heads"]
self.num_clusters = config.num_clusters
# Patch embedding (same as expert)
self.patch_embed = nn.Conv2d(
self.in_channels, self.hidden_size,
kernel_size=self.patch_size, stride=self.patch_size
)
# Calculate number of patches
latent_size = getattr(config, 'image_size', 32) # Assuming 256/8=32 for VAE
self.num_patches = (latent_size // self.patch_size) ** 2
# Fixed sin-cos positional embedding (same as expert)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, self.hidden_size),
requires_grad=False
)
# CLS token (KEY ADDITION from paper)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
# Time embedding - match expert's choice
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
if self.use_dit_time_embed:
self.time_embed = DiTTimestepEmbedder(self.hidden_size)
else:
self.time_embed = TimestepEmbedder(self.hidden_size)
# DiT blocks with AdaLN (reuse DiTBlock from expert)
# Note: Router doesn't need text conditioning
self.layers = nn.ModuleList([
DiTBlock(self.hidden_size, self.num_heads, params["mlp_ratio"], use_text=False)
for _ in range(params["num_layers"])
])
# Final layer norm
self.norm_final = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6)
# Linear classifier on CLS token (as specified in paper)
# self.head = nn.Linear(self.hidden_size, self.num_clusters)
self.head = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.LayerNorm(self.hidden_size),
nn.Dropout(0.1),
nn.Linear(self.hidden_size, self.num_clusters)
)
# Initialize weights
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize CLS token
nn.init.normal_(self.cls_token, std=0.02)
# Initialize positional embedding with sin-cos (same as expert)
grid_size = int(self.num_patches ** 0.5)
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear
w = self.patch_embed.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
if self.patch_embed.bias is not None:
nn.init.constant_(self.patch_embed.bias, 0)
# Initialize timestep embedding MLP
if hasattr(self.time_embed, 'mlp'):
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation in blocks (following expert initialization)
for block in self.layers:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# # Initialize classification head (simpler version for classification head)
# nn.init.constant_(self.head.weight, 0)
# nn.init.constant_(self.head.bias, 0)
# Initialize classification head (Sequential)
# Initialize intermediate layers normally, zero-out final layer
nn.init.normal_(self.head[0].weight, std=0.02) # First linear layer
if self.head[0].bias is not None:
nn.init.constant_(self.head[0].bias, 0)
# Zero-out final classification layer (following DiT paper)
nn.init.constant_(self.head[-1].weight, 0) # Last linear layer
if self.head[-1].bias is not None:
nn.init.constant_(self.head[-1].bias, 0)
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
B, C, H, W = xt.shape
# Match expert's timestep interpretation
if t.max() <= 1.0 and t.min() >= 0.0:
t = t * 999.0
t = t.clamp(0, 999)
# Patchify
x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
# Add positional embedding
x = x + self.pos_embed
# Prepend CLS token
cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, hidden_size]
x = torch.cat([cls_tokens, x], dim=1) # [B, 1 + num_patches, hidden_size]
# Time conditioning
c = self.time_embed(t) # [B, hidden_size]
# Apply DiT blocks with AdaLN modulation
for layer in self.layers:
x = layer(x, c, text_emb=None)
# Extract CLS token and apply final norm
cls_output = x[:, 0] # [B, hidden_size]
cls_output = self.norm_final(cls_output)
# Linear classification head
logits = self.head(cls_output) # [B, num_clusters]
return logits
# =============================================================================
# DETERMINISTIC ROUTER (for controlled experiments)
# =============================================================================
class DeterministicTimestepRouter(nn.Module):
"""
Deterministic router that assigns experts based on timestep.
Useful for controlled experiments where you want to test specific routing strategies,
such as: "high noise → DDPM expert, low noise → FM expert"
Args:
config: Config object with router_params containing:
- timestep_threshold: t value to switch experts (default: 0.5)
- high_noise_expert: Expert ID for t > threshold (default: 0, typically DDPM)
- low_noise_expert: Expert ID for t <= threshold (default: 1, typically FM)
Example config:
router_architecture: "deterministic_timestep"
router_params:
timestep_threshold: 0.5
high_noise_expert: 0 # DDPM for high noise
low_noise_expert: 1 # FM for low noise
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.threshold = config.router_params.get('timestep_threshold', 0.5)
self.high_noise_expert = config.router_params.get('high_noise_expert', 0)
self.low_noise_expert = config.router_params.get('low_noise_expert', 1)
# Validate expert IDs
assert 0 <= self.high_noise_expert < self.num_experts, \
f"high_noise_expert {self.high_noise_expert} out of range [0, {self.num_experts})"
assert 0 <= self.low_noise_expert < self.num_experts, \
f"low_noise_expert {self.low_noise_expert} out of range [0, {self.num_experts})"
# Validate threshold
assert 0.0 <= self.threshold <= 1.0, \
f"timestep_threshold {self.threshold} must be in [0, 1]"
# This router has no trainable parameters
# Register threshold as buffer (not trained, but saved with model)
self.register_buffer('_threshold', torch.tensor(self.threshold))
print(f"DeterministicTimestepRouter initialized:")
print(f" Threshold: {self.threshold}")
print(f" High noise (t > {self.threshold}) → Expert {self.high_noise_expert}")
print(f" Low noise (t <= {self.threshold}) → Expert {self.low_noise_expert}")
def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns one-hot routing probabilities based on timestep.
Args:
x: Input tensor (unused, but kept for API compatibility with other routers)
t: Timesteps, shape (B,)
Returns:
routing_probs: Shape (B, num_experts), one-hot encoded
"""
B = t.shape[0]
device = t.device
# Initialize routing probabilities (all zeros)
routing_probs = torch.zeros(B, self.num_experts, device=device)
# High noise (t > threshold) → high_noise_expert
# Low noise (t <= threshold) → low_noise_expert
high_noise_mask = t > self.threshold
routing_probs[high_noise_mask, self.high_noise_expert] = 1.0
routing_probs[~high_noise_mask, self.low_noise_expert] = 1.0
return routing_probs
def train(self, mode: bool = True):
"""Override train() - this router is never trained, always in eval mode"""
return super(DeterministicTimestepRouter, self).train(False)
# =============================================================================
# ADAPTIVE VIDEO ROUTER (for Video DDM)
# =============================================================================
class AdaptiveVideoRouter(nn.Module):
"""
Time-adaptive router for video DDM.
Key innovation: Learns optimal weighting of information sources
at each noise level, solving the "motion invisible at t=1" problem.
Information availability is time-dependent:
t ~ 1.0: Only text/first_frame informative → Route on conditioning
t ~ 0.5: Structure emerging → Latent becomes useful
t ~ 0.1: Near clean → Full information available
Expected learned behavior:
| Noise Level | Text | Frame | Latent | Behavior |
|-------------|------|-------|--------|-----------------------------|
| t ~ 1.0 | ~0.7 | ~0.2 | ~0.1 | Routes on text semantics |
| t ~ 0.5 | ~0.4 | ~0.3 | ~0.3 | Balanced; emerging structure|
| t ~ 0.1 | ~0.2 | ~0.2 | ~0.6 | Trusts latent; fine-grained |
Enhancements:
- Masked mean pooling for text (handles variable-length prompts)
- Temporal-aware latent encoder (captures motion patterns)
- Temperature scaling for inference control
"""
def __init__(self, config):
super().__init__()
# Default params
default_params = {
"hidden_dim": 512,
"text_embed_dim": 768, # CLIP-L text embedding dimension
"frame_embed_dim": 768, # DINOv2-B (base) feature dimension
"latent_channels": 16, # VAE latent channels (CogVideoX uses 16)
"latent_conv_dim": 64, # Intermediate conv channels for latent encoder
"dropout": 0.1,
"temporal_pool_mode": "attention", # "attention", "avg", or "max"
"normalize_inputs": True, # L2-normalize text/frame inputs (match clustering)
}
params = {**default_params, **getattr(config, 'router_params', {})}
self.hidden_dim = params["hidden_dim"]
self.num_experts = getattr(config, 'num_experts', config.num_clusters)
self.latent_channels = params["latent_channels"]
self.latent_conv_dim = params["latent_conv_dim"]
self.temporal_pool_mode = params["temporal_pool_mode"]
self.normalize_inputs = params.get("normalize_inputs", True)
# === Information Source Encoders ===
# Text pathway (always available, primary signal at high t)
self.text_encoder = nn.Sequential(
nn.Linear(params["text_embed_dim"], self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
nn.GELU(),
nn.Linear(self.hidden_dim, self.hidden_dim)
)
# First frame pathway (available for I2V tasks)
# Uses DINOv2 features extracted from the first frame
self.frame_encoder = nn.Sequential(
nn.Linear(params["frame_embed_dim"], self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
nn.GELU(),
nn.Linear(self.hidden_dim, self.hidden_dim)
)
# === Temporal-Aware Latent Encoder ===
# Captures both spatial content and temporal motion patterns
# Spatial feature extraction (per-frame)
self.spatial_conv = nn.Sequential(
nn.Conv3d(params["latent_channels"], params["latent_conv_dim"],
kernel_size=(1, 3, 3), padding=(0, 1, 1)), # Spatial only
nn.GroupNorm(8, params["latent_conv_dim"]),
nn.GELU(),
)
# Temporal feature extraction (motion patterns)
self.temporal_conv = nn.Sequential(
nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"],
kernel_size=(3, 1, 1), padding=(1, 0, 0)), # Temporal only
nn.GroupNorm(8, params["latent_conv_dim"]),
nn.GELU(),
)
# Combined spatio-temporal processing
self.st_conv = nn.Sequential(
nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"],
kernel_size=3, padding=1), # Full 3D
nn.GroupNorm(8, params["latent_conv_dim"]),
nn.GELU(),
)
# Spatial pooling (keep temporal dimension)
self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) # [B, C, T, 1, 1]
# Temporal attention pooling (learns which frames matter for routing)
if self.temporal_pool_mode == "attention":
self.temporal_attn = nn.Sequential(
nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"] // 4),
nn.Tanh(),
nn.Linear(params["latent_conv_dim"] // 4, 1),
)
# Motion feature extractor (frame differences)
self.motion_encoder = nn.Sequential(
nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"]),
nn.GELU(),
nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2),
)
# Content feature projector
self.content_proj = nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2)
# Final latent projection (combines content + motion)
self.latent_proj = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
)
# === Time-Dependent Weighting ===
# Time embedding using existing infrastructure
self.time_embed = TimestepEmbedder(self.hidden_dim)
self.time_mlp = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.GELU(),
nn.Linear(self.hidden_dim, self.hidden_dim)
)
# Learns adaptive weighting: at high t → trust text; at low t → trust latent
self.source_weighting = nn.Sequential(
nn.Linear(self.hidden_dim, 128),
nn.GELU(),
nn.Linear(128, 3), # [text, frame, latent] weights
nn.Softmax(dim=-1)
)
# === Routing Head ===
self.router_head = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.GELU(),
nn.LayerNorm(self.hidden_dim),
nn.Dropout(params["dropout"]),
nn.Linear(self.hidden_dim, self.num_experts)
)
# Initialize weights
self.initialize_weights()
def initialize_weights(self):
"""Initialize weights following DiT conventions."""
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Conv3d):
# Flatten spatial dims for xavier init
w = module.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP (following DiT)
if hasattr(self.time_embed, 'mlp'):
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
# Small non-zero initialization for final routing layer
# (pure zeros cause uniform outputs that break temperature scaling)
nn.init.normal_(self.router_head[-1].weight, std=0.01)
nn.init.constant_(self.router_head[-1].bias, 0)
# Initialize source weighting to start roughly uniform
# The softmax will make [0, 0, 0] → [0.33, 0.33, 0.33]
nn.init.constant_(self.source_weighting[-2].weight, 0)
nn.init.constant_(self.source_weighting[-2].bias, 0)
# Initialize temporal attention to uniform attention
if self.temporal_pool_mode == "attention":
nn.init.constant_(self.temporal_attn[-1].weight, 0)
nn.init.constant_(self.temporal_attn[-1].bias, 0)
def _masked_mean_pool(self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute masked mean pooling over sequence dimension.
Args:
embeddings: [B, seq_len, embed_dim] - Token embeddings
attention_mask: [B, seq_len] - 1 for real tokens, 0 for padding
Returns:
pooled: [B, embed_dim] - Pooled representation
"""
if attention_mask is None:
# No mask provided, use simple mean
return embeddings.mean(dim=1)
# Expand mask for broadcasting: [B, seq_len] -> [B, seq_len, 1]
mask = attention_mask.unsqueeze(-1).to(embeddings.dtype)
# Masked sum
masked_sum = (embeddings * mask).sum(dim=1) # [B, embed_dim]
# Count of valid tokens (avoid division by zero)
token_counts = mask.sum(dim=1).clamp(min=1.0) # [B, 1]
return masked_sum / token_counts
def _encode_latent_temporal(self, x_t: torch.Tensor) -> torch.Tensor:
"""
Encode video latent with temporal awareness.
Extracts both:
- Content features: What is in the video (spatial)
- Motion features: How things move (temporal differences)
Args:
x_t: [B, C, T, H, W] - Noisy video latent
Returns:
latent_feat: [B, hidden_dim] - Combined latent features
"""
B, C, T, H, W = x_t.shape
# 1. Spatial feature extraction
spatial_feat = self.spatial_conv(x_t) # [B, conv_dim, T, H, W]
# 2. Temporal feature extraction (captures local motion)
temporal_feat = self.temporal_conv(spatial_feat) # [B, conv_dim, T, H, W]
# 3. Combined spatio-temporal processing
st_feat = self.st_conv(temporal_feat) # [B, conv_dim, T, H, W]
# 4. Pool spatially, keep temporal: [B, conv_dim, T, 1, 1] -> [B, T, conv_dim]
pooled = self.spatial_pool(st_feat).squeeze(-1).squeeze(-1) # [B, conv_dim, T]
pooled = pooled.permute(0, 2, 1) # [B, T, conv_dim]
# 5. Temporal pooling with optional attention
if self.temporal_pool_mode == "attention" and T > 1:
# Learn which frames matter for routing
attn_scores = self.temporal_attn(pooled).squeeze(-1) # [B, T]
attn_weights = F.softmax(attn_scores, dim=-1) # [B, T]
content_feat = (pooled * attn_weights.unsqueeze(-1)).sum(dim=1) # [B, conv_dim]
elif self.temporal_pool_mode == "max":
content_feat = pooled.max(dim=1)[0] # [B, conv_dim]
else: # "avg"
content_feat = pooled.mean(dim=1) # [B, conv_dim]
# 6. Extract motion features (frame differences)
if T > 1:
# Compute frame-to-frame differences
frame_diffs = pooled[:, 1:] - pooled[:, :-1] # [B, T-1, conv_dim]
# Motion magnitude and direction encoding
motion_feat = self.motion_encoder(frame_diffs.mean(dim=1)) # [B, hidden_dim//2]
else:
# Single frame, no motion
motion_feat = torch.zeros(B, self.hidden_dim // 2, device=x_t.device)
# 7. Project content features
content_proj = self.content_proj(content_feat) # [B, hidden_dim//2]
# 8. Combine content + motion
combined = torch.cat([content_proj, motion_feat], dim=-1) # [B, hidden_dim]
latent_feat = self.latent_proj(combined) # [B, hidden_dim]
return latent_feat
def forward(self, x_t: torch.Tensor, t: torch.Tensor,
text_embed: torch.Tensor,
first_frame_feat: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temperature: float = 1.0) -> torch.Tensor:
"""
Compute routing logits with time-adaptive information weighting.
Args:
x_t: Noisy video latent [B, C, T, H, W]
t: Noise level [B] in [0, 1] or [0, 999]
text_embed: CLIP text embedding [B, text_embed_dim] or [B, seq_len, text_embed_dim]
first_frame_feat: Optional DINOv2 features [B, frame_embed_dim]
attention_mask: Optional [B, seq_len] mask for text (1=valid, 0=padding)
temperature: Softmax temperature for sharper/softer routing (default: 1.0)
Returns:
logits: Expert selection logits [B, num_experts] (scaled by temperature)
"""
B = x_t.shape[0]
device = x_t.device
# === Encode each information source ===
# Handle both pooled [B, D] and sequence [B, seq_len, D] text embeddings
if text_embed.dim() == 3:
# Use masked mean pooling for sequence embeddings
text_embed_pooled = self._masked_mean_pool(text_embed, attention_mask)
else:
# Already pooled
text_embed_pooled = text_embed
# L2-normalize inputs to match clustering preprocessing
if self.normalize_inputs:
text_embed_pooled = F.normalize(text_embed_pooled, p=2, dim=-1)
text_feat = self.text_encoder(text_embed_pooled) # [B, hidden_dim]
# Frame features (optional for T2V, required for I2V)
if first_frame_feat is not None:
# L2-normalize to match clustering preprocessing
if self.normalize_inputs:
first_frame_feat = F.normalize(first_frame_feat, p=2, dim=-1)
frame_feat = self.frame_encoder(first_frame_feat) # [B, hidden_dim]
else:
frame_feat = torch.zeros(B, self.hidden_dim, device=device)
# Latent features from noisy video (temporal-aware encoding)
latent_feat = self._encode_latent_temporal(x_t) # [B, hidden_dim]
# === Time-dependent weighting ===
# Normalize timesteps to [0, 999] for TimestepEmbedder
if t.max() <= 1.0:
t_scaled = t * 999.0
else:
t_scaled = t
t_scaled = t_scaled.clamp(0, 999)
# Get time features
time_emb = self.time_embed(t_scaled) # [B, hidden_dim]
time_feat = self.time_mlp(time_emb) # [B, hidden_dim]
# Compute adaptive weights based on noise level
# Network learns: high t → high text weight; low t → high latent weight
weights = self.source_weighting(time_feat) # [B, 3]
# === Adaptive combination ===
combined = (
weights[:, 0:1] * text_feat + # Text contribution
weights[:, 1:2] * frame_feat + # Frame contribution
weights[:, 2:3] * latent_feat # Latent contribution
)
# Final routing decision (incorporate time context)
logits = self.router_head(combined + time_feat)
# Apply temperature scaling (lower temp = sharper routing)
if temperature != 1.0:
logits = logits / temperature
return logits
def get_source_weights(self, t: torch.Tensor) -> torch.Tensor:
"""
Get the learned source weights for given timesteps.
Useful for debugging and visualization.
Args:
t: Noise levels [B] in [0, 1] or [0, 999]
Returns:
weights: Source weights [B, 3] for [text, frame, latent]
"""
# Normalize timesteps
if t.max() <= 1.0:
t_scaled = t * 999.0
else:
t_scaled = t
t_scaled = t_scaled.clamp(0, 999)
time_emb = self.time_embed(t_scaled)
time_feat = self.time_mlp(time_emb)
weights = self.source_weighting(time_feat)
return weights
# =============================================================================
# MODEL FACTORY FUNCTIONS
# =============================================================================
def create_expert(config, expert_id: Optional[int] = None) -> nn.Module:
"""
Factory function to create expert model
Args:
config: Config object
expert_id: Optional expert ID for per-expert schedule/objective configuration
"""
# Make a copy of config to avoid modifying the original
import copy
config = copy.copy(config)
config.expert_params = config.expert_params.copy()
# Inject schedule_type into expert_params if not already present
if "schedule_type" not in config.expert_params:
# Check for per-expert schedule first (with backward compatibility)
if (hasattr(config, 'expert_schedule_types') and
config.expert_schedule_types and
expert_id is not None and
expert_id in config.expert_schedule_types):
config.expert_params["schedule_type"] = config.expert_schedule_types[expert_id]
else:
# Use default schedule_type (with fallback for old configs)
config.expert_params["schedule_type"] = getattr(config, 'schedule_type', 'linear_interp')
# Inject objective_type into expert_params if not already present
if "objective_type" not in config.expert_params:
# Check for per-expert objectives (with backward compatibility)
if (hasattr(config, 'expert_objectives') and
config.expert_objectives and
expert_id is not None and
expert_id in config.expert_objectives):
config.expert_params["objective_type"] = config.expert_objectives[expert_id]
else:
# Use default objective (with fallback for old configs)
config.expert_params["objective_type"] = getattr(config, 'default_objective', 'fm')
if config.expert_architecture == "unet":
return UNetExpert(config)
elif config.expert_architecture == "simple_cnn":
return SimpleCNNExpert(config)
elif config.expert_architecture == "dit":
return DiTExpert(config)
else:
raise ValueError(f"Unknown expert architecture: {config.expert_architecture}")
def create_router(config) -> Optional[nn.Module]:
"""Factory function to create router model"""
if config.router_architecture == "none" or config.is_monolithic:
return None
elif config.router_architecture == "deterministic_timestep":
return DeterministicTimestepRouter(config)
elif config.router_architecture == "vit":
return ViTRouter(config)
elif config.router_architecture == "cnn":
return CNNRouter(config)
elif config.router_architecture == "dit":
return DiTRouter(config)
elif config.router_architecture == "adaptive_video":
return AdaptiveVideoRouter(config)
else:
raise ValueError(f"Unknown router architecture: {config.router_architecture}")