import torch import torch.nn as nn from torch.nn.utils.rnn import pack_padded_sequence class BiLSTMClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3): super(BiLSTMClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.lstm = nn.LSTM( embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout ) self.dropout = nn.Dropout(dropout) self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, 1) def forward(self, x, lengths): embedded = self.embedding(x) packed_embedded = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False) packed_output, (hidden, cell) = self.lstm(packed_embedded) # Reshape to separate layers and directions hidden = hidden.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size) last_layer_hidden = hidden[-1] # Concat Forward + Backward cat_hidden = torch.cat((last_layer_hidden[0], last_layer_hidden[1]), dim=1) # Classification Head x = self.dropout(cat_hidden) x = self.fc1(x) x = self.relu(x) x = self.dropout(x) return self.fc2(x)