"""Conv-based MLP with GELU activation for DiCo blocks.""" from __future__ import annotations import torch.nn.functional as F from torch import Tensor, nn from .norms import ChannelWiseRMSNorm class ConvMLP(nn.Module): """1x1 Conv-based MLP: RMSNorm -> Conv1x1 -> GELU -> Conv1x1.""" def __init__( self, channels: int, hidden_channels: int, norm_eps: float = 1e-6 ) -> None: super().__init__() self.norm = ChannelWiseRMSNorm(channels, eps=norm_eps, affine=False) self.conv_in = nn.Conv2d(channels, hidden_channels, kernel_size=1, bias=True) self.conv_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True) def forward(self, x: Tensor) -> Tensor: y = self.norm(x) y = self.conv_in(y) y = F.gelu(y) return self.conv_out(y)