""" Transformer Decoder with Pre-Layer Normalization """ import torch import torch.nn as nn from typing import Optional from .attention import MultiHeadAttention from .feed_forward import FeedForward from .layer_norm import LayerNorm from .embeddings import ScaledEmbedding from .positional_encoding import PositionalEncoding class DecoderLayer(nn.Module): """ Transformer Decoder Layer with Pre-Layer Normalization Structure: 1. LayerNorm -> Masked Self-Attention -> Dropout -> Residual 2. LayerNorm -> Cross-Attention -> Dropout -> Residual 3. LayerNorm -> Feed-Forward -> Dropout -> Residual """ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.0): """ Args: d_model: Model dimension n_heads: Number of attention heads d_ff: Feed-forward dimension dropout: Dropout rate attention_dropout: Dropout rate for attention activation_dropout: Dropout rate for FFN activation """ super().__init__() # Pre-layer normalization self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.norm3 = LayerNorm(d_model) # Masked self-attention self.self_attn = MultiHeadAttention( d_model, n_heads, dropout, attention_dropout ) # Cross-attention self.cross_attn = MultiHeadAttention( d_model, n_heads, dropout, attention_dropout ) # Feed-forward network self.ffn = FeedForward(d_model, d_ff, dropout, activation_dropout) # Dropout for residual connections self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, return_attention: bool = False): """ Args: x: [batch_size, tgt_len, d_model] enc_output: [batch_size, src_len, d_model] src_mask: [batch_size, 1, 1, src_len] tgt_mask: [batch_size, 1, tgt_len, tgt_len] return_attention: Whether to return attention weights Returns: output: [batch_size, tgt_len, d_model] self_attn: [batch_size, n_heads, tgt_len, tgt_len] (if return_attention) cross_attn: [batch_size, n_heads, tgt_len, src_len] (if return_attention) """ # Masked self-attention block with pre-norm residual = x x = self.norm1(x) if return_attention: self_attn_out, self_attn = self.self_attn( x, x, x, tgt_mask, return_attention=True ) x = residual + self.dropout1(self_attn_out) else: x = residual + self.dropout1(self.self_attn(x, x, x, tgt_mask)) self_attn = None # Cross-attention block with pre-norm residual = x x = self.norm2(x) if return_attention: cross_attn_out, cross_attn = self.cross_attn( x, enc_output, enc_output, src_mask, return_attention=True ) x = residual + self.dropout2(cross_attn_out) else: x = residual + self.dropout2( self.cross_attn(x, enc_output, enc_output, src_mask) ) cross_attn = None # Feed-forward block with pre-norm residual = x x = self.norm3(x) x = residual + self.dropout3(self.ffn(x)) if return_attention: return x, self_attn, cross_attn return x class TransformerDecoder(nn.Module): """ Complete Transformer Decoder """ def __init__(self, vocab_size: int, d_model: int, n_heads: int, d_ff: int, n_layers: int, max_len: int = 5000, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.0, pad_idx: int = 0, scale_embedding: bool = True): """ Args: vocab_size: Target vocabulary size d_model: Model dimension n_heads: Number of attention heads d_ff: Feed-forward dimension n_layers: Number of decoder layers max_len: Maximum sequence length dropout: Dropout rate attention_dropout: Dropout rate for attention activation_dropout: Dropout rate for FFN activation pad_idx: Padding token index scale_embedding: Whether to scale embeddings """ super().__init__() self.d_model = d_model self.pad_idx = pad_idx # Embedding layer self.embedding = ScaledEmbedding( vocab_size, d_model, pad_idx, scale=scale_embedding, dropout=0.0 ) # Positional encoding self.pos_encoding = PositionalEncoding(d_model, max_len, dropout) # Stack of decoder layers self.layers = nn.ModuleList([ DecoderLayer( d_model, n_heads, d_ff, dropout, attention_dropout, activation_dropout ) for _ in range(n_layers) ]) # Final layer norm (important for Pre-LN) self.final_norm = LayerNorm(d_model) # Output projection self.fc_out = nn.Linear(d_model, vocab_size) self.dropout = nn.Dropout(dropout) def forward(self, tgt, enc_output, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, return_attention: bool = False): """ Args: tgt: [batch_size, tgt_len] enc_output: [batch_size, src_len, d_model] src_mask: [batch_size, 1, 1, src_len] tgt_mask: [batch_size, 1, tgt_len, tgt_len] return_attention: Whether to return attention weights Returns: output: [batch_size, tgt_len, vocab_size] self_attentions: List of self-attention weights (if return_attention) cross_attentions: List of cross-attention weights (if return_attention) """ # Embedding + positional encoding x = self.embedding(tgt) x = self.pos_encoding(x) # Pass through decoder layers self_attentions = [] if return_attention else None cross_attentions = [] if return_attention else None for layer in self.layers: if return_attention: x, self_attn, cross_attn = layer( x, enc_output, src_mask, tgt_mask, return_attention=True ) self_attentions.append(self_attn) cross_attentions.append(cross_attn) else: x = layer(x, enc_output, src_mask, tgt_mask, return_attention=False) # Final layer normalization x = self.final_norm(x) # Project to vocabulary output = self.fc_out(x) if return_attention: return output, self_attentions, cross_attentions return output def init_weights(self, init_std: float = 0.02): """Initialize model weights""" # Initialize embeddings self.embedding.init_weights(init_std) # Initialize linear layers for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=init_std) if module.bias is not None: nn.init.zeros_(module.bias)