import torch import torch.nn as nn class BiLSTMClassifier(nn.Module): def __init__( self, vocab_size, embedding_dim, hidden_size, num_layers=1, dropout=0.2, **kwargs, ): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM( input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0.0, ) self.fc = nn.Linear(hidden_size * 2, 1) def forward(self, x): x = self.embedding(x) outputs, (h_n, c_n) = self.lstm(x) h_fwd = h_n[-2, :, :] h_bwd = h_n[-1, :, :] h_final = torch.cat((h_fwd, h_bwd), dim=1) logits = self.fc(h_final) return logits