Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/1612.01105 | |
| ref: | |
| - https://github.com/hszhao/PSPNet | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.functional import F | |
| class PPM(nn.Module): | |
| """Pyramid Pooling Module""" | |
| def __init__(self, in_dim, reduction_dim, bins, interplate_mode): | |
| super(PPM, self).__init__() | |
| self.features = [] | |
| for bin in bins: | |
| self.features.append( | |
| nn.Sequential( | |
| nn.AdaptiveAvgPool1d(bin), | |
| nn.Conv1d(in_dim, reduction_dim, kernel_size=1, bias=False), | |
| nn.BatchNorm1d(reduction_dim), | |
| nn.ReLU(), | |
| ) | |
| ) | |
| self.features = nn.ModuleList(self.features) | |
| self.interplate_mode = interplate_mode | |
| def forward(self, x: torch.Tensor): | |
| x_size = x.size() | |
| out = [x] | |
| for f in self.features: | |
| out.append(F.interpolate(f(x), x_size[2], mode=self.interplate_mode)) | |
| return torch.cat(out, dim=1) | |
| class Bottleneck(nn.Module): | |
| def __init__( | |
| self, | |
| inplanes, | |
| planes, | |
| expansion=4, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| padding=1, | |
| downsample=None, | |
| ): | |
| super(Bottleneck, self).__init__() | |
| self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False) | |
| self.bn1 = nn.BatchNorm1d(planes) | |
| self.conv2 = nn.Conv1d( | |
| planes, | |
| planes, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| padding=padding, | |
| bias=False, | |
| ) | |
| self.bn2 = nn.BatchNorm1d(planes) | |
| self.conv3 = nn.Conv1d(planes, planes * expansion, kernel_size=1, bias=False) | |
| self.bn3 = nn.BatchNorm1d(planes * expansion) | |
| self.relu = nn.ReLU() | |
| self.downsample = downsample | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| class PSPNet(nn.Module): | |
| def __init__(self, config): | |
| super(PSPNet, self).__init__() | |
| self.config = config | |
| self.kernel_size = int(config.kernel_size) | |
| self.padding = (self.kernel_size - 1) // 2 | |
| self.expansion = int(config.expansion) | |
| self.inplanes = int(config.inplanes) | |
| num_layers = int(config.num_layers) | |
| self.num_bottlenecks = int(config.num_bottlenecks) | |
| self.interpolate_mode = str(config.interpolate_mode) | |
| self.dilation = int(config.dilation) | |
| ppm_bins: list = config.ppm_bins | |
| self.aux_idx = int(config.aux_idx) | |
| assert self.aux_idx < num_layers | |
| self.aux_ratio = float(config.aux_ratio) | |
| dropout = float(config.dropout) | |
| output_size = config.output_size # 3(p, qrs, t) | |
| # stem ๋จ๊ณ์์ 1/4 ๋งํผ downsample ๋ ์ํ๋ก ์์ | |
| self.stem = nn.Sequential( | |
| *[ | |
| nn.Conv1d( | |
| 1, | |
| self.inplanes, | |
| self.kernel_size, | |
| stride=2, | |
| padding=self.padding, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(self.inplanes), | |
| nn.ReLU(), | |
| nn.MaxPool1d(self.kernel_size, stride=2, padding=self.padding), | |
| ] | |
| ) | |
| self.layers = [] | |
| plane = self.inplanes | |
| for i in range(num_layers): | |
| self.layers.append(self._make_layer(plane * (2 ** (i)))) | |
| self.layers = nn.ModuleList(self.layers) | |
| encode_dim = self.inplanes | |
| self.ppm = PPM( | |
| encode_dim, | |
| int(encode_dim / len(ppm_bins)), | |
| ppm_bins, | |
| self.interpolate_mode, | |
| ) | |
| encode_dim *= 2 | |
| self.cls = nn.Sequential( | |
| nn.Conv1d( | |
| encode_dim, | |
| 512, | |
| kernel_size=self.kernel_size, | |
| padding=self.padding, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout1d(dropout), | |
| nn.Conv1d(512, output_size, kernel_size=1), | |
| ) | |
| self.aux_branch = nn.Sequential( | |
| # ์ถ์ถํ๊ณ ์ ํ๋ layer index ์ ํด๋นํ๋ channel ๊ณผ ๋ง์ถฐ์ฃผ์ด์ผ ํจ | |
| nn.Conv1d( | |
| plane * self.expansion * (2**self.aux_idx), | |
| 256, | |
| kernel_size=self.kernel_size, | |
| padding=self.padding, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout1d(0.1), | |
| nn.Conv1d(256, output_size, kernel_size=1), | |
| ) | |
| def _make_layer(self, planes: int): | |
| """ | |
| self.num_bottlenecks ๊ฐ์ bottleneck ์ผ๋ก ๊ตฌ์ฑ๋ layer ๋ฅผ ๋ฐํ | |
| ์ฒซ๋ฒ์งธ bottleneck ์์ 2 ๋งํผ downsample ๋จ | |
| ๋๋ฒ์งธ ์ดํ๋ถํฐ์ bottleneck ์์ self.dilation ์ผ๋ก dilated conv ์ํ | |
| """ | |
| downsample = nn.Sequential( | |
| nn.Conv1d( | |
| self.inplanes, | |
| planes * self.expansion, | |
| kernel_size=1, | |
| stride=2, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(planes * self.expansion), | |
| ) | |
| bottlenecks = [] | |
| bottlenecks.append( | |
| Bottleneck( | |
| self.inplanes, | |
| planes, | |
| expansion=self.expansion, | |
| kernel_size=self.kernel_size, | |
| stride=2, | |
| dilation=1, | |
| padding=self.padding, | |
| downsample=downsample, | |
| ) | |
| ) | |
| self.inplanes = planes * self.expansion | |
| for _ in range(1, self.num_bottlenecks): | |
| bottlenecks.append( | |
| Bottleneck( | |
| self.inplanes, | |
| planes, | |
| expansion=self.expansion, | |
| kernel_size=self.kernel_size, | |
| stride=1, | |
| dilation=self.dilation, | |
| padding=(self.dilation * (self.kernel_size - 1)) // 2, | |
| ) | |
| ) | |
| return nn.Sequential(*bottlenecks) | |
| def forward(self, input: torch.Tensor, y=None): | |
| output: torch.Tensor = input | |
| output = self.stem(output) | |
| for i, _layer in enumerate(self.layers): | |
| output = _layer(output) | |
| if i == self.aux_idx: | |
| aux = output | |
| output = self.ppm(output) | |
| output = self.cls(output) | |
| output = F.interpolate( | |
| output, | |
| input.shape[2], | |
| mode=self.interpolate_mode, | |
| ) | |
| if self.training: | |
| aux = self.aux_branch(aux) | |
| aux = F.interpolate( | |
| aux, | |
| input.shape[2], | |
| mode=self.interpolate_mode, | |
| ) | |
| return torch.add(output * (1 - self.aux_ratio), aux * self.aux_ratio) | |
| else: | |
| return output | |