# src/models.py import torch import torch.nn as nn import torch.nn.functional as F from diffusers import UNet2DModel from transformers import ViTForImageClassification, ViTConfig import math from typing import Optional, List import numpy as np # ============================================================================= # TIME EMBEDDING (shared utility) # ============================================================================= class TimeEmbedding(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: device = t.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = t[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class DiTTimestepEmbedder(nn.Module): def __init__(self, hidden_size, freq_dim=128, max_period=10000): super().__init__() self.freq_dim = freq_dim self.max_period = max_period self.mlp = nn.Sequential( nn.Linear(2*freq_dim, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) def forward(self, t): # t: [B] integers (float tensor ok) # standard "timestep_embedding" (like ADM/DiT) half = self.freq_dim device = t.device # positions in radians freqs = torch.exp( -torch.arange(half, device=device).float() * np.log(self.max_period) / half ) args = t.float()[:, None] * freqs[None] # [B, half] emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, 2*half] return self.mlp(emb) # ============================================================================= # OUTPUT CONVERTER (for heterogeneous objectives) # ============================================================================= class OutputConverter(nn.Module): def __init__(self, schedule_type: str = 'linear_interp', use_latents: bool = False, derivative_eps: float = 1e-4): super().__init__() from schedules import NoiseSchedule self.schedule = NoiseSchedule(schedule_type) self.schedule_type = schedule_type self.use_latents = use_latents self.derivative_eps = derivative_eps # For finite difference derivatives # Set clamping range based on data type # VAE latents have larger range than pixel-space images self.clamp_range = 20.0 if use_latents else 5.0 def _get_schedule_with_derivatives(self, t: torch.Tensor): """ Compute schedule coefficients and their derivatives. Essential for correct velocity computation with any schedule. """ # Get coefficients at current time alpha_t, sigma_t = self.schedule.get_schedule(t) # Compute derivatives using finite differences h = torch.full_like(t, self.derivative_eps) t_plus = (t + h).clamp(0.0, 1.0) t_minus = (t - h).clamp(0.0, 1.0) alpha_plus, sigma_plus = self.schedule.get_schedule(t_plus) alpha_minus, sigma_minus = self.schedule.get_schedule(t_minus) # Derivatives dt = (t_plus - t_minus).clamp(min=1e-6) d_alpha_dt = (alpha_plus - alpha_minus) / dt d_sigma_dt = (sigma_plus - sigma_minus) / dt return alpha_t, sigma_t, d_alpha_dt, d_sigma_dt def epsilon_to_velocity(self, epsilon_pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Correct ε→v conversion for ANY schedule using proper derivatives. From ODE: dx_t/dt = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε This is the TRUE velocity for the schedule! """ # Get schedule coefficients AND their derivatives alpha_t, sigma_t, d_alpha_dt, d_sigma_dt = self._get_schedule_with_derivatives(t) # Reshape for broadcasting alpha_t = alpha_t.view(-1, 1, 1, 1) sigma_t = sigma_t.view(-1, 1, 1, 1) d_alpha_dt = d_alpha_dt.view(-1, 1, 1, 1) d_sigma_dt = d_sigma_dt.view(-1, 1, 1, 1) # Numerical stability: handle small alpha_t alpha_safe = torch.clamp(alpha_t, min=0.01) # Step 1: Recover x_0 using Tweedie's formula x_0_pred = (x_t - sigma_t * epsilon_pred) / alpha_safe # Step 2: Clamp x_0 to reasonable range (prevents blow-up) # Use adaptive clamping: larger range for VAE latents, tighter for pixel space x_0_pred = torch.clamp(x_0_pred, -self.clamp_range, self.clamp_range) # Step 3: Compute velocity based on schedule type if self.schedule_type == 'linear_interp': # For linear interpolation: x_t = (1-t)*x_0 + t*ε # Velocity is simply: v = ε - x_0 v = epsilon_pred - x_0_pred else: # For cosine and other schedules: use proper derivatives # v = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε v = d_alpha_dt * x_0_pred + d_sigma_dt * epsilon_pred # Adaptive velocity scaling for cosine schedule # Derivatives vary dramatically with timestep - need adaptive dampening if self.schedule_type == 'cosine': t_val = t[0].item() if t.numel() > 0 else 0.5 if t_val > 0.85: # Very high noise: derivatives are large, need dampening scale = 0.88 elif t_val > 0.6: # Medium-high noise: moderate dampening scale = 0.93 else: # Low to medium noise: slight dampening scale = 0.96 v = v * scale # Per-channel bias correction to prevent color drift # The model has inherent channel bias that gets amplified by integration # Remove per-channel mean to prevent accumulation # Only apply to color channels (1,2,3), preserve luminance channel (0) for c in range(1, 4): v[:, c] = v[:, c] - v[:, c].mean() return v def convert(self, prediction: torch.Tensor, objective_type: str, x_t: torch.Tensor, t: torch.Tensor): """ Convert any prediction to velocity space. Args: prediction: expert output objective_type: 'ddpm' | 'fm' | 'rf' x_t: current noisy state t: current timesteps Returns: v: velocity representation """ if objective_type == "ddpm": # Proper ε→v conversion for unified integration return self.epsilon_to_velocity(prediction, x_t, t) elif objective_type in ["fm", "rf"]: return prediction # Already velocity else: raise ValueError(f"Unknown objective type: {objective_type}") # ============================================================================= # EXPERT MODELS # ============================================================================= class UNetExpert(nn.Module): """UNet expert using diffusers""" def __init__(self, config) -> None: super().__init__() # Default UNet params default_params = { "sample_size": config.image_size, "in_channels": config.num_channels, "out_channels": config.num_channels, "layers_per_block": 2, "block_out_channels": [64, 128, 256, 256], "attention_head_dim": 8, } # Override with config params params = {**default_params, **config.expert_params} # Store objective type for heterogeneous training (and remove from params) self.objective_type = params.pop("objective_type", "fm") # Store and initialize schedule (NEW) schedule_type = params.pop("schedule_type", "linear_interp") from schedules import NoiseSchedule self.schedule = NoiseSchedule(schedule_type) self.unet = UNet2DModel(**params) def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: # Scale timesteps for diffusers (expects 0-1000) # t_scaled = (t * 1000).long() t_scaled = (t * 999).round().long().clamp(0, 999) return self.unet(xt, t_scaled).sample def compute_loss(self, x0: torch.Tensor) -> torch.Tensor: """Unified loss computation based on objective type""" if self.objective_type == "ddpm": return self.ddpm_loss(x0) elif self.objective_type == "fm": return self.flow_matching_loss(x0) elif self.objective_type == "rf": return self.rectified_flow_loss(x0) else: raise ValueError(f"Unknown objective type: {self.objective_type}") def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor: """DDPM: predict noise ε""" batch_size = x0.shape[0] device = x0.device t = torch.rand(batch_size, device=device) # Use proper schedule (NEW) alpha_t, sigma_t = self.schedule.get_schedule(t) noise = torch.randn_like(x0) xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise pred_eps = self.forward(xt, t) return F.mse_loss(pred_eps, noise) def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor: """Rectified Flow: predict velocity v = x_1 - x_0""" batch_size = x0.shape[0] device = x0.device t = torch.rand(batch_size, device=device) x1 = torch.randn_like(x0) xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1 pred_v = self.forward(xt, t) true_v = x1 - x0 return F.mse_loss(pred_v, true_v) def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor: """Flow matching loss for training""" batch_size = x0.shape[0] device = x0.device # Sample random timesteps t = torch.rand(batch_size, device=device) # Use proper schedule (NEW) alpha_t, sigma_t = self.schedule.get_schedule(t) # Add noise noise = torch.randn_like(x0) xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise # Predict velocity pred_v = self.forward(xt, t) # True velocity for flow matching # true_v = x0 - xt true_v = noise - x0 return F.mse_loss(pred_v, true_v) class SimpleCNNExpert(nn.Module): """Simple CNN expert for fast training""" def __init__(self, config) -> None: super().__init__() # Default params default_params = { "hidden_dims": [64, 128, 256], "time_dim": 64, } params = {**default_params, **config.expert_params} # Store objective type for heterogeneous training self.objective_type = params.get("objective_type", "fm") # Store and initialize schedule (NEW) schedule_type = params.get("schedule_type", "linear_interp") from schedules import NoiseSchedule self.schedule = NoiseSchedule(schedule_type) self.time_embedding = TimeEmbedding(params["time_dim"]) self.target_size = config.image_size # Simple encoder-decoder self.encoder = self._build_encoder(config.num_channels, params["hidden_dims"]) self.decoder = self._build_decoder(params["hidden_dims"], config.num_channels) # Time conditioning self.time_mlp = nn.Sequential( nn.Linear(params["time_dim"], params["hidden_dims"][-1]), nn.SiLU(), nn.Linear(params["hidden_dims"][-1], params["hidden_dims"][-1]) ) def _build_encoder(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential: layers = [] prev_dim = in_channels for dim in hidden_dims: layers.extend([ nn.Conv2d(prev_dim, dim, 3, padding=1), nn.GroupNorm(8, dim), nn.SiLU(), nn.Conv2d(dim, dim, 3, padding=1), nn.GroupNorm(8, dim), nn.SiLU(), nn.MaxPool2d(2) ]) prev_dim = dim return nn.Sequential(*layers) def _build_decoder(self, hidden_dims: List[int], out_channels: int) -> nn.Sequential: layers = [] reversed_dims = list(reversed(hidden_dims)) for i, dim in enumerate(reversed_dims[:-1]): next_dim = reversed_dims[i + 1] layers.extend([ nn.ConvTranspose2d(dim, next_dim, 4, stride=2, padding=1), nn.GroupNorm(8, next_dim), nn.SiLU(), nn.Conv2d(next_dim, next_dim, 3, padding=1), nn.GroupNorm(8, next_dim), nn.SiLU(), ]) # Final layer layers.append(nn.Conv2d(reversed_dims[-1], out_channels, 3, padding=1)) return nn.Sequential(*layers) def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: # Time embedding time_emb = self.time_embedding(t) time_features = self.time_mlp(time_emb) # Encode encoded = self.encoder(xt) # Add time conditioning time_features = time_features.view(time_features.shape[0], -1, 1, 1) time_features = time_features.expand(-1, -1, encoded.shape[2], encoded.shape[3]) conditioned = encoded + time_features # Decode output = self.decoder(conditioned) # Ensure output matches target size output = F.interpolate(output, size=xt.shape[-2:], mode='bilinear', align_corners=False) return output def compute_loss(self, x0: torch.Tensor) -> torch.Tensor: """Unified loss computation based on objective type""" if self.objective_type == "ddpm": return self.ddpm_loss(x0) elif self.objective_type == "fm": return self.flow_matching_loss(x0) elif self.objective_type == "rf": return self.rectified_flow_loss(x0) else: raise ValueError(f"Unknown objective type: {self.objective_type}") def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor: """DDPM: predict noise ε""" batch_size = x0.shape[0] device = x0.device t = torch.rand(batch_size, device=device) # Use proper schedule (NEW) alpha_t, sigma_t = self.schedule.get_schedule(t) noise = torch.randn_like(x0) xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise pred_eps = self.forward(xt, t) # Ensure pred_eps matches noise shape if pred_eps.shape != noise.shape: pred_eps = F.interpolate(pred_eps, size=noise.shape[-2:], mode='bilinear', align_corners=False) return F.mse_loss(pred_eps, noise) def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor: """Rectified Flow: predict velocity v = x_1 - x_0""" batch_size = x0.shape[0] device = x0.device t = torch.rand(batch_size, device=device) x1 = torch.randn_like(x0) xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1 pred_v = self.forward(xt, t) true_v = x1 - x0 # Ensure pred_v matches true_v shape if pred_v.shape != true_v.shape: pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False) return F.mse_loss(pred_v, true_v) def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor: """Flow matching loss""" batch_size = x0.shape[0] device = x0.device t = torch.rand(batch_size, device=device) # Use proper schedule (NEW) alpha_t, sigma_t = self.schedule.get_schedule(t) noise = torch.randn_like(x0) xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise pred_v = self.forward(xt, t) # true_v = x0 - xt true_v = noise - x0 # Ensure pred_v matches true_v shape if pred_v.shape != true_v.shape: pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False) return F.mse_loss(pred_v, true_v) # Helper function from original DiT def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) # Fixed sin-cos position embedding from original def get_2d_sincos_pos_embed(embed_dim, grid_size): grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) assert embed_dim % 2 == 0 emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) emb = np.concatenate([emb_h, emb_w], axis=1) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2. omega = 1. / 10000**omega pos = pos.reshape(-1) out = np.einsum('m,d->md', pos, omega) emb_sin = np.sin(out) emb_cos = np.cos(out) emb = np.concatenate([emb_sin, emb_cos], axis=1) return emb # Timestep Embedder class TimestepEmbedder(nn.Module): def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): super().__init__() self.frequency_embedding_size = frequency_embedding_size self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t: torch.Tensor) -> torch.Tensor: t_freq = self.timestep_embedding(t, self.frequency_embedding_size) return self.mlp(t_freq) # DiTBlock with proper AdaLN-Zero class DiTBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, use_text: bool = False, use_adaln_single: bool = False): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim), nn.GELU(approximate="tanh"), # Match original nn.Linear(mlp_hidden_dim, hidden_size), ) # AdaLN modulation - either per-block MLP or AdaLN-Single embeddings self.use_adaln_single = use_adaln_single if use_adaln_single: # AdaLN-Single: use learnable per-block embeddings instead of MLP self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) self.adaLN_modulation = None # No MLP needed else: # Original AdaLN with per-block MLP self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.scale_shift_table = None # Optional text cross-attention self.use_text = use_text if use_text: # Note: PixArt uses xformers which may handle unnormalized queries differently # We add a simple norm for stability with PyTorch's MultiheadAttention self.norm_cross = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.cross_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True) def forward(self, x: torch.Tensor, c: torch.Tensor, text_emb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None): # Get modulation parameters if self.use_adaln_single: # AdaLN-Single: combine global time embedding with per-block parameters # c should be pre-computed from global t_block with shape [B, 6*hidden_size] B = x.shape[0] # Chunk and squeeze to get [B, hidden_size] tensors for compatibility with PyTorch's MultiheadAttention temp = (self.scale_shift_table[None] + c.reshape(B, 6, -1)).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp] else: # Original AdaLN: compute modulation from per-block MLP # Also squeeze after chunk to get [B, hidden_size] for consistency temp = self.adaLN_modulation(c).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp] # Self-attention with modulation # Both paths now use modulate function for consistency x_norm = modulate(self.norm1(x), shift_msa, scale_msa) attn_out, _ = self.attn(x_norm, x_norm, x_norm) x = x + gate_msa.unsqueeze(1) * attn_out # Optional cross-attention if self.use_text and text_emb is not None: if text_emb.dim() == 2: text_emb = text_emb.unsqueeze(1) # Convert attention mask to key_padding_mask format (True = ignore) # attention_mask: shape [B, T]; either bool (True=keep) or 0/1 numeric (1=keep) key_padding_mask = None if attention_mask is not None: if attention_mask.dtype is not torch.bool: # Convert 0/1 (or >=1) to bool keep-mask first keep_mask = attention_mask > 0 else: keep_mask = attention_mask # key_padding_mask semantics: True = ignore, False = keep key_padding_mask = ~keep_mask # logical NOT, not arithmetic subtraction # Normalize queries for stability (PixArt uses xformers which may differ) x_norm = self.norm_cross(x) cross_out, _ = self.cross_attn(x_norm, text_emb, text_emb, key_padding_mask=key_padding_mask) x = x + cross_out # MLP with modulation # Both paths now use modulate function for consistency x_norm = modulate(self.norm2(x), shift_mlp, scale_mlp) mlp_out = self.mlp(x_norm) x = x + gate_mlp.unsqueeze(1) * mlp_out return x # FinalLayer with AdaLN modulation class FinalLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x: torch.Tensor, c: torch.Tensor): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x # T2IFinalLayer with AdaLN-Single for parameter efficiency class T2IFinalLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) # AdaLN-Single: use learnable embeddings instead of MLP self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) self.hidden_size = hidden_size def forward(self, x: torch.Tensor, t: torch.Tensor): # t should be the original time embedding with shape [B, hidden_size] # Following PixArt implementation exactly shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) # shift and scale are [B, 1, hidden_size], use t2i_modulate style x = self.norm_final(x) * (1 + scale) + shift x = self.linear(x) return x # DiTExpert class DiTExpert(nn.Module): def __init__(self, config): super().__init__() default_params = { "hidden_size": 768, "num_layers": 12, "num_heads": 12, "patch_size": 2, "in_channels": 4, "out_channels": 4, "use_text_conditioning": False, "use_class_conditioning": False, "num_classes": 1000, # ImageNet classes "mlp_ratio": 4.0, "text_embed_dim": 768, "use_dit_time_embed": False, } params = {**default_params, **config.expert_params} self.patch_size = params["patch_size"] self.in_channels = params["in_channels"] self.out_channels = params["out_channels"] self.hidden_size = params["hidden_size"] self.num_heads = params["num_heads"] self.use_text = params.get("use_text_conditioning", False) self.use_class = params.get("use_class_conditioning", False) self.cfg_dropout_prob = params.get("cfg_dropout_prob", 0.1) # 10% dropout for CFG self.text_embed_dim = params.get("text_embed_dim", 768) self.use_adaln_single = params.get("use_adaln_single", False) # AdaLN-Single for parameter efficiency self.depth = params["num_layers"] # Store objective type for heterogeneous training self.objective_type = params.get("objective_type", "fm") # Store and initialize schedule (NEW) schedule_type = params.get("schedule_type", "linear_interp") from schedules import NoiseSchedule self.schedule = NoiseSchedule(schedule_type) # Validation: cannot use both text and class conditioning simultaneously assert not (self.use_text and self.use_class), "Cannot use both text and class conditioning simultaneously" # Patch embedding self.patch_embed = nn.Conv2d(self.in_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) # Fixed sin-cos positional embedding latent_size = getattr(config, 'image_size', 32) self.num_patches = (latent_size // self.patch_size) ** 2 self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.hidden_size), requires_grad=False) # Time embedding self.use_dit_time_embed = params.get("use_dit_time_embed", False) if self.use_dit_time_embed: self.time_embed = DiTTimestepEmbedder(self.hidden_size) else: self.time_embed = TimestepEmbedder(self.hidden_size) # Global time block for AdaLN-Single if self.use_adaln_single: self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True) ) # Optional text conditioning if self.use_text: self.text_proj = nn.Linear(self.text_embed_dim, self.hidden_size) self.text_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6) # Note: null text embedding will be provided by empty string encoding from CLIP # This is handled in the training loop, not as a learnable parameter # Optional class conditioning (ImageNet style) if self.use_class: # Add 1 extra embedding for null/unconditional class self.class_embed = nn.Embedding(params["num_classes"] + 1, self.hidden_size) self.null_class_id = params["num_classes"] # Use last index as null class # Transformer blocks self.layers = nn.ModuleList([ DiTBlock(self.hidden_size, self.num_heads, params.get("mlp_ratio", 4.0), self.use_text, use_adaln_single=self.use_adaln_single) for _ in range(self.depth) ]) # Final layer with modulation if self.use_adaln_single: self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, self.out_channels) else: self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels) # Initialize weights self.initialize_weights() def initialize_weights(self): # Initialize transformer layers def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize positional embedding with sin-cos grid_size = int(self.num_patches ** 0.5) pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear w = self.patch_embed.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) if self.patch_embed.bias is not None: nn.init.constant_(self.patch_embed.bias, 0) # Initialize timestep embedding MLP nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks (from DiT paper) for block in self.layers: if block.adaLN_modulation is not None: # Original AdaLN mode nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # AdaLN-Single mode: scale_shift_table is already initialized with randn/sqrt(hidden_size) # Zero-out cross-attention output projection (from PixArt-Alpha) if self.use_text and hasattr(block, 'cross_attn'): nn.init.constant_(block.cross_attn.out_proj.weight, 0) nn.init.constant_(block.cross_attn.out_proj.bias, 0) # Initialize text projection layer (analogous to PixArt's caption embedding) if self.use_text and hasattr(self, 'text_proj'): nn.init.normal_(self.text_proj.weight, std=0.02) if self.text_proj.bias is not None: nn.init.constant_(self.text_proj.bias, 0) # Initialize class embedding layer (similar to DiT paper) if self.use_class and hasattr(self, 'class_embed'): nn.init.normal_(self.class_embed.weight, std=0.02) # Initialize global t_block for AdaLN-Single if self.use_adaln_single and hasattr(self, 't_block'): nn.init.normal_(self.t_block[1].weight, std=0.02) # Zero-out t_block initially for stability nn.init.constant_(self.t_block[1].bias, 0) # Zero-out output layers if hasattr(self.final_layer, 'adaLN_modulation') and self.final_layer.adaLN_modulation is not None: # Original FinalLayer nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) # T2IFinalLayer scale_shift_table is already initialized with randn/sqrt(hidden_size) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def forward(self, xt: torch.Tensor, t: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: B, C, H, W = xt.shape # Handle timestep scaling - DiT expects timesteps in [0, 999] range # If t is normalized (in [0, 1]), scale it to [0, 999] if t.max() <= 1.0 and t.min() >= 0.0: # Normalized timesteps, scale to DiT range t = t * 999.0 # Ensure t is in correct range for DiT t = t.clamp(0, 999) # Patchify x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p] x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size] x = x + self.pos_embed # Add positional embedding # Prepare conditioning time_emb = self.time_embed(t) # [B, hidden_size] # Add class conditioning to time embedding if using class conditioning if self.use_class and class_labels is not None: class_emb = self.class_embed(class_labels) # [B, hidden_size] time_emb = time_emb + class_emb # Additive combination following DiT paper # Process conditioning based on AdaLN mode if self.use_adaln_single: # AdaLN-Single: compute global modulation once c = self.t_block(time_emb) # [B, 6*hidden_size] else: # Original AdaLN: pass time embedding to each block c = time_emb # Prepare text tokens for cross-attention (not fused with time) text_tokens = None if self.use_text and text_embeds is not None: if text_embeds.dim() == 3: text_tokens = self.text_proj(text_embeds) # [B, T, hidden_size] text_tokens = self.text_norm(text_tokens) else: text_tokens = self.text_proj(text_embeds).unsqueeze(1) # [B, 1, hidden_size] text_tokens = self.text_norm(text_tokens) if attention_mask is not None: # cast to bool, clamp shapes to text_tokens length attention_mask = attention_mask[:, :text_tokens.shape[1]].to(torch.bool) # safety: avoid all-false rows (would yield NaNs in softmax) all_false = attention_mask.sum(dim=1) == 0 if all_false.any(): attention_mask[all_false, 0] = True # Apply transformer blocks for layer in self.layers: x = layer(x, c, text_tokens, attention_mask) # Final projection if self.use_adaln_single: # T2IFinalLayer expects original time embedding, not global modulation x = self.final_layer(x, time_emb) # [B, num_patches, patch_size^2 * out_channels] else: # Original FinalLayer expects conditioning x = self.final_layer(x, c) # [B, num_patches, patch_size^2 * out_channels] # Unpatchify patch_h = patch_w = int(self.num_patches ** 0.5) x = x.view(B, patch_h, patch_w, self.patch_size, self.patch_size, self.out_channels) x = x.permute(0, 5, 1, 3, 2, 4).contiguous() x = x.view(B, self.out_channels, H, W) return x def compute_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Unified loss computation based on objective type""" if self.objective_type == "ddpm": return self.ddpm_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask) elif self.objective_type == "fm": return self.flow_matching_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask) elif self.objective_type == "rf": return self.rectified_flow_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask) else: raise ValueError(f"Unknown objective type: {self.objective_type}") def ddpm_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """DDPM: predict noise ε""" B = x0.shape[0] device = x0.device # Sample time uniformly t = torch.rand(B, device=device) # Use proper schedule (NEW) alpha_t, sigma_t = self.schedule.get_schedule(t) noise = torch.randn_like(x0) xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise # Apply CFG dropout during training if self.training and self.cfg_dropout_prob > 0: if self.use_text and text_embeds is not None: keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text if null_text_embeds is not None: # Use provided null text embeddings (from empty string CLIP encoding) if null_text_embeds.shape[0] == 1: null_text_embeds = null_text_embeds.expand(B, -1, -1) # Replace dropped samples with null text embeddings dropped = ~keep if dropped.any(): text_embeds = text_embeds.clone() text_embeds[dropped] = null_text_embeds[dropped] # Use provided null attention mask or create default for empty string if attention_mask is not None: attention_mask = attention_mask.clone() if null_attention_mask is not None: if null_attention_mask.shape[0] == 1: null_attention_mask = null_attention_mask.expand(B, -1) attention_mask[dropped] = null_attention_mask[dropped] else: attention_mask[dropped] = 0 attention_mask[dropped, 0] = 1 else: # Fallback to old zeroing approach if null_text_embeds not provided if text_embeds.dim() == 3: # [B, T, D] text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype) else: # [B, D] text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype) if attention_mask is not None: attention_mask = attention_mask.clone() dropped = ~keep if dropped.any(): attention_mask[dropped, 0] = 1 elif self.use_class and class_labels is not None: # Apply CFG dropout to class labels using null class embedding keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) null_class = torch.full_like(class_labels, self.null_class_id) class_labels = torch.where(keep, class_labels, null_class) # Predict noise pred_eps = self.forward(xt, t, text_embeds, attention_mask, class_labels) return F.mse_loss(pred_eps, noise) def rectified_flow_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Rectified Flow: predict velocity v = x_1 - x_0 (straight paths)""" B = x0.shape[0] device = x0.device # Sample time uniformly t = torch.rand(B, device=device) # Straight-line interpolation x1 = torch.randn_like(x0) # Gaussian noise as x_1 xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1 # Apply CFG dropout during training if self.training and self.cfg_dropout_prob > 0: if self.use_text and text_embeds is not None: keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text if null_text_embeds is not None: # Use provided null text embeddings (from empty string CLIP encoding) if null_text_embeds.shape[0] == 1: null_text_embeds = null_text_embeds.expand(B, -1, -1) # Replace dropped samples with null text embeddings dropped = ~keep if dropped.any(): text_embeds = text_embeds.clone() text_embeds[dropped] = null_text_embeds[dropped] # Use provided null attention mask or create default for empty string if attention_mask is not None: attention_mask = attention_mask.clone() if null_attention_mask is not None: if null_attention_mask.shape[0] == 1: null_attention_mask = null_attention_mask.expand(B, -1) attention_mask[dropped] = null_attention_mask[dropped] else: attention_mask[dropped] = 0 attention_mask[dropped, 0] = 1 else: # Fallback to old zeroing approach if null_text_embeds not provided if text_embeds.dim() == 3: # [B, T, D] text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype) else: # [B, D] text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype) if attention_mask is not None: attention_mask = attention_mask.clone() dropped = ~keep if dropped.any(): attention_mask[dropped, 0] = 1 elif self.use_class and class_labels is not None: # Apply CFG dropout to class labels using null class embedding keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) null_class = torch.full_like(class_labels, self.null_class_id) class_labels = torch.where(keep, class_labels, null_class) # Predict velocity (x_1 - x_0) pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels) true_v = x1 - x0 return F.mse_loss(pred_v, true_v) def flow_matching_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Flow matching loss for latent space training with CFG dropout.""" B = x0.shape[0] device = x0.device # Sample time uniformly t = torch.rand(B, device=device) # Use proper schedule (NEW) alpha_t, sigma_t = self.schedule.get_schedule(t) noise = torch.randn_like(x0) xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise # Apply CFG dropout during training if self.training and self.cfg_dropout_prob > 0: if self.use_text and text_embeds is not None: keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text if null_text_embeds is not None: # Use provided null text embeddings (from empty string CLIP encoding) # Ensure null_text_embeds matches the batch size if null_text_embeds.shape[0] == 1: null_text_embeds = null_text_embeds.expand(B, -1, -1) # Replace dropped samples with null text embeddings dropped = ~keep if dropped.any(): text_embeds = text_embeds.clone() text_embeds[dropped] = null_text_embeds[dropped] # Use provided null attention mask or create default for empty string if attention_mask is not None: attention_mask = attention_mask.clone() if null_attention_mask is not None: # Ensure null_attention_mask matches batch size if null_attention_mask.shape[0] == 1: null_attention_mask = null_attention_mask.expand(B, -1) attention_mask[dropped] = null_attention_mask[dropped] else: # Default: For null text (empty string), typically only the first token is valid attention_mask[dropped] = 0 attention_mask[dropped, 0] = 1 # Keep only first token for empty string else: # Fallback to old zeroing approach if null_text_embeds not provided if text_embeds.dim() == 3: # [B, T, D] text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype) else: # [B, D] text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype) # Handle attention mask for fallback approach if attention_mask is not None: attention_mask = attention_mask.clone() dropped = ~keep if dropped.any(): attention_mask[dropped, 0] = 1 elif self.use_class and class_labels is not None: # Apply CFG dropout to class labels using null class embedding keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep class # Use the dedicated null class embedding for unconditional generation null_class = torch.full_like(class_labels, self.null_class_id) class_labels = torch.where(keep, class_labels, null_class) # Predict velocity pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels) true_v = noise - x0 return F.mse_loss(pred_v, true_v) # ============================================================================= # ROUTER MODELS # ============================================================================= class ViTRouter(nn.Module): """ViT-based router for cluster classification""" def __init__(self, config) -> None: super().__init__() # Default params default_params = { "hidden_size": 384, "num_layers": 6, "num_heads": 6, "patch_size": 8, "use_dit_time_embed": False, # Whether to use DiT-style time embedding } params = {**default_params, **config.router_params} if config.router_pretrained: # Use pretrained ViT and adapt self.vit = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224" ) self._adapt_pretrained(config, params) else: # Build from scratch vit_config = ViTConfig( image_size=config.image_size, patch_size=params["patch_size"], num_channels=config.num_channels, hidden_size=params["hidden_size"], num_hidden_layers=params["num_layers"], num_attention_heads=params["num_heads"], num_labels=config.num_clusters ) self.vit = ViTForImageClassification(vit_config) # Time conditioning - support both embedding styles self.use_dit_time_embed = params.get("use_dit_time_embed", False) if self.use_dit_time_embed: # Use DiT-style timestep embedding for consistency self.time_embedding = DiTTimestepEmbedder(params["hidden_size"]) else: # Original simple time embedding self.time_embedding = nn.Sequential( nn.Linear(1, params["hidden_size"]), nn.SiLU(), nn.Linear(params["hidden_size"], params["hidden_size"]) ) # Combined classifier self.classifier = nn.Sequential( nn.Linear(params["hidden_size"] * 2, params["hidden_size"]), nn.ReLU(), nn.Dropout(0.1), nn.Linear(params["hidden_size"], config.num_clusters) ) def _adapt_pretrained(self, config, params) -> ViTForImageClassification: """Adapt pretrained ViT for our task""" # Modify patch embeddings if needed if config.image_size != 224 or config.num_channels != 3: self.vit.vit.embeddings.patch_embeddings.projection = nn.Conv2d( config.num_channels, self.vit.config.hidden_size, kernel_size=params["patch_size"], stride=params["patch_size"] ) def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor: # Process image through ViT vit_outputs = self.vit.vit(xt) image_features = vit_outputs.last_hidden_state[:, 0] # CLS token # Time conditioning if self.use_dit_time_embed: # DiT embedder expects raw timesteps time_features = self.time_embedding(t) else: # Original embedding needs unsqueeze time_features = self.time_embedding(t.unsqueeze(-1)) # Combine and classify combined = torch.cat([image_features, time_features], dim=1) return self.classifier(combined) class CNNRouter(nn.Module): """Simple CNN router for cluster classification""" def __init__(self, config) -> None: super().__init__() # Default params default_params = { "hidden_dims": [64, 128, 256], "use_dit_time_embed": False, # Whether to use DiT-style time embedding } params = {**default_params, **config.router_params} # CNN backbone self.backbone = self._build_cnn(config.num_channels, params["hidden_dims"]) # Time embedding - support both styles self.use_dit_time_embed = params.get("use_dit_time_embed", False) if self.use_dit_time_embed: # Use DiT-style timestep embedding, output to 128 dims for CNN self.time_embedding = DiTTimestepEmbedder(128) else: # Original simple time embedding self.time_embedding = nn.Sequential( nn.Linear(1, 128), nn.SiLU(), nn.Linear(128, 128) ) # Classifier self.classifier = nn.Sequential( nn.Linear(params["hidden_dims"][-1] + 128, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, config.num_clusters) ) def _build_cnn(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential: layers = [] prev_dim = in_channels for dim in hidden_dims: layers.extend([ nn.Conv2d(prev_dim, dim, 3, padding=1), nn.ReLU(), nn.Conv2d(dim, dim, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ]) prev_dim = dim layers.append(nn.AdaptiveAvgPool2d(1)) layers.append(nn.Flatten()) return nn.Sequential(*layers) def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor: # CNN features img_features = self.backbone(xt) # Time features if self.use_dit_time_embed: # DiT embedder expects raw timesteps time_features = self.time_embedding(t) else: # Original embedding needs unsqueeze time_features = self.time_embedding(t.unsqueeze(-1)) # Combine and classify combined = torch.cat([img_features, time_features], dim=1) return self.classifier(combined) class DiTRouter(nn.Module): """DiT B/2 router for cluster classification""" def __init__(self, config): super().__init__() # DiT B/2 specifications default_params = { "hidden_size": 768, # DiT-B uses 768 "num_layers": 12, # DiT-B uses 12 layers "num_heads": 12, # DiT-B uses 12 heads "patch_size": 2, # For latent space (32x32 -> 16x16 patches) "in_channels": 4, # VAE latent channels "mlp_ratio": 4.0, "use_dit_time_embed": False, # Whether to use DiT-style time embedding } params = {**default_params, **config.router_params} self.patch_size = params["patch_size"] self.in_channels = params["in_channels"] self.hidden_size = params["hidden_size"] self.num_heads = params["num_heads"] self.num_clusters = config.num_clusters # Patch embedding (same as expert) self.patch_embed = nn.Conv2d( self.in_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size ) # Calculate number of patches latent_size = getattr(config, 'image_size', 32) # Assuming 256/8=32 for VAE self.num_patches = (latent_size // self.patch_size) ** 2 # Fixed sin-cos positional embedding (same as expert) self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches, self.hidden_size), requires_grad=False ) # CLS token (KEY ADDITION from paper) self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) # Time embedding - match expert's choice self.use_dit_time_embed = params.get("use_dit_time_embed", False) if self.use_dit_time_embed: self.time_embed = DiTTimestepEmbedder(self.hidden_size) else: self.time_embed = TimestepEmbedder(self.hidden_size) # DiT blocks with AdaLN (reuse DiTBlock from expert) # Note: Router doesn't need text conditioning self.layers = nn.ModuleList([ DiTBlock(self.hidden_size, self.num_heads, params["mlp_ratio"], use_text=False) for _ in range(params["num_layers"]) ]) # Final layer norm self.norm_final = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6) # Linear classifier on CLS token (as specified in paper) # self.head = nn.Linear(self.hidden_size, self.num_clusters) self.head = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.LayerNorm(self.hidden_size), nn.Dropout(0.1), nn.Linear(self.hidden_size, self.num_clusters) ) # Initialize weights self.initialize_weights() def initialize_weights(self): # Initialize transformer layers def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize CLS token nn.init.normal_(self.cls_token, std=0.02) # Initialize positional embedding with sin-cos (same as expert) grid_size = int(self.num_patches ** 0.5) pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear w = self.patch_embed.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) if self.patch_embed.bias is not None: nn.init.constant_(self.patch_embed.bias, 0) # Initialize timestep embedding MLP if hasattr(self.time_embed, 'mlp'): nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) # Zero-out adaLN modulation in blocks (following expert initialization) for block in self.layers: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # # Initialize classification head (simpler version for classification head) # nn.init.constant_(self.head.weight, 0) # nn.init.constant_(self.head.bias, 0) # Initialize classification head (Sequential) # Initialize intermediate layers normally, zero-out final layer nn.init.normal_(self.head[0].weight, std=0.02) # First linear layer if self.head[0].bias is not None: nn.init.constant_(self.head[0].bias, 0) # Zero-out final classification layer (following DiT paper) nn.init.constant_(self.head[-1].weight, 0) # Last linear layer if self.head[-1].bias is not None: nn.init.constant_(self.head[-1].bias, 0) def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor: B, C, H, W = xt.shape # Match expert's timestep interpretation if t.max() <= 1.0 and t.min() >= 0.0: t = t * 999.0 t = t.clamp(0, 999) # Patchify x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p] x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size] # Add positional embedding x = x + self.pos_embed # Prepend CLS token cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, hidden_size] x = torch.cat([cls_tokens, x], dim=1) # [B, 1 + num_patches, hidden_size] # Time conditioning c = self.time_embed(t) # [B, hidden_size] # Apply DiT blocks with AdaLN modulation for layer in self.layers: x = layer(x, c, text_emb=None) # Extract CLS token and apply final norm cls_output = x[:, 0] # [B, hidden_size] cls_output = self.norm_final(cls_output) # Linear classification head logits = self.head(cls_output) # [B, num_clusters] return logits # ============================================================================= # DETERMINISTIC ROUTER (for controlled experiments) # ============================================================================= class DeterministicTimestepRouter(nn.Module): """ Deterministic router that assigns experts based on timestep. Useful for controlled experiments where you want to test specific routing strategies, such as: "high noise → DDPM expert, low noise → FM expert" Args: config: Config object with router_params containing: - timestep_threshold: t value to switch experts (default: 0.5) - high_noise_expert: Expert ID for t > threshold (default: 0, typically DDPM) - low_noise_expert: Expert ID for t <= threshold (default: 1, typically FM) Example config: router_architecture: "deterministic_timestep" router_params: timestep_threshold: 0.5 high_noise_expert: 0 # DDPM for high noise low_noise_expert: 1 # FM for low noise """ def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.threshold = config.router_params.get('timestep_threshold', 0.5) self.high_noise_expert = config.router_params.get('high_noise_expert', 0) self.low_noise_expert = config.router_params.get('low_noise_expert', 1) # Validate expert IDs assert 0 <= self.high_noise_expert < self.num_experts, \ f"high_noise_expert {self.high_noise_expert} out of range [0, {self.num_experts})" assert 0 <= self.low_noise_expert < self.num_experts, \ f"low_noise_expert {self.low_noise_expert} out of range [0, {self.num_experts})" # Validate threshold assert 0.0 <= self.threshold <= 1.0, \ f"timestep_threshold {self.threshold} must be in [0, 1]" # This router has no trainable parameters # Register threshold as buffer (not trained, but saved with model) self.register_buffer('_threshold', torch.tensor(self.threshold)) print(f"DeterministicTimestepRouter initialized:") print(f" Threshold: {self.threshold}") print(f" High noise (t > {self.threshold}) → Expert {self.high_noise_expert}") print(f" Low noise (t <= {self.threshold}) → Expert {self.low_noise_expert}") def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: """ Returns one-hot routing probabilities based on timestep. Args: x: Input tensor (unused, but kept for API compatibility with other routers) t: Timesteps, shape (B,) Returns: routing_probs: Shape (B, num_experts), one-hot encoded """ B = t.shape[0] device = t.device # Initialize routing probabilities (all zeros) routing_probs = torch.zeros(B, self.num_experts, device=device) # High noise (t > threshold) → high_noise_expert # Low noise (t <= threshold) → low_noise_expert high_noise_mask = t > self.threshold routing_probs[high_noise_mask, self.high_noise_expert] = 1.0 routing_probs[~high_noise_mask, self.low_noise_expert] = 1.0 return routing_probs def train(self, mode: bool = True): """Override train() - this router is never trained, always in eval mode""" return super(DeterministicTimestepRouter, self).train(False) # ============================================================================= # ADAPTIVE VIDEO ROUTER (for Video DDM) # ============================================================================= class AdaptiveVideoRouter(nn.Module): """ Time-adaptive router for video DDM. Key innovation: Learns optimal weighting of information sources at each noise level, solving the "motion invisible at t=1" problem. Information availability is time-dependent: t ~ 1.0: Only text/first_frame informative → Route on conditioning t ~ 0.5: Structure emerging → Latent becomes useful t ~ 0.1: Near clean → Full information available Expected learned behavior: | Noise Level | Text | Frame | Latent | Behavior | |-------------|------|-------|--------|-----------------------------| | t ~ 1.0 | ~0.7 | ~0.2 | ~0.1 | Routes on text semantics | | t ~ 0.5 | ~0.4 | ~0.3 | ~0.3 | Balanced; emerging structure| | t ~ 0.1 | ~0.2 | ~0.2 | ~0.6 | Trusts latent; fine-grained | Enhancements: - Masked mean pooling for text (handles variable-length prompts) - Temporal-aware latent encoder (captures motion patterns) - Temperature scaling for inference control """ def __init__(self, config): super().__init__() # Default params default_params = { "hidden_dim": 512, "text_embed_dim": 768, # CLIP-L text embedding dimension "frame_embed_dim": 768, # DINOv2-B (base) feature dimension "latent_channels": 16, # VAE latent channels (CogVideoX uses 16) "latent_conv_dim": 64, # Intermediate conv channels for latent encoder "dropout": 0.1, "temporal_pool_mode": "attention", # "attention", "avg", or "max" "normalize_inputs": True, # L2-normalize text/frame inputs (match clustering) } params = {**default_params, **getattr(config, 'router_params', {})} self.hidden_dim = params["hidden_dim"] self.num_experts = getattr(config, 'num_experts', config.num_clusters) self.latent_channels = params["latent_channels"] self.latent_conv_dim = params["latent_conv_dim"] self.temporal_pool_mode = params["temporal_pool_mode"] self.normalize_inputs = params.get("normalize_inputs", True) # === Information Source Encoders === # Text pathway (always available, primary signal at high t) self.text_encoder = nn.Sequential( nn.Linear(params["text_embed_dim"], self.hidden_dim), nn.LayerNorm(self.hidden_dim), nn.GELU(), nn.Linear(self.hidden_dim, self.hidden_dim) ) # First frame pathway (available for I2V tasks) # Uses DINOv2 features extracted from the first frame self.frame_encoder = nn.Sequential( nn.Linear(params["frame_embed_dim"], self.hidden_dim), nn.LayerNorm(self.hidden_dim), nn.GELU(), nn.Linear(self.hidden_dim, self.hidden_dim) ) # === Temporal-Aware Latent Encoder === # Captures both spatial content and temporal motion patterns # Spatial feature extraction (per-frame) self.spatial_conv = nn.Sequential( nn.Conv3d(params["latent_channels"], params["latent_conv_dim"], kernel_size=(1, 3, 3), padding=(0, 1, 1)), # Spatial only nn.GroupNorm(8, params["latent_conv_dim"]), nn.GELU(), ) # Temporal feature extraction (motion patterns) self.temporal_conv = nn.Sequential( nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"], kernel_size=(3, 1, 1), padding=(1, 0, 0)), # Temporal only nn.GroupNorm(8, params["latent_conv_dim"]), nn.GELU(), ) # Combined spatio-temporal processing self.st_conv = nn.Sequential( nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"], kernel_size=3, padding=1), # Full 3D nn.GroupNorm(8, params["latent_conv_dim"]), nn.GELU(), ) # Spatial pooling (keep temporal dimension) self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) # [B, C, T, 1, 1] # Temporal attention pooling (learns which frames matter for routing) if self.temporal_pool_mode == "attention": self.temporal_attn = nn.Sequential( nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"] // 4), nn.Tanh(), nn.Linear(params["latent_conv_dim"] // 4, 1), ) # Motion feature extractor (frame differences) self.motion_encoder = nn.Sequential( nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"]), nn.GELU(), nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2), ) # Content feature projector self.content_proj = nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2) # Final latent projection (combines content + motion) self.latent_proj = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.LayerNorm(self.hidden_dim), ) # === Time-Dependent Weighting === # Time embedding using existing infrastructure self.time_embed = TimestepEmbedder(self.hidden_dim) self.time_mlp = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), nn.Linear(self.hidden_dim, self.hidden_dim) ) # Learns adaptive weighting: at high t → trust text; at low t → trust latent self.source_weighting = nn.Sequential( nn.Linear(self.hidden_dim, 128), nn.GELU(), nn.Linear(128, 3), # [text, frame, latent] weights nn.Softmax(dim=-1) ) # === Routing Head === self.router_head = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), nn.LayerNorm(self.hidden_dim), nn.Dropout(params["dropout"]), nn.Linear(self.hidden_dim, self.num_experts) ) # Initialize weights self.initialize_weights() def initialize_weights(self): """Initialize weights following DiT conventions.""" def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.Conv3d): # Flatten spatial dims for xavier init w = module.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize timestep embedding MLP (following DiT) if hasattr(self.time_embed, 'mlp'): nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) # Small non-zero initialization for final routing layer # (pure zeros cause uniform outputs that break temperature scaling) nn.init.normal_(self.router_head[-1].weight, std=0.01) nn.init.constant_(self.router_head[-1].bias, 0) # Initialize source weighting to start roughly uniform # The softmax will make [0, 0, 0] → [0.33, 0.33, 0.33] nn.init.constant_(self.source_weighting[-2].weight, 0) nn.init.constant_(self.source_weighting[-2].bias, 0) # Initialize temporal attention to uniform attention if self.temporal_pool_mode == "attention": nn.init.constant_(self.temporal_attn[-1].weight, 0) nn.init.constant_(self.temporal_attn[-1].bias, 0) def _masked_mean_pool(self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute masked mean pooling over sequence dimension. Args: embeddings: [B, seq_len, embed_dim] - Token embeddings attention_mask: [B, seq_len] - 1 for real tokens, 0 for padding Returns: pooled: [B, embed_dim] - Pooled representation """ if attention_mask is None: # No mask provided, use simple mean return embeddings.mean(dim=1) # Expand mask for broadcasting: [B, seq_len] -> [B, seq_len, 1] mask = attention_mask.unsqueeze(-1).to(embeddings.dtype) # Masked sum masked_sum = (embeddings * mask).sum(dim=1) # [B, embed_dim] # Count of valid tokens (avoid division by zero) token_counts = mask.sum(dim=1).clamp(min=1.0) # [B, 1] return masked_sum / token_counts def _encode_latent_temporal(self, x_t: torch.Tensor) -> torch.Tensor: """ Encode video latent with temporal awareness. Extracts both: - Content features: What is in the video (spatial) - Motion features: How things move (temporal differences) Args: x_t: [B, C, T, H, W] - Noisy video latent Returns: latent_feat: [B, hidden_dim] - Combined latent features """ B, C, T, H, W = x_t.shape # 1. Spatial feature extraction spatial_feat = self.spatial_conv(x_t) # [B, conv_dim, T, H, W] # 2. Temporal feature extraction (captures local motion) temporal_feat = self.temporal_conv(spatial_feat) # [B, conv_dim, T, H, W] # 3. Combined spatio-temporal processing st_feat = self.st_conv(temporal_feat) # [B, conv_dim, T, H, W] # 4. Pool spatially, keep temporal: [B, conv_dim, T, 1, 1] -> [B, T, conv_dim] pooled = self.spatial_pool(st_feat).squeeze(-1).squeeze(-1) # [B, conv_dim, T] pooled = pooled.permute(0, 2, 1) # [B, T, conv_dim] # 5. Temporal pooling with optional attention if self.temporal_pool_mode == "attention" and T > 1: # Learn which frames matter for routing attn_scores = self.temporal_attn(pooled).squeeze(-1) # [B, T] attn_weights = F.softmax(attn_scores, dim=-1) # [B, T] content_feat = (pooled * attn_weights.unsqueeze(-1)).sum(dim=1) # [B, conv_dim] elif self.temporal_pool_mode == "max": content_feat = pooled.max(dim=1)[0] # [B, conv_dim] else: # "avg" content_feat = pooled.mean(dim=1) # [B, conv_dim] # 6. Extract motion features (frame differences) if T > 1: # Compute frame-to-frame differences frame_diffs = pooled[:, 1:] - pooled[:, :-1] # [B, T-1, conv_dim] # Motion magnitude and direction encoding motion_feat = self.motion_encoder(frame_diffs.mean(dim=1)) # [B, hidden_dim//2] else: # Single frame, no motion motion_feat = torch.zeros(B, self.hidden_dim // 2, device=x_t.device) # 7. Project content features content_proj = self.content_proj(content_feat) # [B, hidden_dim//2] # 8. Combine content + motion combined = torch.cat([content_proj, motion_feat], dim=-1) # [B, hidden_dim] latent_feat = self.latent_proj(combined) # [B, hidden_dim] return latent_feat def forward(self, x_t: torch.Tensor, t: torch.Tensor, text_embed: torch.Tensor, first_frame_feat: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temperature: float = 1.0) -> torch.Tensor: """ Compute routing logits with time-adaptive information weighting. Args: x_t: Noisy video latent [B, C, T, H, W] t: Noise level [B] in [0, 1] or [0, 999] text_embed: CLIP text embedding [B, text_embed_dim] or [B, seq_len, text_embed_dim] first_frame_feat: Optional DINOv2 features [B, frame_embed_dim] attention_mask: Optional [B, seq_len] mask for text (1=valid, 0=padding) temperature: Softmax temperature for sharper/softer routing (default: 1.0) Returns: logits: Expert selection logits [B, num_experts] (scaled by temperature) """ B = x_t.shape[0] device = x_t.device # === Encode each information source === # Handle both pooled [B, D] and sequence [B, seq_len, D] text embeddings if text_embed.dim() == 3: # Use masked mean pooling for sequence embeddings text_embed_pooled = self._masked_mean_pool(text_embed, attention_mask) else: # Already pooled text_embed_pooled = text_embed # L2-normalize inputs to match clustering preprocessing if self.normalize_inputs: text_embed_pooled = F.normalize(text_embed_pooled, p=2, dim=-1) text_feat = self.text_encoder(text_embed_pooled) # [B, hidden_dim] # Frame features (optional for T2V, required for I2V) if first_frame_feat is not None: # L2-normalize to match clustering preprocessing if self.normalize_inputs: first_frame_feat = F.normalize(first_frame_feat, p=2, dim=-1) frame_feat = self.frame_encoder(first_frame_feat) # [B, hidden_dim] else: frame_feat = torch.zeros(B, self.hidden_dim, device=device) # Latent features from noisy video (temporal-aware encoding) latent_feat = self._encode_latent_temporal(x_t) # [B, hidden_dim] # === Time-dependent weighting === # Normalize timesteps to [0, 999] for TimestepEmbedder if t.max() <= 1.0: t_scaled = t * 999.0 else: t_scaled = t t_scaled = t_scaled.clamp(0, 999) # Get time features time_emb = self.time_embed(t_scaled) # [B, hidden_dim] time_feat = self.time_mlp(time_emb) # [B, hidden_dim] # Compute adaptive weights based on noise level # Network learns: high t → high text weight; low t → high latent weight weights = self.source_weighting(time_feat) # [B, 3] # === Adaptive combination === combined = ( weights[:, 0:1] * text_feat + # Text contribution weights[:, 1:2] * frame_feat + # Frame contribution weights[:, 2:3] * latent_feat # Latent contribution ) # Final routing decision (incorporate time context) logits = self.router_head(combined + time_feat) # Apply temperature scaling (lower temp = sharper routing) if temperature != 1.0: logits = logits / temperature return logits def get_source_weights(self, t: torch.Tensor) -> torch.Tensor: """ Get the learned source weights for given timesteps. Useful for debugging and visualization. Args: t: Noise levels [B] in [0, 1] or [0, 999] Returns: weights: Source weights [B, 3] for [text, frame, latent] """ # Normalize timesteps if t.max() <= 1.0: t_scaled = t * 999.0 else: t_scaled = t t_scaled = t_scaled.clamp(0, 999) time_emb = self.time_embed(t_scaled) time_feat = self.time_mlp(time_emb) weights = self.source_weighting(time_feat) return weights # ============================================================================= # MODEL FACTORY FUNCTIONS # ============================================================================= def create_expert(config, expert_id: Optional[int] = None) -> nn.Module: """ Factory function to create expert model Args: config: Config object expert_id: Optional expert ID for per-expert schedule/objective configuration """ # Make a copy of config to avoid modifying the original import copy config = copy.copy(config) config.expert_params = config.expert_params.copy() # Inject schedule_type into expert_params if not already present if "schedule_type" not in config.expert_params: # Check for per-expert schedule first (with backward compatibility) if (hasattr(config, 'expert_schedule_types') and config.expert_schedule_types and expert_id is not None and expert_id in config.expert_schedule_types): config.expert_params["schedule_type"] = config.expert_schedule_types[expert_id] else: # Use default schedule_type (with fallback for old configs) config.expert_params["schedule_type"] = getattr(config, 'schedule_type', 'linear_interp') # Inject objective_type into expert_params if not already present if "objective_type" not in config.expert_params: # Check for per-expert objectives (with backward compatibility) if (hasattr(config, 'expert_objectives') and config.expert_objectives and expert_id is not None and expert_id in config.expert_objectives): config.expert_params["objective_type"] = config.expert_objectives[expert_id] else: # Use default objective (with fallback for old configs) config.expert_params["objective_type"] = getattr(config, 'default_objective', 'fm') if config.expert_architecture == "unet": return UNetExpert(config) elif config.expert_architecture == "simple_cnn": return SimpleCNNExpert(config) elif config.expert_architecture == "dit": return DiTExpert(config) else: raise ValueError(f"Unknown expert architecture: {config.expert_architecture}") def create_router(config) -> Optional[nn.Module]: """Factory function to create router model""" if config.router_architecture == "none" or config.is_monolithic: return None elif config.router_architecture == "deterministic_timestep": return DeterministicTimestepRouter(config) elif config.router_architecture == "vit": return ViTRouter(config) elif config.router_architecture == "cnn": return CNNRouter(config) elif config.router_architecture == "dit": return DiTRouter(config) elif config.router_architecture == "adaptive_video": return AdaptiveVideoRouter(config) else: raise ValueError(f"Unknown router architecture: {config.router_architecture}")