| | """mDiffAE decoder: flat sequential DiCoBlocks with token-level PDG masking.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| | from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector |
| | from .dico_block import DiCoBlock |
| | from .norms import ChannelWiseRMSNorm |
| | from .straight_through_encoder import Patchify |
| | from .time_embed import SinusoidalTimeEmbeddingMLP |
| |
|
| |
|
| | class Decoder(nn.Module): |
| | """VP diffusion decoder conditioned on encoder latents and timestep. |
| | |
| | Architecture: |
| | Patchify x_t -> Norm -> Fuse with upsampled z |
| | -> Blocks (flat sequential, depth blocks) -> Norm -> Conv1x1 -> PixelShuffle |
| | |
| | Token-level PDG: at inference, a fraction of spatial tokens in the fused input |
| | are replaced with a learned mask_feature before the decoder blocks. Comparing |
| | the masked vs unmasked outputs provides guidance signal. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | patch_size: int, |
| | model_dim: int, |
| | depth: int, |
| | bottleneck_dim: int, |
| | mlp_ratio: float, |
| | depthwise_kernel_size: int, |
| | adaln_low_rank_rank: int, |
| | pdg_mask_ratio: float = 0.75, |
| | ) -> None: |
| | super().__init__() |
| | self.patch_size = int(patch_size) |
| | self.model_dim = int(model_dim) |
| | self.pdg_mask_ratio = float(pdg_mask_ratio) |
| |
|
| | |
| | self.patchify = Patchify(in_channels, patch_size, model_dim) |
| | self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) |
| |
|
| | |
| | self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True) |
| | self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) |
| | self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True) |
| |
|
| | |
| | self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim) |
| |
|
| | |
| | self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim) |
| | self.adaln_deltas = nn.ModuleList( |
| | [ |
| | AdaLNZeroLowRankDelta( |
| | d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| |
|
| | |
| | self.blocks = nn.ModuleList( |
| | [ |
| | DiCoBlock( |
| | model_dim, |
| | mlp_ratio, |
| | depthwise_kernel_size=depthwise_kernel_size, |
| | use_external_adaln=True, |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| |
|
| | |
| | self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1))) |
| |
|
| | |
| | self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) |
| | self.out_proj = nn.Conv2d( |
| | model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True |
| | ) |
| | self.unpatchify = nn.PixelShuffle(patch_size) |
| |
|
| | def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor: |
| | """Compute packed AdaLN modulation = shared_base + per-layer delta.""" |
| | act = self.adaln_base.act(cond) |
| | base_m = self.adaln_base.forward_activated(act) |
| | delta_m = self.adaln_deltas[layer_idx](act) |
| | return base_m + delta_m |
| |
|
| | def _apply_token_mask(self, fused: Tensor) -> Tensor: |
| | """Replace a fraction of spatial tokens with mask_feature (2x2 groupwise). |
| | |
| | Divides the spatial grid into 2x2 groups. Within each group, masks |
| | floor(ratio * 4) tokens deterministically (lowest random scores). |
| | |
| | Args: |
| | fused: [B, C, H, W] fused decoder input. |
| | |
| | Returns: |
| | Masked tensor with same shape, where masked positions contain mask_feature. |
| | """ |
| | b, c, h, w = fused.shape |
| | |
| | h_pad = (2 - h % 2) % 2 |
| | w_pad = (2 - w % 2) % 2 |
| | if h_pad > 0 or w_pad > 0: |
| | fused = torch.nn.functional.pad(fused, (0, w_pad, 0, h_pad)) |
| | _, _, h, w = fused.shape |
| |
|
| | |
| | x = fused.reshape(b, c, h // 2, 2, w // 2, 2) |
| | x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4) |
| |
|
| | |
| | scores = torch.rand(b, 1, h // 2, w // 2, 4, device=fused.device) |
| |
|
| | |
| | num_mask = math.floor(self.pdg_mask_ratio * 4) |
| | if num_mask > 0: |
| | |
| | _, indices = scores.sort(dim=-1) |
| | mask = torch.zeros_like(scores, dtype=torch.bool) |
| | mask.scatter_(-1, indices[..., :num_mask], True) |
| | else: |
| | mask = torch.zeros_like(scores, dtype=torch.bool) |
| |
|
| | |
| | mask_feat = self.mask_feature.to(device=fused.device, dtype=fused.dtype) |
| | mask_feat = mask_feat.squeeze(-1).squeeze(-1) |
| | mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x) |
| | mask_expanded = mask.expand_as(x) |
| | x = torch.where(mask_expanded, mask_feat, x) |
| |
|
| | |
| | x = x.reshape(b, c, h // 2, w // 2, 2, 2) |
| | x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w) |
| |
|
| | |
| | if h_pad > 0 or w_pad > 0: |
| | x = x[:, :, : h - h_pad, : w - w_pad] |
| |
|
| | return x |
| |
|
| | def forward( |
| | self, |
| | x_t: Tensor, |
| | t: Tensor, |
| | latents: Tensor, |
| | *, |
| | mask_tokens: bool = False, |
| | ) -> Tensor: |
| | """Single decoder forward pass. |
| | |
| | Args: |
| | x_t: Noised image [B, C, H, W]. |
| | t: Timestep [B] in [0, 1]. |
| | latents: Encoder latents [B, bottleneck_dim, h, w]. |
| | mask_tokens: If True, apply token-level masking to decoder input (for PDG). |
| | |
| | Returns: |
| | x0 prediction [B, C, H, W]. |
| | """ |
| | |
| | x_feat = self.patchify(x_t) |
| | x_feat = self.norm_in(x_feat) |
| |
|
| | |
| | z_up = self.latent_up(latents) |
| | z_up = self.latent_norm(z_up) |
| | fused = torch.cat([x_feat, z_up], dim=1) |
| | fused = self.fuse_in(fused) |
| |
|
| | |
| | if mask_tokens: |
| | fused = self._apply_token_mask(fused) |
| |
|
| | |
| | cond = self.time_embed(t.to(torch.float32).to(device=x_t.device)) |
| |
|
| | |
| | x = fused |
| | for layer_idx, block in enumerate(self.blocks): |
| | adaln_m = self._adaln_m_for_layer(cond, layer_idx=layer_idx) |
| | x = block(x, adaln_m=adaln_m) |
| |
|
| | |
| | x = self.norm_out(x) |
| | patches = self.out_proj(x) |
| | return self.unpatchify(patches) |
| |
|