irdiffae-v1 / ir_diffae /dico_block.py
data-archetype's picture
Initial upload: iRDiffAE v1.0 (p16_c128, EMA weights)
1ed770c verified
"""DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP."""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .compact_channel_attention import CompactChannelAttention
from .conv_mlp import ConvMLP
from .norms import ChannelWiseRMSNorm
class DiCoBlock(nn.Module):
"""DiCo-style conv block with optional external AdaLN conditioning.
Two modes:
- Unconditioned (encoder): uses learned per-channel residual gates.
- External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m.
"""
def __init__(
self,
channels: int,
mlp_ratio: float,
*,
depthwise_kernel_size: int = 7,
use_external_adaln: bool = False,
norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.channels = int(channels)
self.use_external_adaln = bool(use_external_adaln)
# Pre-norm for conv and MLP paths (no affine)
self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
# Conv path: 1x1 -> depthwise kxk -> SiLU -> CCA -> 1x1
self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
self.conv2 = nn.Conv2d(
self.channels,
self.channels,
kernel_size=depthwise_kernel_size,
padding=depthwise_kernel_size // 2,
groups=self.channels,
bias=True,
)
self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
self.cca = CompactChannelAttention(self.channels)
# MLP path: GELU activation
hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1)
self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps)
# Conditioning: learned gates (encoder) or external adaln_m (decoder)
if not self.use_external_adaln:
self.gate_attn = nn.Parameter(torch.zeros(self.channels))
self.gate_mlp = nn.Parameter(torch.zeros(self.channels))
def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
b, c = x.shape[:2]
if self.use_external_adaln:
if adaln_m is None:
raise ValueError(
"adaln_m required for externally-conditioned DiCoBlock"
)
adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype)
scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1)
elif adaln_m is not None:
raise ValueError("adaln_m must be None for unconditioned DiCoBlock")
residual = x
# Conv path
x_att = self.norm1(x)
if self.use_external_adaln:
x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1)) # type: ignore[possibly-undefined]
y = self.conv1(x_att)
y = self.conv2(y)
y = F.silu(y)
y = y * self.cca(y)
y = self.conv3(y)
if self.use_external_adaln:
gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1) # type: ignore[possibly-undefined]
x = residual + gate_a_view * y
else:
gate = self.gate_attn.view(1, self.channels, 1, 1).to(
dtype=y.dtype, device=y.device
)
x = residual + gate * y
# MLP path
residual_mlp = x
x_mlp = self.norm2(x)
if self.use_external_adaln:
x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1)) # type: ignore[possibly-undefined]
y_mlp = self.mlp(x_mlp)
if self.use_external_adaln:
gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1) # type: ignore[possibly-undefined]
x = residual_mlp + gate_m_view * y_mlp
else:
gate = self.gate_mlp.view(1, self.channels, 1, 1).to(
dtype=y_mlp.dtype, device=y_mlp.device
)
x = residual_mlp + gate * y_mlp
return x