File size: 7,132 Bytes
128cb34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """mDiffAE decoder: flat sequential DiCoBlocks with token-level PDG masking."""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector
from .dico_block import DiCoBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
from .time_embed import SinusoidalTimeEmbeddingMLP
class Decoder(nn.Module):
"""VP diffusion decoder conditioned on encoder latents and timestep.
Architecture:
Patchify x_t -> Norm -> Fuse with upsampled z
-> Blocks (flat sequential, depth blocks) -> Norm -> Conv1x1 -> PixelShuffle
Token-level PDG: at inference, a fraction of spatial tokens in the fused input
are replaced with a learned mask_feature before the decoder blocks. Comparing
the masked vs unmasked outputs provides guidance signal.
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
adaln_low_rank_rank: int,
pdg_mask_ratio: float = 0.75,
) -> None:
super().__init__()
self.patch_size = int(patch_size)
self.model_dim = int(model_dim)
self.pdg_mask_ratio = float(pdg_mask_ratio)
# Input processing
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
# Latent conditioning path
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
# Time embedding
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
# AdaLN: shared base projector + per-block low-rank deltas
self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim)
self.adaln_deltas = nn.ModuleList(
[
AdaLNZeroLowRankDelta(
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
)
for _ in range(depth)
]
)
# Flat sequential blocks (no start/middle/end split, no skip connections)
self.blocks = nn.ModuleList(
[
DiCoBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=True,
)
for _ in range(depth)
]
)
# Learned mask feature for token-level PDG
self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
# Output head
self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.out_proj = nn.Conv2d(
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
)
self.unpatchify = nn.PixelShuffle(patch_size)
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
act = self.adaln_base.act(cond)
base_m = self.adaln_base.forward_activated(act)
delta_m = self.adaln_deltas[layer_idx](act)
return base_m + delta_m
def _apply_token_mask(self, fused: Tensor) -> Tensor:
"""Replace a fraction of spatial tokens with mask_feature (2x2 groupwise).
Divides the spatial grid into 2x2 groups. Within each group, masks
floor(ratio * 4) tokens deterministically (lowest random scores).
Args:
fused: [B, C, H, W] fused decoder input.
Returns:
Masked tensor with same shape, where masked positions contain mask_feature.
"""
b, c, h, w = fused.shape
# Pad to even dims if needed
h_pad = (2 - h % 2) % 2
w_pad = (2 - w % 2) % 2
if h_pad > 0 or w_pad > 0:
fused = torch.nn.functional.pad(fused, (0, w_pad, 0, h_pad))
_, _, h, w = fused.shape
# Reshape into 2x2 groups: [B, C, H/2, 2, W/2, 2] -> [B, C, H/2, W/2, 4]
x = fused.reshape(b, c, h // 2, 2, w // 2, 2)
x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4)
# Random scores for each token in each group
scores = torch.rand(b, 1, h // 2, w // 2, 4, device=fused.device)
# Mask the floor(ratio * 4) lowest-scoring tokens per group
num_mask = math.floor(self.pdg_mask_ratio * 4)
if num_mask > 0:
# argsort ascending, mask the first num_mask
_, indices = scores.sort(dim=-1)
mask = torch.zeros_like(scores, dtype=torch.bool)
mask.scatter_(-1, indices[..., :num_mask], True)
else:
mask = torch.zeros_like(scores, dtype=torch.bool)
# Apply mask: replace masked tokens with mask_feature
mask_feat = self.mask_feature.to(device=fused.device, dtype=fused.dtype)
mask_feat = mask_feat.squeeze(-1).squeeze(-1) # [1, C]
mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x)
mask_expanded = mask.expand_as(x)
x = torch.where(mask_expanded, mask_feat, x)
# Reshape back to [B, C, H, W]
x = x.reshape(b, c, h // 2, w // 2, 2, 2)
x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w)
# Remove padding if applied
if h_pad > 0 or w_pad > 0:
x = x[:, :, : h - h_pad, : w - w_pad]
return x
def forward(
self,
x_t: Tensor,
t: Tensor,
latents: Tensor,
*,
mask_tokens: bool = False,
) -> Tensor:
"""Single decoder forward pass.
Args:
x_t: Noised image [B, C, H, W].
t: Timestep [B] in [0, 1].
latents: Encoder latents [B, bottleneck_dim, h, w].
mask_tokens: If True, apply token-level masking to decoder input (for PDG).
Returns:
x0 prediction [B, C, H, W].
"""
# Patchify and normalize x_t
x_feat = self.patchify(x_t)
x_feat = self.norm_in(x_feat)
# Upsample and normalize latents, fuse with x_feat
z_up = self.latent_up(latents)
z_up = self.latent_norm(z_up)
fused = torch.cat([x_feat, z_up], dim=1)
fused = self.fuse_in(fused)
# Token masking for PDG (replaces tokens with mask_feature)
if mask_tokens:
fused = self._apply_token_mask(fused)
# Time conditioning
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
# Run all blocks sequentially
x = fused
for layer_idx, block in enumerate(self.blocks):
adaln_m = self._adaln_m_for_layer(cond, layer_idx=layer_idx)
x = block(x, adaln_m=adaln_m)
# Output head
x = self.norm_out(x)
patches = self.out_proj(x)
return self.unpatchify(patches)
|