|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers import UNet2DModel |
|
|
from transformers import ViTForImageClassification, ViTConfig |
|
|
import math |
|
|
from typing import Optional, List |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
|
def __init__(self, dim: int) -> None: |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor: |
|
|
device = t.device |
|
|
half_dim = self.dim // 2 |
|
|
embeddings = math.log(10000) / (half_dim - 1) |
|
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
|
embeddings = t[:, None] * embeddings[None, :] |
|
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
|
|
return embeddings |
|
|
|
|
|
class DiTTimestepEmbedder(nn.Module): |
|
|
def __init__(self, hidden_size, freq_dim=128, max_period=10000): |
|
|
super().__init__() |
|
|
self.freq_dim = freq_dim |
|
|
self.max_period = max_period |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(2*freq_dim, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
|
) |
|
|
def forward(self, t): |
|
|
|
|
|
half = self.freq_dim |
|
|
device = t.device |
|
|
|
|
|
freqs = torch.exp( |
|
|
-torch.arange(half, device=device).float() * np.log(self.max_period) / half |
|
|
) |
|
|
args = t.float()[:, None] * freqs[None] |
|
|
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
return self.mlp(emb) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OutputConverter(nn.Module): |
|
|
def __init__(self, schedule_type: str = 'linear_interp', use_latents: bool = False, derivative_eps: float = 1e-4): |
|
|
super().__init__() |
|
|
from schedules import NoiseSchedule |
|
|
self.schedule = NoiseSchedule(schedule_type) |
|
|
self.schedule_type = schedule_type |
|
|
self.use_latents = use_latents |
|
|
self.derivative_eps = derivative_eps |
|
|
|
|
|
|
|
|
|
|
|
self.clamp_range = 20.0 if use_latents else 5.0 |
|
|
|
|
|
def _get_schedule_with_derivatives(self, t: torch.Tensor): |
|
|
""" |
|
|
Compute schedule coefficients and their derivatives. |
|
|
Essential for correct velocity computation with any schedule. |
|
|
""" |
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
|
|
|
h = torch.full_like(t, self.derivative_eps) |
|
|
t_plus = (t + h).clamp(0.0, 1.0) |
|
|
t_minus = (t - h).clamp(0.0, 1.0) |
|
|
|
|
|
alpha_plus, sigma_plus = self.schedule.get_schedule(t_plus) |
|
|
alpha_minus, sigma_minus = self.schedule.get_schedule(t_minus) |
|
|
|
|
|
|
|
|
dt = (t_plus - t_minus).clamp(min=1e-6) |
|
|
d_alpha_dt = (alpha_plus - alpha_minus) / dt |
|
|
d_sigma_dt = (sigma_plus - sigma_minus) / dt |
|
|
|
|
|
return alpha_t, sigma_t, d_alpha_dt, d_sigma_dt |
|
|
|
|
|
def epsilon_to_velocity(self, epsilon_pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Correct ε→v conversion for ANY schedule using proper derivatives. |
|
|
|
|
|
From ODE: dx_t/dt = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε |
|
|
This is the TRUE velocity for the schedule! |
|
|
""" |
|
|
|
|
|
alpha_t, sigma_t, d_alpha_dt, d_sigma_dt = self._get_schedule_with_derivatives(t) |
|
|
|
|
|
|
|
|
alpha_t = alpha_t.view(-1, 1, 1, 1) |
|
|
sigma_t = sigma_t.view(-1, 1, 1, 1) |
|
|
d_alpha_dt = d_alpha_dt.view(-1, 1, 1, 1) |
|
|
d_sigma_dt = d_sigma_dt.view(-1, 1, 1, 1) |
|
|
|
|
|
|
|
|
alpha_safe = torch.clamp(alpha_t, min=0.01) |
|
|
|
|
|
|
|
|
x_0_pred = (x_t - sigma_t * epsilon_pred) / alpha_safe |
|
|
|
|
|
|
|
|
|
|
|
x_0_pred = torch.clamp(x_0_pred, -self.clamp_range, self.clamp_range) |
|
|
|
|
|
|
|
|
if self.schedule_type == 'linear_interp': |
|
|
|
|
|
|
|
|
v = epsilon_pred - x_0_pred |
|
|
else: |
|
|
|
|
|
|
|
|
v = d_alpha_dt * x_0_pred + d_sigma_dt * epsilon_pred |
|
|
|
|
|
|
|
|
|
|
|
if self.schedule_type == 'cosine': |
|
|
t_val = t[0].item() if t.numel() > 0 else 0.5 |
|
|
if t_val > 0.85: |
|
|
|
|
|
scale = 0.88 |
|
|
elif t_val > 0.6: |
|
|
|
|
|
scale = 0.93 |
|
|
else: |
|
|
|
|
|
scale = 0.96 |
|
|
v = v * scale |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for c in range(1, 4): |
|
|
v[:, c] = v[:, c] - v[:, c].mean() |
|
|
|
|
|
return v |
|
|
|
|
|
def convert(self, prediction: torch.Tensor, objective_type: str, x_t: torch.Tensor, t: torch.Tensor): |
|
|
""" |
|
|
Convert any prediction to velocity space. |
|
|
|
|
|
Args: |
|
|
prediction: expert output |
|
|
objective_type: 'ddpm' | 'fm' | 'rf' |
|
|
x_t: current noisy state |
|
|
t: current timesteps |
|
|
|
|
|
Returns: |
|
|
v: velocity representation |
|
|
""" |
|
|
if objective_type == "ddpm": |
|
|
|
|
|
return self.epsilon_to_velocity(prediction, x_t, t) |
|
|
elif objective_type in ["fm", "rf"]: |
|
|
return prediction |
|
|
else: |
|
|
raise ValueError(f"Unknown objective type: {objective_type}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UNetExpert(nn.Module): |
|
|
"""UNet expert using diffusers""" |
|
|
|
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
default_params = { |
|
|
"sample_size": config.image_size, |
|
|
"in_channels": config.num_channels, |
|
|
"out_channels": config.num_channels, |
|
|
"layers_per_block": 2, |
|
|
"block_out_channels": [64, 128, 256, 256], |
|
|
"attention_head_dim": 8, |
|
|
} |
|
|
|
|
|
|
|
|
params = {**default_params, **config.expert_params} |
|
|
|
|
|
|
|
|
self.objective_type = params.pop("objective_type", "fm") |
|
|
|
|
|
|
|
|
schedule_type = params.pop("schedule_type", "linear_interp") |
|
|
from schedules import NoiseSchedule |
|
|
self.schedule = NoiseSchedule(schedule_type) |
|
|
|
|
|
self.unet = UNet2DModel(**params) |
|
|
|
|
|
def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
|
|
|
|
|
|
t_scaled = (t * 999).round().long().clamp(0, 999) |
|
|
return self.unet(xt, t_scaled).sample |
|
|
|
|
|
def compute_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""Unified loss computation based on objective type""" |
|
|
if self.objective_type == "ddpm": |
|
|
return self.ddpm_loss(x0) |
|
|
elif self.objective_type == "fm": |
|
|
return self.flow_matching_loss(x0) |
|
|
elif self.objective_type == "rf": |
|
|
return self.rectified_flow_loss(x0) |
|
|
else: |
|
|
raise ValueError(f"Unknown objective type: {self.objective_type}") |
|
|
|
|
|
def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""DDPM: predict noise ε""" |
|
|
batch_size = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
t = torch.rand(batch_size, device=device) |
|
|
|
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
noise = torch.randn_like(x0) |
|
|
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
pred_eps = self.forward(xt, t) |
|
|
return F.mse_loss(pred_eps, noise) |
|
|
|
|
|
def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""Rectified Flow: predict velocity v = x_1 - x_0""" |
|
|
batch_size = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
t = torch.rand(batch_size, device=device) |
|
|
x1 = torch.randn_like(x0) |
|
|
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1 |
|
|
|
|
|
pred_v = self.forward(xt, t) |
|
|
true_v = x1 - x0 |
|
|
return F.mse_loss(pred_v, true_v) |
|
|
|
|
|
def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""Flow matching loss for training""" |
|
|
batch_size = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
|
|
|
t = torch.rand(batch_size, device=device) |
|
|
|
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
|
|
|
noise = torch.randn_like(x0) |
|
|
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
|
|
|
pred_v = self.forward(xt, t) |
|
|
|
|
|
|
|
|
|
|
|
true_v = noise - x0 |
|
|
|
|
|
return F.mse_loss(pred_v, true_v) |
|
|
|
|
|
class SimpleCNNExpert(nn.Module): |
|
|
"""Simple CNN expert for fast training""" |
|
|
|
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
default_params = { |
|
|
"hidden_dims": [64, 128, 256], |
|
|
"time_dim": 64, |
|
|
} |
|
|
params = {**default_params, **config.expert_params} |
|
|
|
|
|
|
|
|
self.objective_type = params.get("objective_type", "fm") |
|
|
|
|
|
|
|
|
schedule_type = params.get("schedule_type", "linear_interp") |
|
|
from schedules import NoiseSchedule |
|
|
self.schedule = NoiseSchedule(schedule_type) |
|
|
|
|
|
self.time_embedding = TimeEmbedding(params["time_dim"]) |
|
|
self.target_size = config.image_size |
|
|
|
|
|
|
|
|
self.encoder = self._build_encoder(config.num_channels, params["hidden_dims"]) |
|
|
self.decoder = self._build_decoder(params["hidden_dims"], config.num_channels) |
|
|
|
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
|
nn.Linear(params["time_dim"], params["hidden_dims"][-1]), |
|
|
nn.SiLU(), |
|
|
nn.Linear(params["hidden_dims"][-1], params["hidden_dims"][-1]) |
|
|
) |
|
|
|
|
|
def _build_encoder(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential: |
|
|
layers = [] |
|
|
prev_dim = in_channels |
|
|
|
|
|
for dim in hidden_dims: |
|
|
layers.extend([ |
|
|
nn.Conv2d(prev_dim, dim, 3, padding=1), |
|
|
nn.GroupNorm(8, dim), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(dim, dim, 3, padding=1), |
|
|
nn.GroupNorm(8, dim), |
|
|
nn.SiLU(), |
|
|
nn.MaxPool2d(2) |
|
|
]) |
|
|
prev_dim = dim |
|
|
|
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def _build_decoder(self, hidden_dims: List[int], out_channels: int) -> nn.Sequential: |
|
|
layers = [] |
|
|
reversed_dims = list(reversed(hidden_dims)) |
|
|
|
|
|
for i, dim in enumerate(reversed_dims[:-1]): |
|
|
next_dim = reversed_dims[i + 1] |
|
|
layers.extend([ |
|
|
nn.ConvTranspose2d(dim, next_dim, 4, stride=2, padding=1), |
|
|
nn.GroupNorm(8, next_dim), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(next_dim, next_dim, 3, padding=1), |
|
|
nn.GroupNorm(8, next_dim), |
|
|
nn.SiLU(), |
|
|
]) |
|
|
|
|
|
|
|
|
layers.append(nn.Conv2d(reversed_dims[-1], out_channels, 3, padding=1)) |
|
|
|
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
|
|
|
time_emb = self.time_embedding(t) |
|
|
time_features = self.time_mlp(time_emb) |
|
|
|
|
|
|
|
|
encoded = self.encoder(xt) |
|
|
|
|
|
|
|
|
time_features = time_features.view(time_features.shape[0], -1, 1, 1) |
|
|
time_features = time_features.expand(-1, -1, encoded.shape[2], encoded.shape[3]) |
|
|
conditioned = encoded + time_features |
|
|
|
|
|
|
|
|
output = self.decoder(conditioned) |
|
|
|
|
|
|
|
|
output = F.interpolate(output, size=xt.shape[-2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
return output |
|
|
|
|
|
def compute_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""Unified loss computation based on objective type""" |
|
|
if self.objective_type == "ddpm": |
|
|
return self.ddpm_loss(x0) |
|
|
elif self.objective_type == "fm": |
|
|
return self.flow_matching_loss(x0) |
|
|
elif self.objective_type == "rf": |
|
|
return self.rectified_flow_loss(x0) |
|
|
else: |
|
|
raise ValueError(f"Unknown objective type: {self.objective_type}") |
|
|
|
|
|
def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""DDPM: predict noise ε""" |
|
|
batch_size = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
t = torch.rand(batch_size, device=device) |
|
|
|
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
noise = torch.randn_like(x0) |
|
|
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
pred_eps = self.forward(xt, t) |
|
|
|
|
|
|
|
|
if pred_eps.shape != noise.shape: |
|
|
pred_eps = F.interpolate(pred_eps, size=noise.shape[-2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
return F.mse_loss(pred_eps, noise) |
|
|
|
|
|
def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""Rectified Flow: predict velocity v = x_1 - x_0""" |
|
|
batch_size = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
t = torch.rand(batch_size, device=device) |
|
|
x1 = torch.randn_like(x0) |
|
|
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1 |
|
|
|
|
|
pred_v = self.forward(xt, t) |
|
|
true_v = x1 - x0 |
|
|
|
|
|
|
|
|
if pred_v.shape != true_v.shape: |
|
|
pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
return F.mse_loss(pred_v, true_v) |
|
|
|
|
|
def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor: |
|
|
"""Flow matching loss""" |
|
|
batch_size = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
t = torch.rand(batch_size, device=device) |
|
|
|
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
noise = torch.randn_like(x0) |
|
|
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
pred_v = self.forward(xt, t) |
|
|
|
|
|
true_v = noise - x0 |
|
|
|
|
|
|
|
|
if pred_v.shape != true_v.shape: |
|
|
pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
return F.mse_loss(pred_v, true_v) |
|
|
|
|
|
|
|
|
def modulate(x, shift, scale): |
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size): |
|
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
|
grid = np.meshgrid(grid_w, grid_h) |
|
|
grid = np.stack(grid, axis=0) |
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
|
|
|
|
assert embed_dim % 2 == 0 |
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
|
assert embed_dim % 2 == 0 |
|
|
omega = np.arange(embed_dim // 2, dtype=np.float64) |
|
|
omega /= embed_dim / 2. |
|
|
omega = 1. / 10000**omega |
|
|
pos = pos.reshape(-1) |
|
|
out = np.einsum('m,d->md', pos, omega) |
|
|
emb_sin = np.sin(out) |
|
|
emb_cos = np.cos(out) |
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
|
return emb |
|
|
|
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): |
|
|
super().__init__() |
|
|
self.frequency_embedding_size = frequency_embedding_size |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def timestep_embedding(t, dim, max_period=10000): |
|
|
half = dim // 2 |
|
|
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half) |
|
|
args = t[:, None].float() * freqs[None] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
if dim % 2: |
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
|
return embedding |
|
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor: |
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
|
|
return self.mlp(t_freq) |
|
|
|
|
|
|
|
|
class DiTBlock(nn.Module): |
|
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, use_text: bool = False, use_adaln_single: bool = False): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
self.attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True) |
|
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(hidden_size, mlp_hidden_dim), |
|
|
nn.GELU(approximate="tanh"), |
|
|
nn.Linear(mlp_hidden_dim, hidden_size), |
|
|
) |
|
|
|
|
|
|
|
|
self.use_adaln_single = use_adaln_single |
|
|
if use_adaln_single: |
|
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) |
|
|
self.adaLN_modulation = None |
|
|
else: |
|
|
|
|
|
self.adaLN_modulation = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
|
) |
|
|
self.scale_shift_table = None |
|
|
|
|
|
|
|
|
self.use_text = use_text |
|
|
if use_text: |
|
|
|
|
|
|
|
|
self.norm_cross = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
self.cross_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True) |
|
|
|
|
|
def forward(self, x: torch.Tensor, c: torch.Tensor, text_emb: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None): |
|
|
|
|
|
if self.use_adaln_single: |
|
|
|
|
|
|
|
|
B = x.shape[0] |
|
|
|
|
|
temp = (self.scale_shift_table[None] + c.reshape(B, 6, -1)).chunk(6, dim=1) |
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp] |
|
|
else: |
|
|
|
|
|
|
|
|
temp = self.adaLN_modulation(c).chunk(6, dim=1) |
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp] |
|
|
|
|
|
|
|
|
|
|
|
x_norm = modulate(self.norm1(x), shift_msa, scale_msa) |
|
|
attn_out, _ = self.attn(x_norm, x_norm, x_norm) |
|
|
x = x + gate_msa.unsqueeze(1) * attn_out |
|
|
|
|
|
|
|
|
if self.use_text and text_emb is not None: |
|
|
if text_emb.dim() == 2: |
|
|
text_emb = text_emb.unsqueeze(1) |
|
|
|
|
|
|
|
|
key_padding_mask = None |
|
|
if attention_mask is not None: |
|
|
if attention_mask.dtype is not torch.bool: |
|
|
|
|
|
keep_mask = attention_mask > 0 |
|
|
else: |
|
|
keep_mask = attention_mask |
|
|
|
|
|
key_padding_mask = ~keep_mask |
|
|
|
|
|
|
|
|
x_norm = self.norm_cross(x) |
|
|
cross_out, _ = self.cross_attn(x_norm, text_emb, text_emb, key_padding_mask=key_padding_mask) |
|
|
x = x + cross_out |
|
|
|
|
|
|
|
|
|
|
|
x_norm = modulate(self.norm2(x), shift_mlp, scale_mlp) |
|
|
mlp_out = self.mlp(x_norm) |
|
|
x = x + gate_mlp.unsqueeze(1) * mlp_out |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class FinalLayer(nn.Module): |
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int): |
|
|
super().__init__() |
|
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
|
|
self.adaLN_modulation = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, c: torch.Tensor): |
|
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) |
|
|
x = modulate(self.norm_final(x), shift, scale) |
|
|
x = self.linear(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class T2IFinalLayer(nn.Module): |
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int): |
|
|
super().__init__() |
|
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
|
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor): |
|
|
|
|
|
|
|
|
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) |
|
|
|
|
|
x = self.norm_final(x) * (1 + scale) + shift |
|
|
x = self.linear(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DiTExpert(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
default_params = { |
|
|
"hidden_size": 768, |
|
|
"num_layers": 12, |
|
|
"num_heads": 12, |
|
|
"patch_size": 2, |
|
|
"in_channels": 4, |
|
|
"out_channels": 4, |
|
|
"use_text_conditioning": False, |
|
|
"use_class_conditioning": False, |
|
|
"num_classes": 1000, |
|
|
"mlp_ratio": 4.0, |
|
|
"text_embed_dim": 768, |
|
|
"use_dit_time_embed": False, |
|
|
} |
|
|
params = {**default_params, **config.expert_params} |
|
|
|
|
|
self.patch_size = params["patch_size"] |
|
|
self.in_channels = params["in_channels"] |
|
|
self.out_channels = params["out_channels"] |
|
|
self.hidden_size = params["hidden_size"] |
|
|
self.num_heads = params["num_heads"] |
|
|
self.use_text = params.get("use_text_conditioning", False) |
|
|
self.use_class = params.get("use_class_conditioning", False) |
|
|
self.cfg_dropout_prob = params.get("cfg_dropout_prob", 0.1) |
|
|
self.text_embed_dim = params.get("text_embed_dim", 768) |
|
|
self.use_adaln_single = params.get("use_adaln_single", False) |
|
|
self.depth = params["num_layers"] |
|
|
|
|
|
|
|
|
self.objective_type = params.get("objective_type", "fm") |
|
|
|
|
|
|
|
|
schedule_type = params.get("schedule_type", "linear_interp") |
|
|
from schedules import NoiseSchedule |
|
|
self.schedule = NoiseSchedule(schedule_type) |
|
|
|
|
|
|
|
|
assert not (self.use_text and self.use_class), "Cannot use both text and class conditioning simultaneously" |
|
|
|
|
|
|
|
|
self.patch_embed = nn.Conv2d(self.in_channels, self.hidden_size, |
|
|
kernel_size=self.patch_size, stride=self.patch_size) |
|
|
|
|
|
|
|
|
latent_size = getattr(config, 'image_size', 32) |
|
|
self.num_patches = (latent_size // self.patch_size) ** 2 |
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.hidden_size), requires_grad=False) |
|
|
|
|
|
|
|
|
self.use_dit_time_embed = params.get("use_dit_time_embed", False) |
|
|
if self.use_dit_time_embed: |
|
|
self.time_embed = DiTTimestepEmbedder(self.hidden_size) |
|
|
else: |
|
|
self.time_embed = TimestepEmbedder(self.hidden_size) |
|
|
|
|
|
|
|
|
if self.use_adaln_single: |
|
|
self.t_block = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True) |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_text: |
|
|
self.text_proj = nn.Linear(self.text_embed_dim, self.hidden_size) |
|
|
self.text_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_class: |
|
|
|
|
|
self.class_embed = nn.Embedding(params["num_classes"] + 1, self.hidden_size) |
|
|
self.null_class_id = params["num_classes"] |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
DiTBlock(self.hidden_size, self.num_heads, params.get("mlp_ratio", 4.0), |
|
|
self.use_text, use_adaln_single=self.use_adaln_single) |
|
|
for _ in range(self.depth) |
|
|
]) |
|
|
|
|
|
|
|
|
if self.use_adaln_single: |
|
|
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, self.out_channels) |
|
|
else: |
|
|
self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels) |
|
|
|
|
|
|
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
def _basic_init(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
grid_size = int(self.num_patches ** 0.5) |
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size) |
|
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
|
|
|
w = self.patch_embed.weight.data |
|
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
if self.patch_embed.bias is not None: |
|
|
nn.init.constant_(self.patch_embed.bias, 0) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) |
|
|
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
|
|
|
for block in self.layers: |
|
|
if block.adaLN_modulation is not None: |
|
|
|
|
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
if self.use_text and hasattr(block, 'cross_attn'): |
|
|
nn.init.constant_(block.cross_attn.out_proj.weight, 0) |
|
|
nn.init.constant_(block.cross_attn.out_proj.bias, 0) |
|
|
|
|
|
|
|
|
if self.use_text and hasattr(self, 'text_proj'): |
|
|
nn.init.normal_(self.text_proj.weight, std=0.02) |
|
|
if self.text_proj.bias is not None: |
|
|
nn.init.constant_(self.text_proj.bias, 0) |
|
|
|
|
|
|
|
|
if self.use_class and hasattr(self, 'class_embed'): |
|
|
nn.init.normal_(self.class_embed.weight, std=0.02) |
|
|
|
|
|
|
|
|
if self.use_adaln_single and hasattr(self, 't_block'): |
|
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.constant_(self.t_block[1].bias, 0) |
|
|
|
|
|
|
|
|
if hasattr(self.final_layer, 'adaLN_modulation') and self.final_layer.adaLN_modulation is not None: |
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
|
|
def forward(self, xt: torch.Tensor, t: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: |
|
|
B, C, H, W = xt.shape |
|
|
|
|
|
|
|
|
|
|
|
if t.max() <= 1.0 and t.min() >= 0.0: |
|
|
|
|
|
t = t * 999.0 |
|
|
|
|
|
t = t.clamp(0, 999) |
|
|
|
|
|
|
|
|
x = self.patch_embed(xt) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = x + self.pos_embed |
|
|
|
|
|
|
|
|
time_emb = self.time_embed(t) |
|
|
|
|
|
|
|
|
if self.use_class and class_labels is not None: |
|
|
class_emb = self.class_embed(class_labels) |
|
|
time_emb = time_emb + class_emb |
|
|
|
|
|
|
|
|
if self.use_adaln_single: |
|
|
|
|
|
c = self.t_block(time_emb) |
|
|
else: |
|
|
|
|
|
c = time_emb |
|
|
|
|
|
|
|
|
text_tokens = None |
|
|
if self.use_text and text_embeds is not None: |
|
|
if text_embeds.dim() == 3: |
|
|
text_tokens = self.text_proj(text_embeds) |
|
|
text_tokens = self.text_norm(text_tokens) |
|
|
else: |
|
|
text_tokens = self.text_proj(text_embeds).unsqueeze(1) |
|
|
text_tokens = self.text_norm(text_tokens) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
attention_mask = attention_mask[:, :text_tokens.shape[1]].to(torch.bool) |
|
|
|
|
|
all_false = attention_mask.sum(dim=1) == 0 |
|
|
if all_false.any(): |
|
|
attention_mask[all_false, 0] = True |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, c, text_tokens, attention_mask) |
|
|
|
|
|
|
|
|
if self.use_adaln_single: |
|
|
|
|
|
x = self.final_layer(x, time_emb) |
|
|
else: |
|
|
|
|
|
x = self.final_layer(x, c) |
|
|
|
|
|
|
|
|
patch_h = patch_w = int(self.num_patches ** 0.5) |
|
|
x = x.view(B, patch_h, patch_w, self.patch_size, self.patch_size, self.out_channels) |
|
|
x = x.permute(0, 5, 1, 3, 2, 4).contiguous() |
|
|
x = x.view(B, self.out_channels, H, W) |
|
|
|
|
|
return x |
|
|
|
|
|
def compute_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, |
|
|
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
"""Unified loss computation based on objective type""" |
|
|
if self.objective_type == "ddpm": |
|
|
return self.ddpm_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask) |
|
|
elif self.objective_type == "fm": |
|
|
return self.flow_matching_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask) |
|
|
elif self.objective_type == "rf": |
|
|
return self.rectified_flow_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask) |
|
|
else: |
|
|
raise ValueError(f"Unknown objective type: {self.objective_type}") |
|
|
|
|
|
def ddpm_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, |
|
|
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
"""DDPM: predict noise ε""" |
|
|
B = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
|
|
|
t = torch.rand(B, device=device) |
|
|
|
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
noise = torch.randn_like(x0) |
|
|
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
|
|
|
if self.training and self.cfg_dropout_prob > 0: |
|
|
if self.use_text and text_embeds is not None: |
|
|
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) |
|
|
|
|
|
if null_text_embeds is not None: |
|
|
|
|
|
if null_text_embeds.shape[0] == 1: |
|
|
null_text_embeds = null_text_embeds.expand(B, -1, -1) |
|
|
|
|
|
|
|
|
dropped = ~keep |
|
|
if dropped.any(): |
|
|
text_embeds = text_embeds.clone() |
|
|
text_embeds[dropped] = null_text_embeds[dropped] |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.clone() |
|
|
if null_attention_mask is not None: |
|
|
if null_attention_mask.shape[0] == 1: |
|
|
null_attention_mask = null_attention_mask.expand(B, -1) |
|
|
attention_mask[dropped] = null_attention_mask[dropped] |
|
|
else: |
|
|
attention_mask[dropped] = 0 |
|
|
attention_mask[dropped, 0] = 1 |
|
|
else: |
|
|
|
|
|
if text_embeds.dim() == 3: |
|
|
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype) |
|
|
else: |
|
|
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.clone() |
|
|
dropped = ~keep |
|
|
if dropped.any(): |
|
|
attention_mask[dropped, 0] = 1 |
|
|
|
|
|
elif self.use_class and class_labels is not None: |
|
|
|
|
|
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) |
|
|
null_class = torch.full_like(class_labels, self.null_class_id) |
|
|
class_labels = torch.where(keep, class_labels, null_class) |
|
|
|
|
|
|
|
|
pred_eps = self.forward(xt, t, text_embeds, attention_mask, class_labels) |
|
|
|
|
|
return F.mse_loss(pred_eps, noise) |
|
|
|
|
|
def rectified_flow_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, |
|
|
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
"""Rectified Flow: predict velocity v = x_1 - x_0 (straight paths)""" |
|
|
B = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
|
|
|
t = torch.rand(B, device=device) |
|
|
|
|
|
|
|
|
x1 = torch.randn_like(x0) |
|
|
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1 |
|
|
|
|
|
|
|
|
if self.training and self.cfg_dropout_prob > 0: |
|
|
if self.use_text and text_embeds is not None: |
|
|
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) |
|
|
|
|
|
if null_text_embeds is not None: |
|
|
|
|
|
if null_text_embeds.shape[0] == 1: |
|
|
null_text_embeds = null_text_embeds.expand(B, -1, -1) |
|
|
|
|
|
|
|
|
dropped = ~keep |
|
|
if dropped.any(): |
|
|
text_embeds = text_embeds.clone() |
|
|
text_embeds[dropped] = null_text_embeds[dropped] |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.clone() |
|
|
if null_attention_mask is not None: |
|
|
if null_attention_mask.shape[0] == 1: |
|
|
null_attention_mask = null_attention_mask.expand(B, -1) |
|
|
attention_mask[dropped] = null_attention_mask[dropped] |
|
|
else: |
|
|
attention_mask[dropped] = 0 |
|
|
attention_mask[dropped, 0] = 1 |
|
|
else: |
|
|
|
|
|
if text_embeds.dim() == 3: |
|
|
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype) |
|
|
else: |
|
|
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.clone() |
|
|
dropped = ~keep |
|
|
if dropped.any(): |
|
|
attention_mask[dropped, 0] = 1 |
|
|
|
|
|
elif self.use_class and class_labels is not None: |
|
|
|
|
|
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) |
|
|
null_class = torch.full_like(class_labels, self.null_class_id) |
|
|
class_labels = torch.where(keep, class_labels, null_class) |
|
|
|
|
|
|
|
|
pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels) |
|
|
true_v = x1 - x0 |
|
|
|
|
|
return F.mse_loss(pred_v, true_v) |
|
|
|
|
|
def flow_matching_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, |
|
|
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
"""Flow matching loss for latent space training with CFG dropout.""" |
|
|
B = x0.shape[0] |
|
|
device = x0.device |
|
|
|
|
|
|
|
|
t = torch.rand(B, device=device) |
|
|
|
|
|
|
|
|
alpha_t, sigma_t = self.schedule.get_schedule(t) |
|
|
|
|
|
noise = torch.randn_like(x0) |
|
|
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise |
|
|
|
|
|
|
|
|
if self.training and self.cfg_dropout_prob > 0: |
|
|
if self.use_text and text_embeds is not None: |
|
|
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) |
|
|
|
|
|
if null_text_embeds is not None: |
|
|
|
|
|
|
|
|
if null_text_embeds.shape[0] == 1: |
|
|
null_text_embeds = null_text_embeds.expand(B, -1, -1) |
|
|
|
|
|
|
|
|
dropped = ~keep |
|
|
if dropped.any(): |
|
|
text_embeds = text_embeds.clone() |
|
|
text_embeds[dropped] = null_text_embeds[dropped] |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.clone() |
|
|
if null_attention_mask is not None: |
|
|
|
|
|
if null_attention_mask.shape[0] == 1: |
|
|
null_attention_mask = null_attention_mask.expand(B, -1) |
|
|
attention_mask[dropped] = null_attention_mask[dropped] |
|
|
else: |
|
|
|
|
|
attention_mask[dropped] = 0 |
|
|
attention_mask[dropped, 0] = 1 |
|
|
else: |
|
|
|
|
|
if text_embeds.dim() == 3: |
|
|
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype) |
|
|
else: |
|
|
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.clone() |
|
|
dropped = ~keep |
|
|
if dropped.any(): |
|
|
attention_mask[dropped, 0] = 1 |
|
|
|
|
|
elif self.use_class and class_labels is not None: |
|
|
|
|
|
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) |
|
|
|
|
|
null_class = torch.full_like(class_labels, self.null_class_id) |
|
|
class_labels = torch.where(keep, class_labels, null_class) |
|
|
|
|
|
|
|
|
pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels) |
|
|
true_v = noise - x0 |
|
|
|
|
|
return F.mse_loss(pred_v, true_v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ViTRouter(nn.Module): |
|
|
"""ViT-based router for cluster classification""" |
|
|
|
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
default_params = { |
|
|
"hidden_size": 384, |
|
|
"num_layers": 6, |
|
|
"num_heads": 6, |
|
|
"patch_size": 8, |
|
|
"use_dit_time_embed": False, |
|
|
} |
|
|
params = {**default_params, **config.router_params} |
|
|
|
|
|
if config.router_pretrained: |
|
|
|
|
|
self.vit = ViTForImageClassification.from_pretrained( |
|
|
"google/vit-base-patch16-224" |
|
|
) |
|
|
self._adapt_pretrained(config, params) |
|
|
else: |
|
|
|
|
|
vit_config = ViTConfig( |
|
|
image_size=config.image_size, |
|
|
patch_size=params["patch_size"], |
|
|
num_channels=config.num_channels, |
|
|
hidden_size=params["hidden_size"], |
|
|
num_hidden_layers=params["num_layers"], |
|
|
num_attention_heads=params["num_heads"], |
|
|
num_labels=config.num_clusters |
|
|
) |
|
|
self.vit = ViTForImageClassification(vit_config) |
|
|
|
|
|
|
|
|
self.use_dit_time_embed = params.get("use_dit_time_embed", False) |
|
|
if self.use_dit_time_embed: |
|
|
|
|
|
self.time_embedding = DiTTimestepEmbedder(params["hidden_size"]) |
|
|
else: |
|
|
|
|
|
self.time_embedding = nn.Sequential( |
|
|
nn.Linear(1, params["hidden_size"]), |
|
|
nn.SiLU(), |
|
|
nn.Linear(params["hidden_size"], params["hidden_size"]) |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(params["hidden_size"] * 2, params["hidden_size"]), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(params["hidden_size"], config.num_clusters) |
|
|
) |
|
|
|
|
|
def _adapt_pretrained(self, config, params) -> ViTForImageClassification: |
|
|
"""Adapt pretrained ViT for our task""" |
|
|
|
|
|
if config.image_size != 224 or config.num_channels != 3: |
|
|
self.vit.vit.embeddings.patch_embeddings.projection = nn.Conv2d( |
|
|
config.num_channels, |
|
|
self.vit.config.hidden_size, |
|
|
kernel_size=params["patch_size"], |
|
|
stride=params["patch_size"] |
|
|
) |
|
|
|
|
|
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
vit_outputs = self.vit.vit(xt) |
|
|
image_features = vit_outputs.last_hidden_state[:, 0] |
|
|
|
|
|
|
|
|
if self.use_dit_time_embed: |
|
|
|
|
|
time_features = self.time_embedding(t) |
|
|
else: |
|
|
|
|
|
time_features = self.time_embedding(t.unsqueeze(-1)) |
|
|
|
|
|
|
|
|
combined = torch.cat([image_features, time_features], dim=1) |
|
|
return self.classifier(combined) |
|
|
|
|
|
class CNNRouter(nn.Module): |
|
|
"""Simple CNN router for cluster classification""" |
|
|
|
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
default_params = { |
|
|
"hidden_dims": [64, 128, 256], |
|
|
"use_dit_time_embed": False, |
|
|
} |
|
|
params = {**default_params, **config.router_params} |
|
|
|
|
|
|
|
|
self.backbone = self._build_cnn(config.num_channels, params["hidden_dims"]) |
|
|
|
|
|
|
|
|
self.use_dit_time_embed = params.get("use_dit_time_embed", False) |
|
|
if self.use_dit_time_embed: |
|
|
|
|
|
self.time_embedding = DiTTimestepEmbedder(128) |
|
|
else: |
|
|
|
|
|
self.time_embedding = nn.Sequential( |
|
|
nn.Linear(1, 128), |
|
|
nn.SiLU(), |
|
|
nn.Linear(128, 128) |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(params["hidden_dims"][-1] + 128, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, config.num_clusters) |
|
|
) |
|
|
|
|
|
def _build_cnn(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential: |
|
|
layers = [] |
|
|
prev_dim = in_channels |
|
|
|
|
|
for dim in hidden_dims: |
|
|
layers.extend([ |
|
|
nn.Conv2d(prev_dim, dim, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(dim, dim, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2) |
|
|
]) |
|
|
prev_dim = dim |
|
|
|
|
|
layers.append(nn.AdaptiveAvgPool2d(1)) |
|
|
layers.append(nn.Flatten()) |
|
|
|
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
img_features = self.backbone(xt) |
|
|
|
|
|
|
|
|
if self.use_dit_time_embed: |
|
|
|
|
|
time_features = self.time_embedding(t) |
|
|
else: |
|
|
|
|
|
time_features = self.time_embedding(t.unsqueeze(-1)) |
|
|
|
|
|
|
|
|
combined = torch.cat([img_features, time_features], dim=1) |
|
|
return self.classifier(combined) |
|
|
|
|
|
class DiTRouter(nn.Module): |
|
|
"""DiT B/2 router for cluster classification""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
default_params = { |
|
|
"hidden_size": 768, |
|
|
"num_layers": 12, |
|
|
"num_heads": 12, |
|
|
"patch_size": 2, |
|
|
"in_channels": 4, |
|
|
"mlp_ratio": 4.0, |
|
|
"use_dit_time_embed": False, |
|
|
} |
|
|
params = {**default_params, **config.router_params} |
|
|
|
|
|
self.patch_size = params["patch_size"] |
|
|
self.in_channels = params["in_channels"] |
|
|
self.hidden_size = params["hidden_size"] |
|
|
self.num_heads = params["num_heads"] |
|
|
self.num_clusters = config.num_clusters |
|
|
|
|
|
|
|
|
self.patch_embed = nn.Conv2d( |
|
|
self.in_channels, self.hidden_size, |
|
|
kernel_size=self.patch_size, stride=self.patch_size |
|
|
) |
|
|
|
|
|
|
|
|
latent_size = getattr(config, 'image_size', 32) |
|
|
self.num_patches = (latent_size // self.patch_size) ** 2 |
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.zeros(1, self.num_patches, self.hidden_size), |
|
|
requires_grad=False |
|
|
) |
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) |
|
|
|
|
|
|
|
|
self.use_dit_time_embed = params.get("use_dit_time_embed", False) |
|
|
if self.use_dit_time_embed: |
|
|
self.time_embed = DiTTimestepEmbedder(self.hidden_size) |
|
|
else: |
|
|
self.time_embed = TimestepEmbedder(self.hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
DiTBlock(self.hidden_size, self.num_heads, params["mlp_ratio"], use_text=False) |
|
|
for _ in range(params["num_layers"]) |
|
|
]) |
|
|
|
|
|
|
|
|
self.norm_final = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, self.hidden_size), |
|
|
nn.GELU(), |
|
|
nn.LayerNorm(self.hidden_size), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(self.hidden_size, self.num_clusters) |
|
|
) |
|
|
|
|
|
|
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
def _basic_init(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.cls_token, std=0.02) |
|
|
|
|
|
|
|
|
grid_size = int(self.num_patches ** 0.5) |
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size) |
|
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
|
|
|
w = self.patch_embed.weight.data |
|
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
if self.patch_embed.bias is not None: |
|
|
nn.init.constant_(self.patch_embed.bias, 0) |
|
|
|
|
|
|
|
|
if hasattr(self.time_embed, 'mlp'): |
|
|
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) |
|
|
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
|
|
|
for block in self.layers: |
|
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.init.normal_(self.head[0].weight, std=0.02) |
|
|
if self.head[0].bias is not None: |
|
|
nn.init.constant_(self.head[0].bias, 0) |
|
|
|
|
|
|
|
|
nn.init.constant_(self.head[-1].weight, 0) |
|
|
if self.head[-1].bias is not None: |
|
|
nn.init.constant_(self.head[-1].bias, 0) |
|
|
|
|
|
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
B, C, H, W = xt.shape |
|
|
|
|
|
|
|
|
if t.max() <= 1.0 and t.min() >= 0.0: |
|
|
t = t * 999.0 |
|
|
t = t.clamp(0, 999) |
|
|
|
|
|
|
|
|
x = self.patch_embed(xt) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
x = x + self.pos_embed |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
|
|
|
|
|
|
c = self.time_embed(t) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, c, text_emb=None) |
|
|
|
|
|
|
|
|
cls_output = x[:, 0] |
|
|
cls_output = self.norm_final(cls_output) |
|
|
|
|
|
|
|
|
logits = self.head(cls_output) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeterministicTimestepRouter(nn.Module): |
|
|
""" |
|
|
Deterministic router that assigns experts based on timestep. |
|
|
|
|
|
Useful for controlled experiments where you want to test specific routing strategies, |
|
|
such as: "high noise → DDPM expert, low noise → FM expert" |
|
|
|
|
|
Args: |
|
|
config: Config object with router_params containing: |
|
|
- timestep_threshold: t value to switch experts (default: 0.5) |
|
|
- high_noise_expert: Expert ID for t > threshold (default: 0, typically DDPM) |
|
|
- low_noise_expert: Expert ID for t <= threshold (default: 1, typically FM) |
|
|
|
|
|
Example config: |
|
|
router_architecture: "deterministic_timestep" |
|
|
router_params: |
|
|
timestep_threshold: 0.5 |
|
|
high_noise_expert: 0 # DDPM for high noise |
|
|
low_noise_expert: 1 # FM for low noise |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.num_experts = config.num_experts |
|
|
self.threshold = config.router_params.get('timestep_threshold', 0.5) |
|
|
self.high_noise_expert = config.router_params.get('high_noise_expert', 0) |
|
|
self.low_noise_expert = config.router_params.get('low_noise_expert', 1) |
|
|
|
|
|
|
|
|
assert 0 <= self.high_noise_expert < self.num_experts, \ |
|
|
f"high_noise_expert {self.high_noise_expert} out of range [0, {self.num_experts})" |
|
|
assert 0 <= self.low_noise_expert < self.num_experts, \ |
|
|
f"low_noise_expert {self.low_noise_expert} out of range [0, {self.num_experts})" |
|
|
|
|
|
|
|
|
assert 0.0 <= self.threshold <= 1.0, \ |
|
|
f"timestep_threshold {self.threshold} must be in [0, 1]" |
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer('_threshold', torch.tensor(self.threshold)) |
|
|
|
|
|
print(f"DeterministicTimestepRouter initialized:") |
|
|
print(f" Threshold: {self.threshold}") |
|
|
print(f" High noise (t > {self.threshold}) → Expert {self.high_noise_expert}") |
|
|
print(f" Low noise (t <= {self.threshold}) → Expert {self.low_noise_expert}") |
|
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Returns one-hot routing probabilities based on timestep. |
|
|
|
|
|
Args: |
|
|
x: Input tensor (unused, but kept for API compatibility with other routers) |
|
|
t: Timesteps, shape (B,) |
|
|
|
|
|
Returns: |
|
|
routing_probs: Shape (B, num_experts), one-hot encoded |
|
|
""" |
|
|
B = t.shape[0] |
|
|
device = t.device |
|
|
|
|
|
|
|
|
routing_probs = torch.zeros(B, self.num_experts, device=device) |
|
|
|
|
|
|
|
|
|
|
|
high_noise_mask = t > self.threshold |
|
|
routing_probs[high_noise_mask, self.high_noise_expert] = 1.0 |
|
|
routing_probs[~high_noise_mask, self.low_noise_expert] = 1.0 |
|
|
|
|
|
return routing_probs |
|
|
|
|
|
def train(self, mode: bool = True): |
|
|
"""Override train() - this router is never trained, always in eval mode""" |
|
|
return super(DeterministicTimestepRouter, self).train(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaptiveVideoRouter(nn.Module): |
|
|
""" |
|
|
Time-adaptive router for video DDM. |
|
|
|
|
|
Key innovation: Learns optimal weighting of information sources |
|
|
at each noise level, solving the "motion invisible at t=1" problem. |
|
|
|
|
|
Information availability is time-dependent: |
|
|
t ~ 1.0: Only text/first_frame informative → Route on conditioning |
|
|
t ~ 0.5: Structure emerging → Latent becomes useful |
|
|
t ~ 0.1: Near clean → Full information available |
|
|
|
|
|
Expected learned behavior: |
|
|
| Noise Level | Text | Frame | Latent | Behavior | |
|
|
|-------------|------|-------|--------|-----------------------------| |
|
|
| t ~ 1.0 | ~0.7 | ~0.2 | ~0.1 | Routes on text semantics | |
|
|
| t ~ 0.5 | ~0.4 | ~0.3 | ~0.3 | Balanced; emerging structure| |
|
|
| t ~ 0.1 | ~0.2 | ~0.2 | ~0.6 | Trusts latent; fine-grained | |
|
|
|
|
|
Enhancements: |
|
|
- Masked mean pooling for text (handles variable-length prompts) |
|
|
- Temporal-aware latent encoder (captures motion patterns) |
|
|
- Temperature scaling for inference control |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
default_params = { |
|
|
"hidden_dim": 512, |
|
|
"text_embed_dim": 768, |
|
|
"frame_embed_dim": 768, |
|
|
"latent_channels": 16, |
|
|
"latent_conv_dim": 64, |
|
|
"dropout": 0.1, |
|
|
"temporal_pool_mode": "attention", |
|
|
"normalize_inputs": True, |
|
|
} |
|
|
params = {**default_params, **getattr(config, 'router_params', {})} |
|
|
|
|
|
self.hidden_dim = params["hidden_dim"] |
|
|
self.num_experts = getattr(config, 'num_experts', config.num_clusters) |
|
|
self.latent_channels = params["latent_channels"] |
|
|
self.latent_conv_dim = params["latent_conv_dim"] |
|
|
self.temporal_pool_mode = params["temporal_pool_mode"] |
|
|
self.normalize_inputs = params.get("normalize_inputs", True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.text_encoder = nn.Sequential( |
|
|
nn.Linear(params["text_embed_dim"], self.hidden_dim), |
|
|
nn.LayerNorm(self.hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.hidden_dim, self.hidden_dim) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.frame_encoder = nn.Sequential( |
|
|
nn.Linear(params["frame_embed_dim"], self.hidden_dim), |
|
|
nn.LayerNorm(self.hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.hidden_dim, self.hidden_dim) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.spatial_conv = nn.Sequential( |
|
|
nn.Conv3d(params["latent_channels"], params["latent_conv_dim"], |
|
|
kernel_size=(1, 3, 3), padding=(0, 1, 1)), |
|
|
nn.GroupNorm(8, params["latent_conv_dim"]), |
|
|
nn.GELU(), |
|
|
) |
|
|
|
|
|
|
|
|
self.temporal_conv = nn.Sequential( |
|
|
nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"], |
|
|
kernel_size=(3, 1, 1), padding=(1, 0, 0)), |
|
|
nn.GroupNorm(8, params["latent_conv_dim"]), |
|
|
nn.GELU(), |
|
|
) |
|
|
|
|
|
|
|
|
self.st_conv = nn.Sequential( |
|
|
nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"], |
|
|
kernel_size=3, padding=1), |
|
|
nn.GroupNorm(8, params["latent_conv_dim"]), |
|
|
nn.GELU(), |
|
|
) |
|
|
|
|
|
|
|
|
self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) |
|
|
|
|
|
|
|
|
if self.temporal_pool_mode == "attention": |
|
|
self.temporal_attn = nn.Sequential( |
|
|
nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"] // 4), |
|
|
nn.Tanh(), |
|
|
nn.Linear(params["latent_conv_dim"] // 4, 1), |
|
|
) |
|
|
|
|
|
|
|
|
self.motion_encoder = nn.Sequential( |
|
|
nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"]), |
|
|
nn.GELU(), |
|
|
nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2), |
|
|
) |
|
|
|
|
|
|
|
|
self.content_proj = nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2) |
|
|
|
|
|
|
|
|
self.latent_proj = nn.Sequential( |
|
|
nn.Linear(self.hidden_dim, self.hidden_dim), |
|
|
nn.LayerNorm(self.hidden_dim), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.time_embed = TimestepEmbedder(self.hidden_dim) |
|
|
|
|
|
self.time_mlp = nn.Sequential( |
|
|
nn.Linear(self.hidden_dim, self.hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.hidden_dim, self.hidden_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.source_weighting = nn.Sequential( |
|
|
nn.Linear(self.hidden_dim, 128), |
|
|
nn.GELU(), |
|
|
nn.Linear(128, 3), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.router_head = nn.Sequential( |
|
|
nn.Linear(self.hidden_dim, self.hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.LayerNorm(self.hidden_dim), |
|
|
nn.Dropout(params["dropout"]), |
|
|
nn.Linear(self.hidden_dim, self.num_experts) |
|
|
) |
|
|
|
|
|
|
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
"""Initialize weights following DiT conventions.""" |
|
|
def _basic_init(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
elif isinstance(module, nn.Conv3d): |
|
|
|
|
|
w = module.weight.data |
|
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
if hasattr(self.time_embed, 'mlp'): |
|
|
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) |
|
|
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
nn.init.normal_(self.router_head[-1].weight, std=0.01) |
|
|
nn.init.constant_(self.router_head[-1].bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
nn.init.constant_(self.source_weighting[-2].weight, 0) |
|
|
nn.init.constant_(self.source_weighting[-2].bias, 0) |
|
|
|
|
|
|
|
|
if self.temporal_pool_mode == "attention": |
|
|
nn.init.constant_(self.temporal_attn[-1].weight, 0) |
|
|
nn.init.constant_(self.temporal_attn[-1].bias, 0) |
|
|
|
|
|
def _masked_mean_pool(self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
""" |
|
|
Compute masked mean pooling over sequence dimension. |
|
|
|
|
|
Args: |
|
|
embeddings: [B, seq_len, embed_dim] - Token embeddings |
|
|
attention_mask: [B, seq_len] - 1 for real tokens, 0 for padding |
|
|
|
|
|
Returns: |
|
|
pooled: [B, embed_dim] - Pooled representation |
|
|
""" |
|
|
if attention_mask is None: |
|
|
|
|
|
return embeddings.mean(dim=1) |
|
|
|
|
|
|
|
|
mask = attention_mask.unsqueeze(-1).to(embeddings.dtype) |
|
|
|
|
|
|
|
|
masked_sum = (embeddings * mask).sum(dim=1) |
|
|
|
|
|
|
|
|
token_counts = mask.sum(dim=1).clamp(min=1.0) |
|
|
|
|
|
return masked_sum / token_counts |
|
|
|
|
|
def _encode_latent_temporal(self, x_t: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode video latent with temporal awareness. |
|
|
|
|
|
Extracts both: |
|
|
- Content features: What is in the video (spatial) |
|
|
- Motion features: How things move (temporal differences) |
|
|
|
|
|
Args: |
|
|
x_t: [B, C, T, H, W] - Noisy video latent |
|
|
|
|
|
Returns: |
|
|
latent_feat: [B, hidden_dim] - Combined latent features |
|
|
""" |
|
|
B, C, T, H, W = x_t.shape |
|
|
|
|
|
|
|
|
spatial_feat = self.spatial_conv(x_t) |
|
|
|
|
|
|
|
|
temporal_feat = self.temporal_conv(spatial_feat) |
|
|
|
|
|
|
|
|
st_feat = self.st_conv(temporal_feat) |
|
|
|
|
|
|
|
|
pooled = self.spatial_pool(st_feat).squeeze(-1).squeeze(-1) |
|
|
pooled = pooled.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
if self.temporal_pool_mode == "attention" and T > 1: |
|
|
|
|
|
attn_scores = self.temporal_attn(pooled).squeeze(-1) |
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
content_feat = (pooled * attn_weights.unsqueeze(-1)).sum(dim=1) |
|
|
elif self.temporal_pool_mode == "max": |
|
|
content_feat = pooled.max(dim=1)[0] |
|
|
else: |
|
|
content_feat = pooled.mean(dim=1) |
|
|
|
|
|
|
|
|
if T > 1: |
|
|
|
|
|
frame_diffs = pooled[:, 1:] - pooled[:, :-1] |
|
|
|
|
|
|
|
|
motion_feat = self.motion_encoder(frame_diffs.mean(dim=1)) |
|
|
else: |
|
|
|
|
|
motion_feat = torch.zeros(B, self.hidden_dim // 2, device=x_t.device) |
|
|
|
|
|
|
|
|
content_proj = self.content_proj(content_feat) |
|
|
|
|
|
|
|
|
combined = torch.cat([content_proj, motion_feat], dim=-1) |
|
|
latent_feat = self.latent_proj(combined) |
|
|
|
|
|
return latent_feat |
|
|
|
|
|
def forward(self, x_t: torch.Tensor, t: torch.Tensor, |
|
|
text_embed: torch.Tensor, |
|
|
first_frame_feat: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
temperature: float = 1.0) -> torch.Tensor: |
|
|
""" |
|
|
Compute routing logits with time-adaptive information weighting. |
|
|
|
|
|
Args: |
|
|
x_t: Noisy video latent [B, C, T, H, W] |
|
|
t: Noise level [B] in [0, 1] or [0, 999] |
|
|
text_embed: CLIP text embedding [B, text_embed_dim] or [B, seq_len, text_embed_dim] |
|
|
first_frame_feat: Optional DINOv2 features [B, frame_embed_dim] |
|
|
attention_mask: Optional [B, seq_len] mask for text (1=valid, 0=padding) |
|
|
temperature: Softmax temperature for sharper/softer routing (default: 1.0) |
|
|
|
|
|
Returns: |
|
|
logits: Expert selection logits [B, num_experts] (scaled by temperature) |
|
|
""" |
|
|
B = x_t.shape[0] |
|
|
device = x_t.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if text_embed.dim() == 3: |
|
|
|
|
|
text_embed_pooled = self._masked_mean_pool(text_embed, attention_mask) |
|
|
else: |
|
|
|
|
|
text_embed_pooled = text_embed |
|
|
|
|
|
|
|
|
if self.normalize_inputs: |
|
|
text_embed_pooled = F.normalize(text_embed_pooled, p=2, dim=-1) |
|
|
|
|
|
text_feat = self.text_encoder(text_embed_pooled) |
|
|
|
|
|
|
|
|
if first_frame_feat is not None: |
|
|
|
|
|
if self.normalize_inputs: |
|
|
first_frame_feat = F.normalize(first_frame_feat, p=2, dim=-1) |
|
|
frame_feat = self.frame_encoder(first_frame_feat) |
|
|
else: |
|
|
frame_feat = torch.zeros(B, self.hidden_dim, device=device) |
|
|
|
|
|
|
|
|
latent_feat = self._encode_latent_temporal(x_t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if t.max() <= 1.0: |
|
|
t_scaled = t * 999.0 |
|
|
else: |
|
|
t_scaled = t |
|
|
t_scaled = t_scaled.clamp(0, 999) |
|
|
|
|
|
|
|
|
time_emb = self.time_embed(t_scaled) |
|
|
time_feat = self.time_mlp(time_emb) |
|
|
|
|
|
|
|
|
|
|
|
weights = self.source_weighting(time_feat) |
|
|
|
|
|
|
|
|
|
|
|
combined = ( |
|
|
weights[:, 0:1] * text_feat + |
|
|
weights[:, 1:2] * frame_feat + |
|
|
weights[:, 2:3] * latent_feat |
|
|
) |
|
|
|
|
|
|
|
|
logits = self.router_head(combined + time_feat) |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
logits = logits / temperature |
|
|
|
|
|
return logits |
|
|
|
|
|
def get_source_weights(self, t: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Get the learned source weights for given timesteps. |
|
|
Useful for debugging and visualization. |
|
|
|
|
|
Args: |
|
|
t: Noise levels [B] in [0, 1] or [0, 999] |
|
|
|
|
|
Returns: |
|
|
weights: Source weights [B, 3] for [text, frame, latent] |
|
|
""" |
|
|
|
|
|
if t.max() <= 1.0: |
|
|
t_scaled = t * 999.0 |
|
|
else: |
|
|
t_scaled = t |
|
|
t_scaled = t_scaled.clamp(0, 999) |
|
|
|
|
|
time_emb = self.time_embed(t_scaled) |
|
|
time_feat = self.time_mlp(time_emb) |
|
|
weights = self.source_weighting(time_feat) |
|
|
|
|
|
return weights |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_expert(config, expert_id: Optional[int] = None) -> nn.Module: |
|
|
""" |
|
|
Factory function to create expert model |
|
|
|
|
|
Args: |
|
|
config: Config object |
|
|
expert_id: Optional expert ID for per-expert schedule/objective configuration |
|
|
""" |
|
|
|
|
|
import copy |
|
|
config = copy.copy(config) |
|
|
config.expert_params = config.expert_params.copy() |
|
|
|
|
|
|
|
|
if "schedule_type" not in config.expert_params: |
|
|
|
|
|
if (hasattr(config, 'expert_schedule_types') and |
|
|
config.expert_schedule_types and |
|
|
expert_id is not None and |
|
|
expert_id in config.expert_schedule_types): |
|
|
config.expert_params["schedule_type"] = config.expert_schedule_types[expert_id] |
|
|
else: |
|
|
|
|
|
config.expert_params["schedule_type"] = getattr(config, 'schedule_type', 'linear_interp') |
|
|
|
|
|
|
|
|
if "objective_type" not in config.expert_params: |
|
|
|
|
|
if (hasattr(config, 'expert_objectives') and |
|
|
config.expert_objectives and |
|
|
expert_id is not None and |
|
|
expert_id in config.expert_objectives): |
|
|
config.expert_params["objective_type"] = config.expert_objectives[expert_id] |
|
|
else: |
|
|
|
|
|
config.expert_params["objective_type"] = getattr(config, 'default_objective', 'fm') |
|
|
|
|
|
if config.expert_architecture == "unet": |
|
|
return UNetExpert(config) |
|
|
elif config.expert_architecture == "simple_cnn": |
|
|
return SimpleCNNExpert(config) |
|
|
elif config.expert_architecture == "dit": |
|
|
return DiTExpert(config) |
|
|
else: |
|
|
raise ValueError(f"Unknown expert architecture: {config.expert_architecture}") |
|
|
|
|
|
def create_router(config) -> Optional[nn.Module]: |
|
|
"""Factory function to create router model""" |
|
|
|
|
|
if config.router_architecture == "none" or config.is_monolithic: |
|
|
return None |
|
|
elif config.router_architecture == "deterministic_timestep": |
|
|
return DeterministicTimestepRouter(config) |
|
|
elif config.router_architecture == "vit": |
|
|
return ViTRouter(config) |
|
|
elif config.router_architecture == "cnn": |
|
|
return CNNRouter(config) |
|
|
elif config.router_architecture == "dit": |
|
|
return DiTRouter(config) |
|
|
elif config.router_architecture == "adaptive_video": |
|
|
return AdaptiveVideoRouter(config) |
|
|
else: |
|
|
raise ValueError(f"Unknown router architecture: {config.router_architecture}") |