|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from fastai.layers import * |
|
|
from fastai.data.core import * |
|
|
from typing import Optional, Collection, Union |
|
|
from collections.abc import Iterable |
|
|
|
|
|
''' |
|
|
This layer creates a convolution kernel that is convolved with the layer input |
|
|
over a single spatial (or temporal) dimension to produce a tensor of outputs. |
|
|
If use_bias is True, a bias vector is created and added to the outputs. |
|
|
Finally, if activation is not None, it is applied to the outputs as well. |
|
|
https://keras.io/api/layers/convolution_layers/convolution1d/ |
|
|
''' |
|
|
def listify(o): |
|
|
if o is None: return [] |
|
|
if isinstance(o, list): return o |
|
|
if isinstance(o, str): return [o] |
|
|
if isinstance(o, Iterable): return list(o) |
|
|
return [o] |
|
|
import torch.nn as nn |
|
|
|
|
|
def bn_drop_lin(ni, no, bn=True, p=0., actn=None): |
|
|
layers = [] |
|
|
if bn: layers.append(nn.BatchNorm1d(ni)) |
|
|
if p != 0.: layers.append(nn.Dropout(p)) |
|
|
layers.append(nn.Linear(ni, no)) |
|
|
if actn is not None: layers.append(actn) |
|
|
return layers |
|
|
|
|
|
def _conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0): |
|
|
lst = [] |
|
|
if (drop_p > 0): |
|
|
lst.append(nn.Dropout(drop_p)) |
|
|
lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, |
|
|
dilation=dilation, bias=not bn)) |
|
|
if bn: |
|
|
lst.append(nn.BatchNorm1d(out_planes)) |
|
|
if act == "relu": |
|
|
lst.append(nn.ReLU(True)) |
|
|
if act == "elu": |
|
|
lst.append(nn.ELU(True)) |
|
|
if act == "prelu": |
|
|
lst.append(nn.PReLU(True)) |
|
|
return nn.Sequential(*lst) |
|
|
|
|
|
|
|
|
def _fc(in_planes, out_planes, act="relu", bn=True): |
|
|
lst = [nn.Linear(in_planes, out_planes, bias=not (bn))] |
|
|
if bn: |
|
|
lst.append(nn.BatchNorm1d(out_planes)) |
|
|
if act == "relu": |
|
|
lst.append(nn.ReLU(True)) |
|
|
if act == "elu": |
|
|
lst.append(nn.ELU(True)) |
|
|
if act == "prelu": |
|
|
lst.append(nn.PReLU(True)) |
|
|
return nn.Sequential(*lst) |
|
|
|
|
|
|
|
|
def cd_adaptive_concat_pool(relevant, irrelevant, module): |
|
|
mpr, mpi = module.mp.attrib(relevant, irrelevant) |
|
|
apr, api = module.ap.attrib(relevant, irrelevant) |
|
|
return torch.cat([mpr, apr], 1), torch.cat([mpi, api], 1) |
|
|
|
|
|
|
|
|
def attrib_adaptive_concat_pool(self, relevant, irrelevant): |
|
|
return cd_adaptive_concat_pool(relevant, irrelevant, self) |
|
|
|
|
|
|
|
|
class AdaptiveConcatPool1d(nn.Module): |
|
|
"""Layer that concat `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`.""" |
|
|
|
|
|
def __init__(self, sz: Optional[int] = None): |
|
|
"""Output will be 2*sz or 2 if sz is None""" |
|
|
super().__init__() |
|
|
sz = sz or 1 |
|
|
self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) |
|
|
|
|
|
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) |
|
|
|
|
|
def attrib(self, relevant, irrelevant): |
|
|
return attrib_adaptive_concat_pool(self, relevant, irrelevant) |
|
|
|
|
|
|
|
|
class SqueezeExcite1d(nn.Module): |
|
|
"""squeeze excite block as used for example in LSTM FCN""" |
|
|
|
|
|
def __init__(self, channels, reduction=16): |
|
|
super().__init__() |
|
|
channels_reduced = channels // reduction |
|
|
self.w1 = torch.nn.Parameter(torch.randn(channels_reduced, channels).unsqueeze(0)) |
|
|
self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0)) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
z = torch.mean(x, dim=2, keepdim=True) |
|
|
intermed = F.relu(torch.matmul(self.w1, z)) |
|
|
s = F.sigmoid(torch.matmul(self.w2, intermed)) |
|
|
return s * x |
|
|
|
|
|
|
|
|
def weight_init(m): |
|
|
"""call weight initialization for model n via n.apply(weight_init)""" |
|
|
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): |
|
|
nn.init.kaiming_normal_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
if isinstance(m, nn.BatchNorm1d): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
if isinstance(m, SqueezeExcite1d): |
|
|
stdv1 = math.sqrt(2. / m.w1.size[0]) |
|
|
nn.init.normal_(m.w1, 0., stdv1) |
|
|
stdv2 = math.sqrt(1. / m.w2.size[1]) |
|
|
nn.init.normal_(m.w2, 0., stdv2) |
|
|
|
|
|
|
|
|
def create_head1d(nf: int, nc: int, lin_ftrs: Optional[Collection[int]] = None, ps: Union[float, Collection[float]] = 0.5, |
|
|
bn_final: bool = False, bn: bool = True, act="relu", concat_pooling=True): |
|
|
"""Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here""" |
|
|
lin_ftrs = [2 * nf if concat_pooling else nf, nc] if lin_ftrs is None else [ |
|
|
2 * nf if concat_pooling else nf] + lin_ftrs + [ |
|
|
nc] |
|
|
ps = listify(ps) |
|
|
if len(ps) == 1: ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps |
|
|
actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None] |
|
|
layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()] |
|
|
for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns): |
|
|
layers += bn_drop_lin(ni, no, bn, p, actn) |
|
|
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicConv1d(nn.Sequential): |
|
|
"""basic conv1d""" |
|
|
|
|
|
def __init__(self, filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, |
|
|
squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False, |
|
|
split_first_layer=False, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, |
|
|
act_head="relu", concat_pooling=True): |
|
|
if filters is None: |
|
|
filters = [128, 128, 128, 128] |
|
|
layers = [] |
|
|
if isinstance(kernel_size, int): |
|
|
kernel_size = [kernel_size] * len(filters) |
|
|
for i in range(len(filters)): |
|
|
layers_tmp = [_conv1d(input_channels if i == 0 else filters[i - 1], filters[i], kernel_size=kernel_size[i], |
|
|
stride=(1 if (split_first_layer is True and i == 0) else stride), dilation=dilation, |
|
|
act="none" if ((headless is True and i == len(filters) - 1) or ( |
|
|
split_first_layer is True and i == 0)) else act, |
|
|
bn=False if (headless is True and i == len(filters) - 1) else bn, |
|
|
drop_p=(0. if i == 0 else drop_p))] |
|
|
|
|
|
if split_first_layer is True and i == 0: |
|
|
layers_tmp.append(_conv1d(filters[0], filters[0], kernel_size=1, stride=1, act=act, bn=bn, drop_p=0.)) |
|
|
|
|
|
|
|
|
if pool > 0 and i < len(filters) - 1: |
|
|
layers_tmp.append(nn.MaxPool1d(pool, stride=pool_stride, padding=(pool - 1) // 2)) |
|
|
if squeeze_excite_reduction > 0: |
|
|
layers_tmp.append(SqueezeExcite1d(filters[i], squeeze_excite_reduction)) |
|
|
layers.append(nn.Sequential(*layers_tmp)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.headless = headless |
|
|
if headless is True: |
|
|
head = nn.Sequential(nn.AdaptiveAvgPool1d(1), Flatten()) |
|
|
else: |
|
|
head = create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, |
|
|
bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) |
|
|
layers.append(head) |
|
|
|
|
|
super().__init__(*layers) |
|
|
|
|
|
def get_layer_groups(self): |
|
|
return self[2], self[-1] |
|
|
|
|
|
def get_output_layer(self): |
|
|
if self.headless is False: |
|
|
return self[-1][-1] |
|
|
else: |
|
|
return None |
|
|
|
|
|
def set_output_layer(self, x): |
|
|
if self.headless is False: |
|
|
self[-1][-1] = x |
|
|
|
|
|
|
|
|
|
|
|
def fcn(filters=None, num_classes=2, input_channels=8): |
|
|
if filters is None: |
|
|
filters = [128] * 5 |
|
|
filters_in = filters + [num_classes] |
|
|
return BasicConv1d(filters=filters_in, kernel_size=3, stride=1, pool=2, pool_stride=2, |
|
|
input_channels=input_channels, act="relu", bn=True, headless=True) |
|
|
|
|
|
|
|
|
def fcn_wang(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, |
|
|
act_head="relu", concat_pooling=True): |
|
|
return BasicConv1d(filters=[128, 256, 128], kernel_size=[8, 5, 3], stride=1, pool=0, pool_stride=2, |
|
|
num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, |
|
|
lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, |
|
|
act_head=act_head, concat_pooling=concat_pooling) |
|
|
|
|
|
|
|
|
def schirrmeister(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, |
|
|
act_head="relu", concat_pooling=True): |
|
|
return BasicConv1d(filters=[25, 50, 100, 200], kernel_size=10, stride=3, pool=3, pool_stride=1, |
|
|
num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False, |
|
|
split_first_layer=True, drop_p=0.5, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, |
|
|
bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |
|
|
|
|
|
|
|
|
def sen(filters=None, num_classes=2, input_channels=8, squeeze_excite_reduction=16, drop_p=0., lin_ftrs_head=None, |
|
|
ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): |
|
|
if filters is None: |
|
|
filters = [128] * 5 |
|
|
return BasicConv1d(filters=filters, kernel_size=3, stride=2, pool=0, pool_stride=0, input_channels=input_channels, |
|
|
act="relu", bn=True, num_classes=num_classes, squeeze_excite_reduction=squeeze_excite_reduction, |
|
|
drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, |
|
|
bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |
|
|
|
|
|
|
|
|
def basic1d(filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, |
|
|
num_classes=2, input_channels=8, act="relu", bn=True, headless=False, drop_p=0., lin_ftrs_head=None, |
|
|
ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): |
|
|
if filters is None: |
|
|
filters = [128] * 5 |
|
|
return BasicConv1d(filters=filters, kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, |
|
|
pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, |
|
|
num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless, |
|
|
drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, |
|
|
bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |