Spaces:
Sleeping
Sleeping
Upload modeling_bilstm.py
Browse files
model_folder/modeling_bilstm.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class BiLSTM(nn.Module):
|
| 5 |
+
def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.3):
|
| 6 |
+
super(BiLSTM, self).__init__()
|
| 7 |
+
self.embedding = nn.Embedding.from_pretrained(
|
| 8 |
+
torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
|
| 9 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, bidirectional=True, batch_first=True)
|
| 10 |
+
self.fc = nn.Sequential(
|
| 11 |
+
nn.Linear(hidden_size * 4, 128),
|
| 12 |
+
nn.ReLU(),
|
| 13 |
+
nn.Dropout(dropout),
|
| 14 |
+
nn.Linear(128, num_labels)
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.embedding(x)
|
| 19 |
+
h_lstm, _ = self.lstm(x)
|
| 20 |
+
avg_pool = torch.mean(h_lstm, 1)
|
| 21 |
+
max_pool, _ = torch.max(h_lstm, 1)
|
| 22 |
+
x = torch.cat((avg_pool, max_pool), dim=1)
|
| 23 |
+
x = self.fc(x)
|
| 24 |
+
return x
|