File size: 6,108 Bytes
1ed770c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """iRDiffAE decoder: conditioned DiCoBlocks with AdaLN + skip connection."""
from __future__ import annotations
import torch
from torch import Tensor, nn
from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector
from .dico_block import DiCoBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
from .time_embed import SinusoidalTimeEmbeddingMLP
class Decoder(nn.Module):
"""VP diffusion decoder conditioned on encoder latents and timestep.
Architecture:
Patchify x_t -> Norm -> Fuse with upsampled z
-> Start blocks (2) -> Middle blocks (depth-4) -> Skip fuse -> End blocks (2)
-> Norm -> Conv1x1 -> PixelShuffle
Middle blocks support path-drop for PDG (inference-time guidance).
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
adaln_low_rank_rank: int,
) -> None:
super().__init__()
self.patch_size = int(patch_size)
self.model_dim = int(model_dim)
# Input processing
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
# Latent conditioning path
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
# Time embedding
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
# AdaLN: shared base projector + per-block low-rank deltas
self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim)
self.adaln_deltas = nn.ModuleList(
[
AdaLNZeroLowRankDelta(
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
)
for _ in range(depth)
]
)
# Block layout: start(2) + middle(depth-4) + end(2)
start_count = 2
end_count = 2
middle_count = depth - start_count - end_count
self._middle_start_idx = start_count
self._end_start_idx = start_count + middle_count
def _make_blocks(count: int) -> nn.ModuleList:
return nn.ModuleList(
[
DiCoBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=True,
)
for _ in range(count)
]
)
self.start_blocks = _make_blocks(start_count)
self.middle_blocks = _make_blocks(middle_count)
self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
self.end_blocks = _make_blocks(end_count)
# Learned mask feature for path-drop guidance
self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
# Output head
self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.out_proj = nn.Conv2d(
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
)
self.unpatchify = nn.PixelShuffle(patch_size)
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
act = self.adaln_base.act(cond)
base_m = self.adaln_base.forward_activated(act)
delta_m = self.adaln_deltas[layer_idx](act)
return base_m + delta_m
def _run_blocks(
self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
) -> Tensor:
"""Run a group of decoder blocks with per-block AdaLN modulation."""
for local_idx, block in enumerate(blocks):
adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
x = block(x, adaln_m=adaln_m)
return x
def forward(
self,
x_t: Tensor,
t: Tensor,
latents: Tensor,
*,
drop_middle_blocks: bool = False,
) -> Tensor:
"""Single decoder forward pass.
Args:
x_t: Noised image [B, C, H, W].
t: Timestep [B] in [0, 1].
latents: Encoder latents [B, bottleneck_dim, h, w].
drop_middle_blocks: If True, replace middle block output with mask_feature (for PDG).
Returns:
x0 prediction [B, C, H, W].
"""
# Patchify and normalize x_t
x_feat = self.patchify(x_t)
x_feat = self.norm_in(x_feat)
# Upsample and normalize latents, fuse with x_feat
z_up = self.latent_up(latents)
z_up = self.latent_norm(z_up)
fused = torch.cat([x_feat, z_up], dim=1)
fused = self.fuse_in(fused)
# Time conditioning
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
# Start blocks
start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
# Middle blocks (or mask feature for PDG)
if drop_middle_blocks:
middle_out = self.mask_feature.to(
device=x_t.device, dtype=x_t.dtype
).expand_as(start_out)
else:
middle_out = self._run_blocks(
self.middle_blocks,
start_out,
cond,
start_index=self._middle_start_idx,
)
# Skip fusion
skip_fused = torch.cat([start_out, middle_out], dim=1)
skip_fused = self.fuse_skip(skip_fused)
# End blocks
end_out = self._run_blocks(
self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
)
# Output head
end_out = self.norm_out(end_out)
patches = self.out_proj(end_out)
return self.unpatchify(patches)
|