# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import torch class EncoderBlock(torch.nn.Module): """ Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm architecture. The block consists of a multi-head self-attention sublayer followed by a position-wise feed-forward network, each wrapped with a residual connection. Layer normalization can be applied either before each sublayer (Pre-LN) or after each residual addition (Post-LN). This design allows stable training of deep Transformer stacks while retaining compatibility with the original Transformer formulation. """ def __init__( self, feature_dim: int, attention_heads: int = 8, feed_forward_multiplier: float = 4, dropout: float = 0.0, valid_padding: bool = False, pre_normalize: bool = True, **kwargs ): """ Initializes a Transformer encoder block. Parameters ---------- feature_dim : int Dimensionality of the input and output feature representations. attention_heads : int, optional Number of attention heads used in the multi-head self-attention layer. Default is 8. feed_forward_multiplier : float, optional Expansion factor for the hidden dimension of the feed-forward network. The intermediate dimension is computed as `feed_forward_multiplier * feature_dim`. Default is 4. dropout : float, optional Dropout probability applied to the feed-forward residual connection. Default is 0.0. valid_padding : bool, optional If True, the provided mask marks valid (non-padded) positions. If False, the mask marks padded (invalid) positions directly. Default is False. pre_normalize : bool, optional If True, uses the Pre-LayerNorm Transformer variant, applying layer normalization before each sublayer (self-attention and feed-forward). If False, uses the Post-LayerNorm variant, applying normalization after each residual connection. Default is True. **kwargs Additional keyword arguments passed to the parent `torch.nn.Module`. """ # Module init via kwargs: super().__init__(**kwargs) # Store params: self.valid_padding = valid_padding self.pre_normalize = pre_normalize # Norm layers: self.norm_in = torch.nn.LayerNorm(feature_dim) self.norm_out = torch.nn.LayerNorm(feature_dim) # Dropout layer: self.dropout = torch.nn.Dropout(dropout) # Attention layer: self.attention = torch.nn.MultiheadAttention( embed_dim=feature_dim, num_heads=attention_heads, dropout=0.0, batch_first=True ) # Feed-forward layer: self.feed_forward = torch.nn.Sequential( torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)), torch.nn.GELU(), torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim), ) def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ Forward pass of a Transformer encoder block. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, sequence_length, feature_dim). mask : torch.Tensor or None, optional Boolean mask indicating valid sequence positions. Shape: (batch_size, sequence_length). If `valid_padding` is True, True values denote valid tokens. Otherwise, True values denote masked (invalid) positions. Returns ------- x : torch.Tensor Output tensor of the same shape as the input (batch_size, sequence_length, feature_dim). """ # Convert mask: if mask is not None and self.valid_padding: key_padding_mask = ~mask.bool() # True = pad valid_mask = mask.bool() elif mask is not None: key_padding_mask = mask.bool() valid_mask = ~mask.bool() else: key_padding_mask = None valid_mask = None # Detect fully padded sequences: if valid_mask is not None: all_pad = ~valid_mask.any(dim=-1) # (B,) else: all_pad = None # Pre-normalization: if self.pre_normalize: h = self.norm_in(x) else: h = x # Attention (guard against fully padded sequences): if all_pad is not None and all_pad.any(): h_attn = h.clone() h_attn[all_pad] = 0.0 if key_padding_mask is not None: key_padding_mask = key_padding_mask.clone() key_padding_mask[all_pad] = False else: h_attn = h attn_out, _ = self.attention( h_attn, h_attn, h_attn, key_padding_mask=key_padding_mask, need_weights=False, ) x = x + attn_out # Post-attention normalization: if not self.pre_normalize: z = self.norm_in(x) else: z = self.norm_out(x) # Feed-forward: z = self.feed_forward(z) x = x + self.dropout(z) if not self.pre_normalize: x = self.norm_out(x) # Re-pad fully padded sequences: if all_pad is not None: x = x.masked_fill(all_pad[:, None, None], 0.0) return x # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #