File size: 11,597 Bytes
264b4c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from fastai.layers import *
from fastai.data.core import *
from typing import Optional, Collection, Union
from collections.abc import Iterable

'''
This layer creates a convolution kernel that is convolved with the layer input
over a single spatial (or temporal) dimension to produce a tensor of outputs.
If use_bias is True, a bias vector is created and added to the outputs.
Finally, if activation is not None, it is applied to the outputs as well.
https://keras.io/api/layers/convolution_layers/convolution1d/
'''
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]
import torch.nn as nn

def bn_drop_lin(ni, no, bn=True, p=0., actn=None):
    layers = []
    if bn: layers.append(nn.BatchNorm1d(ni))
    if p != 0.: layers.append(nn.Dropout(p))
    layers.append(nn.Linear(ni, no))
    if actn is not None: layers.append(actn)
    return layers

def _conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0):
    lst = []
    if (drop_p > 0):
        lst.append(nn.Dropout(drop_p))
    lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
                         dilation=dilation, bias=not bn))
    if bn:
        lst.append(nn.BatchNorm1d(out_planes))
    if act == "relu":
        lst.append(nn.ReLU(True))
    if act == "elu":
        lst.append(nn.ELU(True))
    if act == "prelu":
        lst.append(nn.PReLU(True))
    return nn.Sequential(*lst)


def _fc(in_planes, out_planes, act="relu", bn=True):
    lst = [nn.Linear(in_planes, out_planes, bias=not (bn))]
    if bn:
        lst.append(nn.BatchNorm1d(out_planes))
    if act == "relu":
        lst.append(nn.ReLU(True))
    if act == "elu":
        lst.append(nn.ELU(True))
    if act == "prelu":
        lst.append(nn.PReLU(True))
    return nn.Sequential(*lst)


def cd_adaptive_concat_pool(relevant, irrelevant, module):
    mpr, mpi = module.mp.attrib(relevant, irrelevant)
    apr, api = module.ap.attrib(relevant, irrelevant)
    return torch.cat([mpr, apr], 1), torch.cat([mpi, api], 1)


def attrib_adaptive_concat_pool(self, relevant, irrelevant):
    return cd_adaptive_concat_pool(relevant, irrelevant, self)


class AdaptiveConcatPool1d(nn.Module):
    """Layer that concat `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."""

    def __init__(self, sz: Optional[int] = None):
        """Output will be 2*sz or 2 if sz is None"""
        super().__init__()
        sz = sz or 1
        self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)

    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

    def attrib(self, relevant, irrelevant):
        return attrib_adaptive_concat_pool(self, relevant, irrelevant)


class SqueezeExcite1d(nn.Module):
    """squeeze excite block as used for example in LSTM FCN"""

    def __init__(self, channels, reduction=16):
        super().__init__()
        channels_reduced = channels // reduction
        self.w1 = torch.nn.Parameter(torch.randn(channels_reduced, channels).unsqueeze(0))
        self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))

    def forward(self, x):
        # input is bs,ch,seq
        z = torch.mean(x, dim=2, keepdim=True)  # bs,ch
        intermed = F.relu(torch.matmul(self.w1, z))  # (1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
        s = F.sigmoid(torch.matmul(self.w2, intermed))  # (1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
        return s * x  # bs,ch,seq * bs, ch,1 = bs,ch,seq


def weight_init(m):
    """call weight initialization for model n via n.apply(weight_init)"""
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    if isinstance(m, SqueezeExcite1d):
        stdv1 = math.sqrt(2. / m.w1.size[0])
        nn.init.normal_(m.w1, 0., stdv1)
        stdv2 = math.sqrt(1. / m.w2.size[1])
        nn.init.normal_(m.w2, 0., stdv2)


def create_head1d(nf: int, nc: int, lin_ftrs: Optional[Collection[int]] = None, ps: Union[float, Collection[float]] = 0.5,
                  bn_final: bool = False, bn: bool = True, act="relu", concat_pooling=True):
    """Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"""
    lin_ftrs = [2 * nf if concat_pooling else nf, nc] if lin_ftrs is None else [
                                                                                   2 * nf if concat_pooling else nf] + lin_ftrs + [
                                                                                   nc]  # was [nf, 512,nc]
    ps = listify(ps)
    if len(ps) == 1: ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps
    actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None]
    layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()]
    for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
        layers += bn_drop_lin(ni, no, bn, p, actn)
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    return nn.Sequential(*layers)


# basic convolutional architecture

class BasicConv1d(nn.Sequential):
    """basic conv1d"""

    def __init__(self, filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1,
                 squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,
                 split_first_layer=False, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
                 act_head="relu", concat_pooling=True):
        if filters is None:
            filters = [128, 128, 128, 128]
        layers = []
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size] * len(filters)
        for i in range(len(filters)):
            layers_tmp = [_conv1d(input_channels if i == 0 else filters[i - 1], filters[i], kernel_size=kernel_size[i],
                                  stride=(1 if (split_first_layer is True and i == 0) else stride), dilation=dilation,
                                  act="none" if ((headless is True and i == len(filters) - 1) or (
                                          split_first_layer is True and i == 0)) else act,
                                  bn=False if (headless is True and i == len(filters) - 1) else bn,
                                  drop_p=(0. if i == 0 else drop_p))]

            if split_first_layer is True and i == 0:
                layers_tmp.append(_conv1d(filters[0], filters[0], kernel_size=1, stride=1, act=act, bn=bn, drop_p=0.))
                # layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
                # layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
            if pool > 0 and i < len(filters) - 1:
                layers_tmp.append(nn.MaxPool1d(pool, stride=pool_stride, padding=(pool - 1) // 2))
            if squeeze_excite_reduction > 0:
                layers_tmp.append(SqueezeExcite1d(filters[i], squeeze_excite_reduction))
            layers.append(nn.Sequential(*layers_tmp))

        # head layers.append(nn.AdaptiveAvgPool1d(1)) layers.append(nn.Linear(filters[-1],num_classes)) head
        # #inplace=True leads to a runtime error see ReLU+ dropout
        # https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
        self.headless = headless
        if headless is True:
            head = nn.Sequential(nn.AdaptiveAvgPool1d(1), Flatten())
        else:
            head = create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
                                 bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
        layers.append(head)

        super().__init__(*layers)

    def get_layer_groups(self):
        return self[2], self[-1]

    def get_output_layer(self):
        if self.headless is False:
            return self[-1][-1]
        else:
            return None

    def set_output_layer(self, x):
        if self.headless is False:
            self[-1][-1] = x


# convenience functions for basic convolutional architectures
def fcn(filters=None, num_classes=2, input_channels=8):
    if filters is None:
        filters = [128] * 5
    filters_in = filters + [num_classes]
    return BasicConv1d(filters=filters_in, kernel_size=3, stride=1, pool=2, pool_stride=2,
                       input_channels=input_channels, act="relu", bn=True, headless=True)


def fcn_wang(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
             act_head="relu", concat_pooling=True):
    return BasicConv1d(filters=[128, 256, 128], kernel_size=[8, 5, 3], stride=1, pool=0, pool_stride=2,
                       num_classes=num_classes, input_channels=input_channels, act="relu", bn=True,
                       lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head,
                       act_head=act_head, concat_pooling=concat_pooling)


def schirrmeister(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
                  act_head="relu", concat_pooling=True):
    return BasicConv1d(filters=[25, 50, 100, 200], kernel_size=10, stride=3, pool=3, pool_stride=1,
                       num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,
                       split_first_layer=True, drop_p=0.5, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head,
                       bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)


def sen(filters=None, num_classes=2, input_channels=8, squeeze_excite_reduction=16, drop_p=0., lin_ftrs_head=None,
        ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
    if filters is None:
        filters = [128] * 5
    return BasicConv1d(filters=filters, kernel_size=3, stride=2, pool=0, pool_stride=0, input_channels=input_channels,
                       act="relu", bn=True, num_classes=num_classes, squeeze_excite_reduction=squeeze_excite_reduction,
                       drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head,
                       bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)


def basic1d(filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0,
            num_classes=2, input_channels=8, act="relu", bn=True, headless=False, drop_p=0., lin_ftrs_head=None,
            ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
    if filters is None:
        filters = [128] * 5
    return BasicConv1d(filters=filters, kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool,
                       pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction,
                       num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,
                       drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head,
                       bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)