File size: 830 Bytes
128cb34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | """Conv-based MLP with GELU activation for DiCo blocks."""
from __future__ import annotations
import torch.nn.functional as F
from torch import Tensor, nn
from .norms import ChannelWiseRMSNorm
class ConvMLP(nn.Module):
"""1x1 Conv-based MLP: RMSNorm -> Conv1x1 -> GELU -> Conv1x1."""
def __init__(
self, channels: int, hidden_channels: int, norm_eps: float = 1e-6
) -> None:
super().__init__()
self.norm = ChannelWiseRMSNorm(channels, eps=norm_eps, affine=False)
self.conv_in = nn.Conv2d(channels, hidden_channels, kernel_size=1, bias=True)
self.conv_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True)
def forward(self, x: Tensor) -> Tensor:
y = self.norm(x)
y = self.conv_in(y)
y = F.gelu(y)
return self.conv_out(y)
|