"""DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP.""" from __future__ import annotations import torch import torch.nn.functional as F from torch import Tensor, nn from .compact_channel_attention import CompactChannelAttention from .conv_mlp import ConvMLP from .norms import ChannelWiseRMSNorm class DiCoBlock(nn.Module): """DiCo-style conv block with optional external AdaLN conditioning. Two modes: - Unconditioned (encoder): uses learned per-channel residual gates. - External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m. """ def __init__( self, channels: int, mlp_ratio: float, *, depthwise_kernel_size: int = 7, use_external_adaln: bool = False, norm_eps: float = 1e-6, ) -> None: super().__init__() self.channels = int(channels) self.use_external_adaln = bool(use_external_adaln) # Pre-norm for conv and MLP paths (no affine) self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False) self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False) # Conv path: 1x1 -> depthwise kxk -> SiLU -> CCA -> 1x1 self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True) self.conv2 = nn.Conv2d( self.channels, self.channels, kernel_size=depthwise_kernel_size, padding=depthwise_kernel_size // 2, groups=self.channels, bias=True, ) self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True) self.cca = CompactChannelAttention(self.channels) # MLP path: GELU activation hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1) self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps) # Conditioning: learned gates (encoder) or external adaln_m (decoder) if not self.use_external_adaln: self.gate_attn = nn.Parameter(torch.zeros(self.channels)) self.gate_mlp = nn.Parameter(torch.zeros(self.channels)) def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor: b, c = x.shape[:2] if self.use_external_adaln: if adaln_m is None: raise ValueError( "adaln_m required for externally-conditioned DiCoBlock" ) adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype) scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1) elif adaln_m is not None: raise ValueError("adaln_m must be None for unconditioned DiCoBlock") residual = x # Conv path x_att = self.norm1(x) if self.use_external_adaln: x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1)) # type: ignore[possibly-undefined] y = self.conv1(x_att) y = self.conv2(y) y = F.silu(y) y = y * self.cca(y) y = self.conv3(y) if self.use_external_adaln: gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1) # type: ignore[possibly-undefined] x = residual + gate_a_view * y else: gate = self.gate_attn.view(1, self.channels, 1, 1).to( dtype=y.dtype, device=y.device ) x = residual + gate * y # MLP path residual_mlp = x x_mlp = self.norm2(x) if self.use_external_adaln: x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1)) # type: ignore[possibly-undefined] y_mlp = self.mlp(x_mlp) if self.use_external_adaln: gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1) # type: ignore[possibly-undefined] x = residual_mlp + gate_m_view * y_mlp else: gate = self.gate_mlp.view(1, self.channels, 1, 1).to( dtype=y_mlp.dtype, device=y_mlp.device ) x = residual_mlp + gate * y_mlp return x