irdiffae-v1 / ir_diffae /decoder.py
data-archetype's picture
Initial upload: iRDiffAE v1.0 (p16_c128, EMA weights)
1ed770c verified
"""iRDiffAE decoder: conditioned DiCoBlocks with AdaLN + skip connection."""
from __future__ import annotations
import torch
from torch import Tensor, nn
from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector
from .dico_block import DiCoBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
from .time_embed import SinusoidalTimeEmbeddingMLP
class Decoder(nn.Module):
"""VP diffusion decoder conditioned on encoder latents and timestep.
Architecture:
Patchify x_t -> Norm -> Fuse with upsampled z
-> Start blocks (2) -> Middle blocks (depth-4) -> Skip fuse -> End blocks (2)
-> Norm -> Conv1x1 -> PixelShuffle
Middle blocks support path-drop for PDG (inference-time guidance).
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
adaln_low_rank_rank: int,
) -> None:
super().__init__()
self.patch_size = int(patch_size)
self.model_dim = int(model_dim)
# Input processing
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
# Latent conditioning path
self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
self.latent_norm = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
# Time embedding
self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
# AdaLN: shared base projector + per-block low-rank deltas
self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim)
self.adaln_deltas = nn.ModuleList(
[
AdaLNZeroLowRankDelta(
d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
)
for _ in range(depth)
]
)
# Block layout: start(2) + middle(depth-4) + end(2)
start_count = 2
end_count = 2
middle_count = depth - start_count - end_count
self._middle_start_idx = start_count
self._end_start_idx = start_count + middle_count
def _make_blocks(count: int) -> nn.ModuleList:
return nn.ModuleList(
[
DiCoBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=True,
)
for _ in range(count)
]
)
self.start_blocks = _make_blocks(start_count)
self.middle_blocks = _make_blocks(middle_count)
self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
self.end_blocks = _make_blocks(end_count)
# Learned mask feature for path-drop guidance
self.mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
# Output head
self.norm_out = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.out_proj = nn.Conv2d(
model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
)
self.unpatchify = nn.PixelShuffle(patch_size)
def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
"""Compute packed AdaLN modulation = shared_base + per-layer delta."""
act = self.adaln_base.act(cond)
base_m = self.adaln_base.forward_activated(act)
delta_m = self.adaln_deltas[layer_idx](act)
return base_m + delta_m
def _run_blocks(
self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
) -> Tensor:
"""Run a group of decoder blocks with per-block AdaLN modulation."""
for local_idx, block in enumerate(blocks):
adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
x = block(x, adaln_m=adaln_m)
return x
def forward(
self,
x_t: Tensor,
t: Tensor,
latents: Tensor,
*,
drop_middle_blocks: bool = False,
) -> Tensor:
"""Single decoder forward pass.
Args:
x_t: Noised image [B, C, H, W].
t: Timestep [B] in [0, 1].
latents: Encoder latents [B, bottleneck_dim, h, w].
drop_middle_blocks: If True, replace middle block output with mask_feature (for PDG).
Returns:
x0 prediction [B, C, H, W].
"""
# Patchify and normalize x_t
x_feat = self.patchify(x_t)
x_feat = self.norm_in(x_feat)
# Upsample and normalize latents, fuse with x_feat
z_up = self.latent_up(latents)
z_up = self.latent_norm(z_up)
fused = torch.cat([x_feat, z_up], dim=1)
fused = self.fuse_in(fused)
# Time conditioning
cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
# Start blocks
start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
# Middle blocks (or mask feature for PDG)
if drop_middle_blocks:
middle_out = self.mask_feature.to(
device=x_t.device, dtype=x_t.dtype
).expand_as(start_out)
else:
middle_out = self._run_blocks(
self.middle_blocks,
start_out,
cond,
start_index=self._middle_start_idx,
)
# Skip fusion
skip_fused = torch.cat([start_out, middle_out], dim=1)
skip_fused = self.fuse_skip(skip_fused)
# End blocks
end_out = self._run_blocks(
self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
)
# Output head
end_out = self.norm_out(end_out)
patches = self.out_proj(end_out)
return self.unpatchify(patches)