Spaces:
Sleeping
Sleeping
Update model_folder/modeling_bilstm.py
Browse files- model_folder/modeling_bilstm.py +24 -13
model_folder/modeling_bilstm.py
CHANGED
|
@@ -2,23 +2,34 @@ import torch.nn as nn
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
class BiLSTM(nn.Module):
|
| 5 |
-
def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.
|
| 6 |
super(BiLSTM, self).__init__()
|
|
|
|
| 7 |
self.embedding = nn.Embedding.from_pretrained(
|
| 8 |
torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
|
| 9 |
-
self.lstm = nn.LSTM(embed_size, hidden_size, bidirectional=True, batch_first=True)
|
| 10 |
-
self.
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
nn.Linear(128, num_labels)
|
| 15 |
-
)
|
| 16 |
|
| 17 |
def forward(self, x):
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
avg_pool = torch.mean(h_lstm, 1)
|
| 21 |
max_pool, _ = torch.max(h_lstm, 1)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
class BiLSTM(nn.Module):
|
| 5 |
+
def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1):
|
| 6 |
super(BiLSTM, self).__init__()
|
| 7 |
+
self.hidden_size = hidden_size
|
| 8 |
self.embedding = nn.Embedding.from_pretrained(
|
| 9 |
torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
|
| 10 |
+
self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
|
| 11 |
+
self.linear = nn.Linear(self.hidden_size * 4, 64)
|
| 12 |
+
self.relu = nn.ReLU()
|
| 13 |
+
self.dropout = nn.Dropout(dropout)
|
| 14 |
+
self.out = nn.Linear(64, num_labels)
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def forward(self, x):
|
| 17 |
+
# Embedding layer
|
| 18 |
+
h_embedding = self.embedding(x)
|
| 19 |
+
|
| 20 |
+
# LSTM layer
|
| 21 |
+
h_lstm, _ = self.lstm(h_embedding)
|
| 22 |
+
|
| 23 |
+
# Pooling layers
|
| 24 |
avg_pool = torch.mean(h_lstm, 1)
|
| 25 |
max_pool, _ = torch.max(h_lstm, 1)
|
| 26 |
+
|
| 27 |
+
# Concatenate pooling outputs
|
| 28 |
+
conc = torch.cat((avg_pool, max_pool), 1)
|
| 29 |
+
|
| 30 |
+
# Fully connected layers
|
| 31 |
+
conc = self.relu(self.linear(conc))
|
| 32 |
+
conc = self.dropout(conc)
|
| 33 |
+
out = self.out(conc)
|
| 34 |
+
|
| 35 |
+
return out
|