Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/2105.15203 | |
| - ref: | |
| - encoder: | |
| - https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py | |
| - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/mit.py | |
| - decoder: | |
| - https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py | |
| - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.functional import F | |
| import math | |
| from einops import rearrange | |
| class MixFFN(nn.Module): | |
| def __init__(self, embed_dim, channels, dropout=0.0): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Conv1d( # fc1 | |
| in_channels=embed_dim, out_channels=channels, kernel_size=1, stride=1 | |
| ), | |
| nn.Conv1d( # position embed (depthwise-separable) | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| groups=channels, | |
| ), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Conv1d( # fc2 | |
| in_channels=channels, out_channels=embed_dim, kernel_size=1 | |
| ), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| out = x.transpose(1, 2) | |
| out = self.layers(out) | |
| out = out.transpose(1, 2) | |
| return out | |
| class EfficientMultiheadAttention(nn.Module): | |
| """ | |
| PVT(Pyramid Vision Transformer)์์ ์ฌ์ฉํ Spatial-Reduction Attention ์ ์ฐจ์ฉ | |
| ๋ณ์๋ช ์ค sr ์ Spatial-Reduction ์ ์ฝ์ด | |
| """ | |
| def __init__( | |
| self, embed_dim, num_heads=8, attn_drop=0.0, proj_drop=0.0, sr_ratio=1 | |
| ): | |
| super().__init__() | |
| assert ( | |
| embed_dim % num_heads == 0 | |
| ), f"dim {embed_dim} should be divided by num_heads {num_heads}." | |
| self.num_heads = num_heads | |
| head_dim = embed_dim // num_heads | |
| self.scale = head_dim**-0.5 | |
| self.q = nn.Linear(embed_dim, embed_dim, bias=False) | |
| self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=False) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(embed_dim, embed_dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.sr_ratio = sr_ratio | |
| if sr_ratio > 1: | |
| self.sr = nn.Conv1d( | |
| embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio | |
| ) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| q = self.q(x) | |
| q = rearrange(q, "b n (h c) -> b h n c", h=self.num_heads) | |
| if self.sr_ratio > 1: | |
| x_ = x.transpose(1, 2) | |
| x_ = self.sr(x_).transpose(1, 2) | |
| x_ = self.norm(x_) | |
| kv = self.kv(x_) | |
| kv = rearrange( | |
| kv, | |
| "b n (two_heads h c) -> two_heads b h n c", | |
| two_heads=2, | |
| h=self.num_heads, | |
| ) | |
| else: | |
| kv = self.kv(x) | |
| kv = rearrange( | |
| kv, | |
| "b n (two_heads h c) -> two_heads b h n c", | |
| two_heads=2, | |
| h=self.num_heads, | |
| ) | |
| k, v = kv[0], kv[1] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2) | |
| x = x.reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, ffn_channels, dropout=0.2, sr_ratio=1): | |
| super().__init__() | |
| self.attn = nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| EfficientMultiheadAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| attn_drop=dropout, | |
| proj_drop=dropout, | |
| sr_ratio=sr_ratio, | |
| ), | |
| ) | |
| self.ffn = nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| MixFFN(embed_dim=embed_dim, channels=ffn_channels, dropout=dropout), | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(x) | |
| x = x + self.ffn(x) | |
| return x | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=1, | |
| embed_dim=1024, | |
| kernel_size=7, | |
| stride=4, | |
| padding=3, | |
| bias=False, | |
| ): | |
| super().__init__() | |
| self.projection = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=embed_dim, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| return self.projection(x).transpose(1, 2) | |
| class MiT(nn.Module): | |
| """MixVisionTransformer""" | |
| def __init__( | |
| self, | |
| embed_dim=512, | |
| num_blocks=[2, 2, 6, 2], | |
| num_heads=[1, 2, "ceil"], | |
| sr_ratios=[1, 2, "ceil"], | |
| mlp_ratio=4, | |
| dropout=0.2, | |
| ): | |
| super().__init__() | |
| num_stages = len(num_blocks) | |
| round_func = getattr(math, num_heads[2]) # math.ceil or match.floor | |
| num_heads = [ | |
| round_func((num_heads[0] * math.pow(num_heads[1], itr))) | |
| for itr in range(num_stages) | |
| ] | |
| round_func = getattr(math, sr_ratios[2]) # math.ceil or match.floor | |
| sr_ratios = [ | |
| round_func(sr_ratios[0] * math.pow(sr_ratios[1], itr)) | |
| for itr in range(num_stages) | |
| ] | |
| sr_ratios.reverse() | |
| self.embed_dims = [embed_dim * num_head for num_head in num_heads] | |
| patch_kernel_sizes = [7] # [7, 3, 3, ..] | |
| patch_kernel_sizes.extend([3] * (num_stages - 1)) | |
| patch_strides = [4] # [4, 2, 2, ..] | |
| patch_strides.extend([2] * (num_stages - 1)) | |
| patch_paddings = [3] # [3, 1, 1, ..] | |
| patch_paddings.extend([1] * (num_stages - 1)) | |
| in_channels = 1 | |
| self.stages = nn.ModuleList() | |
| for i, num_block in enumerate(num_blocks): | |
| patch_embed = PatchEmbed( | |
| in_channels=in_channels, | |
| embed_dim=self.embed_dims[i], | |
| kernel_size=patch_kernel_sizes[i], | |
| stride=patch_strides[i], | |
| padding=patch_paddings[i], | |
| ) | |
| blocks = nn.ModuleList( | |
| [ | |
| TransformerBlock( | |
| embed_dim=self.embed_dims[i], | |
| num_heads=num_heads[i], | |
| ffn_channels=mlp_ratio * self.embed_dims[i], | |
| dropout=dropout, | |
| sr_ratio=sr_ratios[i], | |
| ) | |
| for _ in range(num_block) | |
| ] | |
| ) | |
| in_channels = self.embed_dims[i] | |
| norm = nn.LayerNorm(self.embed_dims[i]) | |
| self.stages.append(nn.ModuleList([patch_embed, blocks, norm])) | |
| def forward(self, x): | |
| outs = [] | |
| for stage in self.stages: | |
| x = stage[0](x) # patch embed | |
| for block in stage[1]: # transformer blocks | |
| x = block(x) | |
| x = stage[2](x) # norm | |
| x = x.transpose(1, 2) | |
| outs.append(x) | |
| return outs | |
| class SegFormer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| embed_dim = int(config.embed_dim) | |
| num_blocks = config.num_blocks | |
| num_heads = config.num_heads | |
| assert len(num_heads) == 3 and num_heads[2] in ["floor", "ceil"] | |
| sr_ratios = config.sr_ratios | |
| assert len(sr_ratios) == 3 and sr_ratios[2] in ["floor", "ceil"] | |
| mlp_ratio = int(config.mlp_ratio) | |
| dropout = float(config.dropout) | |
| decoder_channels = int(config.decoder_channels) | |
| self.interpolate_mode = str(config.interpolate_mode) | |
| output_size = int(config.output_size) | |
| self.MiT = MiT(embed_dim, num_blocks, num_heads, sr_ratios, mlp_ratio, dropout) | |
| num_stages = len(num_blocks) | |
| self.decode_mlps = nn.ModuleList( | |
| [ | |
| nn.Conv1d(self.MiT.embed_dims[i], decoder_channels, 1, bias=False) | |
| for i in range(num_stages) | |
| ] | |
| ) | |
| self.decode_fusion = nn.Conv1d( | |
| decoder_channels * num_stages, decoder_channels, 1, bias=False | |
| ) | |
| self.cls = nn.Conv1d(decoder_channels, output_size, 1, bias=False) | |
| def forward(self, input: torch.Tensor, y=None): | |
| output = input | |
| output = self.MiT(output) | |
| for i, (_output, decode_mlp) in enumerate(zip(output, self.decode_mlps)): | |
| _output = decode_mlp(_output) | |
| if i != 0: | |
| _output = F.interpolate( | |
| _output, size=output[0].shape[2], mode=self.interpolate_mode | |
| ) | |
| output[i] = _output | |
| output = torch.concat(output, dim=1) | |
| output = self.decode_fusion(output) | |
| output = self.cls(output) | |
| return F.interpolate(output, size=input.shape[2], mode=self.interpolate_mode) | |