PII-Detection / lstm.py
AmitHirpara's picture
add comments
f53fac9
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
class LSTMCell(nn.Module):
"""
LSTM cell implementation from scratch
"""
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# Weight matrices and biases for each gate
# Input gate parameters
self.W_ii = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_i = nn.Parameter(torch.Tensor(hidden_size))
# Forget gate parameters
self.W_if = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_f = nn.Parameter(torch.Tensor(hidden_size))
# Candidate values parameters
self.W_in = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_hn = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_n = nn.Parameter(torch.Tensor(hidden_size))
# Output gate parameters
self.W_io = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_o = nn.Parameter(torch.Tensor(hidden_size))
# Initialize weights using Xavier initialization
for name, param in self.named_parameters():
if 'W_' in name:
nn.init.xavier_uniform_(param)
elif 'b_' in name:
nn.init.zeros_(param)
def forward(self, input: torch.Tensor, states: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for one time step"""
# Unpack previous states
hidden, cell = states
# Calculate forget gate - decides what to forget from previous cell state
forget_gate = torch.sigmoid(torch.mm(input, self.W_if) + torch.mm(hidden, self.W_hf) + self.b_f)
# Calculate input gate - decides what new information to store
input_gate = torch.sigmoid(torch.mm(input, self.W_ii) + torch.mm(hidden, self.W_hi) + self.b_i)
# Calculate candidate values - new information that could be added
candidate = torch.tanh(torch.mm(input, self.W_in) + torch.mm(hidden, self.W_hn) + self.b_n)
# Calculate output gate - decides what parts of cell state to output
output_gate = torch.sigmoid(torch.mm(input, self.W_io) + torch.mm(hidden, self.W_ho) + self.b_o)
# Update cell state by forgetting old info and adding new info
new_cell = forget_gate * cell + input_gate * candidate
# Generate new hidden state based on filtered cell state
new_hidden = output_gate * torch.tanh(new_cell)
return new_hidden, new_cell
class BidirectionalLSTM(nn.Module):
"""
Multi-layer bidirectional LSTM implementation using custom LSTM cells
"""
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
batch_first: bool = True, dropout: float = 0.0):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_first = batch_first
self.dropout = dropout if num_layers > 1 else 0.0
# Create forward and backward LSTM cells for each layer
self.forward_cells = nn.ModuleList()
self.backward_cells = nn.ModuleList()
self.dropout_layers = nn.ModuleList() if self.dropout > 0 else None
for layer in range(num_layers):
# First layer takes input_size, others take concatenated bidirectional output
layer_input_size = input_size if layer == 0 else hidden_size * 2
# Add forward and backward cells for this layer
self.forward_cells.append(LSTMCell(layer_input_size, hidden_size))
self.backward_cells.append(LSTMCell(layer_input_size, hidden_size))
# Add dropout between layers (except after last layer)
if self.dropout > 0 and layer < num_layers - 1:
self.dropout_layers.append(nn.Dropout(dropout))
def forward(self, input, states=None, lengths=None):
# Check if input is packed sequence
is_packed = isinstance(input, PackedSequence)
if is_packed:
# Unpack for processing
padded, lengths = pad_packed_sequence(input, batch_first=self.batch_first)
outputs, (h_n, c_n) = self._forward_unpacked(padded, states, lengths)
# Pack output back
packed_out = pack_padded_sequence(
outputs, lengths,
batch_first=self.batch_first,
enforce_sorted=False
)
return packed_out, (h_n, c_n)
else:
return self._forward_unpacked(input, states, lengths)
def _forward_unpacked(self, input: torch.Tensor, states, lengths=None):
# Convert to batch-first if needed
if not self.batch_first:
input = input.transpose(0, 1)
batch_size, seq_len, _ = input.size()
# Initialize hidden and cell states if not provided
if states is None:
# Create zero states for each layer and direction
h_t_forward = [input.new_zeros(batch_size, self.hidden_size)
for _ in range(self.num_layers)]
c_t_forward = [input.new_zeros(batch_size, self.hidden_size)
for _ in range(self.num_layers)]
h_t_backward = [input.new_zeros(batch_size, self.hidden_size)
for _ in range(self.num_layers)]
c_t_backward = [input.new_zeros(batch_size, self.hidden_size)
for _ in range(self.num_layers)]
else:
# Unpack provided states
h0, c0 = states
h_t_forward = []
c_t_forward = []
h_t_backward = []
c_t_backward = []
# Separate forward and backward states for each layer
for layer in range(self.num_layers):
h_t_forward.append(h0[layer * 2])
c_t_forward.append(c0[layer * 2])
h_t_backward.append(h0[layer * 2 + 1])
c_t_backward.append(c0[layer * 2 + 1])
# Process through each layer
layer_input = input
for layer_idx in range(self.num_layers):
# Process forward direction
forward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
for t in range(seq_len):
x = layer_input[:, t, :]
h, c = self.forward_cells[layer_idx](x, (h_t_forward[layer_idx], c_t_forward[layer_idx]))
h_t_forward[layer_idx] = h
c_t_forward[layer_idx] = c
forward_output[:, t, :] = h
# Process backward direction
backward_output = input.new_zeros(batch_size, seq_len, self.hidden_size)
for t in reversed(range(seq_len)):
x = layer_input[:, t, :]
h, c = self.backward_cells[layer_idx](x, (h_t_backward[layer_idx], c_t_backward[layer_idx]))
h_t_backward[layer_idx] = h
c_t_backward[layer_idx] = c
backward_output[:, t, :] = h
# Concatenate forward and backward outputs
layer_output = torch.cat([forward_output, backward_output], dim=2)
# Apply dropout between layers
if self.dropout > 0 and layer_idx < self.num_layers - 1:
layer_output = self.dropout_layers[layer_idx](layer_output)
# Use this layer's output as next layer's input
layer_input = layer_output
# Final output is the last layer's output
outputs = layer_output
# Stack final hidden and cell states for all layers
h_n = []
c_n = []
for layer in range(self.num_layers):
h_n.extend([h_t_forward[layer], h_t_backward[layer]])
c_n.extend([c_t_forward[layer], c_t_backward[layer]])
h_n = torch.stack(h_n, dim=0)
c_n = torch.stack(c_n, dim=0)
# Convert back if not batch-first
if not self.batch_first:
outputs = outputs.transpose(0, 1)
return outputs, (h_n, c_n)
class LSTM(nn.Module):
"""
Bidirectional LSTM model for PII detection (sequence labeling)
"""
def __init__(self, vocab_size: int, num_classes: int, embed_size: int = 128,
hidden_size: int = 256, num_layers: int = 2, dropout: float = 0.1,
max_len: int = 512):
super().__init__()
self.vocab_size = vocab_size
self.num_classes = num_classes
self.embed_size = embed_size
self.hidden_size = hidden_size
self.num_layers = num_layers
# Embedding layer to convert tokens to vectors
self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
self.embed_dropout = nn.Dropout(dropout)
# Bidirectional LSTM layers
self.lstm = BidirectionalLSTM(
input_size=embed_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0
)
# Output layer to predict PII labels
lstm_output_size = hidden_size * 2 # doubled for bidirectional
self.fc = nn.Linear(lstm_output_size, num_classes)
self.output_dropout = nn.Dropout(dropout)
def forward(self, input_ids, lengths=None):
"""
Forward pass
Args:
input_ids: token ids [batch_size, seq_len]
lengths: actual lengths of sequences (optional)
Returns:
logits: class predictions [batch_size, seq_len, num_classes]
"""
# Convert token ids to embeddings
embedded = self.embedding(input_ids)
embedded = self.embed_dropout(embedded)
# Pack sequences for efficient processing if lengths provided
if lengths is not None:
packed_embedded = pack_padded_sequence(
embedded, lengths.cpu(),
batch_first=True,
enforce_sorted=False
)
lstm_out, _ = self.lstm(packed_embedded)
# Unpack the output
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
else:
# Process without packing
lstm_out, _ = self.lstm(embedded)
# Apply dropout and get final predictions
lstm_out = self.output_dropout(lstm_out)
logits = self.fc(lstm_out)
return logits
def create_lstm_pii_model(vocab_size: int, num_classes: int, d_model: int = 256,
num_heads: int = 8, d_ff: int = 512, num_layers: int = 4,
dropout: float = 0.1, max_len: int = 512):
"""Create Bidirectional LSTM model for PII detection"""
# Create LSTM with appropriate dimensions
return LSTM(
vocab_size=vocab_size,
num_classes=num_classes,
embed_size=d_model // 2,
hidden_size=d_model,
num_layers=num_layers,
dropout=dropout,
max_len=max_len
)