| import torch | |
| import torch.nn as nn | |
| class BiLSTMClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| embedding_dim, | |
| hidden_size, | |
| num_layers=1, | |
| dropout=0.2, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| self.lstm = nn.LSTM( | |
| input_size=embedding_dim, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| bidirectional=True, | |
| dropout=dropout if num_layers > 1 else 0.0, | |
| ) | |
| self.fc = nn.Linear(hidden_size * 2, 1) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| outputs, (h_n, c_n) = self.lstm(x) | |
| h_fwd = h_n[-2, :, :] | |
| h_bwd = h_n[-1, :, :] | |
| h_final = torch.cat((h_fwd, h_bwd), dim=1) | |
| logits = self.fc(h_final) | |
| return logits | |