| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| '''
|
| [Model Overview]
|
|
|
| Input: (B, 4, 64, 64) - VAE latent
|
|
|
| 1. PatchEmbedding
|
| Conv2d(patch_size=2) → flatten → transpose
|
| (B, 4, 64, 64) → (B, 1024 tokens, 1024 d_model)
|
|
|
| 2. Condition Embedding
|
| ├── Sigma → SinusoidalPosEmb → MLP → sigma_emb # noise level (timestep)
|
| └── Text → Linear → MLP → text_token_emb # tokens for cross attention
|
| Text → mean pooling → MLP → pooled_text # global condition for adaLN
|
|
|
| cond_emb = sigma_emb + pooled_text → adaLN modulation coefficients
|
|
|
| 3. DiT Block × num_layers
|
| each block receives shift/scale modulation from cond_emb (adaLN):
|
| ├── Self Attention + RoPE 2D # spatial relationships between patches
|
| ├── Text Cross Attention # text tokens ↔ image patches
|
| └── FFN # feature transformation
|
|
|
| 4. Final modulation + Output projection
|
| LayerNorm → adaLN shift/scale → Linear → unpatchify
|
|
|
| Output: pred_velocity (B, 4, 64, 64) - direction vector from noise → clean
|
| '''
|
|
|
| class SinusoidalPosEmb(nn.Module):
|
| def __init__(self, dim, sinusoid_rope_hz):
|
| super().__init__()
|
| self.sinusoid_rope_hz = sinusoid_rope_hz
|
| self.dim = dim
|
|
|
| def forward(self, x):
|
| device = x.device
|
| half_dim = self.dim // 2
|
|
|
| emb = math.log(self.sinusoid_rope_hz) / max(half_dim - 1, 1)
|
| emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
|
|
|
| emb = x[:, None].float() * emb[None, :]
|
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
|
| return emb
|
|
|
| class RotaryPositionalEmbedding2D(nn.Module):
|
| def __init__(self, dim, base):
|
| super().__init__()
|
| self.dim = dim
|
| self.rope_dim_per_coord = dim // 2
|
|
|
| inv_freq_h = 1.0 / (base ** (torch.arange(0, self.rope_dim_per_coord, 2).float() / self.rope_dim_per_coord))
|
| self.register_buffer('inv_freq_h', inv_freq_h)
|
|
|
| inv_freq_w = 1.0 / (base ** (torch.arange(0, self.rope_dim_per_coord, 2).float() / self.rope_dim_per_coord))
|
| self.register_buffer('inv_freq_w', inv_freq_w)
|
|
|
| def forward(self, q, k, H_p, W_p):
|
| t_idx = torch.arange(H_p * W_p, device=q.device)
|
| h_idx = t_idx // W_p
|
| w_idx = t_idx % W_p
|
|
|
| freqs_h = torch.einsum('i,j->ij', h_idx.float(), self.inv_freq_h)
|
| freqs_w = torch.einsum('i,j->ij', w_idx.float(), self.inv_freq_w)
|
|
|
| freqs_h = torch.cat((freqs_h, freqs_h), dim=-1)
|
| freqs_w = torch.cat((freqs_w, freqs_w), dim=-1)
|
|
|
| cos_cached_h = freqs_h.cos().view(1, 1, H_p * W_p, self.rope_dim_per_coord)
|
| sin_cached_h = freqs_h.sin().view(1, 1, H_p * W_p, self.rope_dim_per_coord)
|
| cos_cached_w = freqs_w.cos().view(1, 1, H_p * W_p, self.rope_dim_per_coord)
|
| sin_cached_w = freqs_w.sin().view(1, 1, H_p * W_p, self.rope_dim_per_coord)
|
|
|
| q_h, q_w = q.chunk(2, dim=-1)
|
| k_h, k_w = k.chunk(2, dim=-1)
|
|
|
| q_h_rot = (q_h * cos_cached_h) + (self._rotate_half(q_h) * sin_cached_h)
|
| k_h_rot = (k_h * cos_cached_h) + (self._rotate_half(k_h) * sin_cached_h)
|
|
|
| q_w_rot = (q_w * cos_cached_w) + (self._rotate_half(q_w) * sin_cached_w)
|
| k_w_rot = (k_w * cos_cached_w) + (self._rotate_half(k_w) * sin_cached_w)
|
|
|
| q_rot = torch.cat((q_h_rot, q_w_rot), dim=-1)
|
| k_rot = torch.cat((k_h_rot, k_w_rot), dim=-1)
|
|
|
| return q_rot, k_rot
|
|
|
| def _rotate_half(self, x):
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
| class PatchEmbedding(nn.Module):
|
| def __init__(self, in_channels: int, patch_size: int, d_model: int):
|
| super().__init__()
|
| self.patch_size = patch_size
|
| self.in_channels = in_channels
|
| self.proj = nn.Conv2d(
|
| in_channels,
|
| d_model,
|
| kernel_size=patch_size,
|
| stride=patch_size
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| x = self.proj(x)
|
| x = x.flatten(2)
|
| x = x.transpose(1, 2).contiguous()
|
| return x
|
|
|
| def patchify(self, x: torch.Tensor) -> torch.Tensor:
|
| B, C, H, W = x.shape
|
| p = self.patch_size
|
| x = x.reshape(B, C, H // p, p, W // p, p)
|
| x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
|
| x = x.reshape(B, -1, p * p * C)
|
| return x
|
|
|
| def unpatchify(self, x: torch.Tensor, H, W):
|
| B = x.shape[0]
|
| p = self.patch_size
|
| C = self.in_channels
|
| h, w = H // p, W // p
|
| x = x.reshape(B, h, w, C, p, p)
|
| x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
|
| x = x.reshape(B, C, h * p, w * p)
|
| return x
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(self, d_model, nhead, dim_feedforward, dropout, rope_hz):
|
| super().__init__()
|
|
|
| self.nhead = nhead
|
| self.d_model = d_model
|
|
|
|
|
| self.adaLN_modulation = nn.Sequential(
|
| nn.SiLU(),
|
| nn.Linear(d_model, d_model * 6)
|
| )
|
|
|
|
|
| self.self_norm = nn.LayerNorm(d_model)
|
| self.qkv = nn.Linear(d_model, d_model * 3)
|
| self.out_proj = nn.Linear(d_model, d_model)
|
| self.rope = RotaryPositionalEmbedding2D(d_model // nhead, rope_hz)
|
|
|
|
|
| self.text_norm = nn.LayerNorm(d_model)
|
| self.text_cross_q = nn.Linear(d_model, d_model)
|
| self.text_cross_kv = nn.Linear(d_model, d_model * 2)
|
| self.text_cross_out = nn.Linear(d_model, d_model)
|
|
|
|
|
| self.ffn_norm = nn.LayerNorm(d_model)
|
| self.ff1 = nn.Linear(d_model, dim_feedforward)
|
| self.ff2 = nn.Linear(dim_feedforward, d_model)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
| self.q_norm = nn.LayerNorm(d_model // nhead)
|
| self.k_norm = nn.LayerNorm(d_model // nhead)
|
|
|
| def self_attention(self, x_norm, H, W):
|
| B, T, D = x_norm.shape
|
| N = self.nhead
|
| d_k = D // N
|
|
|
| qkv = self.qkv(x_norm)
|
| Q, K, V = qkv.chunk(3, dim=-1)
|
|
|
| Q = Q.reshape(B, T, N, d_k).transpose(1, 2)
|
| K = K.reshape(B, T, N, d_k).transpose(1, 2)
|
| V = V.reshape(B, T, N, d_k).transpose(1, 2)
|
|
|
| Q = self.q_norm(Q)
|
| K = self.k_norm(K)
|
|
|
| Q_rot, K_rot = self.rope(Q, K, H, W)
|
|
|
| attn_out = F.scaled_dot_product_attention(
|
| Q_rot, K_rot, V,
|
| dropout_p=0.0,
|
| is_causal=False,
|
| )
|
|
|
| attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
|
| attn_out = self.out_proj(attn_out)
|
| return attn_out
|
|
|
| def _cross_attention_impl(self, x_norm, cond, q_proj, kv_proj, out_proj, H, W, attn_mask=None):
|
| B, T, D = x_norm.shape
|
| Bc, L, Dc = cond.shape
|
|
|
| Q = q_proj(x_norm)
|
| kv = kv_proj(cond)
|
| K, V = kv.chunk(2, dim=-1)
|
|
|
| N = self.nhead
|
| d_k = D // N
|
|
|
| Q = Q.reshape(B, T, N, d_k).transpose(1, 2)
|
| K = K.reshape(B, L, N, d_k).transpose(1, 2)
|
| V = V.reshape(B, L, N, d_k).transpose(1, 2)
|
|
|
| if attn_mask is not None:
|
| attn_mask = attn_mask.to(device=Q.device, dtype=torch.bool)
|
| attn_mask = attn_mask[:, None, None, :]
|
|
|
| out = F.scaled_dot_product_attention(
|
| Q, K, V,
|
| attn_mask=attn_mask,
|
| dropout_p=0.0,
|
| is_causal=False,
|
| )
|
|
|
| out = out.transpose(1, 2).reshape(B, T, D)
|
| out = out_proj(out)
|
| return out
|
|
|
| def text_cross_attention(self, x_norm, text_emb, text_mask, H, W):
|
| return self._cross_attention_impl(
|
| x_norm=x_norm,
|
| cond=text_emb,
|
| q_proj=self.text_cross_q,
|
| kv_proj=self.text_cross_kv,
|
| out_proj=self.text_cross_out,
|
| H=H,
|
| W=W,
|
| attn_mask=text_mask,
|
| )
|
|
|
| def forward(self, x, cond_emb, text_emb, text_mask=None, H=None, W=None, key=None):
|
| B, T, D = x.shape
|
|
|
|
|
| c = cond_emb.squeeze(1)
|
| chunks = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| shift_msa, scale_msa = chunks[0], chunks[1]
|
| shift_cross, scale_cross = chunks[2], chunks[3]
|
| shift_mlp, scale_mlp = chunks[4], chunks[5]
|
|
|
|
|
| x_norm = self.self_norm(x)
|
| x_norm = x_norm * (1 + scale_msa[:, None, :]) + shift_msa[:, None, :]
|
| self_out = self.self_attention(x_norm=x_norm, H=H, W=W)
|
| x = x + self.dropout(self_out)
|
|
|
|
|
| x_norm = self.text_norm(x)
|
| x_norm = x_norm * (1 + scale_cross[:, None, :]) + shift_cross[:, None, :]
|
| text_cross_out = self.text_cross_attention(
|
| x_norm=x_norm,
|
| text_emb=text_emb,
|
| text_mask=text_mask,
|
| H=H,
|
| W=W,
|
| )
|
| x = x + self.dropout(text_cross_out)
|
|
|
|
|
| x_norm = self.ffn_norm(x)
|
| x_norm = x_norm * (1 + scale_mlp[:, None, :]) + shift_mlp[:, None, :]
|
| ffn = self.ff1(x_norm)
|
| ffn = F.gelu(ffn, approximate="tanh")
|
| ffn = self.ff2(ffn)
|
| x = x + self.dropout(ffn)
|
|
|
| with torch.no_grad():
|
| self_std = self_out.float().std().item()
|
| text_std = text_cross_out.float().std().item()
|
| ffn_std = ffn.float().std().item()
|
|
|
| state = {
|
| "key": key,
|
| "self_out": self_std,
|
| "text_out": text_std,
|
| "ffn_out": ffn_std,
|
| }
|
|
|
| return x, state
|
|
|
|
|
| class Model(nn.Module):
|
| def __init__(self, d_model, nhead, num_layers, dropout, sigma_emb_hz, in_channels, patch_size, text_dim, rope_hz, **kwargs):
|
| super().__init__()
|
| dim_feedforward = 4 * d_model
|
| self.patch_size = patch_size
|
| self.in_channels = in_channels
|
|
|
|
|
| self.patch_embedding = PatchEmbedding(
|
| in_channels=in_channels,
|
| patch_size=patch_size,
|
| d_model=d_model
|
| )
|
| self.patch_norm = nn.LayerNorm(d_model)
|
|
|
|
|
| self.sigma_proj = SinusoidalPosEmb(d_model, sigma_emb_hz)
|
| self.sigma_embed = nn.Sequential(
|
| nn.Linear(d_model, d_model),
|
| nn.SiLU(),
|
| nn.Linear(d_model, d_model),
|
| )
|
| self.sigma_norm = nn.LayerNorm(d_model)
|
|
|
|
|
| self.text_proj = nn.Linear(text_dim, d_model)
|
| self.text_embed = nn.Sequential(
|
| nn.Linear(d_model, d_model),
|
| nn.SiLU(),
|
| nn.Linear(d_model, d_model),
|
| )
|
| self.text_norm = nn.LayerNorm(d_model)
|
|
|
|
|
| self.pooled_text_proj = nn.Sequential(
|
| nn.Linear(text_dim, d_model),
|
| nn.SiLU(),
|
| nn.Linear(d_model, d_model),
|
| )
|
| self.pooled_text_norm = nn.LayerNorm(d_model)
|
|
|
| self.blocks = nn.ModuleList([
|
| Block(
|
| d_model=d_model,
|
| nhead=nhead,
|
| dim_feedforward=dim_feedforward,
|
| dropout=dropout,
|
| rope_hz=rope_hz
|
| )
|
| for _ in range(num_layers)
|
| ])
|
|
|
| self.cond_norm = nn.LayerNorm(d_model)
|
| self.cond_proj = nn.Linear(d_model, d_model)
|
|
|
| self.norm = nn.LayerNorm(d_model)
|
|
|
| self.final_mod = nn.Sequential(
|
| nn.SiLU(),
|
| nn.Linear(d_model, d_model * 2)
|
| )
|
|
|
| self.output_proj = nn.Linear(
|
| d_model,
|
| patch_size * patch_size * in_channels
|
| )
|
|
|
| def forward(self, x, sigma, text_emb, text_mask=None):
|
| B, C, H, W = x.shape
|
| H_p = H // self.patch_size
|
| W_p = W // self.patch_size
|
|
|
| sigma = sigma.to(device=x.device, dtype=torch.float32)
|
|
|
|
|
| x_emb = self.patch_norm(self.patch_embedding(x))
|
|
|
|
|
| sigma_emb = self.sigma_proj(sigma)
|
| sigma_emb = self.sigma_embed(sigma_emb)
|
| sigma_emb = self.sigma_norm(sigma_emb)
|
|
|
|
|
| text_token_emb = self.text_proj(text_emb)
|
| text_token_emb = self.text_embed(text_token_emb)
|
| text_token_emb = self.text_norm(text_token_emb)
|
|
|
|
|
| if text_mask is not None:
|
| mask_float = text_mask.float().unsqueeze(-1)
|
| pooled_text = (text_emb * mask_float).sum(dim=1) / mask_float.sum(dim=1).clamp(min=1)
|
| else:
|
| pooled_text = text_emb.mean(dim=1)
|
|
|
| pooled_text = self.pooled_text_proj(pooled_text)
|
| pooled_text = self.pooled_text_norm(pooled_text)
|
|
|
|
|
| cond_emb = self.cond_norm(self.cond_proj(sigma_emb + pooled_text))
|
| cond_emb = cond_emb[:, None, :]
|
|
|
|
|
| layer_states = []
|
| for i, block in enumerate(self.blocks):
|
| x_emb, state = block(
|
| x=x_emb,
|
| cond_emb=cond_emb,
|
| text_emb=text_token_emb,
|
| text_mask=text_mask,
|
| H=H_p,
|
| W=W_p,
|
| key=i,
|
| )
|
| layer_states.append(state)
|
|
|
|
|
| shift, scale = self.final_mod(cond_emb.squeeze(1)).chunk(2, dim=-1)
|
| x_final = self.norm(x_emb)
|
| x_final = x_final * (1 + scale[:, None, :]) + shift[:, None, :]
|
|
|
| pred_velocity = self.output_proj(x_final)
|
| pred_velocity = self.patch_embedding.unpatchify(pred_velocity, H, W)
|
| return pred_velocity, layer_states |