"""mDiffAE decoder: flat sequential DiCoBlocks with token-level PDG masking.""" from __future__ import annotations import math import torch from torch import Tensor, nn from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector from .dico_block import DiCoBlock from .norms import ChannelWiseRMSNorm from .straight_through_encoder import Patchify from .time_embed import SinusoidalTimeEmbeddingMLP class Decoder(nn.Module): """VP diffusion decoder conditioned on encoder latents and timestep. Architecture: Patchify x_t -> Norm -> Fuse with upsampled z -> Blocks (flat sequential, depth blocks) -> Norm -> Conv1x1 -> PixelShuffle Token-level PDG: at inference, a fraction of spatial tokens in the fused input are replaced with a learned mask_feature before the decoder blocks. Comparing the masked vs unmasked outputs provides guidance signal. """ def __init__( self, in_channels: int, patch_size: int, model_dim: int, depth: int, bottleneck_dim: int, mlp_ratio: float, depthwise_kernel_size: int, adaln_low_rank_rank: int, pdg_mask_ratio: float = 0.75, ) -> None: super().__init__() self.patch_size = int(patch_size) self.model_dim = int(model_dim) self.pdg_mask_ratio = float(pdg_mask_ratio) # Input processing self.patchify = Patchify(in_channels, patch_size, model_dim) self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) # Latent conditioning path self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True) self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True) # Time embedding self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim) # AdaLN: shared base projector + per-block low-rank deltas self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim) self.adaln_deltas = nn.ModuleList( [ AdaLNZeroLowRankDelta( d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank ) for _ in range(depth) ] ) # Flat sequential blocks (no start/middle/end split, no skip connections) self.blocks = nn.ModuleList( [ DiCoBlock( model_dim, mlp_ratio, depthwise_kernel_size=depthwise_kernel_size, use_external_adaln=True, ) for _ in range(depth) ] ) # Learned mask feature for token-level PDG self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1))) # Output head self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) self.out_proj = nn.Conv2d( model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True ) self.unpatchify = nn.PixelShuffle(patch_size) def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor: """Compute packed AdaLN modulation = shared_base + per-layer delta.""" act = self.adaln_base.act(cond) base_m = self.adaln_base.forward_activated(act) delta_m = self.adaln_deltas[layer_idx](act) return base_m + delta_m def _apply_token_mask(self, fused: Tensor) -> Tensor: """Replace a fraction of spatial tokens with mask_feature (2x2 groupwise). Divides the spatial grid into 2x2 groups. Within each group, masks floor(ratio * 4) tokens deterministically (lowest random scores). Args: fused: [B, C, H, W] fused decoder input. Returns: Masked tensor with same shape, where masked positions contain mask_feature. """ b, c, h, w = fused.shape # Pad to even dims if needed h_pad = (2 - h % 2) % 2 w_pad = (2 - w % 2) % 2 if h_pad > 0 or w_pad > 0: fused = torch.nn.functional.pad(fused, (0, w_pad, 0, h_pad)) _, _, h, w = fused.shape # Reshape into 2x2 groups: [B, C, H/2, 2, W/2, 2] -> [B, C, H/2, W/2, 4] x = fused.reshape(b, c, h // 2, 2, w // 2, 2) x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4) # Random scores for each token in each group scores = torch.rand(b, 1, h // 2, w // 2, 4, device=fused.device) # Mask the floor(ratio * 4) lowest-scoring tokens per group num_mask = math.floor(self.pdg_mask_ratio * 4) if num_mask > 0: # argsort ascending, mask the first num_mask _, indices = scores.sort(dim=-1) mask = torch.zeros_like(scores, dtype=torch.bool) mask.scatter_(-1, indices[..., :num_mask], True) else: mask = torch.zeros_like(scores, dtype=torch.bool) # Apply mask: replace masked tokens with mask_feature mask_feat = self.mask_feature.to(device=fused.device, dtype=fused.dtype) mask_feat = mask_feat.squeeze(-1).squeeze(-1) # [1, C] mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x) mask_expanded = mask.expand_as(x) x = torch.where(mask_expanded, mask_feat, x) # Reshape back to [B, C, H, W] x = x.reshape(b, c, h // 2, w // 2, 2, 2) x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w) # Remove padding if applied if h_pad > 0 or w_pad > 0: x = x[:, :, : h - h_pad, : w - w_pad] return x def forward( self, x_t: Tensor, t: Tensor, latents: Tensor, *, mask_tokens: bool = False, ) -> Tensor: """Single decoder forward pass. Args: x_t: Noised image [B, C, H, W]. t: Timestep [B] in [0, 1]. latents: Encoder latents [B, bottleneck_dim, h, w]. mask_tokens: If True, apply token-level masking to decoder input (for PDG). Returns: x0 prediction [B, C, H, W]. """ # Patchify and normalize x_t x_feat = self.patchify(x_t) x_feat = self.norm_in(x_feat) # Upsample and normalize latents, fuse with x_feat z_up = self.latent_up(latents) z_up = self.latent_norm(z_up) fused = torch.cat([x_feat, z_up], dim=1) fused = self.fuse_in(fused) # Token masking for PDG (replaces tokens with mask_feature) if mask_tokens: fused = self._apply_token_mask(fused) # Time conditioning cond = self.time_embed(t.to(torch.float32).to(device=x_t.device)) # Run all blocks sequentially x = fused for layer_idx, block in enumerate(self.blocks): adaln_m = self._adaln_m_for_layer(cond, layer_idx=layer_idx) x = block(x, adaln_m=adaln_m) # Output head x = self.norm_out(x) patches = self.out_proj(x) return self.unpatchify(patches)