ECG / models /resnet1d.py
cuongnx2001's picture
Upload 34 files
264b4c4 verified
import torch.nn as nn
import torch.nn.functional as F
# Standard resnet
from models.basicconv1d import create_head1d
def conv(in_planes, out_planes, stride=1, kernel_size=3):
"""convolution with padding"""
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=(kernel_size - 1) // 2, bias=False)
class BasicBlock1d(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, kernel_size=None, down_sample=None):
if kernel_size is None:
kernel_size = [3, 3]
super().__init__()
if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size // 2 + 1]
self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=kernel_size[0])
self.bn1 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv(planes, planes, kernel_size=kernel_size[1])
self.bn2 = nn.BatchNorm1d(planes)
self.down_sample = down_sample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.down_sample is not None:
residual = self.down_sample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck1d(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, kernel_size=3, down_sample=None):
super().__init__()
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride,
padding=(kernel_size - 1) // 2, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm1d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.down_sample = down_sample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample is not None:
residual = self.down_sample(x)
out += residual
out = self.relu(out)
return out
class ResNet1d(nn.Sequential):
"""1d adaptation of the torchvision resnet"""
def __init__(self, block, layers, kernel_size=3, num_classes=2, input_channels=3, inplanes=64, fix_feature_dim=True,
kernel_size_stem=None, stride_stem=2, pooling_stem=True, stride=2, lin_ftrs_head=None, ps_head=0.5,
bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
self.inplanes = inplanes
layers_tmp = []
if kernel_size_stem is None:
kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size
# stem
layers_tmp.append(nn.Conv1d(input_channels, inplanes, kernel_size=kernel_size_stem, stride=stride_stem,
padding=(kernel_size_stem - 1) // 2, bias=False))
layers_tmp.append(nn.BatchNorm1d(inplanes))
layers_tmp.append(nn.ReLU(inplace=True))
if pooling_stem is True:
layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
# backbone
for i, l in enumerate(layers):
if i == 0:
layers_tmp.append(self._make_layer(block, inplanes, layers[0], kernel_size=kernel_size))
else:
layers_tmp.append(
self._make_layer(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i],
stride=stride, kernel_size=kernel_size))
# head
# layers_tmp.append(nn.AdaptiveAvgPool1d(1))
# layers_tmp.append(Flatten())
# layers_tmp.append(nn.Linear((inplanes if fix_feature_dim else (2**len(layers)*inplanes)) * block.expansion, num_classes))
head = create_head1d(
(inplanes if fix_feature_dim else (2 ** len(layers) * inplanes)) * block.expansion, nc=num_classes,
lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head,
concat_pooling=concat_pooling)
layers_tmp.append(head)
super().__init__()
def _make_layer(self, block, planes, blocks, stride=1, kernel_size=3):
down_sample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_sample = nn.Sequential(
nn.Conv1d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, kernel_size, down_sample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def get_layer_groups(self):
return self[6], self[-1]
def get_output_layer(self):
return self[-1][-1]
def set_output_layer(self, x):
self[-1][-1] = x
def resnet1d18(**kwargs):
"""
Constructs a ResNet-18 model.
"""
return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
def resnet1d34(**kwargs):
"""
Constructs a ResNet-34 model.
"""
return ResNet1d(BasicBlock1d, [3, 4, 6, 3], **kwargs)
def resnet1d50(**kwargs):
"""
Constructs a ResNet-50 model.
"""
return ResNet1d(Bottleneck1d, [3, 4, 6, 3], **kwargs)
def resnet1d101(**kwargs):
"""
Constructs a ResNet-101 model.
"""
return ResNet1d(Bottleneck1d, [3, 4, 23, 3], **kwargs)
def resnet1d152(**kwargs):
"""
Constructs a ResNet-152 model.
"""
return ResNet1d(Bottleneck1d, [3, 8, 36, 3], **kwargs)
# original used kernel_size_stem = 8
def resnet1d_wang(**kwargs):
if not ("kernel_size" in kwargs.keys()):
kwargs["kernel_size"] = [5, 3]
if not ("kernel_size_stem" in kwargs.keys()):
kwargs["kernel_size_stem"] = 7
if not ("stride_stem" in kwargs.keys()):
kwargs["stride_stem"] = 1
if not ("pooling_stem" in kwargs.keys()):
kwargs["pooling_stem"] = False
if not ("inplanes" in kwargs.keys()):
kwargs["inplanes"] = 128
return ResNet1d(BasicBlock1d, [1, 1, 1], **kwargs)
def resnet1d(**kwargs):
"""
Constructs a custom ResNet model.
"""
return ResNet1d(BasicBlock1d, **kwargs)
# wide resnet adopted from fastai wrn
def noop(x): return x
def conv1d(ni: int, nf: int, ks: int = 3, stride: int = 1, padding: int = None, bias=False) -> nn.Conv1d:
"Create `nn.Conv1d` layer: `ni` inputs, `nf` outputs, `ks` kernel size. `padding` defaults to `k//2`."
if padding is None: padding = ks // 2
return nn.Conv1d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias)
def _bn1d(ni, init_zero=False):
"Batchnorm layer with 0 initialization"
m = nn.BatchNorm1d(ni)
m.weight.data.fill_(0 if init_zero else 1)
m.bias.data.zero_()
return m
def bn_relu_conv1d(ni, nf, ks, stride, init_zero=False):
bn_initzero = _bn1d(ni, init_zero=init_zero)
return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv1d(ni, nf, ks, stride))
class BasicBlock1dwrn(nn.Module):
def __init__(self, ni, nf, stride, drop_p=0.0, ks=3):
super().__init__()
if isinstance(ks, int):
ks = [ks, ks // 2 + 1]
self.bn = nn.BatchNorm1d(ni)
self.conv1 = conv1d(ni, nf, ks[0], stride)
self.conv2 = bn_relu_conv1d(nf, nf, ks[0], 1)
self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
self.shortcut = conv1d(ni, nf, ks[1], stride) if (
ni != nf or stride > 1) else noop # adapted to make it work for fix_feature_dim=True
def forward(self, x):
x2 = F.relu(self.bn(x), inplace=True)
r = self.shortcut(x2)
x = self.conv1(x2)
if self.drop: x = self.drop(x)
x = self.conv2(x) * 0.2
return x.add_(r)
def _make_group(N, ni, nf, block, stride, drop_p, ks=3):
return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p, ks=ks) for i in range(N)]
class WideResNet1d(nn.Sequential):
def __init__(self, input_channels: int, num_groups: int, N: int, num_classes: int, k: int = 1, drop_p: float = 0.0,
start_nf: int = 16, fix_feature_dim=True, kernel_size=5, lin_ftrs_head=None, ps_head=0.5,
bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
super().__init__()
n_channels = [start_nf]
for i in range(num_groups): n_channels.append(start_nf if fix_feature_dim else start_nf * (2 ** i) * k)
layers = [conv1d(input_channels, n_channels[0], 3, 1)] # conv1 stem
for i in range(num_groups):
layers += _make_group(N, n_channels[i], n_channels[i + 1], BasicBlock1dwrn,
(1 if i == 0 else 2), drop_p, ks=kernel_size)
# layers += [nn.BatchNorm1d(n_channels[-1]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1),
# Flatten(), nn.Linear(n_channels[-1], num_classes)]
head = create_head1d(n_channels[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
bn_final=bn_final_head, bn=bn_head, act=act_head,
concat_pooling=concat_pooling)
layers.append(head)
super().__init__()
def get_layer_groups(self):
return self[6], self[-1]
def get_output_layer(self):
return self[-1][-1]
def set_output_layer(self, x):
self[-1][-1] = x
def wrn1d_22(**kwargs): return WideResNet1d(num_groups=3, N=3, k=6, drop_p=0., **kwargs)