File size: 1,196 Bytes
d9f0436
 
 
 
8ea164e
d9f0436
8ea164e
d9f0436
 
8ea164e
 
 
 
 
d9f0436
 
8ea164e
 
 
 
 
 
 
d9f0436
 
8ea164e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn
import torch

class BiLSTM(nn.Module):
    def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
        self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(self.hidden_size * 4, 64)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(64, num_labels)

    def forward(self, x):
        # Embedding layer
        h_embedding = self.embedding(x)
        
        # LSTM layer
        h_lstm, _ = self.lstm(h_embedding)
        
        # Pooling layers
        avg_pool = torch.mean(h_lstm, 1)
        max_pool, _ = torch.max(h_lstm, 1)
        
        # Concatenate pooling outputs
        conc = torch.cat((avg_pool, max_pool), 1)
        
        # Fully connected layers
        conc = self.relu(self.linear(conc))
        conc = self.dropout(conc)
        out = self.out(conc)
        
        return out