| import torch |
| import torch.nn as nn |
|
|
| class RNNClassifier(nn.Module): |
| def __init__(self, hidden_dim=256, rnn_type='GRU'): |
| super().__init__() |
| self.rnn_type = rnn_type.upper() |
| if self.rnn_type == 'LSTM': |
| self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_dim, batch_first=True) |
| else: |
| self.rnn = nn.GRU(input_size=1, hidden_size=hidden_dim, batch_first=True) |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 1) |
| ) |
|
|
| def forward(self, x, lengths): |
| packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) |
| if self.rnn_type == 'LSTM': |
| packed_out, (hn, cn) = self.rnn(packed_x) |
| else: |
| packed_out, hn = self.rnn(packed_x) |
| last_hidden = hn[-1] |
| out = self.classifier(last_hidden) |
| return torch.sigmoid(out).squeeze(-1) |