puneeth1 commited on
Commit
8ea164e
·
verified ·
1 Parent(s): 939e633

Update model_folder/modeling_bilstm.py

Browse files
Files changed (1) hide show
  1. model_folder/modeling_bilstm.py +24 -13
model_folder/modeling_bilstm.py CHANGED
@@ -2,23 +2,34 @@ import torch.nn as nn
2
  import torch
3
 
4
  class BiLSTM(nn.Module):
5
- def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.3):
6
  super(BiLSTM, self).__init__()
 
7
  self.embedding = nn.Embedding.from_pretrained(
8
  torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
9
- self.lstm = nn.LSTM(embed_size, hidden_size, bidirectional=True, batch_first=True)
10
- self.fc = nn.Sequential(
11
- nn.Linear(hidden_size * 4, 128),
12
- nn.ReLU(),
13
- nn.Dropout(dropout),
14
- nn.Linear(128, num_labels)
15
- )
16
 
17
  def forward(self, x):
18
- x = self.embedding(x)
19
- h_lstm, _ = self.lstm(x)
 
 
 
 
 
20
  avg_pool = torch.mean(h_lstm, 1)
21
  max_pool, _ = torch.max(h_lstm, 1)
22
- x = torch.cat((avg_pool, max_pool), dim=1)
23
- x = self.fc(x)
24
- return x
 
 
 
 
 
 
 
 
2
  import torch
3
 
4
  class BiLSTM(nn.Module):
5
+ def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1):
6
  super(BiLSTM, self).__init__()
7
+ self.hidden_size = hidden_size
8
  self.embedding = nn.Embedding.from_pretrained(
9
  torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
10
+ self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
11
+ self.linear = nn.Linear(self.hidden_size * 4, 64)
12
+ self.relu = nn.ReLU()
13
+ self.dropout = nn.Dropout(dropout)
14
+ self.out = nn.Linear(64, num_labels)
 
 
15
 
16
  def forward(self, x):
17
+ # Embedding layer
18
+ h_embedding = self.embedding(x)
19
+
20
+ # LSTM layer
21
+ h_lstm, _ = self.lstm(h_embedding)
22
+
23
+ # Pooling layers
24
  avg_pool = torch.mean(h_lstm, 1)
25
  max_pool, _ = torch.max(h_lstm, 1)
26
+
27
+ # Concatenate pooling outputs
28
+ conc = torch.cat((avg_pool, max_pool), 1)
29
+
30
+ # Fully connected layers
31
+ conc = self.relu(self.linear(conc))
32
+ conc = self.dropout(conc)
33
+ out = self.out(conc)
34
+
35
+ return out