File size: 6,307 Bytes
dbd79bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
class EncoderBlock(torch.nn.Module):
"""
Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm
architecture.
The block consists of a multi-head self-attention sublayer followed by a
position-wise feed-forward network, each wrapped with a residual connection.
Layer normalization can be applied either before each sublayer (Pre-LN) or
after each residual addition (Post-LN).
This design allows stable training of deep Transformer stacks while retaining
compatibility with the original Transformer formulation.
"""
def __init__(
self,
feature_dim: int,
attention_heads: int = 8,
feed_forward_multiplier: float = 4,
dropout: float = 0.0,
valid_padding: bool = False,
pre_normalize: bool = True,
**kwargs
):
"""
Initializes a Transformer encoder block.
Parameters
----------
feature_dim : int
Dimensionality of the input and output feature representations.
attention_heads : int, optional
Number of attention heads used in the multi-head self-attention layer.
Default is 8.
feed_forward_multiplier : float, optional
Expansion factor for the hidden dimension of the feed-forward network.
The intermediate dimension is computed as
`feed_forward_multiplier * feature_dim`.
Default is 4.
dropout : float, optional
Dropout probability applied to the feed-forward residual connection.
Default is 0.0.
valid_padding : bool, optional
If True, the provided mask marks valid (non-padded) positions.
If False, the mask marks padded (invalid) positions directly.
Default is False.
pre_normalize : bool, optional
If True, uses the Pre-LayerNorm Transformer variant, applying layer
normalization before each sublayer (self-attention and feed-forward).
If False, uses the Post-LayerNorm variant, applying normalization after
each residual connection.
Default is True.
**kwargs
Additional keyword arguments passed to the parent `torch.nn.Module`.
"""
# Module init via kwargs:
super().__init__(**kwargs)
# Store params:
self.valid_padding = valid_padding
self.pre_normalize = pre_normalize
# Norm layers:
self.norm_in = torch.nn.LayerNorm(feature_dim)
self.norm_out = torch.nn.LayerNorm(feature_dim)
# Dropout layer:
self.dropout = torch.nn.Dropout(dropout)
# Attention layer:
self.attention = torch.nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=attention_heads,
dropout=0.0,
batch_first=True
)
# Feed-forward layer:
self.feed_forward = torch.nn.Sequential(
torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)),
torch.nn.GELU(),
torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Forward pass of a Transformer encoder block.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, feature_dim).
mask : torch.Tensor or None, optional
Boolean mask indicating valid sequence positions.
Shape: (batch_size, sequence_length).
If `valid_padding` is True, True values denote valid tokens.
Otherwise, True values denote masked (invalid) positions.
Returns
-------
x : torch.Tensor
Output tensor of the same shape as the input
(batch_size, sequence_length, feature_dim).
"""
# Convert mask:
if mask is not None and self.valid_padding:
key_padding_mask = ~mask.bool() # True = pad
valid_mask = mask.bool()
elif mask is not None:
key_padding_mask = mask.bool()
valid_mask = ~mask.bool()
else:
key_padding_mask = None
valid_mask = None
# Detect fully padded sequences:
if valid_mask is not None:
all_pad = ~valid_mask.any(dim=-1) # (B,)
else:
all_pad = None
# Pre-normalization:
if self.pre_normalize:
h = self.norm_in(x)
else:
h = x
# Attention (guard against fully padded sequences):
if all_pad is not None and all_pad.any():
h_attn = h.clone()
h_attn[all_pad] = 0.0
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.clone()
key_padding_mask[all_pad] = False
else:
h_attn = h
attn_out, _ = self.attention(
h_attn, h_attn, h_attn,
key_padding_mask=key_padding_mask,
need_weights=False,
)
x = x + attn_out
# Post-attention normalization:
if not self.pre_normalize:
z = self.norm_in(x)
else:
z = self.norm_out(x)
# Feed-forward:
z = self.feed_forward(z)
x = x + self.dropout(z)
if not self.pre_normalize:
x = self.norm_out(x)
# Re-pad fully padded sequences:
if all_pad is not None:
x = x.masked_fill(all_pad[:, None, None], 0.0)
return x
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|