"""iRDiffAE decoder: conditioned DiCoBlocks with AdaLN + skip connection.""" from __future__ import annotations import torch from torch import Tensor, nn from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector from .dico_block import DiCoBlock from .norms import ChannelWiseRMSNorm from .straight_through_encoder import Patchify from .time_embed import SinusoidalTimeEmbeddingMLP class Decoder(nn.Module): """VP diffusion decoder conditioned on encoder latents and timestep. Architecture: Patchify x_t -> Norm -> Fuse with upsampled z -> Start blocks (2) -> Middle blocks (depth-4) -> Skip fuse -> End blocks (2) -> Norm -> Conv1x1 -> PixelShuffle Middle blocks support path-drop for PDG (inference-time guidance). """ def __init__( self, in_channels: int, patch_size: int, model_dim: int, depth: int, bottleneck_dim: int, mlp_ratio: float, depthwise_kernel_size: int, adaln_low_rank_rank: int, ) -> None: super().__init__() self.patch_size = int(patch_size) self.model_dim = int(model_dim) # Input processing self.patchify = Patchify(in_channels, patch_size, model_dim) self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) # Latent conditioning path self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True) self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True) # Time embedding self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim) # AdaLN: shared base projector + per-block low-rank deltas self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim) self.adaln_deltas = nn.ModuleList( [ AdaLNZeroLowRankDelta( d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank ) for _ in range(depth) ] ) # Block layout: start(2) + middle(depth-4) + end(2) start_count = 2 end_count = 2 middle_count = depth - start_count - end_count self._middle_start_idx = start_count self._end_start_idx = start_count + middle_count def _make_blocks(count: int) -> nn.ModuleList: return nn.ModuleList( [ DiCoBlock( model_dim, mlp_ratio, depthwise_kernel_size=depthwise_kernel_size, use_external_adaln=True, ) for _ in range(count) ] ) self.start_blocks = _make_blocks(start_count) self.middle_blocks = _make_blocks(middle_count) self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True) self.end_blocks = _make_blocks(end_count) # Learned mask feature for path-drop guidance self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1))) # Output head self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) self.out_proj = nn.Conv2d( model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True ) self.unpatchify = nn.PixelShuffle(patch_size) def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor: """Compute packed AdaLN modulation = shared_base + per-layer delta.""" act = self.adaln_base.act(cond) base_m = self.adaln_base.forward_activated(act) delta_m = self.adaln_deltas[layer_idx](act) return base_m + delta_m def _run_blocks( self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int ) -> Tensor: """Run a group of decoder blocks with per-block AdaLN modulation.""" for local_idx, block in enumerate(blocks): adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx) x = block(x, adaln_m=adaln_m) return x def forward( self, x_t: Tensor, t: Tensor, latents: Tensor, *, drop_middle_blocks: bool = False, ) -> Tensor: """Single decoder forward pass. Args: x_t: Noised image [B, C, H, W]. t: Timestep [B] in [0, 1]. latents: Encoder latents [B, bottleneck_dim, h, w]. drop_middle_blocks: If True, replace middle block output with mask_feature (for PDG). Returns: x0 prediction [B, C, H, W]. """ # Patchify and normalize x_t x_feat = self.patchify(x_t) x_feat = self.norm_in(x_feat) # Upsample and normalize latents, fuse with x_feat z_up = self.latent_up(latents) z_up = self.latent_norm(z_up) fused = torch.cat([x_feat, z_up], dim=1) fused = self.fuse_in(fused) # Time conditioning cond = self.time_embed(t.to(torch.float32).to(device=x_t.device)) # Start blocks start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0) # Middle blocks (or mask feature for PDG) if drop_middle_blocks: middle_out = self.mask_feature.to( device=x_t.device, dtype=x_t.dtype ).expand_as(start_out) else: middle_out = self._run_blocks( self.middle_blocks, start_out, cond, start_index=self._middle_start_idx, ) # Skip fusion skip_fused = torch.cat([start_out, middle_out], dim=1) skip_fused = self.fuse_skip(skip_fused) # End blocks end_out = self._run_blocks( self.end_blocks, skip_fused, cond, start_index=self._end_start_idx ) # Output head end_out = self.norm_out(end_out) patches = self.out_proj(end_out) return self.unpatchify(patches)