ecroatt commited on
Commit
902cf61
·
verified ·
1 Parent(s): d65b5f6

Create lstm_model.py

Browse files
Files changed (1) hide show
  1. lstm_model.py +28 -0
lstm_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.utils.rnn as rnn_utils
4
+
5
+ class LSTMClassifier(nn.Module):
6
+ def __init__(self, vocab_size, embed_dim=256, hidden_dim=256,
7
+ n_layers=2, dropout=0.3, bidirectional=True):
8
+ super().__init__()
9
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
10
+ self.lstm = nn.LSTM(embed_dim,
11
+ hidden_dim,
12
+ n_layers,
13
+ batch_first=True,
14
+ dropout=dropout,
15
+ bidirectional=bidirectional)
16
+ self.bi = 2 if bidirectional else 1
17
+ self.fc = nn.Linear(hidden_dim * self.bi, 1)
18
+
19
+ def forward(self, x, lengths):
20
+ x = self.embedding(x)
21
+ packed = rnn_utils.pack_padded_sequence(
22
+ x, lengths.cpu(), batch_first=True, enforce_sorted=False)
23
+ _, (h, _) = self.lstm(packed)
24
+ if self.bi == 2:
25
+ h = torch.cat((h[-2], h[-1]), dim=1) # concat fwd+rev
26
+ else:
27
+ h = h[-1]
28
+ return self.fc(h).squeeze(1)