| """Conv-based MLP with GELU activation for DiCo blocks.""" | |
| from __future__ import annotations | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| from .norms import ChannelWiseRMSNorm | |
| class ConvMLP(nn.Module): | |
| """1x1 Conv-based MLP: RMSNorm -> Conv1x1 -> GELU -> Conv1x1.""" | |
| def __init__( | |
| self, channels: int, hidden_channels: int, norm_eps: float = 1e-6 | |
| ) -> None: | |
| super().__init__() | |
| self.norm = ChannelWiseRMSNorm(channels, eps=norm_eps, affine=False) | |
| self.conv_in = nn.Conv2d(channels, hidden_channels, kernel_size=1, bias=True) | |
| self.conv_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True) | |
| def forward(self, x: Tensor) -> Tensor: | |
| y = self.norm(x) | |
| y = self.conv_in(y) | |
| y = F.gelu(y) | |
| return self.conv_out(y) | |