Sanuka0523's picture
Deploy cardiac monitor FastAPI backend
11e9a40
"""
Net1D: 1D CNN with Squeeze-and-Excitation for ECG classification.
From PKUDigitalHealth/ECGFounder (MIT License).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyConv1dPadSame(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, groups=groups)
def forward(self, x):
in_dim = x.shape[-1]
out_dim = (in_dim + self.stride - 1) // self.stride
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
pad_left = p // 2
pad_right = p - pad_left
x = F.pad(x, (pad_left, pad_right), "constant", 0)
return self.conv(x)
class MyMaxPool1dPadSame(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.kernel_size = kernel_size
self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
def forward(self, x):
p = max(0, self.kernel_size - 1)
pad_left = p // 2
pad_right = p - pad_left
x = F.pad(x, (pad_left, pad_right), "constant", 0)
return self.max_pool(x)
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, ratio, kernel_size, stride,
groups, downsample, is_first_block=False, use_bn=True, use_do=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.downsample = downsample
self.stride = stride if downsample else 1
self.is_first_block = is_first_block
self.use_bn = use_bn
self.use_do = use_do
middle = int(out_channels * ratio)
self.bn1 = nn.BatchNorm1d(in_channels)
self.activation1 = Swish()
self.do1 = nn.Dropout(p=0.5)
self.conv1 = MyConv1dPadSame(in_channels, middle, 1, 1, 1)
self.bn2 = nn.BatchNorm1d(middle)
self.activation2 = Swish()
self.do2 = nn.Dropout(p=0.5)
self.conv2 = MyConv1dPadSame(middle, middle, kernel_size, self.stride, groups)
self.bn3 = nn.BatchNorm1d(middle)
self.activation3 = Swish()
self.do3 = nn.Dropout(p=0.5)
self.conv3 = MyConv1dPadSame(middle, out_channels, 1, 1, 1)
# Squeeze-and-Excitation
r = 2
self.se_fc1 = nn.Linear(out_channels, out_channels // r)
self.se_fc2 = nn.Linear(out_channels // r, out_channels)
self.se_activation = Swish()
if self.downsample:
self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)
def forward(self, x):
identity = x
out = x
if not self.is_first_block:
if self.use_bn:
out = self.bn1(out)
out = self.activation1(out)
if self.use_do:
out = self.do1(out)
out = self.conv1(out)
if self.use_bn:
out = self.bn2(out)
out = self.activation2(out)
if self.use_do:
out = self.do2(out)
out = self.conv2(out)
if self.use_bn:
out = self.bn3(out)
out = self.activation3(out)
if self.use_do:
out = self.do3(out)
out = self.conv3(out)
# SE attention
se = out.mean(-1)
se = self.se_fc1(se)
se = self.se_activation(se)
se = self.se_fc2(se)
se = torch.sigmoid(se)
out = torch.einsum('abc,ab->abc', out, se)
if self.downsample:
identity = self.max_pool(identity)
if self.out_channels != self.in_channels:
identity = identity.transpose(-1, -2)
ch1 = (self.out_channels - self.in_channels) // 2
ch2 = self.out_channels - self.in_channels - ch1
identity = F.pad(identity, (ch1, ch2), "constant", 0)
identity = identity.transpose(-1, -2)
out += identity
return out
class BasicStage(nn.Module):
def __init__(self, in_channels, out_channels, ratio, kernel_size, stride,
groups, i_stage, m_blocks, use_bn=True, use_do=True):
super().__init__()
self.block_list = nn.ModuleList()
for i_block in range(m_blocks):
is_first = (i_stage == 0 and i_block == 0)
if i_block == 0:
tmp_block = BasicBlock(
in_channels, out_channels, ratio, kernel_size,
stride, groups, downsample=True,
is_first_block=is_first, use_bn=use_bn, use_do=use_do)
else:
tmp_block = BasicBlock(
out_channels, out_channels, ratio, kernel_size,
1, groups, downsample=False,
is_first_block=False, use_bn=use_bn, use_do=use_do)
self.block_list.append(tmp_block)
def forward(self, x):
for block in self.block_list:
x = block(x)
return x
class Net1D(nn.Module):
"""
1D CNN for ECG classification.
Input: (batch, in_channels, length)
Output: (batch, n_classes)
"""
def __init__(self, in_channels, base_filters, ratio, filter_list,
m_blocks_list, kernel_size, stride, groups_width,
n_classes, use_bn=True, use_do=True, verbose=False):
super().__init__()
self.n_stages = len(filter_list)
self.use_bn = use_bn
self.first_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size, stride=2)
self.first_bn = nn.BatchNorm1d(base_filters)
self.first_activation = Swish()
self.stage_list = nn.ModuleList()
in_ch = base_filters
for i_stage in range(self.n_stages):
out_ch = filter_list[i_stage]
self.stage_list.append(BasicStage(
in_ch, out_ch, ratio, kernel_size, stride,
out_ch // groups_width, i_stage, m_blocks_list[i_stage],
use_bn=use_bn, use_do=use_do))
in_ch = out_ch
self.dense = nn.Linear(in_ch, n_classes)
def forward(self, x):
out = self.first_conv(x)
if self.use_bn:
out = self.first_bn(out)
out = self.first_activation(out)
for stage in self.stage_list:
out = stage(out)
features = out.mean(-1) # Global Average Pooling
out = self.dense(features)
return out