Spaces:
Runtime error
Runtime error
| """ | |
| Net1D: 1D CNN with Squeeze-and-Excitation for ECG classification. | |
| From PKUDigitalHealth/ECGFounder (MIT License). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MyConv1dPadSame(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, groups=groups) | |
| def forward(self, x): | |
| in_dim = x.shape[-1] | |
| out_dim = (in_dim + self.stride - 1) // self.stride | |
| p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim) | |
| pad_left = p // 2 | |
| pad_right = p - pad_left | |
| x = F.pad(x, (pad_left, pad_right), "constant", 0) | |
| return self.conv(x) | |
| class MyMaxPool1dPadSame(nn.Module): | |
| def __init__(self, kernel_size): | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.max_pool = nn.MaxPool1d(kernel_size=kernel_size) | |
| def forward(self, x): | |
| p = max(0, self.kernel_size - 1) | |
| pad_left = p // 2 | |
| pad_right = p - pad_left | |
| x = F.pad(x, (pad_left, pad_right), "constant", 0) | |
| return self.max_pool(x) | |
| class Swish(nn.Module): | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| class BasicBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, ratio, kernel_size, stride, | |
| groups, downsample, is_first_block=False, use_bn=True, use_do=True): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.downsample = downsample | |
| self.stride = stride if downsample else 1 | |
| self.is_first_block = is_first_block | |
| self.use_bn = use_bn | |
| self.use_do = use_do | |
| middle = int(out_channels * ratio) | |
| self.bn1 = nn.BatchNorm1d(in_channels) | |
| self.activation1 = Swish() | |
| self.do1 = nn.Dropout(p=0.5) | |
| self.conv1 = MyConv1dPadSame(in_channels, middle, 1, 1, 1) | |
| self.bn2 = nn.BatchNorm1d(middle) | |
| self.activation2 = Swish() | |
| self.do2 = nn.Dropout(p=0.5) | |
| self.conv2 = MyConv1dPadSame(middle, middle, kernel_size, self.stride, groups) | |
| self.bn3 = nn.BatchNorm1d(middle) | |
| self.activation3 = Swish() | |
| self.do3 = nn.Dropout(p=0.5) | |
| self.conv3 = MyConv1dPadSame(middle, out_channels, 1, 1, 1) | |
| # Squeeze-and-Excitation | |
| r = 2 | |
| self.se_fc1 = nn.Linear(out_channels, out_channels // r) | |
| self.se_fc2 = nn.Linear(out_channels // r, out_channels) | |
| self.se_activation = Swish() | |
| if self.downsample: | |
| self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride) | |
| def forward(self, x): | |
| identity = x | |
| out = x | |
| if not self.is_first_block: | |
| if self.use_bn: | |
| out = self.bn1(out) | |
| out = self.activation1(out) | |
| if self.use_do: | |
| out = self.do1(out) | |
| out = self.conv1(out) | |
| if self.use_bn: | |
| out = self.bn2(out) | |
| out = self.activation2(out) | |
| if self.use_do: | |
| out = self.do2(out) | |
| out = self.conv2(out) | |
| if self.use_bn: | |
| out = self.bn3(out) | |
| out = self.activation3(out) | |
| if self.use_do: | |
| out = self.do3(out) | |
| out = self.conv3(out) | |
| # SE attention | |
| se = out.mean(-1) | |
| se = self.se_fc1(se) | |
| se = self.se_activation(se) | |
| se = self.se_fc2(se) | |
| se = torch.sigmoid(se) | |
| out = torch.einsum('abc,ab->abc', out, se) | |
| if self.downsample: | |
| identity = self.max_pool(identity) | |
| if self.out_channels != self.in_channels: | |
| identity = identity.transpose(-1, -2) | |
| ch1 = (self.out_channels - self.in_channels) // 2 | |
| ch2 = self.out_channels - self.in_channels - ch1 | |
| identity = F.pad(identity, (ch1, ch2), "constant", 0) | |
| identity = identity.transpose(-1, -2) | |
| out += identity | |
| return out | |
| class BasicStage(nn.Module): | |
| def __init__(self, in_channels, out_channels, ratio, kernel_size, stride, | |
| groups, i_stage, m_blocks, use_bn=True, use_do=True): | |
| super().__init__() | |
| self.block_list = nn.ModuleList() | |
| for i_block in range(m_blocks): | |
| is_first = (i_stage == 0 and i_block == 0) | |
| if i_block == 0: | |
| tmp_block = BasicBlock( | |
| in_channels, out_channels, ratio, kernel_size, | |
| stride, groups, downsample=True, | |
| is_first_block=is_first, use_bn=use_bn, use_do=use_do) | |
| else: | |
| tmp_block = BasicBlock( | |
| out_channels, out_channels, ratio, kernel_size, | |
| 1, groups, downsample=False, | |
| is_first_block=False, use_bn=use_bn, use_do=use_do) | |
| self.block_list.append(tmp_block) | |
| def forward(self, x): | |
| for block in self.block_list: | |
| x = block(x) | |
| return x | |
| class Net1D(nn.Module): | |
| """ | |
| 1D CNN for ECG classification. | |
| Input: (batch, in_channels, length) | |
| Output: (batch, n_classes) | |
| """ | |
| def __init__(self, in_channels, base_filters, ratio, filter_list, | |
| m_blocks_list, kernel_size, stride, groups_width, | |
| n_classes, use_bn=True, use_do=True, verbose=False): | |
| super().__init__() | |
| self.n_stages = len(filter_list) | |
| self.use_bn = use_bn | |
| self.first_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size, stride=2) | |
| self.first_bn = nn.BatchNorm1d(base_filters) | |
| self.first_activation = Swish() | |
| self.stage_list = nn.ModuleList() | |
| in_ch = base_filters | |
| for i_stage in range(self.n_stages): | |
| out_ch = filter_list[i_stage] | |
| self.stage_list.append(BasicStage( | |
| in_ch, out_ch, ratio, kernel_size, stride, | |
| out_ch // groups_width, i_stage, m_blocks_list[i_stage], | |
| use_bn=use_bn, use_do=use_do)) | |
| in_ch = out_ch | |
| self.dense = nn.Linear(in_ch, n_classes) | |
| def forward(self, x): | |
| out = self.first_conv(x) | |
| if self.use_bn: | |
| out = self.first_bn(out) | |
| out = self.first_activation(out) | |
| for stage in self.stage_list: | |
| out = stage(out) | |
| features = out.mean(-1) # Global Average Pooling | |
| out = self.dense(features) | |
| return out | |