| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class DocumentBiLSTM(nn.Module): |
| """ |
| BiLSTM implementation with stability improvements inspired by DocBERT |
| """ |
| def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, |
| n_layers=2, dropout=0.5, pad_idx=0): |
| super().__init__() |
| |
| self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) |
| |
| self.lstm = nn.LSTM(embedding_dim, |
| hidden_dim, |
| num_layers=n_layers, |
| bidirectional=True, |
| dropout=dropout if n_layers > 1 else 0, |
| batch_first=True) |
| |
| |
| self.layer_norm = nn.LayerNorm(hidden_dim * 2) |
| |
| self.fc = nn.Linear(hidden_dim * 2, output_dim) |
| |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| |
| |
| |
| embedded = self.embedding(input_ids) |
| |
| |
| embedded = self.dropout(embedded) |
| |
| |
| hidden = None |
| cell = None |
| |
| if attention_mask is not None: |
| |
| seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu() |
| |
| |
| seq_lengths, indices = torch.sort(seq_lengths, descending=True) |
| sorted_embedded = embedded[indices] |
| |
| |
| packed_embedded = nn.utils.rnn.pack_padded_sequence( |
| sorted_embedded, seq_lengths, batch_first=True, enforce_sorted=True |
| ) |
| |
| |
| packed_output, (hidden, cell) = self.lstm(packed_embedded) |
| |
| |
| output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) |
| |
| |
| _, restore_indices = torch.sort(indices) |
| hidden = hidden[:, restore_indices] |
| else: |
| |
| _, (hidden, cell) = self.lstm(embedded) |
| |
| |
| hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) |
| |
| |
| normalized = self.layer_norm(hidden_cat) |
| |
| |
| dropped = self.dropout(normalized) |
| |
| |
| prediction = self.fc(dropped) |
| |
| return prediction |