| import torch |
| import torch.nn as nn |
| import math |
| import numpy as np |
| import torch.nn.functional as F |
| |
|
|
| class Mlp(nn.Module): |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = head_dim ** -0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| x = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) |
| |
| x = x.transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
|
|
| class Patch1D(nn.Module): |
| """ |
| [B, L, D] -> [B, L/P, D*P] |
| """ |
| def __init__(self, patch_size): |
| super().__init__() |
| self.patch_size = patch_size |
|
|
| def forward(self, x): |
| B, L, D = x.shape |
| |
| |
| if L % self.patch_size != 0: |
| pad = self.patch_size - (L % self.patch_size) |
| x = F.pad(x, (0, 0, 0, pad)) |
| |
| B, L_new, D = x.shape |
| |
| return x.view(B, L_new // self.patch_size, D * self.patch_size) |
|
|
| class Unpatch1D(nn.Module): |
| """ |
| [B, L/P, D*P] -> [B, L, D] |
| """ |
| def __init__(self, patch_size): |
| super().__init__() |
| self.patch_size = patch_size |
|
|
| def forward(self, x): |
| B, L_new, DP = x.shape |
| return x.view(B, L_new * self.patch_size, DP // self.patch_size) |
|
|
| |
| |
| class TimestepEmbedder(nn.Module): |
| """Sinusoidal Time Embeddings""" |
| def __init__(self, hidden_size, frequency_embedding_size=256): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, hidden_size,bias=True), |
| nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size,bias=True), |
| ) |
| self.frequency_embedding_size = frequency_embedding_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| """ |
| Create sinusoidal timestep embeddings. |
| :param t: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param dim: the dimension of the output. |
| :param max_period: controls the minimum frequency of the embeddings. |
| :return: an (N, D) Tensor of positional embeddings. |
| """ |
| |
| |
| if t.ndim > 1: |
| t = t.view(-1) |
|
|
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
| ).to(device=t.device) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
|
|
| def modulate(x, shift, scale): |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| |
| class DiTBlock(nn.Module): |
| """Transformer Block with Adaptive Layer Norm (adaLN)""" |
| def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True) |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| approx_gelu = lambda: nn.GELU(approximate="tanh") |
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| ) |
|
|
| def forward(self, x, c): |
| |
| |
| adaLN_out = self.adaLN_modulation(c) |
| |
| |
| if adaLN_out.shape[1] != 6 * self.hidden_size: |
| print(f"⚠️ DiTBlock Shape Error!") |
| print(f"Input c shape: {c.shape}") |
| print(f"adaLN output shape: {adaLN_out.shape}") |
| print(f"Expected dim1: {6 * self.hidden_size}") |
| raise ValueError("adaLN output dimension mismatch!") |
| |
|
|
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = adaLN_out.chunk(6, dim=1) |
| x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) |
| x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) |
| return x |
|
|
| class PatchedFlowDiT(nn.Module): |
| """ |
| Main DiT Architecture for Flow Matching |
| Input: z_t (Noisy Latent) + t (Time) + condition (Original Latent) |
| Output: velocity vector |
| """ |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| |
| self.patcher = Patch1D(cfg.patch_size) |
| self.unpatcher = Unpatch1D(cfg.patch_size) |
|
|
| |
| |
| |
| input_feat_dim = cfg.latent_dim * cfg.patch_size |
| |
| |
| self.input_proj = nn.Linear(input_feat_dim * 2, cfg.dit_hidden) |
|
|
| |
| self.time_embed = TimestepEmbedder(cfg.dit_hidden) |
| patched_len = (cfg.max_seq_len + cfg.patch_size - 1) // cfg.patch_size |
| self.pos_embed = nn.Parameter(torch.zeros(1, patched_len, cfg.dit_hidden)) |
| |
| self.blocks = nn.ModuleList([ |
| DiTBlock(cfg.dit_hidden, cfg.dit_heads) for _ in range(cfg.dit_layers) |
| ]) |
| |
| |
| self.final_layer = nn.Linear(cfg.dit_hidden, input_feat_dim) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| self.apply(_basic_init) |
|
|
| |
| nn.init.normal_(self.pos_embed, std=0.02) |
|
|
| |
| for block in self.blocks: |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
| |
| |
| |
| nn.init.xavier_uniform_(self.final_layer.weight) |
| nn.init.constant_(self.final_layer.bias, 0) |
|
|
| def forward(self, z_t, t, condition): |
| |
| |
| |
| """ |
| z_t: [B, L, D] |
| condition: [B, L, D] |
| """ |
| |
| z_p = self.patcher(z_t) |
| c_p = self.patcher(condition) |
| |
| |
| x = torch.cat([z_p, c_p], dim=-1) |
| x = self.input_proj(x) |
| |
| |
| t_emb = self.time_embed(t) |
| |
| L_curr = x.shape[1] |
| x = x + self.pos_embed[:, :L_curr, :] |
| |
| |
| for block in self.blocks: |
| x = block(x, t_emb) |
| |
| |
| v_p = self.final_layer(x) |
| v = self.unpatcher(v_p) |
| |
| |
| return v[:, :z_t.shape[1], :] |
|
|
| def forward_with_cfg(self, x, t, condition, cfg_scale): |
| """ |
| 支持 Classifier-Free Guidance 的前向传播 |
| """ |
| |
| cond_out = self.forward(x, t, condition) |
| |
| |
| uncond_out = self.forward(x, t, condition=None) |
| |
| |
| |
| return uncond_out + cfg_scale * (cond_out - uncond_out) |