File size: 912 Bytes
1481eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
import torch.nn as nn
class BiLSTMClassifier(nn.Module):
def __init__(
self,
vocab_size,
embedding_dim,
hidden_size,
num_layers=1,
dropout=0.2,
**kwargs,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.fc = nn.Linear(hidden_size * 2, 1)
def forward(self, x):
x = self.embedding(x)
outputs, (h_n, c_n) = self.lstm(x)
h_fwd = h_n[-2, :, :]
h_bwd = h_n[-1, :, :]
h_final = torch.cat((h_fwd, h_bwd), dim=1)
logits = self.fc(h_final)
return logits
|