import torch import torch.nn as nn import torch.nn.functional as F import math from fastai.layers import * from fastai.data.core import * from typing import Optional, Collection, Union from collections.abc import Iterable ''' This layer creates a convolution kernel that is convolved with the layer input over a single spatial (or temporal) dimension to produce a tensor of outputs. If use_bias is True, a bias vector is created and added to the outputs. Finally, if activation is not None, it is applied to the outputs as well. https://keras.io/api/layers/convolution_layers/convolution1d/ ''' def listify(o): if o is None: return [] if isinstance(o, list): return o if isinstance(o, str): return [o] if isinstance(o, Iterable): return list(o) return [o] import torch.nn as nn def bn_drop_lin(ni, no, bn=True, p=0., actn=None): layers = [] if bn: layers.append(nn.BatchNorm1d(ni)) if p != 0.: layers.append(nn.Dropout(p)) layers.append(nn.Linear(ni, no)) if actn is not None: layers.append(actn) return layers def _conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0): lst = [] if (drop_p > 0): lst.append(nn.Dropout(drop_p)) lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, dilation=dilation, bias=not bn)) if bn: lst.append(nn.BatchNorm1d(out_planes)) if act == "relu": lst.append(nn.ReLU(True)) if act == "elu": lst.append(nn.ELU(True)) if act == "prelu": lst.append(nn.PReLU(True)) return nn.Sequential(*lst) def _fc(in_planes, out_planes, act="relu", bn=True): lst = [nn.Linear(in_planes, out_planes, bias=not (bn))] if bn: lst.append(nn.BatchNorm1d(out_planes)) if act == "relu": lst.append(nn.ReLU(True)) if act == "elu": lst.append(nn.ELU(True)) if act == "prelu": lst.append(nn.PReLU(True)) return nn.Sequential(*lst) def cd_adaptive_concat_pool(relevant, irrelevant, module): mpr, mpi = module.mp.attrib(relevant, irrelevant) apr, api = module.ap.attrib(relevant, irrelevant) return torch.cat([mpr, apr], 1), torch.cat([mpi, api], 1) def attrib_adaptive_concat_pool(self, relevant, irrelevant): return cd_adaptive_concat_pool(relevant, irrelevant, self) class AdaptiveConcatPool1d(nn.Module): """Layer that concat `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`.""" def __init__(self, sz: Optional[int] = None): """Output will be 2*sz or 2 if sz is None""" super().__init__() sz = sz or 1 self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) def attrib(self, relevant, irrelevant): return attrib_adaptive_concat_pool(self, relevant, irrelevant) class SqueezeExcite1d(nn.Module): """squeeze excite block as used for example in LSTM FCN""" def __init__(self, channels, reduction=16): super().__init__() channels_reduced = channels // reduction self.w1 = torch.nn.Parameter(torch.randn(channels_reduced, channels).unsqueeze(0)) self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0)) def forward(self, x): # input is bs,ch,seq z = torch.mean(x, dim=2, keepdim=True) # bs,ch intermed = F.relu(torch.matmul(self.w1, z)) # (1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1) s = F.sigmoid(torch.matmul(self.w2, intermed)) # (1,ch,ch_red * bs, ch_red, 1=bs, ch, 1 return s * x # bs,ch,seq * bs, ch,1 = bs,ch,seq def weight_init(m): """call weight initialization for model n via n.apply(weight_init)""" if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if isinstance(m, SqueezeExcite1d): stdv1 = math.sqrt(2. / m.w1.size[0]) nn.init.normal_(m.w1, 0., stdv1) stdv2 = math.sqrt(1. / m.w2.size[1]) nn.init.normal_(m.w2, 0., stdv2) def create_head1d(nf: int, nc: int, lin_ftrs: Optional[Collection[int]] = None, ps: Union[float, Collection[float]] = 0.5, bn_final: bool = False, bn: bool = True, act="relu", concat_pooling=True): """Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here""" lin_ftrs = [2 * nf if concat_pooling else nf, nc] if lin_ftrs is None else [ 2 * nf if concat_pooling else nf] + lin_ftrs + [ nc] # was [nf, 512,nc] ps = listify(ps) if len(ps) == 1: ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None] layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()] for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns): layers += bn_drop_lin(ni, no, bn, p, actn) if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) return nn.Sequential(*layers) # basic convolutional architecture class BasicConv1d(nn.Sequential): """basic conv1d""" def __init__(self, filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False, split_first_layer=False, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): if filters is None: filters = [128, 128, 128, 128] layers = [] if isinstance(kernel_size, int): kernel_size = [kernel_size] * len(filters) for i in range(len(filters)): layers_tmp = [_conv1d(input_channels if i == 0 else filters[i - 1], filters[i], kernel_size=kernel_size[i], stride=(1 if (split_first_layer is True and i == 0) else stride), dilation=dilation, act="none" if ((headless is True and i == len(filters) - 1) or ( split_first_layer is True and i == 0)) else act, bn=False if (headless is True and i == len(filters) - 1) else bn, drop_p=(0. if i == 0 else drop_p))] if split_first_layer is True and i == 0: layers_tmp.append(_conv1d(filters[0], filters[0], kernel_size=1, stride=1, act=act, bn=bn, drop_p=0.)) # layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn))) # layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn)) if pool > 0 and i < len(filters) - 1: layers_tmp.append(nn.MaxPool1d(pool, stride=pool_stride, padding=(pool - 1) // 2)) if squeeze_excite_reduction > 0: layers_tmp.append(SqueezeExcite1d(filters[i], squeeze_excite_reduction)) layers.append(nn.Sequential(*layers_tmp)) # head layers.append(nn.AdaptiveAvgPool1d(1)) layers.append(nn.Linear(filters[-1],num_classes)) head # #inplace=True leads to a runtime error see ReLU+ dropout # https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5 self.headless = headless if headless is True: head = nn.Sequential(nn.AdaptiveAvgPool1d(1), Flatten()) else: head = create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) layers.append(head) super().__init__(*layers) def get_layer_groups(self): return self[2], self[-1] def get_output_layer(self): if self.headless is False: return self[-1][-1] else: return None def set_output_layer(self, x): if self.headless is False: self[-1][-1] = x # convenience functions for basic convolutional architectures def fcn(filters=None, num_classes=2, input_channels=8): if filters is None: filters = [128] * 5 filters_in = filters + [num_classes] return BasicConv1d(filters=filters_in, kernel_size=3, stride=1, pool=2, pool_stride=2, input_channels=input_channels, act="relu", bn=True, headless=True) def fcn_wang(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): return BasicConv1d(filters=[128, 256, 128], kernel_size=[8, 5, 3], stride=1, pool=0, pool_stride=2, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) def schirrmeister(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): return BasicConv1d(filters=[25, 50, 100, 200], kernel_size=10, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False, split_first_layer=True, drop_p=0.5, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) def sen(filters=None, num_classes=2, input_channels=8, squeeze_excite_reduction=16, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): if filters is None: filters = [128] * 5 return BasicConv1d(filters=filters, kernel_size=3, stride=2, pool=0, pool_stride=0, input_channels=input_channels, act="relu", bn=True, num_classes=num_classes, squeeze_excite_reduction=squeeze_excite_reduction, drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) def basic1d(filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): if filters is None: filters = [128] * 5 return BasicConv1d(filters=filters, kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless, drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)