File size: 6,307 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch


class EncoderBlock(torch.nn.Module):
    """
        Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm
        architecture.

        The block consists of a multi-head self-attention sublayer followed by a
        position-wise feed-forward network, each wrapped with a residual connection.
        Layer normalization can be applied either before each sublayer (Pre-LN) or
        after each residual addition (Post-LN).

        This design allows stable training of deep Transformer stacks while retaining
        compatibility with the original Transformer formulation.
    """
    def __init__(
            self,
            feature_dim: int,
            attention_heads: int = 8,
            feed_forward_multiplier: float = 4,
            dropout: float = 0.0,
            valid_padding: bool = False,
            pre_normalize: bool = True,
            **kwargs
    ):
        """
        Initializes a Transformer encoder block.

        Parameters
        ----------
        feature_dim : int
            Dimensionality of the input and output feature representations.
        attention_heads : int, optional
            Number of attention heads used in the multi-head self-attention layer.
            Default is 8.
        feed_forward_multiplier : float, optional
            Expansion factor for the hidden dimension of the feed-forward network.
            The intermediate dimension is computed as
            `feed_forward_multiplier * feature_dim`.
            Default is 4.
        dropout : float, optional
            Dropout probability applied to the feed-forward residual connection.
            Default is 0.0.
        valid_padding : bool, optional
            If True, the provided mask marks valid (non-padded) positions.
            If False, the mask marks padded (invalid) positions directly.
            Default is False.
        pre_normalize : bool, optional
            If True, uses the Pre-LayerNorm Transformer variant, applying layer
            normalization before each sublayer (self-attention and feed-forward).
            If False, uses the Post-LayerNorm variant, applying normalization after
            each residual connection.
            Default is True.
        **kwargs
            Additional keyword arguments passed to the parent `torch.nn.Module`.
        """
        # Module init via kwargs:
        super().__init__(**kwargs)

        # Store params:
        self.valid_padding = valid_padding
        self.pre_normalize = pre_normalize

        # Norm layers:
        self.norm_in = torch.nn.LayerNorm(feature_dim)
        self.norm_out = torch.nn.LayerNorm(feature_dim)

        # Dropout layer:
        self.dropout = torch.nn.Dropout(dropout)

        # Attention layer:
        self.attention = torch.nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=attention_heads,
            dropout=0.0,
            batch_first=True
        )

        # Feed-forward layer:
        self.feed_forward = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)),
            torch.nn.GELU(),
            torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim),
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        """
        Forward pass of a Transformer encoder block.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, feature_dim).
        mask : torch.Tensor or None, optional
            Boolean mask indicating valid sequence positions.
            Shape: (batch_size, sequence_length).
            If `valid_padding` is True, True values denote valid tokens.
            Otherwise, True values denote masked (invalid) positions.

        Returns
        -------
        x : torch.Tensor
            Output tensor of the same shape as the input
            (batch_size, sequence_length, feature_dim).
        """

        # Convert mask:
        if mask is not None and self.valid_padding:
            key_padding_mask = ~mask.bool()  # True = pad
            valid_mask = mask.bool()
        elif mask is not None:
            key_padding_mask = mask.bool()
            valid_mask = ~mask.bool()
        else:
            key_padding_mask = None
            valid_mask = None

        # Detect fully padded sequences:
        if valid_mask is not None:
            all_pad = ~valid_mask.any(dim=-1)  # (B,)
        else:
            all_pad = None

        # Pre-normalization:
        if self.pre_normalize:
            h = self.norm_in(x)
        else:
            h = x

        # Attention (guard against fully padded sequences):
        if all_pad is not None and all_pad.any():
            h_attn = h.clone()
            h_attn[all_pad] = 0.0

            if key_padding_mask is not None:
                key_padding_mask = key_padding_mask.clone()
                key_padding_mask[all_pad] = False
        else:
            h_attn = h

        attn_out, _ = self.attention(
            h_attn, h_attn, h_attn,
            key_padding_mask=key_padding_mask,
            need_weights=False,
        )
        x = x + attn_out

        # Post-attention normalization:
        if not self.pre_normalize:
            z = self.norm_in(x)
        else:
            z = self.norm_out(x)

        # Feed-forward:
        z = self.feed_forward(z)
        x = x + self.dropout(z)

        if not self.pre_normalize:
            x = self.norm_out(x)

        # Re-pad fully padded sequences:
        if all_pad is not None:
            x = x.masked_fill(all_pad[:, None, None], 0.0)

        return x
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #