| | import torch.nn as nn |
| |
|
| | class LSTMClassifier(nn.Module): |
| | def __init__(self, vocab_size, embed_dim=256, hidden_dim=256, |
| | n_layers=2, dropout=0.3, bidirectional=True): |
| | super().__init__() |
| | self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) |
| | self.lstm = nn.LSTM(embed_dim, |
| | hidden_dim, |
| | n_layers, |
| | batch_first=True, |
| | dropout=dropout, |
| | bidirectional=bidirectional) |
| | self.bi = 2 if bidirectional else 1 |
| | self.fc = nn.Linear(hidden_dim * self.bi, 1) |
| |
|
| | def forward(self, x, lengths): |
| | x = self.embedding(x) |
| | packed = nn.utils.rnn.pack_padded_sequence( |
| | x, lengths.cpu(), batch_first=True, enforce_sorted=False) |
| | _, (h, _) = self.lstm(packed) |
| | if self.bi == 2: |
| | h = torch.cat((h[-2], h[-1]), dim=1) |
| | else: |
| | h = h[-1] |
| | return self.fc(h).squeeze(1) |