File size: 7,920 Bytes
5ce8761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from torch import nn

from .multihead_custom_attention import MultiheadCustomAttention


class AdaLN(nn.Module):
    """Adaptive LayerNorm - signal-modulated linear transformation."""

    def __init__(self, d_model):
        super().__init__()
        self.modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_model, 2 * d_model)
        )
        # Initialize as 0 (no scale/shift)
        nn.init.constant_(self.modulation[-1].weight, 0)
        nn.init.constant_(self.modulation[-1].bias, 0)

    def forward(self, x, t):
        """
        Args:
            x: tensor (B, S, C)
            t: tensor (B, C)

        Returns:
            tensor (B, S, C)
        """
        scale, shift = self.modulation(t).chunk(2, dim=-1)  # (B, C), (B, C)
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class DummyLayer(nn.Module):
    """Implement adaptive normalization wrappers, pre-/post-norm, pos embed."""

    def __init__(self, pre_norm=False):
        super().__init__()
        self.pre_norm = pre_norm

    def _norm(self, x, layer, normalize=True):
        if normalize and layer is not None:
            return layer(x)
        return x

    def with_pos_embed(self, tensor, pos=None):
        return tensor if pos is None else tensor + pos

    def _adaln(self, x, layer, ada_sgnl):
        if layer is not None and ada_sgnl is not None:
            return layer(x, ada_sgnl)
        return x

    def forward(self):
        pass


class FFWLayer(DummyLayer):
    """Feed-forward layer for Transformers."""

    def __init__(self, d_model, dim_fw=None, dropout=0.1, use_adaln=False,
                 pre_norm=False):
        super().__init__(pre_norm=pre_norm)
        # Initialize MLP and normalization
        dim_fw = 4 * d_model if dim_fw is None else dim_fw
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_fw),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_fw, d_model),
            nn.Dropout(dropout)
        )
        self.norm = nn.LayerNorm(d_model)

        # Initialize those with Xavier
        self._reset_parameters()

        # Initialize adaptive normalization separately
        self.adaln = None
        if use_adaln:
            self.adaln = AdaLN(d_model)

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x, ada_sgnl=None):
        """
        Args:
            x: tensor (B, S, C)
            ada_sgnl: tensor (B, C)

        Returns:
            tensor (B, S, C)
        """
        # Normalize if pre-norm
        x = self._norm(x, self.norm, self.pre_norm)
        # Adaptive normalization if applicable
        x = self._adaln(x, self.adaln, ada_sgnl)
        # Main FFW
        x = x + self.ffn(x)
        # Normalize if post-norm
        x = self._norm(x, self.norm, not self.pre_norm)
        return x


class AttentionLayer(DummyLayer):
    """Attention layer, for self-/cross-attention."""

    def __init__(self, d_model=256, dropout=0.1, n_heads=8, pre_norm=False,
                 rotary_pe=False, use_adaln=False, is_self=False):
        """Initialize layers, d_model is the encoder dimension."""
        super().__init__(pre_norm=pre_norm)
        self.rotary_pe = rotary_pe
        self.is_self = is_self  # self-attention, different normalization

        # Normalization and attention layers
        self.adaln = None
        if use_adaln:
            self.adaln = AdaLN(d_model)
        self.attention = MultiheadCustomAttention(
            d_model, n_heads, dropout=dropout
        )
        self.dropout = nn.Dropout(dropout)
        self.norm_q = nn.LayerNorm(d_model)
        self.norm_kv = None
        if pre_norm:
            self.norm_kv = self.norm_q if is_self else nn.LayerNorm(d_model)

    def forward(self, seq1, seq2,
                seq2_key_padding_mask=None,
                seq1_pos=None, seq2_pos=None,
                seq1_sem_pos=None, seq2_sem_pos=None,
                ada_sgnl=None):
        """
        Args:
            seq1: tensor (B, S1, C)
            seq1_pos: (B, S1, C) if not rotary, else (B, S1, C, 2)
            seq1_sem_pos: (B, S1, C), semantic embedding
            seq2: tensor (B, S2, C)
            seq2_key_padding_mask: tensor (B, S2)
            seq2_pos: (B, S2, C) if not rotary, else (B, S2, C, 2)
            seq2_sem_pos: (B, S2, C), semantic embedding
            ada_sgnl: tensor (B, C)

        Returns:
            tensor (B, S, C)
        """
        # Normalize if pre-norm
        q1 = self._norm(seq1, self.norm_q, self.pre_norm)
        if self.is_self:
            k2 = v2 = self._norm(seq2, self.norm_q, self.pre_norm)
        else:
            k2 = v2 = self._norm(seq2, self.norm_kv, self.pre_norm)
        # Add positional embeddings if not rotary - rotary are handled later
        if not self.rotary_pe:
            q1 = self.with_pos_embed(seq1, seq1_pos)
            k2 = self.with_pos_embed(seq2, seq2_pos)
        # Add semantic embeddings, e.g. ids of each token
        q1 = self.with_pos_embed(q1, seq1_sem_pos)
        k2 = self.with_pos_embed(k2, seq2_sem_pos)
        # Adaptive normalization if applicable
        q1 = self._adaln(q1, self.adaln, ada_sgnl)
        k2 = self._adaln(k2, self.adaln if self.is_self else None, ada_sgnl)
        v2 = self._adaln(v2, self.adaln if self.is_self else None, ada_sgnl)
        # Main attention code
        seq1b = self.attention(
            query=q1.transpose(0, 1),
            key=k2.transpose(0, 1),
            value=v2.transpose(0, 1),
            attn_mask=None,
            key_padding_mask=seq2_key_padding_mask,  # (B, S2)
            rotary_pe=(seq1_pos, seq2_pos) if self.rotary_pe else None
        )[0].transpose(0, 1)
        seq1 = seq1 + self.dropout(seq1b)
        # Normalize if post-norm
        seq1 = self._norm(seq1, self.norm_q, not self.pre_norm)
        return seq1


class AttentionModule(nn.Module):
    """Stacking of attention and feed-forward layers."""

    def __init__(self, num_layers, d_model=256, dim_fw=None,
                 dropout=0.1, n_heads=8, pre_norm=False,
                 rotary_pe=False, use_adaln=False, is_self=False):
        super().__init__()
        self.num_layers = num_layers
        self.is_self = is_self
        self.attn_layers = nn.ModuleList()
        self.ffw_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.attn_layers.append(AttentionLayer(
                d_model, dropout, n_heads, pre_norm,
                rotary_pe, use_adaln, is_self
            ))
            self.ffw_layers.append(FFWLayer(
                d_model, dim_fw, dropout, use_adaln, pre_norm=False
            ))

    def forward(self, seq1, seq2,
                seq2_key_padding_mask=None,
                seq1_pos=None, seq2_pos=None,
                seq1_sem_pos=None, seq2_sem_pos=None,
                ada_sgnl=None):
        """
        Args:
            seq1: tensor (B, S1, C)
            seq2: tensor (B, S2, C)
            seq2_key_padding_mask: tensor (B, S2)
            seq1_pos: (B, S1, C) if not rotary, else (B, S1, C, 2)
            seq2_pos: (B, S2, C) if not rotary, else (B, S2, C, 2)
            seq1_sem_pos: (B, S1, C), semantic embedding
            seq2_sem_pos: (B, S2, C), semantic embedding
            ada_sgnl: tensor (B, C)

        Returns:
            tensor (B, S1, C)
        """
        output = []
        for i in range(self.num_layers):
            if self.is_self:
                seq2 = seq1
            seq1 = self.attn_layers[i](
                seq1, seq2,
                seq2_key_padding_mask,
                seq1_pos, seq2_pos,
                seq1_sem_pos, seq2_sem_pos,
                ada_sgnl
            )
            seq1 = self.ffw_layers[i](seq1, ada_sgnl)
            output.append(seq1)
        return output