import torch import torch.nn as nn import torch.nn.functional as F import math def scaled_dot_product_attention(q, k, v, mask=None, dropout=None): """Compute scaled dot-product attention.""" # Get dimension of keys for scaling d_k = q.size(-1) # Compute attention scores using dot product scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # Mask out padding positions if mask provided if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # Convert scores to probabilities attention_weights = F.softmax(scores, dim=-1) # Apply dropout to attention weights if specified if dropout is not None: attention_weights = dropout(attention_weights) # Apply attention weights to values output = torch.matmul(attention_weights, v) return output, attention_weights class MultiHeadAttention(nn.Module): """Multi-Head Attention mechanism""" def __init__(self, d_model, num_heads, dropout=0.1): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # Dimension per head # Linear layers for projecting Q, K, V self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) # Final output projection self.w_o = nn.Linear(d_model, d_model) # Dropout layer self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): """ query: (batch_size, seq_len_q, d_model) key: (batch_size, seq_len_k, d_model) value: (batch_size, seq_len_v, d_model) mask: (batch_size, 1, 1, seq_len_k) or None """ batch_size = query.size(0) seq_len_q = query.size(1) seq_len_k = key.size(1) seq_len_v = value.size(1) # Project and reshape for multiple heads Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2) K = self.w_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2) V = self.w_v(value).view(batch_size, seq_len_v, self.num_heads, self.d_k).transpose(1, 2) # Apply scaled dot-product attention attention_output, attention_weights = scaled_dot_product_attention( Q, K, V, mask=mask, dropout=self.dropout ) # Concatenate heads and apply output projection attention_output = attention_output.transpose(1, 2).contiguous().view( batch_size, seq_len_q, self.d_model ) output = self.w_o(attention_output) return output, attention_weights class PositionwiseFeedForward(nn.Module): """Position-wise Feed Forward Network""" def __init__(self, d_model, d_ff, dropout=0.1): super(PositionwiseFeedForward, self).__init__() # Two linear layers with ReLU activation self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() def forward(self, x): # Apply first linear layer, activation, dropout, then second linear layer return self.w_2(self.dropout(self.activation(self.w_1(x)))) class EncoderLayer(nn.Module): """Single Encoder Layer""" def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super(EncoderLayer, self).__init__() # Multi-head self-attention sublayer self.self_attention = MultiHeadAttention(d_model, num_heads, dropout) # Position-wise feed forward sublayer self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) # Layer normalization for each sublayer self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) # Dropout for residual connections self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Self-attention sublayer with residual connection and layer norm attn_output, _ = self.self_attention(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # Feed forward sublayer with residual connection and layer norm ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x class TransformerEncoder(nn.Module): """Stack of Encoder Layers""" def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1): super(TransformerEncoder, self).__init__() # Create stack of encoder layers self.layers = nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) # Final layer normalization self.norm = nn.LayerNorm(d_model) def forward(self, x, mask=None): # Pass through each encoder layer sequentially for layer in self.layers: x = layer(x, mask) # Apply final normalization return self.norm(x) class PositionalEncoding(nn.Module): """Positional Encoding for Transformer""" def __init__(self, d_model, max_len=5000, dropout=0.1): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(dropout) # Create matrix to hold positional encodings pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # Create frequency terms for sin/cos functions div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # Apply sine to even indices pe[:, 0::2] = torch.sin(position * div_term) # Apply cosine to odd indices if d_model % 2 == 1: pe[:, 1::2] = torch.cos(position * div_term[:-1]) else: pe[:, 1::2] = torch.cos(position * div_term) # Add batch dimension and save as buffer pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): # Add positional encoding to input embeddings x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class TransformerPII(nn.Module): """ Transformer model for PII detection (token classification) Built from scratch with custom implementation """ def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8, d_ff=512, num_layers=4, dropout=0.1, max_len=512, pad_idx=0): super(TransformerPII, self).__init__() self.d_model = d_model self.pad_idx = pad_idx # Embedding layer for input tokens self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx) # Add positional information to embeddings self.positional_encoding = PositionalEncoding(d_model, max_len, dropout) # Stack of transformer encoder layers self.encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, dropout) # Classification head for token-level predictions self.classifier = nn.Linear(d_model, num_classes) # Dropout layer self.dropout = nn.Dropout(dropout) # Initialize model weights self._init_weights() def _init_weights(self): """Initialize model weights""" # Initialize embeddings with normal distribution nn.init.normal_(self.embedding.weight, mean=0, std=self.d_model**-0.5) # Set padding token embedding to zero if self.pad_idx is not None: nn.init.constant_(self.embedding.weight[self.pad_idx], 0) # Initialize classifier with Xavier uniform nn.init.xavier_uniform_(self.classifier.weight) if self.classifier.bias is not None: nn.init.constant_(self.classifier.bias, 0) def create_padding_mask(self, x): """Create padding mask for attention""" # Create mask where non-padding tokens are marked as 1 mask = (x != self.pad_idx).unsqueeze(1).unsqueeze(2) return mask.float() def forward(self, x, mask=None): """Forward pass for token classification""" # Validate input dimensions if x.dim() != 2: raise ValueError(f"Expected input to have 2 dimensions [batch_size, seq_len], got {x.dim()}") batch_size, seq_len = x.shape # Create padding mask if not provided if mask is None: mask = self.create_padding_mask(x) # Embed and scale by sqrt(d_model) x = self.embedding(x) * math.sqrt(self.d_model) # Add positional encoding x = self.positional_encoding(x) # Pass through transformer encoder stack encoder_output = self.encoder(x, mask) # Apply dropout before classification encoder_output = self.dropout(encoder_output) # Get class predictions for each token logits = self.classifier(encoder_output) return logits def predict(self, x): """Get predictions for inference""" # Switch to evaluation mode self.eval() with torch.no_grad(): logits = self.forward(x) predictions = torch.argmax(logits, dim=-1) return predictions def create_transformer_pii_model(vocab_size, num_classes, d_model=256, num_heads=8, d_ff=512, num_layers=4, dropout=0.1, max_len=512): """Factory function to create transformer model for PII detection""" model = TransformerPII( vocab_size=vocab_size, num_classes=num_classes, d_model=d_model, num_heads=num_heads, d_ff=d_ff, num_layers=num_layers, dropout=dropout, max_len=max_len, pad_idx=0 ) return model