Spaces:
Sleeping
Sleeping
| """ | |
| paper: https://arxiv.org/abs/1904.04514 | |
| ref: https://github.com/HRNet/HRNet-Semantic-Segmentation/blob/HRNet-OCR/lib/models/seg_hrnet.py | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch.functional import F | |
| import math | |
| def _gen_same_length_conv(in_channel, out_channel, kernel_size=1, dilation=1): | |
| """κΈΈμ΄κ° λ³νμ§ μλ conv μμ±, block λ΄μμ feature λ₯Ό μΆμΆνλ convolution μμ μ¬μ©""" | |
| return nn.Conv1d( | |
| in_channel, | |
| out_channel, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=(dilation * (kernel_size - 1)) // 2, | |
| dilation=dilation, | |
| bias=False, | |
| ) | |
| def _gen_downsample(in_channel, out_channel): | |
| """kernel_size:3, stride:2, padding:1 μΈ 2λ°° downsample νλ conv μμ±""" | |
| return nn.Conv1d( | |
| in_channel, out_channel, kernel_size=3, stride=2, padding=1, bias=False | |
| ) | |
| def _gen_channel_change_conv(in_channel, out_channel): | |
| """kernel_size:1, stride:1 μΈ channel λ³κ²½νλ conv μμ±""" | |
| return nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, bias=False) | |
| class BasicBlock(nn.Module): | |
| """resnet μ basic block μΌλ‘ channel λ³νλ inplanes -> planes""" | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, kernel_size=3, dilation=1): | |
| super().__init__() | |
| self.conv1 = _gen_same_length_conv(inplanes, planes, kernel_size, dilation) | |
| self.bn1 = nn.BatchNorm1d(planes) | |
| self.relu = nn.ReLU() | |
| self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation) | |
| self.bn2 = nn.BatchNorm1d(planes) | |
| self.make_residual = ( | |
| _gen_channel_change_conv(inplanes, planes) | |
| if inplanes != planes | |
| else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| residual = self.make_residual(x) | |
| out = out + residual | |
| out = self.relu(out) | |
| return out | |
| class Bottleneck(nn.Module): | |
| """resnet μ Bottleneck block μΌλ‘ channel λ³νλ inplanes -> planes * 4""" | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, kernel_size=3, dilation=1): | |
| super().__init__() | |
| self.conv1 = _gen_same_length_conv(inplanes, planes) | |
| self.bn1 = nn.BatchNorm1d(planes) | |
| self.conv2 = _gen_same_length_conv(planes, planes, kernel_size, dilation) | |
| self.bn2 = nn.BatchNorm1d(planes) | |
| self.conv3 = _gen_same_length_conv(planes, planes * self.expansion) | |
| self.bn3 = nn.BatchNorm1d(planes * self.expansion) | |
| self.relu = nn.ReLU() | |
| self.make_residual = ( | |
| _gen_channel_change_conv(inplanes, planes * self.expansion) | |
| if inplanes != planes * self.expansion | |
| else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| residual = self.make_residual(x) | |
| out = out + residual | |
| out = self.relu(out) | |
| return out | |
| class HRModule(nn.Module): | |
| def __init__( | |
| self, | |
| stage_idx, | |
| num_blocks, | |
| block_type_by_stage, | |
| in_channels_by_stage, | |
| out_channels_by_stage, | |
| data_len_by_branch, | |
| kernel_size, | |
| dilation, | |
| interpolate_mode, | |
| ): | |
| super().__init__() | |
| self.branches = nn.ModuleList() | |
| self.fusions = nn.ModuleList() | |
| block_type: BasicBlock | Bottleneck = block_type_by_stage[stage_idx] | |
| in_channels = in_channels_by_stage[stage_idx] | |
| for i in range(stage_idx + 1): # branch μμ± | |
| blocks_by_branch = [] | |
| _channels = in_channels[i] | |
| blocks_by_branch.append( | |
| block_type(_channels, _channels, kernel_size, dilation) | |
| ) | |
| for _ in range(1, num_blocks): | |
| blocks_by_branch.append( | |
| block_type( | |
| _channels * block_type.expansion, | |
| _channels, | |
| kernel_size, | |
| dilation, | |
| ) | |
| ) | |
| self.branches.append(nn.Sequential(*blocks_by_branch)) | |
| out_channels = out_channels_by_stage[stage_idx] | |
| for i in range(stage_idx + 1): | |
| fusion_by_branch = nn.ModuleList() | |
| for j in range(stage_idx + 1): | |
| if i < j: | |
| fusion_by_branch.append( | |
| nn.Sequential( | |
| _gen_channel_change_conv(out_channels[j], in_channels[i]), | |
| nn.BatchNorm1d(in_channels[i]), | |
| nn.Upsample( | |
| size=data_len_by_branch[i], mode=interpolate_mode | |
| ), | |
| ) | |
| ) | |
| elif i == j: | |
| if out_channels[i] != in_channels[j]: | |
| fusion_by_branch.append( | |
| nn.Sequential( | |
| _gen_channel_change_conv( | |
| out_channels[i], in_channels[j] | |
| ), | |
| nn.BatchNorm1d(in_channels[j]), | |
| nn.ReLU(), | |
| ) | |
| ) | |
| else: | |
| fusion_by_branch.append(nn.Identity()) | |
| else: | |
| # μ°¨μ΄λλ branch λ§νΌ 2λ°°μ© downsample, channel μ νμ¬ layer μ in_channel λ‘ λ§μΆ°μ€ | |
| downsamples = [ | |
| _gen_downsample(out_channels[j], in_channels[i]), | |
| nn.BatchNorm1d(in_channels[i]), | |
| ] | |
| for _ in range(1, i - j): | |
| downsamples.extend( | |
| [ | |
| nn.ReLU(), | |
| _gen_downsample(in_channels[i], in_channels[i]), | |
| nn.BatchNorm1d(in_channels[i]), | |
| ] | |
| ) | |
| fusion_by_branch.append(nn.Sequential(*downsamples)) | |
| self.fusions.append(fusion_by_branch) | |
| class HRNetV2(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| data_len = int(config.data_len) # ECGPQRSTDataset.second, hz μ λ§μΆ°μ | |
| kernel_size = int(config.kernel_size) | |
| dilation = int(config.dilation) | |
| num_stages = int(config.num_stages) | |
| num_blocks = int(config.num_blocks) | |
| self.num_modules = config.num_modules # [1, 1, 4, 3, ..] | |
| assert num_stages <= len(self.num_modules) | |
| use_bottleneck = config.use_bottleneck # [1, 0, 0, 0, ..] | |
| assert num_stages <= len(use_bottleneck) | |
| stage1_channels = int(config.stage1_channels) # 64, 128 | |
| num_channels_init = int(config.num_channels_init) # 18, 32, 48 | |
| self.interpolate_mode = config.interpolate_mode | |
| output_size = config.output_size # 3(p, qrs, t) | |
| # stem | |
| self.stem = nn.Sequential( | |
| nn.Conv1d( | |
| 1, stage1_channels, kernel_size=3, stride=2, padding=1, bias=False | |
| ), | |
| nn.BatchNorm1d(stage1_channels), | |
| nn.Conv1d( | |
| stage1_channels, | |
| stage1_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm1d(stage1_channels), | |
| nn.ReLU(), | |
| ) | |
| for _ in range(2): # stem μ κ±°μΉ μ΄ν λ°μ΄ν° κΈΈμ΄ κ³μ° | |
| data_len = math.floor((data_len - 1) / 2 + 1) | |
| # create meta: λ€νΈμν¬ μμ± μ κ° stage μ in_channel, out_channel λ±μ μ 보λ₯Ό λ¨Όμ λ§λ€κ³ μμ | |
| in_channels_by_stage = [] | |
| out_channels_by_stage = [] | |
| block_type_by_stage = [] | |
| for stage_idx in range(num_stages): | |
| block_type_each_stage = ( | |
| Bottleneck if use_bottleneck[stage_idx] == 1 else BasicBlock | |
| ) | |
| if stage_idx == 0: | |
| in_channels_each_stage = [stage1_channels] | |
| out_channels_each_stage = [ | |
| stage1_channels * block_type_each_stage.expansion | |
| ] | |
| data_len_by_branch = [data_len] | |
| else: | |
| in_channels_each_stage = [ | |
| num_channels_init * 2**idx for idx in range(stage_idx + 1) | |
| ] | |
| out_channels_each_stage = [ | |
| (num_channels_init * 2**idx) * block_type_each_stage.expansion | |
| for idx in range(stage_idx + 1) | |
| ] | |
| data_len_by_branch.append( | |
| math.floor((data_len_by_branch[-1] - 1) / 2 + 1) | |
| ) | |
| block_type_by_stage.append(block_type_each_stage) | |
| in_channels_by_stage.append(in_channels_each_stage) | |
| out_channels_by_stage.append(out_channels_each_stage) | |
| # create stages | |
| self.stages = nn.ModuleList() | |
| for stage_idx in range(num_stages): | |
| modules_by_stage = nn.ModuleList() | |
| for _ in range(self.num_modules[stage_idx]): | |
| modules_by_stage.append( | |
| HRModule( | |
| stage_idx, | |
| num_blocks, | |
| block_type_by_stage, | |
| in_channels_by_stage, | |
| out_channels_by_stage, | |
| data_len_by_branch, | |
| kernel_size, | |
| dilation, | |
| self.interpolate_mode, | |
| ) | |
| ) | |
| self.stages.append(modules_by_stage) | |
| # create transition | |
| self.transitions = nn.ModuleList() | |
| for stage_idx in range(num_stages - 1): | |
| # μ¬κΈ°μμ stage_idx λ μ΄μ stage λ₯Ό λ»ν¨. transition μ κ° stage μ¬μ΄μμ channel μ λ°κΏμ£Όκ±°λ μλ‘μ΄ branch λ₯Ό μμ±νλ μν | |
| transition_by_stage = nn.ModuleList() | |
| psc = in_channels_by_stage[stage_idx] # psc: prev_stage_channels | |
| nsc = in_channels_by_stage[stage_idx + 1] # nsc: next_stage_channels | |
| for nsbi in range(stage_idx + 2): # nsbi: next_stage_branch_idx | |
| if nsbi < stage_idx + 1: # λμΌν branch level | |
| if psc[nsbi] != nsc[nsbi]: | |
| transition_by_stage.append( | |
| nn.Sequential( | |
| _gen_channel_change_conv(psc[nsbi], nsc[nsbi]), | |
| nn.BatchNorm1d(nsc[nsbi]), | |
| nn.ReLU(), | |
| ) | |
| ) | |
| else: | |
| transition_by_stage.append(nn.Identity()) | |
| else: # create new branch from exists branches | |
| transition_from_branches = nn.ModuleList() | |
| for psbi in range(nsbi): | |
| # psbi: prev_stage_branch_idx | |
| transition_from_one_branch = [ | |
| _gen_downsample(psc[psbi], nsc[nsbi]), | |
| nn.BatchNorm1d(nsc[nsbi]), | |
| ] | |
| for _ in range(1, nsbi - psbi): | |
| transition_from_one_branch.extend( | |
| [ | |
| nn.ReLU(), | |
| _gen_downsample(nsc[nsbi], nsc[nsbi]), | |
| nn.BatchNorm1d(nsc[nsbi]), | |
| ] | |
| ) | |
| transition_from_branches.append( | |
| nn.Sequential(*transition_from_one_branch) | |
| ) | |
| transition_by_stage.append(transition_from_branches) | |
| self.transitions.append(transition_by_stage) | |
| self.cls = nn.Conv1d(sum(in_channels_each_stage), output_size, 1, bias=False) | |
| def forward(self, input: torch.Tensor, y=None): | |
| output: torch.Tensor = input | |
| output = self.stem(output) | |
| outputs = [output] | |
| for stage_idx, stage in enumerate(self.stages): | |
| for module_idx in range(self.num_modules[stage_idx]): | |
| for branch_idx in range(stage_idx + 1): | |
| outputs[branch_idx] = stage[module_idx].branches[branch_idx]( | |
| outputs[branch_idx] | |
| ) | |
| fusion_outputs = [] | |
| for next in range(stage_idx + 1): | |
| fusion_output_from_branches = [] | |
| for prev in range(stage_idx + 1): | |
| fusion_output_from_branch: torch.Tensor = stage[ | |
| module_idx | |
| ].fusions[next][prev](outputs[prev]) | |
| fusion_output_from_branches.append(fusion_output_from_branch) | |
| fusion_outputs.append(sum(fusion_output_from_branches)) | |
| outputs = fusion_outputs | |
| if stage_idx < len(self.stages) - 1: | |
| transition_outputs = [] | |
| for trans_idx, transition in enumerate(self.transitions[stage_idx]): | |
| # transition μλ λ€μ stage μ branch κ°μλ§νΌ Sequential μ΄λ ModuleList κ° μ‘΄μ¬ | |
| # μμ Sequential λ€μ channel λ§ λ€μ stage μ λ§κ² λ³κ²½νκ±°λ κΈ°μ‘΄ κ·Έλλ‘ μ¬μ© (Identity) | |
| # λ§μ§λ§ ModuleList κ° branch μ fusion κ²°κ³Όλ€μ downsample ν κ²°κ³Όλ€λ‘λΆν° μλ‘μ΄ branch λ₯Ό μμ± | |
| if trans_idx < stage_idx + 1: | |
| transition_outputs.append(transition(outputs[trans_idx])) | |
| else: | |
| transition_outputs.append( | |
| sum( | |
| [ | |
| transition_from_each_branch(output) | |
| for transition_from_each_branch, output in zip( | |
| transition, outputs | |
| ) | |
| ] | |
| ) | |
| ) | |
| outputs = transition_outputs | |
| # HRNetV2 | |
| outputs = [ | |
| F.interpolate(output, size=outputs[0].shape[-1], mode=self.interpolate_mode) | |
| for output in outputs | |
| ] | |
| output = torch.cat(outputs, dim=1) | |
| return F.interpolate( | |
| self.cls(output), size=input.shape[-1], mode=self.interpolate_mode | |
| ) | |