Commit ·
251e9cd
1
Parent(s): 466b8a2
Add layer normalization to LSTM model
Browse files- models/lstm_model.py +25 -17
models/lstm_model.py
CHANGED
|
@@ -2,10 +2,13 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class DocumentBiLSTM(nn.Module):
|
| 6 |
"""
|
| 7 |
-
|
| 8 |
-
Good for getting started quickly
|
| 9 |
"""
|
| 10 |
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
|
| 11 |
n_layers=2, dropout=0.5, pad_idx=0):
|
|
@@ -20,6 +23,9 @@ class DocumentBiLSTM(nn.Module):
|
|
| 20 |
dropout=dropout if n_layers > 1 else 0,
|
| 21 |
batch_first=True)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
| 24 |
|
| 25 |
self.dropout = nn.Dropout(dropout)
|
|
@@ -33,44 +39,46 @@ class DocumentBiLSTM(nn.Module):
|
|
| 33 |
# Apply dropout to embeddings
|
| 34 |
embedded = self.dropout(embedded)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if attention_mask is not None:
|
| 37 |
# Convert attention mask to sequence lengths
|
| 38 |
-
# First, get the length of each sequence by summing the attention mask
|
| 39 |
seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu()
|
| 40 |
|
| 41 |
-
# Sort sequences by decreasing length
|
| 42 |
seq_lengths, indices = torch.sort(seq_lengths, descending=True)
|
| 43 |
-
|
| 44 |
|
| 45 |
# Pack the embedded sequences
|
| 46 |
packed_embedded = nn.utils.rnn.pack_padded_sequence(
|
| 47 |
-
|
| 48 |
)
|
| 49 |
|
| 50 |
-
# Pass
|
| 51 |
packed_output, (hidden, cell) = self.lstm(packed_embedded)
|
| 52 |
|
| 53 |
# Unpack the sequence
|
| 54 |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
| 55 |
|
| 56 |
-
#
|
| 57 |
_, restore_indices = torch.sort(indices)
|
|
|
|
| 58 |
else:
|
| 59 |
# Standard processing without masking
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# output = [batch size, seq len, hid dim * num directions]
|
| 63 |
-
# hidden = [n layers * num directions, batch size, hid dim]
|
| 64 |
-
# cell = [n layers * num directions, batch size, hid dim]
|
| 65 |
-
output, (hidden, cell) = self.lstm(embedded)
|
| 66 |
|
| 67 |
# Concatenate the final forward and backward hidden states
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# Apply dropout to hidden state
|
| 71 |
-
|
| 72 |
|
| 73 |
# prediction = [batch size, output dim]
|
| 74 |
-
prediction = self.fc(
|
| 75 |
|
| 76 |
return prediction
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
class DocumentBiLSTM(nn.Module):
|
| 10 |
"""
|
| 11 |
+
BiLSTM implementation with stability improvements inspired by DocBERT
|
|
|
|
| 12 |
"""
|
| 13 |
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
|
| 14 |
n_layers=2, dropout=0.5, pad_idx=0):
|
|
|
|
| 23 |
dropout=dropout if n_layers > 1 else 0,
|
| 24 |
batch_first=True)
|
| 25 |
|
| 26 |
+
# Add layer normalization for stability (like in DocBERT)
|
| 27 |
+
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
|
| 28 |
+
|
| 29 |
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
| 30 |
|
| 31 |
self.dropout = nn.Dropout(dropout)
|
|
|
|
| 39 |
# Apply dropout to embeddings
|
| 40 |
embedded = self.dropout(embedded)
|
| 41 |
|
| 42 |
+
# Initialize hidden and cell variables
|
| 43 |
+
hidden = None
|
| 44 |
+
cell = None
|
| 45 |
+
|
| 46 |
if attention_mask is not None:
|
| 47 |
# Convert attention mask to sequence lengths
|
|
|
|
| 48 |
seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu()
|
| 49 |
|
| 50 |
+
# Sort sequences by decreasing length
|
| 51 |
seq_lengths, indices = torch.sort(seq_lengths, descending=True)
|
| 52 |
+
sorted_embedded = embedded[indices]
|
| 53 |
|
| 54 |
# Pack the embedded sequences
|
| 55 |
packed_embedded = nn.utils.rnn.pack_padded_sequence(
|
| 56 |
+
sorted_embedded, seq_lengths, batch_first=True, enforce_sorted=True
|
| 57 |
)
|
| 58 |
|
| 59 |
+
# Pass through LSTM
|
| 60 |
packed_output, (hidden, cell) = self.lstm(packed_embedded)
|
| 61 |
|
| 62 |
# Unpack the sequence
|
| 63 |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
| 64 |
|
| 65 |
+
# Get the hidden states in correct order
|
| 66 |
_, restore_indices = torch.sort(indices)
|
| 67 |
+
hidden = hidden[:, restore_indices]
|
| 68 |
else:
|
| 69 |
# Standard processing without masking
|
| 70 |
+
_, (hidden, cell) = self.lstm(embedded)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Concatenate the final forward and backward hidden states
|
| 73 |
+
hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
|
| 74 |
+
|
| 75 |
+
# Apply layer normalization (improves stability)
|
| 76 |
+
normalized = self.layer_norm(hidden_cat)
|
| 77 |
|
| 78 |
# Apply dropout to hidden state
|
| 79 |
+
dropped = self.dropout(normalized)
|
| 80 |
|
| 81 |
# prediction = [batch size, output dim]
|
| 82 |
+
prediction = self.fc(dropped)
|
| 83 |
|
| 84 |
return prediction
|