| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from typing import Type |
| |
|
| |
|
| | class MLPBlock(nn.Module): |
| | def __init__( |
| | self, |
| | embedding_dim: int, |
| | mlp_dim: int, |
| | act: Type[nn.Module] = nn.GELU, |
| | ) -> None: |
| | super().__init__() |
| | self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
| | self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
| | self.act = act() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.lin2(self.act(self.lin1(x))) |
| |
|
| |
|
| | |
| | |
| | class LayerNorm2d(nn.Module): |
| | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(num_channels)) |
| | self.bias = nn.Parameter(torch.zeros(num_channels)) |
| | self.eps = eps |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | u = x.mean(1, keepdim=True) |
| | s = (x - u).pow(2).mean(1, keepdim=True) |
| | x = (x - u) / torch.sqrt(s + self.eps) |
| | x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| | return x |
| |
|
| |
|
| | def val2list(x: list or tuple or any, repeat_time=1) -> list: |
| | if isinstance(x, (list, tuple)): |
| | return list(x) |
| | return [x for _ in range(repeat_time)] |
| |
|
| |
|
| | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: |
| | x = val2list(x) |
| |
|
| | |
| | if len(x) > 0: |
| | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] |
| |
|
| | return tuple(x) |
| |
|
| |
|
| | def list_sum(x: list) -> any: |
| | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) |
| |
|
| |
|
| | def resize( |
| | x: torch.Tensor, |
| | size: any or None = None, |
| | scale_factor=None, |
| | mode: str = "bicubic", |
| | align_corners: bool or None = False, |
| | ) -> torch.Tensor: |
| | if mode in ["bilinear", "bicubic"]: |
| | return F.interpolate( |
| | x, |
| | size=size, |
| | scale_factor=scale_factor, |
| | mode=mode, |
| | align_corners=align_corners, |
| | ) |
| | elif mode in ["nearest", "area"]: |
| | return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) |
| | else: |
| | raise NotImplementedError(f"resize(mode={mode}) not implemented.") |
| |
|
| |
|
| | class UpSampleLayer(nn.Module): |
| | def __init__( |
| | self, |
| | mode="bicubic", |
| | size=None, |
| | factor=2, |
| | align_corners=False, |
| | ): |
| | super(UpSampleLayer, self).__init__() |
| | self.mode = mode |
| | self.size = val2list(size, 2) if size is not None else None |
| | self.factor = None if self.size is not None else factor |
| | self.align_corners = align_corners |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return resize(x, self.size, self.factor, self.mode, self.align_corners) |
| |
|
| |
|
| | class OpSequential(nn.Module): |
| | def __init__(self, op_list): |
| | super(OpSequential, self).__init__() |
| | valid_op_list = [] |
| | for op in op_list: |
| | if op is not None: |
| | valid_op_list.append(op) |
| | self.op_list = nn.ModuleList(valid_op_list) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | for op in self.op_list: |
| | x = op(x) |
| | return x |