Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/1605.06211 | |
| ref: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class FCN(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.kernel_size = int(config.kernel_size) | |
| last_layer_kernel_size = int(config.last_layer_kernel_size) | |
| inplanes = int(config.inplanes) | |
| combine_conf: dict = config.combine_conf | |
| self.num_layers = int(combine_conf["num_layers"]) | |
| self.first_padding = {6: 240, 5: 130, 4: 80}[self.num_layers] | |
| self.num_convs = int(config.num_convs) | |
| self.dilation = int(config.dilation) | |
| self.combine_until = int(combine_conf["combine_until"]) | |
| assert self.combine_until < self.num_layers | |
| dropout = float(config.dropout) | |
| output_size = config.output_size # 3(p, qrs, t) | |
| self.layers = nn.ModuleList() | |
| for i in range(self.num_layers): | |
| self.layers.append( | |
| self._make_layer( | |
| 1 if i == 0 else inplanes * (2 ** (i - 1)), | |
| inplanes * (2 ** (i)), | |
| is_first=True if i == 0 else False, | |
| ) | |
| ) | |
| # pool ๋จ๊ณ๊ฐ ์๋ ๋ง์ง๋ง conv layer๋ก ๋ค๋ฅธ layer ์ ๋ค๋ฅด๊ฒ conv ๊ฐ์(2)์ channel์ด ๊ณ ์ ์ด๊ณ , dropout์ ์ํ | |
| self.layers.append( | |
| nn.Sequential( | |
| nn.Conv1d(inplanes * (2 ** (i)), 4096, last_layer_kernel_size), | |
| nn.BatchNorm1d(4096), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv1d(4096, 4096, 1), | |
| nn.BatchNorm1d(4096), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ) | |
| ) | |
| self.score_convs = [] | |
| self.up_convs = [] | |
| for i in range(self.combine_until, self.num_layers - 1): | |
| # pool ๊ฒฐ๊ณผ๋ฅผ combine ํ๋ ๋งํผ๋ง score_convs ์ up_convs ๊ฐ ์์ฑ๋จ | |
| self.score_convs.append( | |
| nn.Conv1d(inplanes * (2 ** (i)), output_size, kernel_size=1, bias=False) | |
| ) | |
| self.up_convs.append( | |
| nn.ConvTranspose1d(output_size, output_size, kernel_size=4, stride=2) | |
| ) | |
| # pool ์ด ์๋ ๋ง์ง๋ง convs ๊ฒฐ๊ณผ์ ์ํํ๋ score_convs | |
| # self.score_convs ๋ ํญ์ self.up_convs ์ ๊ฐ์๋ณด๋ค 1๊ฐ ๋ ๋ง์ | |
| self.score_convs.append(nn.Conv1d(4096, output_size, kernel_size=1, bias=False)) | |
| self.score_convs.reverse() | |
| self.score_convs = nn.ModuleList(self.score_convs) | |
| self.up_convs = nn.ModuleList(self.up_convs) | |
| self.last_up_convs = nn.ConvTranspose1d( | |
| output_size, | |
| output_size, | |
| kernel_size=2 ** (self.combine_until + 1) * 2, # stride * 2 | |
| stride=2 ** (self.combine_until + 1), | |
| ) | |
| def _make_layer( | |
| self, | |
| in_channel: int, | |
| out_channel: int, | |
| is_first: bool = False, | |
| ): | |
| layer = [] | |
| plane = in_channel | |
| for idx in range(self.num_convs): | |
| layer.append( | |
| nn.Conv1d( | |
| plane, | |
| out_channel, | |
| kernel_size=self.kernel_size, | |
| padding=self.first_padding | |
| if idx == 0 and is_first | |
| else (self.dilation * (self.kernel_size - 1)) // 2, | |
| dilation=self.dilation, | |
| bias=False, | |
| ) | |
| ) | |
| layer.append(nn.BatchNorm1d(out_channel)) | |
| layer.append(nn.ReLU()) | |
| plane = out_channel | |
| layer.append(nn.MaxPool1d(2, 2, ceil_mode=True)) | |
| return nn.Sequential(*layer) | |
| def forward(self, input: torch.Tensor, y=None): | |
| output: torch.Tensor = input | |
| pools = [] | |
| for idx, layer in enumerate(self.layers): | |
| output = layer(output) | |
| if self.combine_until <= idx < (self.num_layers - 1): | |
| pools.append(output) | |
| pools.reverse() | |
| output = self.score_convs[0](output) | |
| if len(pools) > 0: | |
| output = self.up_convs[0](output) | |
| for i in range(len(pools)): | |
| score_pool = self.score_convs[i + 1](pools[i]) | |
| offset = (score_pool.shape[2] - output.shape[2]) // 2 | |
| cropped_score_pool = torch.tensor_split( | |
| score_pool, (offset, offset + output.shape[2]), dim=2 | |
| )[1] | |
| output = torch.add(cropped_score_pool, output) | |
| if i < len(pools) - 1: # ๋ง์ง๋ง up_conv ๋ last_up_convs ์ด์ฉ | |
| output = self.up_convs[i + 1](output) | |
| output = self.last_up_convs(output) | |
| offset = (output.shape[2] - input.shape[2]) // 2 | |
| cropped_score_pool = torch.tensor_split( | |
| output, (offset, offset + input.shape[2]), dim=2 | |
| )[1] | |
| return cropped_score_pool | |