File size: 830 Bytes
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Conv-based MLP with GELU activation for DiCo blocks."""

from __future__ import annotations

import torch.nn.functional as F
from torch import Tensor, nn

from .norms import ChannelWiseRMSNorm


class ConvMLP(nn.Module):
    """1x1 Conv-based MLP: RMSNorm -> Conv1x1 -> GELU -> Conv1x1."""

    def __init__(
        self, channels: int, hidden_channels: int, norm_eps: float = 1e-6
    ) -> None:
        super().__init__()
        self.norm = ChannelWiseRMSNorm(channels, eps=norm_eps, affine=False)
        self.conv_in = nn.Conv2d(channels, hidden_channels, kernel_size=1, bias=True)
        self.conv_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True)

    def forward(self, x: Tensor) -> Tensor:
        y = self.norm(x)
        y = self.conv_in(y)
        y = F.gelu(y)
        return self.conv_out(y)