File size: 2,586 Bytes
264b4c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
from fastai.layers import *
from fastai.data.core import *


class AdaptiveConcatPoolRNN(nn.Module):
    def __init__(self, bidirectional):
        super().__init__()
        self.bidirectional = bidirectional

    def forward(self, x):
        # input shape bs, ch, ts
        t1 = nn.AdaptiveAvgPool1d(1)(x)
        t2 = nn.AdaptiveMaxPool1d(1)(x)

        if self.bidirectional is False:
            t3 = x[:, :, -1]
        else:
            channels = x.size()[1]
            t3 = torch.cat([x[:, :channels, -1], x[:, channels:, 0]], 1)
        out = torch.cat([t1.squeeze(-1), t2.squeeze(-1), t3], 1)  # output shape bs, 3*ch
        return out


class RNN1d(nn.Sequential):
    def __init__(self, input_channels, num_classes, lstm=True, hidden_dim=256, num_layers=2, bidirectional=False,
                 ps_head=0.5, act_head="relu", lin_ftrs_head=None, bn=True):
        # bs, ch, ts -> ts, bs, ch
        layers_tmp = [Lambda(lambda x: x.transpose(1, 2)), Lambda(lambda x: x.transpose(0, 1))]
        # LSTM
        if lstm:
            layers_tmp.append(nn.LSTM(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
                                      bidirectional=bidirectional))
        else:
            layers_tmp.append(nn.GRU(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
                                     bidirectional=bidirectional))
        # pooling
        layers_tmp.append(Lambda(lambda x: x[0].transpose(0, 1)))
        layers_tmp.append(Lambda(lambda x: x.transpose(1, 2)))

        layers_head = [AdaptiveConcatPoolRNN(bidirectional)]

        # classifier
        nf = 3 * hidden_dim if bidirectional is False else 6 * hidden_dim
        lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
        ps_head = listify(ps_head)
        if len(ps_head) == 1:
            ps_head = [ps_head[0] / 2] * (len(lin_ftrs_head) - 2) + ps_head
        actns = [nn.ReLU(inplace=True) if act_head == "relu" else nn.ELU(inplace=True)] * (
                len(lin_ftrs_head) - 2) + [None]

        for ni, no, p, actn in zip(lin_ftrs_head[:-1], lin_ftrs_head[1:], ps_head, actns):
            layers_head += bn_drop_lin(ni, no, bn, p, actn)
        layers_head = nn.Sequential(*layers_head)
        layers_tmp.append(layers_head)

        super().__init__()

    def get_layer_groups(self):
        return self[-1],

    def get_output_layer(self):
        return self[-1][-1]

    def set_output_layer(self, x):
        self[-1][-1] = x