ECG / models /rnn1d.py
cuongnx2001's picture
Upload 34 files
264b4c4 verified
import torch
import torch.nn as nn
from fastai.layers import *
from fastai.data.core import *
class AdaptiveConcatPoolRNN(nn.Module):
def __init__(self, bidirectional):
super().__init__()
self.bidirectional = bidirectional
def forward(self, x):
# input shape bs, ch, ts
t1 = nn.AdaptiveAvgPool1d(1)(x)
t2 = nn.AdaptiveMaxPool1d(1)(x)
if self.bidirectional is False:
t3 = x[:, :, -1]
else:
channels = x.size()[1]
t3 = torch.cat([x[:, :channels, -1], x[:, channels:, 0]], 1)
out = torch.cat([t1.squeeze(-1), t2.squeeze(-1), t3], 1) # output shape bs, 3*ch
return out
class RNN1d(nn.Sequential):
def __init__(self, input_channels, num_classes, lstm=True, hidden_dim=256, num_layers=2, bidirectional=False,
ps_head=0.5, act_head="relu", lin_ftrs_head=None, bn=True):
# bs, ch, ts -> ts, bs, ch
layers_tmp = [Lambda(lambda x: x.transpose(1, 2)), Lambda(lambda x: x.transpose(0, 1))]
# LSTM
if lstm:
layers_tmp.append(nn.LSTM(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
bidirectional=bidirectional))
else:
layers_tmp.append(nn.GRU(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
bidirectional=bidirectional))
# pooling
layers_tmp.append(Lambda(lambda x: x[0].transpose(0, 1)))
layers_tmp.append(Lambda(lambda x: x.transpose(1, 2)))
layers_head = [AdaptiveConcatPoolRNN(bidirectional)]
# classifier
nf = 3 * hidden_dim if bidirectional is False else 6 * hidden_dim
lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
ps_head = listify(ps_head)
if len(ps_head) == 1:
ps_head = [ps_head[0] / 2] * (len(lin_ftrs_head) - 2) + ps_head
actns = [nn.ReLU(inplace=True) if act_head == "relu" else nn.ELU(inplace=True)] * (
len(lin_ftrs_head) - 2) + [None]
for ni, no, p, actn in zip(lin_ftrs_head[:-1], lin_ftrs_head[1:], ps_head, actns):
layers_head += bn_drop_lin(ni, no, bn, p, actn)
layers_head = nn.Sequential(*layers_head)
layers_tmp.append(layers_head)
super().__init__()
def get_layer_groups(self):
return self[-1],
def get_output_layer(self):
return self[-1][-1]
def set_output_layer(self, x):
self[-1][-1] = x