Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/1802.02611 | |
| ref: | |
| - https://github.com/tensorflow/models/tree/master/research/deeplab | |
| - https://github.com/VainF/DeepLabV3Plus-Pytorch | |
| - https://github.com/Hyunjulie/KR-Reading-Computer-Vision-Papers/blob/master/DeepLabv3%2B/deeplabv3p.py | |
| """ | |
| import math | |
| import torch | |
| from torch import nn | |
| from torch.functional import F | |
| class AtrousSeparableConv1d(nn.Module): | |
| def __init__( | |
| self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False | |
| ): | |
| super(AtrousSeparableConv1d, self).__init__() | |
| self.depthwise = nn.Conv1d( | |
| inplanes, | |
| inplanes, | |
| kernel_size, | |
| stride, | |
| 0, | |
| dilation, | |
| groups=inplanes, | |
| bias=bias, | |
| ) | |
| self.pointwise = nn.Conv1d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) | |
| def forward(self, x): | |
| x = self.apply_fixed_padding( | |
| x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0] | |
| ) | |
| x = self.depthwise(x) | |
| x = self.pointwise(x) | |
| return x | |
| def apply_fixed_padding(self, inputs, kernel_size, rate): | |
| """ | |
| ํด๋น ํจ์๋ (dilation)rate ์ kernel_size ์ ๋ฐ๋ผ output ์ ํฌ๊ธฐ๊ฐ input ์ ํฌ๊ธฐ์ ๋์ผํด์ง ์ ์๋๋ก input ์ padding ์ ์ ์ฉํฉ๋๋ค. | |
| ๋ค๋ง, stride ๊ฐ 2 ์ด์์ธ ๊ฒฝ์ฐ์๋ ํด๋น ํจ์๋ฅผ ๊ฑฐ์น๋๋ผ๋ input ๊ณผ output ํฌ๊ธฐ๊ฐ ๋์ผํด์ง์ง ์์ ์ ์์ต๋๋ค. | |
| ์ด ๊ฒฝ์ฐ๋ ์ต๋ํ input ๊ณผ output ํฌ๊ธฐ๋ฅผ ๋ง์ถฐ์ฃผ๋ ๊ฒ์ ์๋ฏธ๊ฐ ์๊ณ , ์ ์ฒด ๋คํธ์ํฌ์ ๋ง์ง๋ง upsample ๋จ๊ณ์์ ์ต์ข ์ ์ผ๋ก ํฌ๊ธฐ๋ฅผ ๋ง์ถฐ์ค๋๋ค. | |
| """ | |
| kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) | |
| pad_total = kernel_size_effective - 1 | |
| pad_beg = pad_total // 2 | |
| pad_end = pad_total - pad_beg | |
| padded_inputs = F.pad(inputs, (pad_beg, pad_end)) | |
| return padded_inputs | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| inplanes, | |
| planes, | |
| reps, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| start_with_relu=True, | |
| grow_first=True, | |
| is_last=False, | |
| ): | |
| super(Block, self).__init__() | |
| if planes != inplanes or stride != 1: | |
| self.skip = nn.Conv1d(inplanes, planes, 1, stride=stride, bias=False) | |
| self.skipbn = nn.BatchNorm1d(planes) | |
| else: | |
| self.skip = None | |
| self.relu = nn.ReLU(inplace=True) | |
| rep = [] | |
| filters = inplanes | |
| if grow_first: | |
| rep.append(self.relu) | |
| rep.append( | |
| AtrousSeparableConv1d( | |
| inplanes, planes, kernel_size, stride=1, dilation=dilation | |
| ) | |
| ) | |
| rep.append(nn.BatchNorm1d(planes)) | |
| filters = planes | |
| for _ in range(reps - 1): | |
| rep.append(self.relu) | |
| rep.append( | |
| AtrousSeparableConv1d( | |
| filters, filters, kernel_size, stride=1, dilation=dilation | |
| ) | |
| ) | |
| rep.append(nn.BatchNorm1d(filters)) | |
| if not grow_first: | |
| rep.append(self.relu) | |
| rep.append( | |
| AtrousSeparableConv1d( | |
| inplanes, planes, kernel_size, stride=1, dilation=dilation | |
| ) | |
| ) | |
| rep.append(nn.BatchNorm1d(planes)) | |
| if not start_with_relu: | |
| rep = rep[1:] | |
| if stride == 2: | |
| rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=2)) | |
| elif stride == 1: | |
| if is_last: | |
| rep.append(AtrousSeparableConv1d(planes, planes, kernel_size, stride=1)) | |
| else: | |
| raise NotImplementedError("stride must be 1 or 2 in Block.") | |
| self.rep = nn.Sequential(*rep) | |
| def forward(self, inp): | |
| x = self.rep(inp) | |
| if self.skip is not None: | |
| skip = self.skip(inp) | |
| skip = self.skipbn(skip) | |
| else: | |
| skip = inp | |
| x += skip | |
| return x | |
| class Xception(nn.Module): | |
| """Modified Aligned Xception""" | |
| def __init__( | |
| self, | |
| inplanes=1, | |
| output_stride=16, | |
| kernel_size=3, | |
| middle_repeat=16, | |
| middle_block_rate=1, | |
| exit_block_rates=(1, 2), | |
| ): | |
| super(Xception, self).__init__() | |
| if output_stride == 16: | |
| entry3_stride = 2 | |
| elif output_stride == 8: | |
| entry3_stride = 1 | |
| else: | |
| raise NotImplementedError | |
| self.conv1 = nn.Conv1d( | |
| inplanes, | |
| 32, | |
| kernel_size, | |
| stride=2, | |
| padding=(kernel_size - 1) // 2, | |
| bias=False, | |
| ) | |
| self.bn1 = nn.BatchNorm1d(32) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv1d( | |
| 32, 64, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False | |
| ) | |
| self.bn2 = nn.BatchNorm1d(64) | |
| self.entry1 = Block( | |
| 64, 128, reps=2, kernel_size=kernel_size, stride=2, start_with_relu=False | |
| ) | |
| self.entry2 = Block( | |
| 128, | |
| 256, | |
| reps=2, | |
| kernel_size=kernel_size, | |
| stride=2, | |
| start_with_relu=True, | |
| grow_first=True, | |
| ) | |
| self.entry3 = Block( | |
| 256, | |
| 728, | |
| reps=2, | |
| kernel_size=kernel_size, | |
| stride=entry3_stride, | |
| start_with_relu=True, | |
| grow_first=True, | |
| is_last=True, | |
| ) | |
| self.middle = nn.Sequential( | |
| *[ | |
| Block( | |
| 728, | |
| 728, | |
| reps=3, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| dilation=middle_block_rate, | |
| start_with_relu=True, | |
| grow_first=True, | |
| ) | |
| for _ in range(middle_repeat) | |
| ] | |
| ) | |
| self.exit = Block( | |
| 728, | |
| 1024, | |
| reps=2, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| dilation=exit_block_rates[0], | |
| start_with_relu=True, | |
| grow_first=False, | |
| is_last=True, | |
| ) | |
| self.conv3 = AtrousSeparableConv1d( | |
| 1024, 1536, kernel_size, stride=1, dilation=exit_block_rates[1] | |
| ) | |
| self.bn3 = nn.BatchNorm1d(1536) | |
| self.conv4 = AtrousSeparableConv1d( | |
| 1536, 1536, kernel_size, stride=1, dilation=exit_block_rates[1] | |
| ) | |
| self.bn4 = nn.BatchNorm1d(1536) | |
| self.conv5 = AtrousSeparableConv1d( | |
| 1536, 2048, kernel_size, stride=1, dilation=exit_block_rates[1] | |
| ) | |
| self.bn5 = nn.BatchNorm1d(2048) | |
| def forward(self, x: torch.Tensor): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu(x) | |
| low_level = x = self.entry1(x) | |
| x = self.entry2(x) | |
| x = self.entry3(x) | |
| x = self.middle(x) | |
| x = self.exit(x) | |
| x = self.conv3(x) | |
| x = self.bn3(x) | |
| x = self.relu(x) | |
| x = self.conv4(x) | |
| x = self.bn4(x) | |
| x = self.relu(x) | |
| x = self.conv5(x) | |
| x = self.bn5(x) | |
| x = self.relu(x) | |
| return x, low_level | |
| class ASPP(nn.Module): | |
| """Atrous Spatial Pyramid Pooling""" | |
| def __init__(self, inplanes, planes, rate, kernel_size=3): | |
| super(ASPP, self).__init__() | |
| if rate == 1: | |
| kernel_size = 1 | |
| padding = 0 | |
| else: | |
| padding = rate * (kernel_size - 1) // 2 | |
| self.atrous_convolution = nn.Conv1d( | |
| inplanes, | |
| planes, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding, | |
| dilation=rate, | |
| bias=False, | |
| ) | |
| self.bn = nn.BatchNorm1d(planes) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| x = self.atrous_convolution(x) | |
| x = self.bn(x) | |
| return self.relu(x) | |
| class DeepLabV3Plus(nn.Module): | |
| def __init__(self, config): | |
| super(DeepLabV3Plus, self).__init__() | |
| self.config = config | |
| # output_stride: (input's spatial resolution / output's resolution) | |
| output_stride = int(config.output_stride) | |
| kernel_size = int(config.kernel_size) | |
| middle_block_rate = int(config.middle_block_rate) | |
| exit_block_rates: list = config.exit_block_rates | |
| middle_repeat = int(config.middle_repeat) | |
| self.interpolate_mode = str(config.interpolate_mode) | |
| aspp_channel = int(config.aspp_channel) | |
| aspp_rate: list = config.aspp_rate | |
| output_size = config.output_size # 3(p, qrs, t) | |
| self.xception_features = Xception( | |
| output_stride=output_stride, | |
| kernel_size=kernel_size, | |
| middle_repeat=middle_repeat, | |
| middle_block_rate=middle_block_rate, | |
| exit_block_rates=exit_block_rates, | |
| ) | |
| # ASPP | |
| self.aspp1 = ASPP( | |
| 2048, aspp_channel, rate=aspp_rate[0], kernel_size=kernel_size | |
| ) | |
| self.aspp2 = ASPP( | |
| 2048, aspp_channel, rate=aspp_rate[1], kernel_size=kernel_size | |
| ) | |
| self.aspp3 = ASPP( | |
| 2048, aspp_channel, rate=aspp_rate[2], kernel_size=kernel_size | |
| ) | |
| self.aspp4 = ASPP( | |
| 2048, aspp_channel, rate=aspp_rate[3], kernel_size=kernel_size | |
| ) | |
| self.relu = nn.ReLU() | |
| self.global_avg_pool = nn.Sequential( | |
| nn.AdaptiveAvgPool1d(1), | |
| nn.Conv1d(2048, aspp_channel, 1, stride=1, bias=False), | |
| nn.BatchNorm1d(aspp_channel), | |
| nn.ReLU(), | |
| ) | |
| self.conv1 = nn.Conv1d(aspp_channel * 5, aspp_channel, 1, bias=False) | |
| self.bn1 = nn.BatchNorm1d(aspp_channel) | |
| # adopt [1x1, 48] for channel reduction. | |
| self.conv2 = nn.Conv1d(128, 48, 1, bias=False) | |
| self.bn2 = nn.BatchNorm1d(48) | |
| self.last_conv = nn.Sequential( | |
| nn.Conv1d( | |
| aspp_channel + 48, | |
| 256, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=(kernel_size - 1) // 2, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Conv1d( | |
| 256, | |
| 256, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=(kernel_size - 1) // 2, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Conv1d(256, output_size, kernel_size=1, stride=1), | |
| ) | |
| def forward(self, input): | |
| x, low_level_features = self.xception_features(input) | |
| x1 = self.aspp1(x) | |
| x2 = self.aspp2(x) | |
| x3 = self.aspp3(x) | |
| x4 = self.aspp4(x) | |
| x5 = self.global_avg_pool(x) | |
| x5 = F.interpolate(x5, size=x4.shape[2:], mode=self.interpolate_mode) | |
| x = torch.cat((x1, x2, x3, x4, x5), dim=1) | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = F.interpolate( | |
| x, size=int(math.ceil(input.shape[-1] / 4)), mode=self.interpolate_mode | |
| ) | |
| low_level_features = self.conv2(low_level_features) | |
| low_level_features = self.bn2(low_level_features) | |
| low_level_features = self.relu(low_level_features) | |
| x = torch.cat((x, low_level_features), dim=1) | |
| x = self.last_conv(x) | |
| x = F.interpolate(x, size=input.shape[2:], mode=self.interpolate_mode) | |
| return x | |