ele-sage's picture
Upload folder using huggingface_hub
34f7798 verified
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
class BiLSTMClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3):
super(BiLSTMClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, x, lengths):
embedded = self.embedding(x)
packed_embedded = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
packed_output, (hidden, cell) = self.lstm(packed_embedded)
# Reshape to separate layers and directions
hidden = hidden.view(self.lstm.num_layers, 2, -1, self.lstm.hidden_size)
last_layer_hidden = hidden[-1]
# Concat Forward + Backward
cat_hidden = torch.cat((last_layer_hidden[0], last_layer_hidden[1]), dim=1)
# Classification Head
x = self.dropout(cat_hidden)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
return self.fc2(x)