Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/2012.15840 | |
| - ref | |
| - encoder: | |
| - https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py | |
| - https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py | |
| - decoder: | |
| - https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_up_head.py | |
| - https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/setr_mla_head.py | |
| - encoder: ViT ์ ๊ตฌ์กฐ๊ฐ ๋์ผํ๋ฉฐ, PatchEmbed ์ ๊ฒฝ์ฐ patch_size๋ฅผ kernel_size์ stride ๋ก ํ๋ Conv1d๋ฅผ ์ฌ์ฉ | |
| - decoder: upsample ํ๋ ๋ฐฉ์์ผ๋ก ๋ค์ ๋๊ฐ์ง๋ฅผ ์ฌ์ฉ (scale_factor: ํน์ ๋ฐฐ์๋งํผ upsample / size: ํน์ ํฌ๊ธฐ์ ๋์ผํ ํฌ๊ธฐ๋ก upsample) | |
| - naive: ์๋ณธ ๊ธธ์ด๋ก size ๋ฐฉ์ upsample | |
| - pup: scale_factor ๋ฐฉ์์ผ๋ก ์ํํ๋ค๊ฐ ๋ง์ง๋ง์ ์๋ณธ ๊ธธ์ด๋ก size ๋ฐฉ์์ผ๋ก upsample | |
| - mla: ์ด ๋ ๋จ๊ณ๋ก ์ํํ๋ฉฐ, ์ฒซ๋ฒ์งธ ๋จ๊ณ์์ transformer block ์ ๊ฒฐ๊ณผ๋ค์ scale_factor ๋ฐฉ์์ผ๋ก ์ํํ๊ณ ๋๋ฒ์งธ ๋จ๊ณ์์ ์ฒซ๋ฒ์งธ ๊ฒฐ๊ณผ๋ค์ concat ํ ํ size ๋ฐฉ์์ผ๋ก upsample | |
| """ | |
| import math | |
| import torch | |
| from torch import nn | |
| from einops import rearrange | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.0): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| project_out = not (heads == 1 and dim_head == dim) | |
| self.heads = heads | |
| self.scale = dim_head**-0.5 | |
| self.norm = nn.LayerNorm(dim) | |
| self.attend = nn.Softmax(dim=-1) | |
| self.dropout = nn.Dropout(dropout) | |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
| self.to_out = ( | |
| nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) | |
| if project_out | |
| else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| x = self.norm(x) | |
| qkv = self.to_qkv(x).chunk(3, dim=-1) | |
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) | |
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
| attn = self.attend(dots) | |
| attn = self.dropout(attn) | |
| out = torch.matmul(attn, v) | |
| out = rearrange(out, "b h n d -> b n (h d)") | |
| return self.to_out(out) | |
| # ========== ์ฌ๊ธฐ๊น์ง https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_1d.py ์ฐจ์ฉ ========== | |
| # ========== ์๋๋ถํฐ setr ์๋ณธ ์ฐธ๊ณ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/vit.py ========== | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_attn_heads, | |
| attn_head_dim, | |
| mlp_dim, | |
| attn_dropout=0.0, | |
| ffn_dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.attn = Attention( | |
| dim, heads=num_attn_heads, dim_head=attn_head_dim, dropout=attn_dropout | |
| ) | |
| self.ffn = FeedForward(dim, mlp_dim, dropout=ffn_dropout) | |
| def forward(self, x): | |
| x = self.attn(x) + x | |
| x = self.ffn(x) + x | |
| return x | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim=1024, | |
| kernel_size=16, | |
| bias=False, | |
| ): | |
| super().__init__() | |
| self.projection = nn.Conv1d( | |
| in_channels=1, | |
| out_channels=embed_dim, | |
| kernel_size=kernel_size, | |
| stride=kernel_size, | |
| bias=bias, | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| return self.projection(x).transpose(1, 2) | |
| class SETR(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| embed_dim = int(config.embed_dim) | |
| data_len = int(config.data_len) # ECGPQRSTDataset.second, hz ์ ๋ง์ถฐ์ | |
| patch_size = int(config.patch_size) | |
| assert data_len % patch_size == 0 | |
| num_patches = data_len // patch_size | |
| patch_bias = bool(config.patch_bias) | |
| dropout = float(config.dropout) | |
| # pos_dropout_p: float = config.pos_dropout_p # ํ๋ผ๋ฏธํฐ๋ผ ๋๋ฌด ๋ง์ผ๋ฏ๋ก ์ฐ์ dropout ๊ฐ์๋ ํ๋๋ก ์ฌ์ฉ | |
| num_layers = int(config.num_layers) # transformer block ๊ฐ์ | |
| num_attn_heads = int(config.num_attn_heads) | |
| attn_head_dim = int(config.attn_head_dim) | |
| mlp_dim = int(config.mlp_dim) | |
| # attn_dropout: float = config.attn_dropout | |
| # ffn_dropout: float = config.ffn_dropout | |
| interpolate_mode = str(config.interpolate_mode) | |
| dec_conf: dict = config.dec_conf | |
| assert len(dec_conf) == 1 | |
| self.dec_mode: str = list(dec_conf.keys())[0] | |
| assert self.dec_mode in ["naive", "pup", "mla"] | |
| self.dec_param: dict = dec_conf[self.dec_mode] | |
| output_size = int(config.output_size) | |
| # patch embedding | |
| self.patch_embed = PatchEmbed( | |
| embed_dim=embed_dim, | |
| kernel_size=patch_size, | |
| bias=patch_bias, | |
| ) | |
| # positional embedding | |
| self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim)) | |
| self.pos_dropout = nn.Dropout(p=dropout) | |
| # transformer encoder | |
| self.layers = nn.ModuleList() | |
| for _ in range(num_layers): | |
| self.layers.append( | |
| TransformerBlock( | |
| dim=embed_dim, | |
| num_attn_heads=num_attn_heads, | |
| attn_head_dim=attn_head_dim, | |
| mlp_dim=mlp_dim, | |
| attn_dropout=dropout, | |
| ffn_dropout=dropout, | |
| ) | |
| ) | |
| # decoder | |
| self.dec_layers = nn.ModuleList() | |
| if self.dec_mode == "naive": | |
| self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode)) | |
| dec_out_channel = embed_dim | |
| elif self.dec_mode == "pup": | |
| self.dec_layers.append(nn.LayerNorm(embed_dim)) | |
| dec_up_scale = int(self.dec_param["up_scale"]) | |
| available_up_count = int( | |
| math.log(data_len // num_patches, dec_up_scale) | |
| ) # scale_factor ๋ฐฉ๋ฒ์ผ๋ก upsample ํ ์ ์๋ ๋จ๊ณ ๊ณ์ฐ, ๋๋จธ์ง๋ size ๋ฐฉ๋ฒ์ผ๋ก upsample | |
| pup_channels = int(self.dec_param["channels"]) | |
| dec_in_channel = embed_dim | |
| dec_out_channel = pup_channels | |
| dec_kernel_size = int(self.dec_param["kernel_size"]) | |
| dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"]) | |
| assert dec_kernel_size in [1, 3] # ์๋ณธ ์ฝ๋ ๊ทธ๋๋ก | |
| for i in range(available_up_count + 1): | |
| for _ in range(dec_num_convs_by_layer): | |
| self.dec_layers.append( | |
| nn.Conv1d( | |
| dec_in_channel, | |
| dec_out_channel, | |
| kernel_size=dec_kernel_size, | |
| stride=1, | |
| padding=(dec_kernel_size - 1) // 2, | |
| ) | |
| ) | |
| dec_in_channel = dec_out_channel | |
| if i < available_up_count: | |
| self.dec_layers.append( | |
| nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode) | |
| ) | |
| else: # last upsample | |
| self.dec_layers.append( | |
| nn.Upsample(size=data_len, mode=interpolate_mode) | |
| ) | |
| else: # mla | |
| dec_up_scale = int(self.dec_param["up_scale"]) | |
| assert ( | |
| data_len >= dec_up_scale * num_patches | |
| ) # transformer ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ up_scale ๋งํผ upsample ํ์ ๋ ์๋ณธ ๋ณด๋ค๋ ์์์ผ ์ต์ข upsample ์ด ์๋ฏธ๊ฐ ์์ | |
| dec_output_step = int(self.dec_param["output_step"]) | |
| assert num_layers % dec_output_step == 0 | |
| dec_num_convs_by_layer = int(self.dec_param["num_convs_by_layer"]) | |
| dec_kernel_size = int(self.dec_param["kernel_size"]) | |
| mid_feature_cnt = num_layers // dec_output_step | |
| mla_channel = int(self.dec_param["channels"]) | |
| for _ in range(mid_feature_cnt): | |
| # transformer block ์ค๊ฐ ๊ฒฐ๊ณผ์์ ๊ฐ step ๋ณ๋ก ์ถ์ถํ feature map ์ ์ ์ฉํ conv-upsample | |
| dec_in_channel = embed_dim | |
| dec_layers_each_upsample = [] | |
| for _ in range(dec_num_convs_by_layer): | |
| dec_layers_each_upsample.append( | |
| nn.Conv1d( | |
| dec_in_channel, | |
| mla_channel, | |
| kernel_size=dec_kernel_size, | |
| stride=1, | |
| padding=(dec_kernel_size - 1) // 2, | |
| ) | |
| ) | |
| dec_in_channel = mla_channel | |
| dec_layers_each_upsample.append( | |
| nn.Upsample(scale_factor=dec_up_scale, mode=interpolate_mode) | |
| ) | |
| self.dec_layers.append(nn.Sequential(*dec_layers_each_upsample)) | |
| # last decoder layer: ์ค๊ฐ feature map ์ concat ํ ์ดํ, upsample | |
| self.dec_layers.append(nn.Upsample(size=data_len, mode=interpolate_mode)) | |
| dec_out_channel = ( | |
| mla_channel * mid_feature_cnt | |
| ) # self.dec_layers ๋ฅผ transformer ์ค๊ฐ ๊ฒฐ๊ณผ๋ค์ ์ ์ฉํ feature map ๊ฐ์(mid_feature_cnt)๋งํผ channel-wise concat ํ๊ธฐ ๋๋ฌธ์ ๊ทธ๋งํผ ์ฆ๊ฐ๋ channel ์ ์๋ self.cls ์ in_channel ๋ก ์ฌ์ฉ๋์ด์ด์ผ ํจ | |
| self.cls = nn.Conv1d(dec_out_channel, output_size, 1, bias=False) | |
| def forward(self, input: torch.Tensor, y=None): | |
| output = input | |
| # patch embedding | |
| output = self.patch_embed(output) | |
| # positional embedding | |
| output += self.pos_embed | |
| output = self.pos_dropout(output) | |
| outputs = [] | |
| # transformer encoder | |
| for i, layer in enumerate(self.layers): | |
| output = layer(output) | |
| if self.dec_mode == "mla": | |
| if (i + 1) % int(self.dec_param["output_step"]) == 0: | |
| outputs.append(output.transpose(1, 2)) | |
| if self.dec_mode != "mla": # mla ์ ๊ฒฝ์ฐ ์์์ ์ด๋ฏธ ์ถ๊ฐ | |
| outputs.append(output.transpose(1, 2)) | |
| # decoder | |
| if self.dec_mode == "naive": | |
| assert len(outputs) == 1 | |
| output = outputs[0] | |
| output = self.dec_layers[0](output) | |
| elif self.dec_mode == "pup": | |
| assert len(outputs) == 1 | |
| output = outputs[0] | |
| pup_norm = self.dec_layers[0] | |
| output = pup_norm(output.transpose(1, 2)).transpose(1, 2) | |
| for i, dec_layer in enumerate(self.dec_layers[1:]): | |
| output = dec_layer(output) | |
| else: # mla | |
| dec_output_step = int(self.dec_param["output_step"]) | |
| mid_feature_cnt = len(self.layers) // dec_output_step | |
| assert len(outputs) == mid_feature_cnt | |
| for i in range(len(outputs)): | |
| outputs[i] = self.dec_layers[i](outputs[i]) | |
| output = torch.cat(outputs, dim=1) | |
| output = self.dec_layers[-1](output) | |
| return self.cls(output) | |