alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
class EncoderBlock(torch.nn.Module):
"""
Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm
architecture.
The block consists of a multi-head self-attention sublayer followed by a
position-wise feed-forward network, each wrapped with a residual connection.
Layer normalization can be applied either before each sublayer (Pre-LN) or
after each residual addition (Post-LN).
This design allows stable training of deep Transformer stacks while retaining
compatibility with the original Transformer formulation.
"""
def __init__(
self,
feature_dim: int,
attention_heads: int = 8,
feed_forward_multiplier: float = 4,
dropout: float = 0.0,
valid_padding: bool = False,
pre_normalize: bool = True,
**kwargs
):
"""
Initializes a Transformer encoder block.
Parameters
----------
feature_dim : int
Dimensionality of the input and output feature representations.
attention_heads : int, optional
Number of attention heads used in the multi-head self-attention layer.
Default is 8.
feed_forward_multiplier : float, optional
Expansion factor for the hidden dimension of the feed-forward network.
The intermediate dimension is computed as
`feed_forward_multiplier * feature_dim`.
Default is 4.
dropout : float, optional
Dropout probability applied to the feed-forward residual connection.
Default is 0.0.
valid_padding : bool, optional
If True, the provided mask marks valid (non-padded) positions.
If False, the mask marks padded (invalid) positions directly.
Default is False.
pre_normalize : bool, optional
If True, uses the Pre-LayerNorm Transformer variant, applying layer
normalization before each sublayer (self-attention and feed-forward).
If False, uses the Post-LayerNorm variant, applying normalization after
each residual connection.
Default is True.
**kwargs
Additional keyword arguments passed to the parent `torch.nn.Module`.
"""
# Module init via kwargs:
super().__init__(**kwargs)
# Store params:
self.valid_padding = valid_padding
self.pre_normalize = pre_normalize
# Norm layers:
self.norm_in = torch.nn.LayerNorm(feature_dim)
self.norm_out = torch.nn.LayerNorm(feature_dim)
# Dropout layer:
self.dropout = torch.nn.Dropout(dropout)
# Attention layer:
self.attention = torch.nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=attention_heads,
dropout=0.0,
batch_first=True
)
# Feed-forward layer:
self.feed_forward = torch.nn.Sequential(
torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)),
torch.nn.GELU(),
torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Forward pass of a Transformer encoder block.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, feature_dim).
mask : torch.Tensor or None, optional
Boolean mask indicating valid sequence positions.
Shape: (batch_size, sequence_length).
If `valid_padding` is True, True values denote valid tokens.
Otherwise, True values denote masked (invalid) positions.
Returns
-------
x : torch.Tensor
Output tensor of the same shape as the input
(batch_size, sequence_length, feature_dim).
"""
# Convert mask:
if mask is not None and self.valid_padding:
key_padding_mask = ~mask.bool() # True = pad
valid_mask = mask.bool()
elif mask is not None:
key_padding_mask = mask.bool()
valid_mask = ~mask.bool()
else:
key_padding_mask = None
valid_mask = None
# Detect fully padded sequences:
if valid_mask is not None:
all_pad = ~valid_mask.any(dim=-1) # (B,)
else:
all_pad = None
# Pre-normalization:
if self.pre_normalize:
h = self.norm_in(x)
else:
h = x
# Attention (guard against fully padded sequences):
if all_pad is not None and all_pad.any():
h_attn = h.clone()
h_attn[all_pad] = 0.0
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.clone()
key_padding_mask[all_pad] = False
else:
h_attn = h
attn_out, _ = self.attention(
h_attn, h_attn, h_attn,
key_padding_mask=key_padding_mask,
need_weights=False,
)
x = x + attn_out
# Post-attention normalization:
if not self.pre_normalize:
z = self.norm_in(x)
else:
z = self.norm_out(x)
# Feed-forward:
z = self.feed_forward(z)
x = x + self.dropout(z)
if not self.pre_normalize:
x = self.norm_out(x)
# Re-pad fully padded sequences:
if all_pad is not None:
x = x.masked_fill(all_pad[:, None, None], 0.0)
return x
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #