| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from typing import Type |
| |
|
| |
|
| | class MLPBlock(nn.Module): |
| | def __init__( |
| | self, |
| | embedding_dim: int, |
| | mlp_dim: int, |
| | act: Type[nn.Module] = nn.GELU, |
| | ) -> None: |
| | super().__init__() |
| | self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
| | self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
| | self.act = act() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.lin2(self.act(self.lin1(x))) |
| |
|
| |
|
| | |
| | |
| | class LayerNorm2d(nn.Module): |
| | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(num_channels)) |
| | self.bias = nn.Parameter(torch.zeros(num_channels)) |
| | self.eps = eps |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | u = x.mean(1, keepdim=True) |
| | s = (x - u).pow(2).mean(1, keepdim=True) |
| | x = (x - u) / torch.sqrt(s + self.eps) |
| | x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| | return x |
| |
|