mdiffae-v1 / m_diffae /decoder.py
data-archetype's picture
Upload folder using huggingface_hub
128cb34 verified
"""mDiffAE decoder: flat sequential DiCoBlocks with token-level PDG masking."""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector
from .dico_block import DiCoBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
from .time_embed import SinusoidalTimeEmbeddingMLP
class Decoder(nn.Module):
"""VP diffusion decoder conditioned on encoder latents and timestep.
Architecture:
Patchify x_t -> Norm -> Fuse with upsampled z
-> Blocks (flat sequential, depth blocks) -> Norm -> Conv1x1 -> PixelShuffle
Token-level PDG: at inference, a fraction of spatial tokens in the fused input
are replaced with a learned mask_feature before the decoder blocks. Comparing
the masked vs unmasked outputs provides guidance signal.
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
adaln_low_rank_rank: int,
pdg_mask_ratio: float = 0.75,
) -> None:
super().__init__()
self.patch_size = int(patch_size)
self.model_dim = int(model_dim)
self.pdg_mask_ratio = float(pdg_mask_ratio)
# Input processing
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
# Latent conditioning path
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
# Time embedding
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
# AdaLN: shared base projector + per-block low-rank deltas
self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim)
self.adaln_deltas = nn.ModuleList(
[
AdaLNZeroLowRankDelta(
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
)
for _ in range(depth)
]
)
# Flat sequential blocks (no start/middle/end split, no skip connections)
self.blocks = nn.ModuleList(
[
DiCoBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=True,
)
for _ in range(depth)
]
)
# Learned mask feature for token-level PDG
self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
# Output head
self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.out_proj = nn.Conv2d(
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
)
self.unpatchify = nn.PixelShuffle(patch_size)
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
act = self.adaln_base.act(cond)
base_m = self.adaln_base.forward_activated(act)
delta_m = self.adaln_deltas[layer_idx](act)
return base_m + delta_m
def _apply_token_mask(self, fused: Tensor) -> Tensor:
"""Replace a fraction of spatial tokens with mask_feature (2x2 groupwise).
Divides the spatial grid into 2x2 groups. Within each group, masks
floor(ratio * 4) tokens deterministically (lowest random scores).
Args:
fused: [B, C, H, W] fused decoder input.
Returns:
Masked tensor with same shape, where masked positions contain mask_feature.
"""
b, c, h, w = fused.shape
# Pad to even dims if needed
h_pad = (2 - h % 2) % 2
w_pad = (2 - w % 2) % 2
if h_pad > 0 or w_pad > 0:
fused = torch.nn.functional.pad(fused, (0, w_pad, 0, h_pad))
_, _, h, w = fused.shape
# Reshape into 2x2 groups: [B, C, H/2, 2, W/2, 2] -> [B, C, H/2, W/2, 4]
x = fused.reshape(b, c, h // 2, 2, w // 2, 2)
x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4)
# Random scores for each token in each group
scores = torch.rand(b, 1, h // 2, w // 2, 4, device=fused.device)
# Mask the floor(ratio * 4) lowest-scoring tokens per group
num_mask = math.floor(self.pdg_mask_ratio * 4)
if num_mask > 0:
# argsort ascending, mask the first num_mask
_, indices = scores.sort(dim=-1)
mask = torch.zeros_like(scores, dtype=torch.bool)
mask.scatter_(-1, indices[..., :num_mask], True)
else:
mask = torch.zeros_like(scores, dtype=torch.bool)
# Apply mask: replace masked tokens with mask_feature
mask_feat = self.mask_feature.to(device=fused.device, dtype=fused.dtype)
mask_feat = mask_feat.squeeze(-1).squeeze(-1) # [1, C]
mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x)
mask_expanded = mask.expand_as(x)
x = torch.where(mask_expanded, mask_feat, x)
# Reshape back to [B, C, H, W]
x = x.reshape(b, c, h // 2, w // 2, 2, 2)
x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w)
# Remove padding if applied
if h_pad > 0 or w_pad > 0:
x = x[:, :, : h - h_pad, : w - w_pad]
return x
def forward(
self,
x_t: Tensor,
t: Tensor,
latents: Tensor,
*,
mask_tokens: bool = False,
) -> Tensor:
"""Single decoder forward pass.
Args:
x_t: Noised image [B, C, H, W].
t: Timestep [B] in [0, 1].
latents: Encoder latents [B, bottleneck_dim, h, w].
mask_tokens: If True, apply token-level masking to decoder input (for PDG).
Returns:
x0 prediction [B, C, H, W].
"""
# Patchify and normalize x_t
x_feat = self.patchify(x_t)
x_feat = self.norm_in(x_feat)
# Upsample and normalize latents, fuse with x_feat
z_up = self.latent_up(latents)
z_up = self.latent_norm(z_up)
fused = torch.cat([x_feat, z_up], dim=1)
fused = self.fuse_in(fused)
# Token masking for PDG (replaces tokens with mask_feature)
if mask_tokens:
fused = self._apply_token_mask(fused)
# Time conditioning
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
# Run all blocks sequentially
x = fused
for layer_idx, block in enumerate(self.blocks):
adaln_m = self._adaln_m_for_layer(cond, layer_idx=layer_idx)
x = block(x, adaln_m=adaln_m)
# Output head
x = self.norm_out(x)
patches = self.out_proj(x)
return self.unpatchify(patches)