|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class EncoderBlock(torch.nn.Module): |
|
|
""" |
|
|
Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm |
|
|
architecture. |
|
|
|
|
|
The block consists of a multi-head self-attention sublayer followed by a |
|
|
position-wise feed-forward network, each wrapped with a residual connection. |
|
|
Layer normalization can be applied either before each sublayer (Pre-LN) or |
|
|
after each residual addition (Post-LN). |
|
|
|
|
|
This design allows stable training of deep Transformer stacks while retaining |
|
|
compatibility with the original Transformer formulation. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
feature_dim: int, |
|
|
attention_heads: int = 8, |
|
|
feed_forward_multiplier: float = 4, |
|
|
dropout: float = 0.0, |
|
|
valid_padding: bool = False, |
|
|
pre_normalize: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Initializes a Transformer encoder block. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
feature_dim : int |
|
|
Dimensionality of the input and output feature representations. |
|
|
attention_heads : int, optional |
|
|
Number of attention heads used in the multi-head self-attention layer. |
|
|
Default is 8. |
|
|
feed_forward_multiplier : float, optional |
|
|
Expansion factor for the hidden dimension of the feed-forward network. |
|
|
The intermediate dimension is computed as |
|
|
`feed_forward_multiplier * feature_dim`. |
|
|
Default is 4. |
|
|
dropout : float, optional |
|
|
Dropout probability applied to the feed-forward residual connection. |
|
|
Default is 0.0. |
|
|
valid_padding : bool, optional |
|
|
If True, the provided mask marks valid (non-padded) positions. |
|
|
If False, the mask marks padded (invalid) positions directly. |
|
|
Default is False. |
|
|
pre_normalize : bool, optional |
|
|
If True, uses the Pre-LayerNorm Transformer variant, applying layer |
|
|
normalization before each sublayer (self-attention and feed-forward). |
|
|
If False, uses the Post-LayerNorm variant, applying normalization after |
|
|
each residual connection. |
|
|
Default is True. |
|
|
**kwargs |
|
|
Additional keyword arguments passed to the parent `torch.nn.Module`. |
|
|
""" |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.valid_padding = valid_padding |
|
|
self.pre_normalize = pre_normalize |
|
|
|
|
|
|
|
|
self.norm_in = torch.nn.LayerNorm(feature_dim) |
|
|
self.norm_out = torch.nn.LayerNorm(feature_dim) |
|
|
|
|
|
|
|
|
self.dropout = torch.nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.attention = torch.nn.MultiheadAttention( |
|
|
embed_dim=feature_dim, |
|
|
num_heads=attention_heads, |
|
|
dropout=0.0, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.feed_forward = torch.nn.Sequential( |
|
|
torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)), |
|
|
torch.nn.GELU(), |
|
|
torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of a Transformer encoder block. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
x : torch.Tensor |
|
|
Input tensor of shape (batch_size, sequence_length, feature_dim). |
|
|
mask : torch.Tensor or None, optional |
|
|
Boolean mask indicating valid sequence positions. |
|
|
Shape: (batch_size, sequence_length). |
|
|
If `valid_padding` is True, True values denote valid tokens. |
|
|
Otherwise, True values denote masked (invalid) positions. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
x : torch.Tensor |
|
|
Output tensor of the same shape as the input |
|
|
(batch_size, sequence_length, feature_dim). |
|
|
""" |
|
|
|
|
|
|
|
|
if mask is not None and self.valid_padding: |
|
|
key_padding_mask = ~mask.bool() |
|
|
valid_mask = mask.bool() |
|
|
elif mask is not None: |
|
|
key_padding_mask = mask.bool() |
|
|
valid_mask = ~mask.bool() |
|
|
else: |
|
|
key_padding_mask = None |
|
|
valid_mask = None |
|
|
|
|
|
|
|
|
if valid_mask is not None: |
|
|
all_pad = ~valid_mask.any(dim=-1) |
|
|
else: |
|
|
all_pad = None |
|
|
|
|
|
|
|
|
if self.pre_normalize: |
|
|
h = self.norm_in(x) |
|
|
else: |
|
|
h = x |
|
|
|
|
|
|
|
|
if all_pad is not None and all_pad.any(): |
|
|
h_attn = h.clone() |
|
|
h_attn[all_pad] = 0.0 |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
key_padding_mask = key_padding_mask.clone() |
|
|
key_padding_mask[all_pad] = False |
|
|
else: |
|
|
h_attn = h |
|
|
|
|
|
attn_out, _ = self.attention( |
|
|
h_attn, h_attn, h_attn, |
|
|
key_padding_mask=key_padding_mask, |
|
|
need_weights=False, |
|
|
) |
|
|
x = x + attn_out |
|
|
|
|
|
|
|
|
if not self.pre_normalize: |
|
|
z = self.norm_in(x) |
|
|
else: |
|
|
z = self.norm_out(x) |
|
|
|
|
|
|
|
|
z = self.feed_forward(z) |
|
|
x = x + self.dropout(z) |
|
|
|
|
|
if not self.pre_normalize: |
|
|
x = self.norm_out(x) |
|
|
|
|
|
|
|
|
if all_pad is not None: |
|
|
x = x.masked_fill(all_pad[:, None, None], 0.0) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|