import torch.nn as nn import torch class BiLSTM(nn.Module): def __init__(self, embedding_matrix, embed_size, num_labels, hidden_size=64, dropout=0.1): super(BiLSTM, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding.from_pretrained( torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True) self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True) self.linear = nn.Linear(self.hidden_size * 4, 64) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) self.out = nn.Linear(64, num_labels) def forward(self, x): # Embedding layer h_embedding = self.embedding(x) # LSTM layer h_lstm, _ = self.lstm(h_embedding) # Pooling layers avg_pool = torch.mean(h_lstm, 1) max_pool, _ = torch.max(h_lstm, 1) # Concatenate pooling outputs conc = torch.cat((avg_pool, max_pool), 1) # Fully connected layers conc = self.relu(self.linear(conc)) conc = self.dropout(conc) out = self.out(conc) return out