Spaces:
Runtime error
Runtime error
Create lstm_model.py
Browse files- lstm_model.py +28 -0
lstm_model.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.utils.rnn as rnn_utils
|
| 4 |
+
|
| 5 |
+
class LSTMClassifier(nn.Module):
|
| 6 |
+
def __init__(self, vocab_size, embed_dim=256, hidden_dim=256,
|
| 7 |
+
n_layers=2, dropout=0.3, bidirectional=True):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
|
| 10 |
+
self.lstm = nn.LSTM(embed_dim,
|
| 11 |
+
hidden_dim,
|
| 12 |
+
n_layers,
|
| 13 |
+
batch_first=True,
|
| 14 |
+
dropout=dropout,
|
| 15 |
+
bidirectional=bidirectional)
|
| 16 |
+
self.bi = 2 if bidirectional else 1
|
| 17 |
+
self.fc = nn.Linear(hidden_dim * self.bi, 1)
|
| 18 |
+
|
| 19 |
+
def forward(self, x, lengths):
|
| 20 |
+
x = self.embedding(x)
|
| 21 |
+
packed = rnn_utils.pack_padded_sequence(
|
| 22 |
+
x, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
| 23 |
+
_, (h, _) = self.lstm(packed)
|
| 24 |
+
if self.bi == 2:
|
| 25 |
+
h = torch.cat((h[-2], h[-1]), dim=1) # concat fwd+rev
|
| 26 |
+
else:
|
| 27 |
+
h = h[-1]
|
| 28 |
+
return self.fc(h).squeeze(1)
|