puneeth1 commited on
Commit
d9f0436
·
verified ·
1 Parent(s): a0279ff

Upload modeling_bilstm.py

Browse files
Files changed (1) hide show
  1. model_folder/modeling_bilstm.py +24 -0
model_folder/modeling_bilstm.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class BiLSTM(nn.Module):
5
+ def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.3):
6
+ super(BiLSTM, self).__init__()
7
+ self.embedding = nn.Embedding.from_pretrained(
8
+ torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
9
+ self.lstm = nn.LSTM(embed_size, hidden_size, bidirectional=True, batch_first=True)
10
+ self.fc = nn.Sequential(
11
+ nn.Linear(hidden_size * 4, 128),
12
+ nn.ReLU(),
13
+ nn.Dropout(dropout),
14
+ nn.Linear(128, num_labels)
15
+ )
16
+
17
+ def forward(self, x):
18
+ x = self.embedding(x)
19
+ h_lstm, _ = self.lstm(x)
20
+ avg_pool = torch.mean(h_lstm, 1)
21
+ max_pool, _ = torch.max(h_lstm, 1)
22
+ x = torch.cat((avg_pool, max_pool), dim=1)
23
+ x = self.fc(x)
24
+ return x