| from typing import Tuple |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class SiLU(nn.Module): |
| """Module SiLU (Sigmoid Linear Units) |
| |
| This implementation is to support pytorch < 1.8, and will be deprecated after 1.8. |
| |
| Paper: https://arxiv.org/abs/1702.03118 |
| """ |
| def forward(self, x): |
| return x * torch.sigmoid(x) |
|
|
|
|
| def conv_1x1_bn(inp: int, oup: int) -> nn.Module: |
| return nn.Sequential( |
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), |
| nn.BatchNorm2d(oup), |
| SiLU() |
| ) |
|
|
|
|
| def conv_nxn_bn(inp: int, oup: int, kernal_size: int = 3, stride: int = 1) -> nn.Module: |
| return nn.Sequential( |
| nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), |
| nn.BatchNorm2d(oup), |
| SiLU() |
| ) |
|
|
|
|
| class PreNorm(nn.Module): |
| def __init__(self, dim: int, fn: nn.Module) -> None: |
| super().__init__() |
| self.norm = nn.LayerNorm(dim) |
| self.fn = fn |
|
|
| def forward(self, x: torch.Tensor, **kwargs) -> nn.Module: |
| return self.fn(self.norm(x), **kwargs) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.) -> None: |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, hidden_dim), |
| SiLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, dim), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.) -> None: |
| super().__init__() |
| inner_dim = dim_head * heads |
| project_out = not (heads == 1 and dim_head == dim) |
|
|
| self.heads = heads |
| self.scale = dim_head ** -0.5 |
|
|
| self.attend = nn.Softmax(dim=-1) |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
|
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, dim), |
| nn.Dropout(dropout) |
| ) if project_out else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
|
| B, P, N, HD = qkv[0].shape |
| q, k, v = map(lambda t: t.contiguous().view(B, P, self.heads, N, HD // self.heads), qkv) |
|
|
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| attn = self.attend(dots) |
| out = torch.matmul(attn, v) |
| B, P, H, N, D = out.shape |
| out = out.view(B, P, N, H * D) |
| return self.to_out(out) |
|
|
|
|
| class Transformer(nn.Module): |
| """Transformer block described in ViT. |
| |
| Paper: https://arxiv.org/abs/2010.11929 |
| Based on: https://github.com/lucidrains/vit-pytorch |
| |
| Args: |
| dim: input dimension. |
| depth: depth for transformer block. |
| heads: number of heads in multi-head attention layer. |
| dim_head: head size. |
| mlp_dim: dimension of the FeedForward layer. |
| dropout: dropout ratio, defaults to 0. |
| """ |
|
|
| def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.) -> None: |
| super().__init__() |
| self.layers = nn.ModuleList([]) |
| for _ in range(depth): |
| self.layers.append(nn.ModuleList([ |
| PreNorm(dim, Attention(dim, heads, dim_head, dropout)), |
| PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) |
| ])) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for attn, ff in self.layers: |
| x = attn(x) + x |
| x = ff(x) + x |
| return x |
|
|
|
|
| class MV2Block(nn.Module): |
| """MV2 block described in MobileNetV2. |
| |
| Paper: https://arxiv.org/pdf/1801.04381 |
| Based on: https://github.com/tonylins/pytorch-mobilenet-v2 |
| |
| Args: |
| inp: input channel. |
| oup: output channel. |
| stride: stride for convolution, defaults to 1, set to 2 if down-sample. |
| expansion: expansion ratio for hidden dimension, defaults to 4. |
| """ |
|
|
| def __init__(self, inp: int, oup: int, stride: int = 1, expansion: int = 4) -> None: |
| super().__init__() |
| self.stride = stride |
|
|
| hidden_dim = int(inp * expansion) |
| self.use_res_connect = self.stride == 1 and inp == oup |
|
|
| if expansion == 1: |
| self.conv = nn.Sequential( |
| |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), |
| nn.BatchNorm2d(hidden_dim), |
| SiLU(), |
| |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), |
| nn.BatchNorm2d(oup), |
| ) |
| else: |
| self.conv = nn.Sequential( |
| |
| nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), |
| nn.BatchNorm2d(hidden_dim), |
| SiLU(), |
| |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), |
| nn.BatchNorm2d(hidden_dim), |
| SiLU(), |
| |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), |
| nn.BatchNorm2d(oup), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.use_res_connect: |
| return x + self.conv(x) |
| else: |
| return self.conv(x) |
|
|
|
|
| class MobileViTBlock(nn.Module): |
| """MobileViT block mentioned in MobileViT. |
| |
| Args: |
| dim: input dimension of Transformer. |
| depth: depth of Transformer. |
| channel: input channel. |
| kernel_size: kernel size. |
| patch_size: patch size for folding and unfloding. |
| mlp_dim: dimension of the FeedForward layer in Transformer. |
| dropout: dropout ratio, defaults to 0. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| depth: int, |
| channel: int, |
| kernel_size: int, |
| patch_size: Tuple[int, int], |
| mlp_dim: int, |
| dropout: float = 0. |
| ) -> None: |
| super().__init__() |
| self.ph, self.pw = patch_size |
|
|
| self.conv1 = conv_nxn_bn(channel, channel, kernel_size) |
| self.conv2 = conv_1x1_bn(channel, dim) |
|
|
| self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout) |
|
|
| self.conv3 = conv_1x1_bn(dim, channel) |
| self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| y = x.clone() |
|
|
| |
| x = self.conv1(x) |
| x = self.conv2(x) |
|
|
| |
| b, d, h, w = x.shape |
| x = x.view(b, self.ph * self.pw, (h // self.ph) * (w // self.pw), d) |
| x = self.transformer(x) |
| x = x.view(b, d, h, w) |
|
|
| |
| x = self.conv3(x) |
| x = torch.cat((x, y), 1) |
| x = self.conv4(x) |
| return x |
|
|
|
|
| class MobileViT(nn.Module): |
| """Module MobileViT. Default arguments is for MobileViT XXS. |
| |
| Paper: https://arxiv.org/abs/2110.02178 |
| Based on: https://github.com/chinhsuanwu/mobilevit-pytorch |
| |
| Args: |
| mode: 'xxs', 'xs' or 's', defaults to 'xxs'. |
| in_channels: the number of channels for the input image. |
| patch_size: image_size must be divisible by patch_size. |
| dropout: dropout ratio in Transformer. |
| |
| Example: |
| >>> img = torch.rand(1, 3, 256, 256) |
| >>> mvit = MobileViT(mode='xxs') |
| >>> mvit(img).shape |
| torch.Size([1, 320, 8, 8]) |
| """ |
|
|
| def __init__( |
| self, |
| mode: str = 'xxs', |
| in_channels: int = 3, |
| patch_size: Tuple[int, int] = (2, 2), |
| dropout: float = 0. |
| ) -> None: |
| super().__init__() |
| if mode == 'xxs': |
| expansion = 2 |
| dims = [64, 80, 96] |
| channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320] |
| elif mode == 'xs': |
| expansion = 4 |
| dims = [96, 120, 144] |
| channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] |
| elif mode == 's': |
| expansion = 4 |
| dims = [144, 192, 240] |
| channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640] |
|
|
| kernel_size = 3 |
| depth = [2, 4, 3] |
|
|
| self.conv1 = conv_nxn_bn(in_channels, channels[0], stride=2) |
|
|
| self.mv2 = nn.ModuleList([]) |
| self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) |
| self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) |
| self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) |
| self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) |
| self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) |
| self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) |
| self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) |
|
|
| self.mvit = nn.ModuleList([]) |
| self.mvit.append(MobileViTBlock(dims[0], depth[0], channels[5], |
| kernel_size, patch_size, int(dims[0] * 2), dropout=dropout)) |
| self.mvit.append(MobileViTBlock(dims[1], depth[1], channels[7], |
| kernel_size, patch_size, int(dims[1] * 4), dropout=dropout)) |
| self.mvit.append(MobileViTBlock(dims[2], depth[2], channels[9], |
| kernel_size, patch_size, int(dims[2] * 4), dropout=dropout)) |
|
|
| self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.conv1(x) |
| x = self.mv2[0](x) |
|
|
| x = self.mv2[1](x) |
| x = self.mv2[2](x) |
| x = self.mv2[3](x) |
|
|
| x = self.mv2[4](x) |
| x = self.mvit[0](x) |
|
|
| x = self.mv2[5](x) |
| x = self.mvit[1](x) |
|
|
| x = self.mv2[6](x) |
| x = self.mvit[2](x) |
| x = self.conv2(x) |
| return x |
|
|